|
import os.path |
|
import random |
|
import numpy as np |
|
import torch |
|
import re |
|
import torch.utils.data |
|
import json |
|
|
|
import kaldiio |
|
from tqdm import tqdm |
|
from text import text_to_sequence |
|
|
|
class BaseLoader(torch.utils.data.Dataset): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str): |
|
""" |
|
:param utts: file path. A list of utts for this loader. These are the only utts that this loader has access. |
|
This loader only deals with text, duration and feats. Other files despite `utts` can be larger. |
|
""" |
|
self.n_mel_channels = hparams.n_mel_channels |
|
self.sampling_rate = hparams.sampling_rate |
|
self.utts = self.get_utts(utts) |
|
self.utt2feat = self.get_utt2feat(feats_scp) |
|
self.utt2text = self.get_utt2text(utt2text) |
|
|
|
def get_utts(self, utts: str) -> list: |
|
with open(utts, 'r') as f: |
|
L = f.readlines() |
|
L = list(map(lambda x: x.strip(), L)) |
|
random.seed(1234) |
|
random.shuffle(L) |
|
return L |
|
|
|
|
|
def get_utt2feat(self, feats_scp: str): |
|
utt2feat = kaldiio.load_scp(feats_scp) |
|
print(f"Succeed reading feats from {feats_scp}") |
|
return utt2feat |
|
|
|
def get_utt2text(self, utt2text: str): |
|
with open(utt2text, 'r') as f: |
|
L = f.readlines() |
|
utt2text = {line.split()[0]: line.strip().split(" ", 1)[1] for line in L} |
|
return utt2text |
|
|
|
def get_mel_from_kaldi(self, utt): |
|
feat = self.utt2feat[utt] |
|
feat = torch.FloatTensor(feat).squeeze() |
|
assert self.n_mel_channels in feat.shape |
|
if feat.shape[0] == self.n_mel_channels: |
|
return feat |
|
else: |
|
return feat.T |
|
|
|
def get_text(self, utt): |
|
text = self.utt2text[utt] |
|
text_norm = text_to_sequence(text) |
|
text_norm = torch.IntTensor(text_norm) |
|
return text_norm |
|
|
|
def __getitem__(self, index): |
|
res = self.get_mel_text_pair(self.utts[index]) |
|
return res |
|
|
|
def __len__(self): |
|
return len(self.utts) |
|
|
|
def sample_test_batch(self, size): |
|
idx = np.random.choice(range(len(self)), size=size, replace=False) |
|
test_batch = [] |
|
for index in idx: |
|
test_batch.append(self.__getitem__(index)) |
|
return test_batch |
|
|
|
|
|
class SpkIDLoader(BaseLoader): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, |
|
utt2phn_duration: str, utt2spk: str): |
|
""" |
|
:param utt2spk: json file path (utt name -> spk id) |
|
This loader loads speaker as a speaker ID for embedding table |
|
""" |
|
super(SpkIDLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) |
|
self.utt2spk = self.get_utt2spk(utt2spk) |
|
|
|
def get_utt2spk(self, utt2spk: str) -> dict: |
|
with open(utt2spk, 'r') as f: |
|
res = json.load(f) |
|
return res |
|
|
|
def get_mel_text_pair(self, utt): |
|
|
|
spkid = self.utt2spk[utt] |
|
phn_ids = self.get_text(utt) |
|
mel = self.get_mel_from_kaldi(utt) |
|
dur = self.get_dur_from_kaldi(utt) |
|
|
|
assert sum(dur) == mel.shape[1], f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" |
|
res = { |
|
"utt": utt, |
|
"mel": mel, |
|
"spk_ids": spkid |
|
} |
|
return res |
|
|
|
def __getitem__(self, index): |
|
res = self.get_mel_text_pair(self.utts[index]) |
|
return res |
|
|
|
def __len__(self): |
|
return len(self.utts) |
|
|
|
|
|
class SpkIDLoaderWithEmo(BaseLoader): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str, utt2spk: str, utt2emo: str): |
|
""" |
|
:param utt2spk: json file path (utt name -> spk id) |
|
This loader loads speaker as a speaker ID for embedding table |
|
""" |
|
super(SpkIDLoaderWithEmo, self).__init__(utts, hparams, feats_scp, utt2text) |
|
self.utt2spk = self.get_utt2spk(utt2spk) |
|
self.utt2emo = self.get_utt2emo(utt2emo) |
|
|
|
def get_utt2spk(self, utt2spk: str) -> dict: |
|
with open(utt2spk, 'r') as f: |
|
res = json.load(f) |
|
return res |
|
|
|
def get_utt2emo(self, utt2emo: str) -> dict: |
|
with open(utt2emo, 'r') as f: |
|
res = json.load(f) |
|
return res |
|
|
|
def get_mel_text_pair(self, utt): |
|
|
|
spkid = int(self.utt2spk[utt]) |
|
emoid = int(self.utt2emo[utt]) |
|
text = self.get_text(utt) |
|
mel = self.get_mel_from_kaldi(utt) |
|
|
|
res = { |
|
"utt": utt, |
|
"text": text, |
|
"mel": mel, |
|
"spk_ids": spkid, |
|
"emo_ids": emoid |
|
} |
|
return res |
|
|
|
def __getitem__(self, index): |
|
res = self.get_mel_text_pair(self.utts[index]) |
|
return res |
|
|
|
def __len__(self): |
|
return len(self.utts) |
|
|
|
|
|
class SpkIDLoaderWithPE(SpkIDLoader): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, |
|
utt2phn_duration: str, utt2spk: str, var_scp: str): |
|
""" |
|
This loader loads speaker ID together with variance (4-dim pitch, 1-dim energy) |
|
""" |
|
super(SpkIDLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration, utt2spk) |
|
self.utt2var = self.get_utt2var(var_scp) |
|
|
|
def get_utt2var(self, utt2var: str) -> dict: |
|
res = kaldiio.load_scp(utt2var) |
|
print(f"Succeed reading feats from {utt2var}") |
|
return res |
|
|
|
def get_var_from_kaldi(self, utt): |
|
var = self.utt2var[utt] |
|
var = torch.FloatTensor(var).squeeze() |
|
assert 5 in var.shape |
|
if var.shape[0] == 5: |
|
return var |
|
else: |
|
return var.T |
|
|
|
def get_mel_text_pair(self, utt): |
|
|
|
spkid = self.utt2spk[utt] |
|
phn_ids = self.get_text(utt) |
|
mel = self.get_mel_from_kaldi(utt) |
|
dur = self.get_dur_from_kaldi(utt) |
|
var = self.get_var_from_kaldi(utt) |
|
|
|
assert sum(dur) == mel.shape[1] == var.shape[1], \ |
|
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" |
|
|
|
res = { |
|
"utt": utt, |
|
"phn_ids": phn_ids, |
|
"mel": mel, |
|
"dur": dur, |
|
"spk_ids": spkid, |
|
"var": var |
|
} |
|
return res |
|
|
|
|
|
class XvectorLoader(BaseLoader): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, |
|
utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str): |
|
""" |
|
:param utt2spk_name: like kaldi-style utt2spk |
|
:param spk_xvector_scp: kaldi-style speaker-level xvector.scp |
|
""" |
|
super(XvectorLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) |
|
self.utt2spk = self.get_utt2spk(utt2spk_name) |
|
self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) |
|
|
|
def get_utt2spk(self, utt2spk): |
|
res = dict() |
|
with open(utt2spk, 'r') as f: |
|
for l in f.readlines(): |
|
res[l.split()[0]] = l.split()[1] |
|
return res |
|
|
|
def get_spk2xvector(self, spk_xvector_scp: str) -> dict: |
|
res = kaldiio.load_scp(spk_xvector_scp) |
|
print(f"Succeed reading xvector from {spk_xvector_scp}") |
|
return res |
|
|
|
def get_xvector(self, utt): |
|
xv = self.spk2xvector[self.utt2spk[utt]] |
|
xv = torch.FloatTensor(xv).squeeze() |
|
return xv |
|
|
|
def get_mel_text_pair(self, utt): |
|
phn_ids = self.get_text(utt) |
|
mel = self.get_mel_from_kaldi(utt) |
|
dur = self.get_dur_from_kaldi(utt) |
|
xvector = self.get_xvector(utt) |
|
|
|
assert sum(dur) == mel.shape[1], \ |
|
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" |
|
|
|
res = { |
|
"utt": utt, |
|
"phn_ids": phn_ids, |
|
"mel": mel, |
|
"dur": dur, |
|
"xvector": xvector, |
|
} |
|
return res |
|
|
|
|
|
class XvectorLoaderWithPE(BaseLoader): |
|
def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, |
|
utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str, var_scp: str): |
|
super(XvectorLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) |
|
self.utt2spk = self.get_utt2spk(utt2spk_name) |
|
self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) |
|
self.utt2var = self.get_utt2var(var_scp) |
|
|
|
def get_spk2xvector(self, spk_xvector_scp: str) -> dict: |
|
res = kaldiio.load_scp(spk_xvector_scp) |
|
print(f"Succeed reading xvector from {spk_xvector_scp}") |
|
return res |
|
|
|
def get_utt2spk(self, utt2spk): |
|
res = dict() |
|
with open(utt2spk, 'r') as f: |
|
for l in f.readlines(): |
|
res[l.split()[0]] = l.split()[1] |
|
return res |
|
|
|
def get_utt2var(self, utt2var: str) -> dict: |
|
res = kaldiio.load_scp(utt2var) |
|
print(f"Succeed reading feats from {utt2var}") |
|
return res |
|
|
|
def get_var_from_kaldi(self, utt): |
|
var = self.utt2var[utt] |
|
var = torch.FloatTensor(var).squeeze() |
|
assert 5 in var.shape |
|
if var.shape[0] == 5: |
|
return var |
|
else: |
|
return var.T |
|
|
|
def get_xvector(self, utt): |
|
xv = self.spk2xvector[self.utt2spk[utt]] |
|
xv = torch.FloatTensor(xv).squeeze() |
|
return xv |
|
|
|
def get_mel_text_pair(self, utt): |
|
|
|
spkid = self.utt2spk[utt] |
|
phn_ids = self.get_text(utt) |
|
mel = self.get_mel_from_kaldi(utt) |
|
dur = self.get_dur_from_kaldi(utt) |
|
var = self.get_var_from_kaldi(utt) |
|
xvector = self.get_xvector(utt) |
|
|
|
assert sum(dur) == mel.shape[1] == var.shape[1], \ |
|
f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" |
|
|
|
res = { |
|
"utt": utt, |
|
"phn_ids": phn_ids, |
|
"mel": mel, |
|
"dur": dur, |
|
"spk_ids": spkid, |
|
"var": var, |
|
"xvector": xvector |
|
} |
|
return res |
|
|