# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from models.tts.base.tts_dataset import ( TTSDataset, TTSCollator, TTSTestDataset, TTSTestCollator, ) from torch.utils.data.sampler import ( BatchSampler, RandomSampler, SequentialSampler, ) from utils.tokenizer import tokenize_audio class VALLEDataset(TTSDataset): def __init__(self, cfg, dataset, is_valid=False): super().__init__(cfg, dataset, is_valid=is_valid) """ Args: cfg: config dataset: dataset name is_valid: whether to use train or valid dataset """ assert isinstance(dataset, str) assert cfg.preprocess.use_acoustic_token == True if cfg.preprocess.use_acoustic_token: self.utt2acousticToken_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2acousticToken_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.acoustic_token_dir, # code uid + ".npy", ) self.all_num_frames = [] for i in range(len(self.metadata)): self.all_num_frames.append(self.metadata[i]["Duration"]) self.num_frame_sorted = np.array(sorted(self.all_num_frames)) self.num_frame_indices = np.array( sorted( range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] ) ) def __len__(self): return super().__len__() def get_metadata(self): metadata_filter = [] with open(self.metafile_path, "r", encoding="utf-8") as f: metadata = json.load(f) for utt_info in metadata: duration = utt_info["Duration"] if ( duration >= self.cfg.preprocess.max_duration or duration <= self.cfg.preprocess.min_duration ): continue metadata_filter.append(utt_info) return metadata_filter def get_dur(self, idx): utt_info = self.metadata[idx] return utt_info["Duration"] def __getitem__(self, index): single_feature = super().__getitem__(index) utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) # acoustic token if self.cfg.preprocess.use_acoustic_token: acoustic_token = np.load(self.utt2acousticToken_path[utt]) if "target_len" not in single_feature.keys(): single_feature["target_len"] = acoustic_token.shape[0] single_feature["acoustic_token"] = acoustic_token # [T, 8] return single_feature def get_num_frames(self, index): utt_info = self.metadata[index] return int( utt_info["Duration"] * (self.cfg.preprocess.sample_rate // self.cfg.preprocess.codec_hop_size) ) class VALLECollator(TTSCollator): def __init__(self, cfg): super().__init__(cfg) def __call__(self, batch): parsed_batch_features = super().__call__(batch) return parsed_batch_features class VALLETestDataset(TTSTestDataset): def __init__(self, args, cfg): super().__init__(args, cfg) # prepare data assert cfg.preprocess.use_acoustic_token == True if cfg.preprocess.use_acoustic_token: self.utt2acousticToken = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) # extract acoustic token audio_file = utt_info["Audio_pormpt_path"] encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file) audio_prompt_token = ( encoded_frames[0][0].transpose(2, 1).squeeze(0).cpu().numpy() ) self.utt2acousticToken[utt] = audio_prompt_token def __getitem__(self, index): utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) single_feature = dict() # acoustic token if self.cfg.preprocess.use_acoustic_token: acoustic_token = self.utt2acousticToken[utt] if "target_len" not in single_feature.keys(): single_feature["target_len"] = acoustic_token.shape[0] single_feature["acoustic_token"] = acoustic_token # [T, 8] # phone sequence todo if self.cfg.preprocess.use_phone: single_feature["phone_seq"] = np.array(self.utt2seq[utt]) single_feature["phone_len"] = len(self.utt2seq[utt]) single_feature["pmt_phone_seq"] = np.array(self.utt2pmtseq[utt]) single_feature["pmt_phone_len"] = len(self.utt2pmtseq[utt]) return single_feature def get_metadata(self): with open(self.metafile_path, "r", encoding="utf-8") as f: metadata = json.load(f) return metadata def __len__(self): return len(self.metadata) class VALLETestCollator(TTSTestCollator): def __init__(self, cfg): self.cfg = cfg def __call__(self, batch): packed_batch_features = dict() for key in batch[0].keys(): if key == "target_len": packed_batch_features["target_len"] = torch.LongTensor( [b["target_len"] for b in batch] ) masks = [ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "phone_len": packed_batch_features["phone_len"] = torch.LongTensor( [b["phone_len"] for b in batch] ) masks = [ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["phn_mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "pmt_phone_len": packed_batch_features["pmt_phone_len"] = torch.LongTensor( [b["pmt_phone_len"] for b in batch] ) masks = [ torch.ones((b["pmt_phone_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["pmt_phone_len_mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "audio_len": packed_batch_features["audio_len"] = torch.LongTensor( [b["audio_len"] for b in batch] ) masks = [ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch ] else: values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) return packed_batch_features def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): if len(batch) == 0: return 0 if len(batch) == max_sentences: return 1 if num_tokens > max_tokens: return 1 return 0 def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ bsz_mult = required_batch_size_multiple sample_len = 0 sample_lens = [] batch = [] batches = [] for i in range(len(indices)): idx = indices[i] num_tokens = num_tokens_fn(idx) sample_lens.append(num_tokens) sample_len = max(sample_len, num_tokens) assert ( sample_len <= max_tokens ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( idx, sample_len, max_tokens ) num_tokens = (len(batch) + 1) * sample_len if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, ) batches.append(batch[:mod_len]) batch = batch[mod_len:] sample_lens = sample_lens[mod_len:] sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 batch.append(idx) if len(batch) > 0: batches.append(batch) return batches class VariableSampler(BatchSampler): def __init__(self, sampler, drop_last: bool, use_random_sampler=False): self.data_list = sampler if use_random_sampler: self.sampler = RandomSampler(sampler) else: self.sampler = SequentialSampler(sampler) super().__init__(self.sampler, 1, drop_last) def __iter__(self): for batch_ids in self.data_list: yield batch_ids def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size