|
import os.path |
|
import random |
|
import numpy as np |
|
import torch |
|
import re |
|
import torch.utils.data |
|
import json |
|
|
|
import kaldiio |
|
from tqdm import tqdm |
|
|
|
|
|
class BaseCollate: |
|
def __init__(self, n_frames_per_step=1): |
|
self.n_frames_per_step = n_frames_per_step |
|
|
|
def collate_text_mel(self, batch: [dict]): |
|
""" |
|
:param batch: list of dicts |
|
""" |
|
utt = list(map(lambda x: x['utt'], batch)) |
|
input_lengths, ids_sorted_decreasing = torch.sort( |
|
torch.LongTensor([len(x['text']) for x in batch]), |
|
dim=0, descending=True) |
|
max_input_len = input_lengths[0] |
|
|
|
text_padded = torch.LongTensor(len(batch), max_input_len) |
|
text_padded.zero_() |
|
for i in range(len(ids_sorted_decreasing)): |
|
text = batch[ids_sorted_decreasing[i]]['text'] |
|
text_padded[i, :text.size(0)] = text |
|
|
|
|
|
num_mels = batch[0]['mel'].size(0) |
|
max_target_len = max([x['mel'].size(1) for x in batch]) |
|
if max_target_len % self.n_frames_per_step != 0: |
|
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step |
|
assert max_target_len % self.n_frames_per_step == 0 |
|
|
|
|
|
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) |
|
mel_padded.zero_() |
|
output_lengths = torch.LongTensor(len(batch)) |
|
for i in range(len(ids_sorted_decreasing)): |
|
mel = batch[ids_sorted_decreasing[i]]['mel'] |
|
mel_padded[i, :, :mel.size(1)] = mel |
|
output_lengths[i] = mel.size(1) |
|
|
|
utt_name = np.array(utt)[ids_sorted_decreasing].tolist() |
|
if isinstance(utt_name, str): |
|
utt_name = [utt_name] |
|
|
|
res = { |
|
"utt": utt_name, |
|
"text_padded": text_padded, |
|
"input_lengths": input_lengths, |
|
"mel_padded": mel_padded, |
|
"output_lengths": output_lengths, |
|
} |
|
return res, ids_sorted_decreasing |
|
|
|
|
|
class SpkIDCollate(BaseCollate): |
|
def __call__(self, batch, *args, **kwargs): |
|
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) |
|
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) |
|
spk_ids = spk_ids[ids_sorted_decreasing] |
|
base_data.update({ |
|
"spk_ids": spk_ids |
|
}) |
|
return base_data |
|
|
|
|
|
class SpkIDCollateWithEmo(BaseCollate): |
|
def __call__(self, batch, *args, **kwargs): |
|
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) |
|
|
|
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) |
|
spk_ids = spk_ids[ids_sorted_decreasing] |
|
emo_ids = torch.LongTensor(list(map(lambda x: x['emo_ids'], batch))) |
|
emo_ids = emo_ids[ids_sorted_decreasing] |
|
base_data.update({ |
|
"spk_ids": spk_ids, |
|
"emo_ids": emo_ids |
|
}) |
|
return base_data |
|
|
|
|
|
class XvectorCollate(BaseCollate): |
|
def __call__(self, batch, *args, **kwargs): |
|
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) |
|
xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) |
|
xvectors = xvectors[ids_sorted_decreasing] |
|
base_data.update({ |
|
"xvector": xvectors |
|
}) |
|
return base_data |
|
|
|
|
|
class SpkIDCollateWithPE(BaseCollate): |
|
def __call__(self, batch, *args, **kwargs): |
|
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) |
|
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) |
|
spk_ids = spk_ids[ids_sorted_decreasing] |
|
|
|
num_var = batch[0]["var"].size(0) |
|
max_target_len = max([x["var"].size(1) for x in batch]) |
|
if max_target_len % self.n_frames_per_step != 0: |
|
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step |
|
assert max_target_len % self.n_frames_per_step == 0 |
|
|
|
var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) |
|
var_padded.zero_() |
|
for i in range(len(ids_sorted_decreasing)): |
|
var = batch[ids_sorted_decreasing[i]]["var"] |
|
var_padded[i, :, :var.size(1)] = var |
|
|
|
base_data.update({ |
|
"spk_ids": spk_ids, |
|
"var_padded": var_padded |
|
}) |
|
return base_data |
|
|
|
|
|
class XvectorCollateWithPE(BaseCollate): |
|
def __call__(self, batch, *args, **kwargs): |
|
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) |
|
xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) |
|
xvectors = xvectors[ids_sorted_decreasing] |
|
|
|
num_var = batch[0]["var"].size(0) |
|
max_target_len = max([x["var"].size(1) for x in batch]) |
|
if max_target_len % self.n_frames_per_step != 0: |
|
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step |
|
assert max_target_len % self.n_frames_per_step == 0 |
|
|
|
var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) |
|
var_padded.zero_() |
|
for i in range(len(ids_sorted_decreasing)): |
|
var = batch[ids_sorted_decreasing[i]]["var"] |
|
var_padded[i, :, :var.size(1)] = var |
|
|
|
base_data.update({ |
|
"xvector": xvectors, |
|
"var_padded": var_padded |
|
}) |
|
return base_data |
|
|