# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Iterable import torch import numpy as np import torch.utils.data from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from torch.utils.data import ConcatDataset, Dataset class VocoderDataset(torch.utils.data.Dataset): def __init__(self, cfg, dataset, is_valid=False): """ Args: cfg: config dataset: dataset name is_valid: whether to use train or valid dataset """ assert isinstance(dataset, str) processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file self.metafile_path = os.path.join(processed_data_dir, meta_file) self.metadata = self.get_metadata() self.data_root = processed_data_dir self.cfg = cfg if cfg.preprocess.use_audio: self.utt2audio_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2audio_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.audio_dir, uid + ".npy", ) elif cfg.preprocess.use_label: self.utt2label_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2label_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.label_dir, uid + ".npy", ) elif cfg.preprocess.use_one_hot: self.utt2one_hot_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2one_hot_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.one_hot_dir, uid + ".npy", ) if cfg.preprocess.use_mel: self.utt2mel_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2mel_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.mel_dir, uid + ".npy", ) if cfg.preprocess.use_frame_pitch: self.utt2frame_pitch_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2frame_pitch_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.pitch_dir, uid + ".npy", ) if cfg.preprocess.use_uv: self.utt2uv_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2uv_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.uv_dir, uid + ".npy", ) if cfg.preprocess.use_amplitude_phase: self.utt2logamp_path = {} self.utt2pha_path = {} self.utt2rea_path = {} self.utt2imag_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2logamp_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.log_amplitude_dir, uid + ".npy", ) self.utt2pha_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.phase_dir, uid + ".npy", ) self.utt2rea_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.real_dir, uid + ".npy", ) self.utt2imag_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.imaginary_dir, uid + ".npy", ) def __getitem__(self, index): utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) single_feature = dict() if self.cfg.preprocess.use_mel: mel = np.load(self.utt2mel_path[utt]) assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] if "target_len" not in single_feature.keys(): single_feature["target_len"] = mel.shape[1] single_feature["mel"] = mel if self.cfg.preprocess.use_frame_pitch: frame_pitch = np.load(self.utt2frame_pitch_path[utt]) if "target_len" not in single_feature.keys(): single_feature["target_len"] = len(frame_pitch) aligned_frame_pitch = align_length( frame_pitch, single_feature["target_len"] ) single_feature["frame_pitch"] = aligned_frame_pitch if self.cfg.preprocess.use_audio: audio = np.load(self.utt2audio_path[utt]) single_feature["audio"] = audio return single_feature def get_metadata(self): with open(self.metafile_path, "r", encoding="utf-8") as f: metadata = json.load(f) return metadata def get_dataset_name(self): return self.metadata[0]["Dataset"] def __len__(self): return len(self.metadata) class VocoderConcatDataset(ConcatDataset): def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False): """Concatenate a series of datasets with their random inference audio merged.""" super().__init__(datasets) self.cfg = self.datasets[0].cfg self.metadata = [] # Merge metadata for dataset in self.datasets: self.metadata += dataset.metadata # Merge random inference features if full_audio_inference: self.eval_audios = [] self.eval_dataset_names = [] if self.cfg.preprocess.use_mel: self.eval_mels = [] if self.cfg.preprocess.use_frame_pitch: self.eval_pitchs = [] for dataset in self.datasets: self.eval_audios.append(dataset.eval_audio) self.eval_dataset_names.append(dataset.get_dataset_name()) if self.cfg.preprocess.use_mel: self.eval_mels.append(dataset.eval_mel) if self.cfg.preprocess.use_frame_pitch: self.eval_pitchs.append(dataset.eval_pitch) class VocoderCollator(object): """Zero-pads model inputs and targets based on number of frames per step""" def __init__(self, cfg): self.cfg = cfg def __call__(self, batch): packed_batch_features = dict() # mel: [b, n_mels, frame] # frame_pitch: [b, frame] # audios: [b, frame * hop_size] for key in batch[0].keys(): if key == "target_len": packed_batch_features["target_len"] = torch.LongTensor( [b["target_len"] for b in batch] ) masks = [ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "mel": values = [torch.from_numpy(b[key]).T for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) else: values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) return packed_batch_features