# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. import random import numpy as np import torch import torchaudio as ta from text import text_to_sequence, cmudict from text.symbols import symbols from utils import parse_filelist, intersperse from model.utils import fix_len_compatibility from params import seed as random_seed import sys sys.path.insert(0, 'hifi-gan') from meldataset import mel_spectrogram class TextMelDataset(torch.utils.data.Dataset): def __init__(self, filelist_path, cmudict_path, add_blank=True, n_fft=1024, n_mels=80, sample_rate=22050, hop_length=256, win_length=1024, f_min=0., f_max=8000): self.filepaths_and_text = parse_filelist(filelist_path) self.cmudict = cmudict.CMUDict(cmudict_path) self.add_blank = add_blank self.n_fft = n_fft self.n_mels = n_mels self.sample_rate = sample_rate self.hop_length = hop_length self.win_length = win_length self.f_min = f_min self.f_max = f_max random.seed(random_seed) random.shuffle(self.filepaths_and_text) def get_pair(self, filepath_and_text): filepath, text = filepath_and_text[0], filepath_and_text[1] text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) return (text, mel) def get_mel(self, filepath): audio, sr = ta.load(filepath) assert sr == self.sample_rate mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False).squeeze() return mel def get_text(self, text, add_blank=True): text_norm = text_to_sequence(text, dictionary=self.cmudict) if self.add_blank: text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) text_norm = torch.IntTensor(text_norm) return text_norm def __getitem__(self, index): text, mel = self.get_pair(self.filepaths_and_text[index]) item = {'y': mel, 'x': text} return item def __len__(self): return len(self.filepaths_and_text) def sample_test_batch(self, size): idx = np.random.choice(range(len(self)), size=size, replace=False) test_batch = [] for index in idx: test_batch.append(self.__getitem__(index)) return test_batch class TextMelBatchCollate(object): def __call__(self, batch): B = len(batch) y_max_length = max([item['y'].shape[-1] for item in batch]) y_max_length = fix_len_compatibility(y_max_length) x_max_length = max([item['x'].shape[-1] for item in batch]) n_feats = batch[0]['y'].shape[-2] y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) x = torch.zeros((B, x_max_length), dtype=torch.long) y_lengths, x_lengths = [], [] for i, item in enumerate(batch): y_, x_ = item['y'], item['x'] y_lengths.append(y_.shape[-1]) x_lengths.append(x_.shape[-1]) y[i, :, :y_.shape[-1]] = y_ x[i, :x_.shape[-1]] = x_ y_lengths = torch.LongTensor(y_lengths) x_lengths = torch.LongTensor(x_lengths) return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths} class TextMelSpeakerDataset(torch.utils.data.Dataset): def __init__(self, filelist_path, cmudict_path, add_blank=True, n_fft=1024, n_mels=80, sample_rate=22050, hop_length=256, win_length=1024, f_min=0., f_max=8000): super().__init__() self.filelist = parse_filelist(filelist_path, split_char='|') self.cmudict = cmudict.CMUDict(cmudict_path) self.n_fft = n_fft self.n_mels = n_mels self.sample_rate = sample_rate self.hop_length = hop_length self.win_length = win_length self.f_min = f_min self.f_max = f_max self.add_blank = add_blank random.seed(random_seed) random.shuffle(self.filelist) def get_triplet(self, line): filepath, text, speaker = line[0], line[1], line[2] text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) speaker = self.get_speaker(speaker) return (text, mel, speaker) def get_mel(self, filepath): audio, sr = ta.load(filepath) assert sr == self.sample_rate mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False).squeeze() return mel def get_text(self, text, add_blank=True): text_norm = text_to_sequence(text, dictionary=self.cmudict) if self.add_blank: text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) text_norm = torch.LongTensor(text_norm) return text_norm def get_speaker(self, speaker): speaker = torch.LongTensor([int(speaker)]) return speaker def __getitem__(self, index): text, mel, speaker = self.get_triplet(self.filelist[index]) item = {'y': mel, 'x': text, 'spk': speaker} return item def __len__(self): return len(self.filelist) def sample_test_batch(self, size): idx = np.random.choice(range(len(self)), size=size, replace=False) test_batch = [] for index in idx: test_batch.append(self.__getitem__(index)) return test_batch class TextMelSpeakerBatchCollate(object): def __call__(self, batch): B = len(batch) y_max_length = max([item['y'].shape[-1] for item in batch]) y_max_length = fix_len_compatibility(y_max_length) x_max_length = max([item['x'].shape[-1] for item in batch]) n_feats = batch[0]['y'].shape[-2] y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) x = torch.zeros((B, x_max_length), dtype=torch.long) y_lengths, x_lengths = [], [] spk = [] for i, item in enumerate(batch): y_, x_, spk_ = item['y'], item['x'], item['spk'] y_lengths.append(y_.shape[-1]) x_lengths.append(x_.shape[-1]) y[i, :, :y_.shape[-1]] = y_ x[i, :x_.shape[-1]] = x_ spk.append(spk_) y_lengths = torch.LongTensor(y_lengths) x_lengths = torch.LongTensor(x_lengths) spk = torch.cat(spk, dim=0) return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths, 'spk': spk}