EmotionalIntensityControl / data_loader.py
AdalAbilbekov's picture
First commit
ae8e1dd
import os.path
import random
import numpy as np
import torch
import re
import torch.utils.data
import json
import kaldiio
from tqdm import tqdm
from text import text_to_sequence
class BaseLoader(torch.utils.data.Dataset):
def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str):
"""
:param utts: file path. A list of utts for this loader. These are the only utts that this loader has access.
This loader only deals with text, duration and feats. Other files despite `utts` can be larger.
"""
self.n_mel_channels = hparams.n_mel_channels
self.sampling_rate = hparams.sampling_rate
self.utts = self.get_utts(utts)
self.utt2feat = self.get_utt2feat(feats_scp)
self.utt2text = self.get_utt2text(utt2text)
def get_utts(self, utts: str) -> list:
with open(utts, 'r') as f:
L = f.readlines()
L = list(map(lambda x: x.strip(), L))
random.seed(1234)
random.shuffle(L)
return L
def get_utt2feat(self, feats_scp: str):
utt2feat = kaldiio.load_scp(feats_scp) # lazy load mode
print(f"Succeed reading feats from {feats_scp}")
return utt2feat
def get_utt2text(self, utt2text: str):
with open(utt2text, 'r') as f:
L = f.readlines()
utt2text = {line.split()[0]: line.strip().split(" ", 1)[1] for line in L}
return utt2text
def get_mel_from_kaldi(self, utt):
feat = self.utt2feat[utt]
feat = torch.FloatTensor(feat).squeeze()
assert self.n_mel_channels in feat.shape
if feat.shape[0] == self.n_mel_channels:
return feat
else:
return feat.T
def get_text(self, utt):
text = self.utt2text[utt]
text_norm = text_to_sequence(text)
text_norm = torch.IntTensor(text_norm)
return text_norm
def __getitem__(self, index):
res = self.get_mel_text_pair(self.utts[index])
return res
def __len__(self):
return len(self.utts)
def sample_test_batch(self, size):
idx = np.random.choice(range(len(self)), size=size, replace=False)
test_batch = []
for index in idx:
test_batch.append(self.__getitem__(index))
return test_batch
class SpkIDLoader(BaseLoader):
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
utt2phn_duration: str, utt2spk: str):
"""
:param utt2spk: json file path (utt name -> spk id)
This loader loads speaker as a speaker ID for embedding table
"""
super(SpkIDLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
self.utt2spk = self.get_utt2spk(utt2spk)
def get_utt2spk(self, utt2spk: str) -> dict:
with open(utt2spk, 'r') as f:
res = json.load(f)
return res
def get_mel_text_pair(self, utt):
# separate filename and text
spkid = self.utt2spk[utt]
phn_ids = self.get_text(utt)
mel = self.get_mel_from_kaldi(utt)
dur = self.get_dur_from_kaldi(utt)
assert sum(dur) == mel.shape[1], f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}"
res = {
"utt": utt,
"mel": mel,
"spk_ids": spkid
}
return res
def __getitem__(self, index):
res = self.get_mel_text_pair(self.utts[index])
return res
def __len__(self):
return len(self.utts)
class SpkIDLoaderWithEmo(BaseLoader):
def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str, utt2spk: str, utt2emo: str):
"""
:param utt2spk: json file path (utt name -> spk id)
This loader loads speaker as a speaker ID for embedding table
"""
super(SpkIDLoaderWithEmo, self).__init__(utts, hparams, feats_scp, utt2text)
self.utt2spk = self.get_utt2spk(utt2spk)
self.utt2emo = self.get_utt2emo(utt2emo)
def get_utt2spk(self, utt2spk: str) -> dict:
with open(utt2spk, 'r') as f:
res = json.load(f)
return res
def get_utt2emo(self, utt2emo: str) -> dict:
with open(utt2emo, 'r') as f:
res = json.load(f)
return res
def get_mel_text_pair(self, utt):
# separate filename and text
spkid = int(self.utt2spk[utt])
emoid = int(self.utt2emo[utt])
text = self.get_text(utt)
mel = self.get_mel_from_kaldi(utt)
res = {
"utt": utt,
"text": text,
"mel": mel,
"spk_ids": spkid,
"emo_ids": emoid
}
return res
def __getitem__(self, index):
res = self.get_mel_text_pair(self.utts[index])
return res
def __len__(self):
return len(self.utts)
class SpkIDLoaderWithPE(SpkIDLoader):
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
utt2phn_duration: str, utt2spk: str, var_scp: str):
"""
This loader loads speaker ID together with variance (4-dim pitch, 1-dim energy)
"""
super(SpkIDLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration, utt2spk)
self.utt2var = self.get_utt2var(var_scp)
def get_utt2var(self, utt2var: str) -> dict:
res = kaldiio.load_scp(utt2var)
print(f"Succeed reading feats from {utt2var}")
return res
def get_var_from_kaldi(self, utt):
var = self.utt2var[utt]
var = torch.FloatTensor(var).squeeze()
assert 5 in var.shape
if var.shape[0] == 5:
return var
else:
return var.T
def get_mel_text_pair(self, utt):
# separate filename and text
spkid = self.utt2spk[utt]
phn_ids = self.get_text(utt)
mel = self.get_mel_from_kaldi(utt)
dur = self.get_dur_from_kaldi(utt)
var = self.get_var_from_kaldi(utt)
assert sum(dur) == mel.shape[1] == var.shape[1], \
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}"
res = {
"utt": utt,
"phn_ids": phn_ids,
"mel": mel,
"dur": dur,
"spk_ids": spkid,
"var": var
}
return res
class XvectorLoader(BaseLoader):
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str):
"""
:param utt2spk_name: like kaldi-style utt2spk
:param spk_xvector_scp: kaldi-style speaker-level xvector.scp
"""
super(XvectorLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
self.utt2spk = self.get_utt2spk(utt2spk_name)
self.spk2xvector = self.get_spk2xvector(spk_xvector_scp)
def get_utt2spk(self, utt2spk):
res = dict()
with open(utt2spk, 'r') as f:
for l in f.readlines():
res[l.split()[0]] = l.split()[1]
return res
def get_spk2xvector(self, spk_xvector_scp: str) -> dict:
res = kaldiio.load_scp(spk_xvector_scp)
print(f"Succeed reading xvector from {spk_xvector_scp}")
return res
def get_xvector(self, utt):
xv = self.spk2xvector[self.utt2spk[utt]]
xv = torch.FloatTensor(xv).squeeze()
return xv
def get_mel_text_pair(self, utt):
phn_ids = self.get_text(utt)
mel = self.get_mel_from_kaldi(utt)
dur = self.get_dur_from_kaldi(utt)
xvector = self.get_xvector(utt)
assert sum(dur) == mel.shape[1], \
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}"
res = {
"utt": utt,
"phn_ids": phn_ids,
"mel": mel,
"dur": dur,
"xvector": xvector,
}
return res
class XvectorLoaderWithPE(BaseLoader):
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str, var_scp: str):
super(XvectorLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
self.utt2spk = self.get_utt2spk(utt2spk_name)
self.spk2xvector = self.get_spk2xvector(spk_xvector_scp)
self.utt2var = self.get_utt2var(var_scp)
def get_spk2xvector(self, spk_xvector_scp: str) -> dict:
res = kaldiio.load_scp(spk_xvector_scp)
print(f"Succeed reading xvector from {spk_xvector_scp}")
return res
def get_utt2spk(self, utt2spk):
res = dict()
with open(utt2spk, 'r') as f:
for l in f.readlines():
res[l.split()[0]] = l.split()[1]
return res
def get_utt2var(self, utt2var: str) -> dict:
res = kaldiio.load_scp(utt2var)
print(f"Succeed reading feats from {utt2var}")
return res
def get_var_from_kaldi(self, utt):
var = self.utt2var[utt]
var = torch.FloatTensor(var).squeeze()
assert 5 in var.shape
if var.shape[0] == 5:
return var
else:
return var.T
def get_xvector(self, utt):
xv = self.spk2xvector[self.utt2spk[utt]]
xv = torch.FloatTensor(xv).squeeze()
return xv
def get_mel_text_pair(self, utt):
# separate filename and text
spkid = self.utt2spk[utt]
phn_ids = self.get_text(utt)
mel = self.get_mel_from_kaldi(utt)
dur = self.get_dur_from_kaldi(utt)
var = self.get_var_from_kaldi(utt)
xvector = self.get_xvector(utt)
assert sum(dur) == mel.shape[1] == var.shape[1], \
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}"
res = {
"utt": utt,
"phn_ids": phn_ids,
"mel": mel,
"dur": dur,
"spk_ids": spkid,
"var": var,
"xvector": xvector
}
return res