|
|
|
|
|
|
|
|
|
|
|
import random |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from utils.data_utils import * |
|
from processors.acoustic_extractor import cal_normalized_mel |
|
from processors.acoustic_extractor import load_normalized |
|
from models.base.base_dataset import ( |
|
BaseOfflineCollator, |
|
BaseOfflineDataset, |
|
BaseTestDataset, |
|
BaseTestCollator, |
|
) |
|
from text import text_to_sequence |
|
from text.cmudict import valid_symbols |
|
from tqdm import tqdm |
|
import pickle |
|
|
|
|
|
class NS2Dataset(torch.utils.data.Dataset): |
|
def __init__(self, cfg, dataset, is_valid=False): |
|
assert isinstance(dataset, str) |
|
|
|
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) |
|
|
|
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file |
|
|
|
|
|
self.metafile_path = os.path.join(processed_data_dir, meta_file) |
|
|
|
self.metadata = self.get_metadata() |
|
|
|
self.cfg = cfg |
|
|
|
assert cfg.preprocess.use_mel == False |
|
if cfg.preprocess.use_mel: |
|
self.utt2melspec_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2melspec_path[utt] = os.path.join( |
|
cfg.preprocess.processed_dir, |
|
dataset, |
|
cfg.preprocess.melspec_dir, |
|
utt_info["speaker"], |
|
uid + ".npy", |
|
) |
|
|
|
assert cfg.preprocess.use_code == True |
|
if cfg.preprocess.use_code: |
|
self.utt2code_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2code_path[utt] = os.path.join( |
|
cfg.preprocess.processed_dir, |
|
dataset, |
|
cfg.preprocess.code_dir, |
|
utt_info["speaker"], |
|
uid + ".npy", |
|
) |
|
|
|
assert cfg.preprocess.use_spkid == True |
|
if cfg.preprocess.use_spkid: |
|
self.utt2spkid = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2spkid[utt] = utt_info["speaker"] |
|
|
|
assert cfg.preprocess.use_pitch == True |
|
if cfg.preprocess.use_pitch: |
|
self.utt2pitch_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2pitch_path[utt] = os.path.join( |
|
cfg.preprocess.processed_dir, |
|
dataset, |
|
cfg.preprocess.pitch_dir, |
|
utt_info["speaker"], |
|
uid + ".npy", |
|
) |
|
|
|
assert cfg.preprocess.use_duration == True |
|
if cfg.preprocess.use_duration: |
|
self.utt2duration_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2duration_path[utt] = os.path.join( |
|
cfg.preprocess.processed_dir, |
|
dataset, |
|
cfg.preprocess.duration_dir, |
|
utt_info["speaker"], |
|
uid + ".npy", |
|
) |
|
|
|
assert cfg.preprocess.use_phone == True |
|
if cfg.preprocess.use_phone: |
|
self.utt2phone = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2phone[utt] = utt_info["phones"] |
|
|
|
assert cfg.preprocess.use_len == True |
|
if cfg.preprocess.use_len: |
|
self.utt2len = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2len[utt] = utt_info["num_frames"] |
|
|
|
|
|
if cfg.preprocess.use_cross_reference: |
|
self.spkid2utt = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
spkid = utt_info["speaker"] |
|
if spkid not in self.spkid2utt: |
|
self.spkid2utt[spkid] = [] |
|
self.spkid2utt[spkid].append(utt) |
|
|
|
|
|
self.phone2id, self.id2phone = self.get_phone_map() |
|
|
|
self.all_num_frames = [] |
|
for i in range(len(self.metadata)): |
|
self.all_num_frames.append(self.metadata[i]["num_frames"]) |
|
self.num_frame_sorted = np.array(sorted(self.all_num_frames)) |
|
self.num_frame_indices = np.array( |
|
sorted( |
|
range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] |
|
) |
|
) |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def get_dataset_name(self): |
|
return self.metadata[0]["Dataset"] |
|
|
|
def get_metadata(self): |
|
with open(self.metafile_path, "r", encoding="utf-8") as f: |
|
metadata = json.load(f) |
|
|
|
print("metadata len: ", len(metadata)) |
|
|
|
return metadata |
|
|
|
def get_phone_map(self): |
|
symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"] |
|
phone2id = {s: i for i, s in enumerate(symbols)} |
|
id2phone = {i: s for s, i in phone2id.items()} |
|
return phone2id, id2phone |
|
|
|
def __getitem__(self, index): |
|
utt_info = self.metadata[index] |
|
|
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
single_feature = dict() |
|
|
|
if self.cfg.preprocess.read_metadata: |
|
metadata_uid_path = os.path.join( |
|
self.cfg.preprocess.processed_dir, |
|
self.cfg.preprocess.metadata_dir, |
|
dataset, |
|
|
|
uid + ".pkl", |
|
) |
|
with open(metadata_uid_path, "rb") as f: |
|
metadata_uid = pickle.load(f) |
|
|
|
code = metadata_uid["code"] |
|
|
|
frame_nums = code.shape[1] |
|
|
|
pitch = metadata_uid["pitch"] |
|
|
|
duration = metadata_uid["duration"] |
|
|
|
phone_id = np.array( |
|
[ |
|
*map( |
|
self.phone2id.get, |
|
self.utt2phone[utt].replace("{", "").replace("}", "").split(), |
|
) |
|
] |
|
) |
|
|
|
else: |
|
|
|
code = np.load(self.utt2code_path[utt]) |
|
|
|
frame_nums = code.shape[1] |
|
|
|
pitch = np.load(self.utt2pitch_path[utt]) |
|
|
|
duration = np.load(self.utt2duration_path[utt]) |
|
|
|
phone_id = np.array( |
|
[ |
|
*map( |
|
self.phone2id.get, |
|
self.utt2phone[utt].replace("{", "").replace("}", "").split(), |
|
) |
|
] |
|
) |
|
|
|
|
|
code, pitch, duration, phone_id, frame_nums = self.align_length( |
|
code, pitch, duration, phone_id, frame_nums |
|
) |
|
|
|
|
|
spkid = self.utt2spkid[utt] |
|
|
|
|
|
out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums) |
|
code, ref_code = out["code"], out["ref_code"] |
|
pitch, ref_pitch = out["pitch"], out["ref_pitch"] |
|
duration, ref_duration = out["duration"], out["ref_duration"] |
|
phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"] |
|
frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"] |
|
|
|
|
|
assert len(phone_id) == len(duration) |
|
phone_id_frame = [] |
|
for i in range(len(phone_id)): |
|
phone_id_frame.extend([phone_id[i] for _ in range(duration[i])]) |
|
phone_id_frame = np.array(phone_id_frame) |
|
|
|
|
|
assert len(ref_phone_id) == len(ref_duration) |
|
ref_phone_id_frame = [] |
|
for i in range(len(ref_phone_id)): |
|
ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])]) |
|
ref_phone_id_frame = np.array(ref_phone_id_frame) |
|
|
|
single_feature.update( |
|
{ |
|
"code": code, |
|
"frame_nums": frame_nums, |
|
"pitch": pitch, |
|
"duration": duration, |
|
"phone_id": phone_id, |
|
"phone_id_frame": phone_id_frame, |
|
"ref_code": ref_code, |
|
"ref_frame_nums": ref_frame_nums, |
|
"ref_pitch": ref_pitch, |
|
"ref_duration": ref_duration, |
|
"ref_phone_id": ref_phone_id, |
|
"ref_phone_id_frame": ref_phone_id_frame, |
|
"spkid": spkid, |
|
} |
|
) |
|
|
|
return single_feature |
|
|
|
def get_num_frames(self, index): |
|
utt_info = self.metadata[index] |
|
return utt_info["num_frames"] |
|
|
|
def align_length(self, code, pitch, duration, phone_id, frame_nums): |
|
|
|
code_len = code.shape[1] |
|
pitch_len = len(pitch) |
|
dur_sum = sum(duration) |
|
min_len = min(code_len, dur_sum) |
|
code = code[:, :min_len] |
|
if pitch_len >= min_len: |
|
pitch = pitch[:min_len] |
|
else: |
|
pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge") |
|
frame_nums = min_len |
|
if dur_sum > min_len: |
|
assert (duration[-1] - (dur_sum - min_len)) >= 0 |
|
duration[-1] = duration[-1] - (dur_sum - min_len) |
|
assert duration[-1] >= 0 |
|
|
|
return code, pitch, duration, phone_id, frame_nums |
|
|
|
def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums): |
|
phone_nums = len(phone_id) |
|
clip_phone_nums = np.random.randint( |
|
int(phone_nums * 0.1), int(phone_nums * 0.5) + 1 |
|
) |
|
clip_phone_nums = max(clip_phone_nums, 1) |
|
assert clip_phone_nums < phone_nums and clip_phone_nums >= 1 |
|
if self.cfg.preprocess.clip_mode == "mid": |
|
start_idx = np.random.randint(0, phone_nums - clip_phone_nums) |
|
elif self.cfg.preprocess.clip_mode == "start": |
|
if duration[0] == 0 and clip_phone_nums == 1: |
|
start_idx = 1 |
|
else: |
|
start_idx = 0 |
|
else: |
|
assert self.cfg.preprocess.clip_mode in ["mid", "start"] |
|
end_idx = start_idx + clip_phone_nums |
|
start_frames = sum(duration[:start_idx]) |
|
end_frames = sum(duration[:end_idx]) |
|
|
|
new_code = np.concatenate( |
|
(code[:, :start_frames], code[:, end_frames:]), axis=1 |
|
) |
|
ref_code = code[:, start_frames:end_frames] |
|
|
|
new_pitch = np.append(pitch[:start_frames], pitch[end_frames:]) |
|
ref_pitch = pitch[start_frames:end_frames] |
|
|
|
new_duration = np.append(duration[:start_idx], duration[end_idx:]) |
|
ref_duration = duration[start_idx:end_idx] |
|
|
|
new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:]) |
|
ref_phone_id = phone_id[start_idx:end_idx] |
|
|
|
new_frame_nums = frame_nums - (end_frames - start_frames) |
|
ref_frame_nums = end_frames - start_frames |
|
|
|
return { |
|
"code": new_code, |
|
"ref_code": ref_code, |
|
"pitch": new_pitch, |
|
"ref_pitch": ref_pitch, |
|
"duration": new_duration, |
|
"ref_duration": ref_duration, |
|
"phone_id": new_phone_id, |
|
"ref_phone_id": ref_phone_id, |
|
"frame_nums": new_frame_nums, |
|
"ref_frame_nums": ref_frame_nums, |
|
} |
|
|
|
|
|
class NS2Collator(BaseOfflineCollator): |
|
def __init__(self, cfg): |
|
BaseOfflineCollator.__init__(self, cfg) |
|
|
|
def __call__(self, batch): |
|
packed_batch_features = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key in batch[0].keys(): |
|
if key == "phone_id": |
|
phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch] |
|
phone_masks = [torch.ones(len(b["phone_id"])) for b in batch] |
|
packed_batch_features["phone_id"] = pad_sequence( |
|
phone_ids, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
packed_batch_features["phone_mask"] = pad_sequence( |
|
phone_masks, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
elif key == "phone_id_frame": |
|
phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch] |
|
masks = [torch.ones(len(b["phone_id_frame"])) for b in batch] |
|
packed_batch_features["phone_id_frame"] = pad_sequence( |
|
phone_id_frames, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
packed_batch_features["mask"] = pad_sequence( |
|
masks, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
elif key == "ref_code": |
|
ref_codes = [ |
|
torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch |
|
] |
|
ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch] |
|
packed_batch_features["ref_code"] = pad_sequence( |
|
ref_codes, |
|
batch_first=True, |
|
padding_value=0, |
|
).transpose(1, 2) |
|
packed_batch_features["ref_mask"] = pad_sequence( |
|
ref_masks, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
elif key == "code": |
|
codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch] |
|
masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch] |
|
packed_batch_features["code"] = pad_sequence( |
|
codes, |
|
batch_first=True, |
|
padding_value=0, |
|
).transpose(1, 2) |
|
packed_batch_features["mask"] = pad_sequence( |
|
masks, |
|
batch_first=True, |
|
padding_value=0, |
|
) |
|
elif key == "pitch": |
|
values = [torch.from_numpy(b[key]) for b in batch] |
|
packed_batch_features[key] = pad_sequence( |
|
values, batch_first=True, padding_value=50.0 |
|
) |
|
elif key == "duration": |
|
values = [torch.from_numpy(b[key]) for b in batch] |
|
packed_batch_features[key] = pad_sequence( |
|
values, batch_first=True, padding_value=0 |
|
) |
|
elif key == "frame_nums": |
|
packed_batch_features["frame_nums"] = torch.LongTensor( |
|
[b["frame_nums"] for b in batch] |
|
) |
|
elif key == "ref_frame_nums": |
|
packed_batch_features["ref_frame_nums"] = torch.LongTensor( |
|
[b["ref_frame_nums"] for b in batch] |
|
) |
|
else: |
|
pass |
|
|
|
return packed_batch_features |
|
|
|
|
|
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
if len(batch) == 0: |
|
return 0 |
|
if len(batch) == max_sentences: |
|
return 1 |
|
if num_tokens > max_tokens: |
|
return 1 |
|
return 0 |
|
|
|
|
|
def batch_by_size( |
|
indices, |
|
num_tokens_fn, |
|
max_tokens=None, |
|
max_sentences=None, |
|
required_batch_size_multiple=1, |
|
): |
|
""" |
|
Yield mini-batches of indices bucketed by size. Batches may contain |
|
sequences of different lengths. |
|
|
|
Args: |
|
indices (List[int]): ordered list of dataset indices |
|
num_tokens_fn (callable): function that returns the number of tokens at |
|
a given index |
|
max_tokens (int, optional): max number of tokens in each batch |
|
(default: None). |
|
max_sentences (int, optional): max number of sentences in each |
|
batch (default: None). |
|
required_batch_size_multiple (int, optional): require batch size to |
|
be a multiple of N (default: 1). |
|
""" |
|
bsz_mult = required_batch_size_multiple |
|
|
|
sample_len = 0 |
|
sample_lens = [] |
|
batch = [] |
|
batches = [] |
|
for i in range(len(indices)): |
|
idx = indices[i] |
|
num_tokens = num_tokens_fn(idx) |
|
sample_lens.append(num_tokens) |
|
sample_len = max(sample_len, num_tokens) |
|
|
|
assert ( |
|
sample_len <= max_tokens |
|
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( |
|
idx, sample_len, max_tokens |
|
) |
|
num_tokens = (len(batch) + 1) * sample_len |
|
|
|
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
mod_len = max( |
|
bsz_mult * (len(batch) // bsz_mult), |
|
len(batch) % bsz_mult, |
|
) |
|
batches.append(batch[:mod_len]) |
|
batch = batch[mod_len:] |
|
sample_lens = sample_lens[mod_len:] |
|
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 |
|
batch.append(idx) |
|
if len(batch) > 0: |
|
batches.append(batch) |
|
return batches |
|
|