yuancwang
init
b725c5a
raw
history blame
No virus
18.6 kB
import random
import torch
from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *
from processors.acoustic_extractor import cal_normalized_mel
from processors.acoustic_extractor import load_normalized
from models.base.base_dataset import (
BaseCollator,
BaseDataset,
BaseTestDataset,
BaseTestCollator,
)
from text import text_to_sequence
from text.cmudict import valid_symbols
from tqdm import tqdm
import pickle
class NS2Dataset(torch.utils.data.Dataset):
def __init__(self, cfg, dataset, is_valid=False):
assert isinstance(dataset, str)
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
# for example: /home/v-detaixin/LibriTTS/processed_data; train-full
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
# train.json
self.metafile_path = os.path.join(processed_data_dir, meta_file)
# /home/v-detaixin/LibriTTS/processed_data/train-full/train.json
self.metadata = self.get_metadata()
self.cfg = cfg
assert cfg.preprocess.use_mel == False
if cfg.preprocess.use_mel:
self.utt2melspec_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2melspec_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.melspec_dir, # mel
utt_info["speaker"],
uid + ".npy",
)
assert cfg.preprocess.use_code == True
if cfg.preprocess.use_code:
self.utt2code_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2code_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.code_dir, # code
utt_info["speaker"],
uid + ".npy",
)
assert cfg.preprocess.use_spkid == True
if cfg.preprocess.use_spkid:
self.utt2spkid = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2spkid[utt] = utt_info["speaker"]
assert cfg.preprocess.use_pitch == True
if cfg.preprocess.use_pitch:
self.utt2pitch_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2pitch_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.pitch_dir, # pitch
utt_info["speaker"],
uid + ".npy",
)
assert cfg.preprocess.use_duration == True
if cfg.preprocess.use_duration:
self.utt2duration_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2duration_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.duration_dir, # duration
utt_info["speaker"],
uid + ".npy",
)
assert cfg.preprocess.use_phone == True
if cfg.preprocess.use_phone:
self.utt2phone = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2phone[utt] = utt_info["phones"]
assert cfg.preprocess.use_len == True
if cfg.preprocess.use_len:
self.utt2len = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2len[utt] = utt_info["num_frames"]
# for cross reference
if cfg.preprocess.use_cross_reference:
self.spkid2utt = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
spkid = utt_info["speaker"]
if spkid not in self.spkid2utt:
self.spkid2utt[spkid] = []
self.spkid2utt[spkid].append(utt)
# get phone to id / id to phone map
self.phone2id, self.id2phone = self.get_phone_map()
self.all_num_frames = []
for i in range(len(self.metadata)):
self.all_num_frames.append(self.metadata[i]["num_frames"])
self.num_frame_sorted = np.array(sorted(self.all_num_frames))
self.num_frame_indices = np.array(
sorted(
range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
)
)
def __len__(self):
return len(self.metadata)
def get_dataset_name(self):
return self.metadata[0]["Dataset"]
def get_metadata(self):
with open(self.metafile_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
print("metadata len: ", len(metadata))
return metadata
def get_phone_map(self):
symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
phone2id = {s: i for i, s in enumerate(symbols)}
id2phone = {i: s for s, i in phone2id.items()}
return phone2id, id2phone
def __getitem__(self, index):
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
single_feature = dict()
if self.cfg.preprocess.read_metadata:
metadata_uid_path = os.path.join(
self.cfg.preprocess.processed_dir,
self.cfg.preprocess.metadata_dir,
dataset,
# utt_info["speaker"],
uid + ".pkl",
)
with open(metadata_uid_path, "rb") as f:
metadata_uid = pickle.load(f)
# code
code = metadata_uid["code"]
# frame_nums
frame_nums = code.shape[1]
# pitch
pitch = metadata_uid["pitch"]
# duration
duration = metadata_uid["duration"]
# phone_id
phone_id = np.array(
[
*map(
self.phone2id.get,
self.utt2phone[utt].replace("{", "").replace("}", "").split(),
)
]
)
else:
# code
code = np.load(self.utt2code_path[utt])
# frame_nums
frame_nums = code.shape[1]
# pitch
pitch = np.load(self.utt2pitch_path[utt])
# duration
duration = np.load(self.utt2duration_path[utt])
# phone_id
phone_id = np.array(
[
*map(
self.phone2id.get,
self.utt2phone[utt].replace("{", "").replace("}", "").split(),
)
]
)
# align length
code, pitch, duration, phone_id, frame_nums = self.align_length(
code, pitch, duration, phone_id, frame_nums
)
# spkid
spkid = self.utt2spkid[utt]
# get target and reference
out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums)
code, ref_code = out["code"], out["ref_code"]
pitch, ref_pitch = out["pitch"], out["ref_pitch"]
duration, ref_duration = out["duration"], out["ref_duration"]
phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"]
frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"]
# phone_id_frame
assert len(phone_id) == len(duration)
phone_id_frame = []
for i in range(len(phone_id)):
phone_id_frame.extend([phone_id[i] for _ in range(duration[i])])
phone_id_frame = np.array(phone_id_frame)
# ref_phone_id_frame
assert len(ref_phone_id) == len(ref_duration)
ref_phone_id_frame = []
for i in range(len(ref_phone_id)):
ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])])
ref_phone_id_frame = np.array(ref_phone_id_frame)
single_feature.update(
{
"code": code,
"frame_nums": frame_nums,
"pitch": pitch,
"duration": duration,
"phone_id": phone_id,
"phone_id_frame": phone_id_frame,
"ref_code": ref_code,
"ref_frame_nums": ref_frame_nums,
"ref_pitch": ref_pitch,
"ref_duration": ref_duration,
"ref_phone_id": ref_phone_id,
"ref_phone_id_frame": ref_phone_id_frame,
"spkid": spkid,
}
)
return single_feature
def get_num_frames(self, index):
utt_info = self.metadata[index]
return utt_info["num_frames"]
def align_length(self, code, pitch, duration, phone_id, frame_nums):
# aligh lenght of code, pitch, duration, phone_id, and frame nums
code_len = code.shape[1]
pitch_len = len(pitch)
dur_sum = sum(duration)
min_len = min(code_len, dur_sum)
code = code[:, :min_len]
if pitch_len >= min_len:
pitch = pitch[:min_len]
else:
pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge")
frame_nums = min_len
if dur_sum > min_len:
assert (duration[-1] - (dur_sum - min_len)) >= 0
duration[-1] = duration[-1] - (dur_sum - min_len)
assert duration[-1] >= 0
return code, pitch, duration, phone_id, frame_nums
def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums):
phone_nums = len(phone_id)
clip_phone_nums = np.random.randint(
int(phone_nums * 0.1), int(phone_nums * 0.5) + 1
)
clip_phone_nums = max(clip_phone_nums, 1)
assert clip_phone_nums < phone_nums and clip_phone_nums >= 1
if self.cfg.preprocess.clip_mode == "mid":
start_idx = np.random.randint(0, phone_nums - clip_phone_nums)
elif self.cfg.preprocess.clip_mode == "start":
if duration[0] == 0 and clip_phone_nums == 1:
start_idx = 1
else:
start_idx = 0
else:
assert self.cfg.preprocess.clip_mode in ["mid", "start"]
end_idx = start_idx + clip_phone_nums
start_frames = sum(duration[:start_idx])
end_frames = sum(duration[:end_idx])
new_code = np.concatenate(
(code[:, :start_frames], code[:, end_frames:]), axis=1
)
ref_code = code[:, start_frames:end_frames]
new_pitch = np.append(pitch[:start_frames], pitch[end_frames:])
ref_pitch = pitch[start_frames:end_frames]
new_duration = np.append(duration[:start_idx], duration[end_idx:])
ref_duration = duration[start_idx:end_idx]
new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:])
ref_phone_id = phone_id[start_idx:end_idx]
new_frame_nums = frame_nums - (end_frames - start_frames)
ref_frame_nums = end_frames - start_frames
return {
"code": new_code,
"ref_code": ref_code,
"pitch": new_pitch,
"ref_pitch": ref_pitch,
"duration": new_duration,
"ref_duration": ref_duration,
"phone_id": new_phone_id,
"ref_phone_id": ref_phone_id,
"frame_nums": new_frame_nums,
"ref_frame_nums": ref_frame_nums,
}
class NS2Collator(BaseCollator):
def __init__(self, cfg):
BaseCollator.__init__(self, cfg)
def __call__(self, batch):
packed_batch_features = dict()
# code: (B, 16, T)
# frame_nums: (B,) not used
# pitch: (B, T)
# duration: (B, N)
# phone_id: (B, N)
# phone_id_frame: (B, T)
# ref_code: (B, 16, T')
# ref_frame_nums: (B,) not used
# ref_pitch: (B, T) not used
# ref_duration: (B, N') not used
# ref_phone_id: (B, N') not used
# ref_phone_frame: (B, T') not used
# spkid: (B,) not used
# phone_mask: (B, N)
# mask: (B, T)
# ref_mask: (B, T')
for key in batch[0].keys():
if key == "phone_id":
phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch]
phone_masks = [torch.ones(len(b["phone_id"])) for b in batch]
packed_batch_features["phone_id"] = pad_sequence(
phone_ids,
batch_first=True,
padding_value=0,
)
packed_batch_features["phone_mask"] = pad_sequence(
phone_masks,
batch_first=True,
padding_value=0,
)
elif key == "phone_id_frame":
phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch]
masks = [torch.ones(len(b["phone_id_frame"])) for b in batch]
packed_batch_features["phone_id_frame"] = pad_sequence(
phone_id_frames,
batch_first=True,
padding_value=0,
)
packed_batch_features["mask"] = pad_sequence(
masks,
batch_first=True,
padding_value=0,
)
elif key == "ref_code":
ref_codes = [
torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch
]
ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch]
packed_batch_features["ref_code"] = pad_sequence(
ref_codes,
batch_first=True,
padding_value=0,
).transpose(1, 2)
packed_batch_features["ref_mask"] = pad_sequence(
ref_masks,
batch_first=True,
padding_value=0,
)
elif key == "code":
codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch]
masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch]
packed_batch_features["code"] = pad_sequence(
codes,
batch_first=True,
padding_value=0,
).transpose(1, 2)
packed_batch_features["mask"] = pad_sequence(
masks,
batch_first=True,
padding_value=0,
)
elif key == "pitch":
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=50.0
)
elif key == "duration":
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
elif key == "frame_nums":
packed_batch_features["frame_nums"] = torch.LongTensor(
[b["frame_nums"] for b in batch]
)
elif key == "ref_frame_nums":
packed_batch_features["ref_frame_nums"] = torch.LongTensor(
[b["ref_frame_nums"] for b in batch]
)
else:
pass
return packed_batch_features
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
if len(batch) == 0:
return 0
if len(batch) == max_sentences:
return 1
if num_tokens > max_tokens:
return 1
return 0
def batch_by_size(
indices,
num_tokens_fn,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
"""
bsz_mult = required_batch_size_multiple
sample_len = 0
sample_lens = []
batch = []
batches = []
for i in range(len(indices)):
idx = indices[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert (
sample_len <= max_tokens
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
idx, sample_len, max_tokens
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches