import os from dataclasses import dataclass, field from itertools import chain from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist import torchaudio from coqpit import Coqpit from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor from TTS.utils.io import load_fsspec from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results def id_to_torch(aux_id, cuda=False): if aux_id is not None: aux_id = np.asarray(aux_id) aux_id = torch.from_numpy(aux_id) if cuda: return aux_id.cuda() return aux_id def embedding_to_torch(d_vector, cuda=False): if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).float() d_vector = d_vector.squeeze().unsqueeze(0) if cuda: return d_vector.cuda() return d_vector def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None tensor = torch.as_tensor(np_array, dtype=dtype) if cuda: return tensor.cuda() return tensor def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: batch_size = lengths.shape[0] max_len = torch.max(lengths).item() ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) return mask def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: out_list = torch.jit.annotate(List[torch.Tensor], []) for batch in input_ele: if len(batch.shape) == 1: one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) else: one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) out_list.append(one_batch_padded) out_padded = torch.stack(out_list) return out_padded def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: return torch.ceil(lens / stride).int() def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." return torch.randn(shape) * np.sqrt(2 / shape[1]) # pylint: disable=redefined-outer-name def calc_same_padding(kernel_size: int) -> Tuple[int, int]: pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) hann_window = {} mel_basis = {} @torch.no_grad() def weights_reset(m: nn.Module): # check if the current module has reset_parameters and if it is reset the weight reset_parameters = getattr(m, "reset_parameters", None) if callable(reset_parameters): m.reset_parameters() def get_module_weights_sum(mdl: nn.Module): dict_sums = {} for name, w in mdl.named_parameters(): if "weight" in name: value = w.data.sum().item() dict_sums[name] = value return dict_sums def load_audio(file_path: str): """Load the audio file normalized in [-1, 1] Return Shapes: - x: :math:`[1, T]` """ x, sr = torchaudio.load( file_path, ) assert (x > 1).sum() + (x < -1).sum() == 0 return x, sr def _amp_to_db(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def _db_to_amp(x, C=1): return torch.exp(x) / C def amp_to_db(magnitudes): output = _amp_to_db(magnitudes) return output def db_to_amp(magnitudes): output = _db_to_amp(magnitudes) return output def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: print("max value is ", torch.max(y)) global hann_window # pylint: disable=global-statement dtype_device = str(y.dtype) + "_" + str(y.device) wnsize_dtype_device = str(win_length) + "_" + dtype_device if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) y = torch.nn.functional.pad( y.unsqueeze(1), (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode="reflect", ) y = y.squeeze(1) spec = torch.stft( y, n_fft, hop_length=hop_length, win_length=win_length, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=False, ) return spec def wav_to_spec(y, n_fft, hop_length, win_length, center=False): """ Args Shapes: - y : :math:`[B, 1, T]` Return Shapes: - spec : :math:`[B,C,T]` """ spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec def wav_to_energy(y, n_fft, hop_length, win_length, center=False): spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return torch.norm(spec, dim=1, keepdim=True) def name_mel_basis(spec, n_fft, fmax): n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" return n_fft_len def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): """ Args Shapes: - spec : :math:`[B,C,T]` Return Shapes: - mel : :math:`[B,C,T]` """ global mel_basis # pylint: disable=global-statement mel_basis_key = name_mel_basis(spec, n_fft, fmax) # pylint: disable=too-many-function-args if mel_basis_key not in mel_basis: # pylint: disable=missing-kwoa mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) mel = torch.matmul(mel_basis[mel_basis_key], spec) mel = amp_to_db(mel) return mel def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): """ Args Shapes: - y : :math:`[B, 1, T_y]` Return Shapes: - spec : :math:`[B,C,T_spec]` """ y = y.squeeze(1) if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: print("max value is ", torch.max(y)) global mel_basis, hann_window # pylint: disable=global-statement mel_basis_key = name_mel_basis(y, n_fft, fmax) wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) if mel_basis_key not in mel_basis: # pylint: disable=missing-kwoa mel = librosa_mel_fn( sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax ) # pylint: disable=too-many-function-args mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) y = torch.nn.functional.pad( y.unsqueeze(1), (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode="reflect", ) y = y.squeeze(1) spec = torch.stft( y, n_fft, hop_length=hop_length, win_length=win_length, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=False, ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.matmul(mel_basis[mel_basis_key], spec) spec = amp_to_db(spec) return spec ############################## # DATASET ############################## def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): """Create balancer weight for torch WeightedSampler""" attr_names_samples = np.array([item[attr_name] for item in items]) unique_attr_names = np.unique(attr_names_samples).tolist() attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) weight_attr = 1.0 / attr_count dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) if multi_dict is not None: multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) dataset_samples_weight *= multiplier_samples return ( torch.from_numpy(dataset_samples_weight).float(), unique_attr_names, np.unique(dataset_samples_weight).tolist(), ) class ForwardTTSE2eF0Dataset(F0Dataset): """Override F0Dataset to avoid slow computing of pitches""" def __init__( self, ap, samples: Union[List[List], List[Dict]], verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): super().__init__( samples=samples, ap=ap, verbose=verbose, cache_path=cache_path, precompute_num_workers=precompute_num_workers, normalize_f0=normalize_f0, ) def _compute_and_save_pitch(self, wav_file, pitch_file=None): wav, _ = load_audio(wav_file) f0 = compute_f0( x=wav.numpy()[0], sample_rate=self.ap.sample_rate, hop_length=self.ap.hop_length, pitch_fmax=self.ap.pitch_fmax, pitch_fmin=self.ap.pitch_fmin, win_length=self.ap.win_length, ) # skip the last F0 value to align with the spectrogram if wav.shape[1] % self.ap.hop_length != 0: f0 = f0[:-1] if pitch_file: np.save(pitch_file, f0) return f0 def compute_or_load(self, wav_file, audio_name): """ compute pitch and return a numpy array of pitch values """ pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) if not os.path.exists(pitch_file): pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) class ForwardTTSE2eDataset(TTSDataset): def __init__(self, *args, **kwargs): # don't init the default F0Dataset in TTSDataset compute_f0 = kwargs.pop("compute_f0", False) kwargs["compute_f0"] = False self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") super().__init__(*args, **kwargs) self.compute_f0 = compute_f0 self.pad_id = self.tokenizer.characters.pad_id self.ap = kwargs["ap"] if self.compute_f0: self.f0_dataset = ForwardTTSE2eF0Dataset( ap=self.ap, samples=self.samples, cache_path=kwargs["f0_cache_path"], precompute_num_workers=kwargs["precompute_num_workers"], ) if self.attn_prior_cache_path is not None: os.makedirs(self.attn_prior_cache_path, exist_ok=True) def __getitem__(self, idx): item = self.samples[idx] rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") rel_wav_path = str(rel_wav_path).replace("/", "_") raw_text = item["text"] wav, _ = load_audio(item["audio_file"]) wav_filename = os.path.basename(item["audio_file"]) try: token_ids = self.get_token_ids(idx, item["text"]) except: print(idx, item) # pylint: disable=raise-missing-from raise OSError f0 = None if self.compute_f0: f0 = self.get_f0(idx)["f0"] # after phonemization the text length may change # this is a shameful 🤭 hack to prevent longer phonemes # TODO: find a better fix if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: self.rescue_item_idx += 1 return self.__getitem__(self.rescue_item_idx) attn_prior = None if self.attn_prior_cache_path is not None: attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) return { "raw_text": raw_text, "token_ids": token_ids, "token_len": len(token_ids), "wav": wav, "pitch": f0, "wav_file": wav_filename, "speaker_name": item["speaker_name"], "language_name": item["language"], "attn_prior": attn_prior, "audio_unique_name": item["audio_unique_name"], } def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): """Load or compute and save the attention prior.""" attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") # pylint: disable=no-else-return if os.path.exists(attn_prior_file): return np.load(attn_prior_file) else: token_len = len(token_ids) mel_len = wav.shape[1] // self.ap.hop_length attn_prior = compute_attn_prior(token_len, mel_len) np.save(attn_prior_file, attn_prior) return attn_prior @property def lengths(self): lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio lens.append(audio_len) return lens def collate_fn(self, batch): """ Return Shapes: - tokens: :math:`[B, T]` - token_lens :math:`[B]` - token_rel_lens :math:`[B]` - pitch :math:`[B, T]` - waveform: :math:`[B, 1, T]` - waveform_lens: :math:`[B]` - waveform_rel_lens: :math:`[B]` - speaker_names: :math:`[B]` - language_names: :math:`[B]` - audiofile_paths: :math:`[B]` - raw_texts: :math:`[B]` - attn_prior: :math:`[[T_token, T_mel]]` """ B = len(batch) batch = {k: [dic[k] for dic in batch] for k in batch[0]} max_text_len = max([len(x) for x in batch["token_ids"]]) token_lens = torch.LongTensor(batch["token_len"]) token_rel_lens = token_lens / token_lens.max() wav_lens = [w.shape[1] for w in batch["wav"]] wav_lens = torch.LongTensor(wav_lens) wav_lens_max = torch.max(wav_lens) wav_rel_lens = wav_lens / wav_lens_max pitch_padded = None if self.compute_f0: pitch_lens = [p.shape[0] for p in batch["pitch"]] pitch_lens = torch.LongTensor(pitch_lens) pitch_lens_max = torch.max(pitch_lens) pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) pitch_padded = pitch_padded.zero_() + self.pad_id token_padded = torch.LongTensor(B, max_text_len) wav_padded = torch.FloatTensor(B, 1, wav_lens_max) token_padded = token_padded.zero_() + self.pad_id wav_padded = wav_padded.zero_() + self.pad_id for i in range(B): token_ids = batch["token_ids"][i] token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) wav = batch["wav"][i] wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) if self.compute_f0: pitch = batch["pitch"][i] pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) return { "text_input": token_padded, "text_lengths": token_lens, "text_rel_lens": token_rel_lens, "pitch": pitch_padded, "waveform": wav_padded, # (B x T) "waveform_lens": wav_lens, # (B) "waveform_rel_lens": wav_rel_lens, "speaker_names": batch["speaker_name"], "language_names": batch["language_name"], "audio_unique_names": batch["audio_unique_name"], "audio_files": batch["wav_file"], "raw_text": batch["raw_text"], "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, } ############################## # CONFIG DEFINITIONS ############################## @dataclass class VocoderConfig(Coqpit): resblock_type_decoder: str = "1" resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) upsample_initial_channel_decoder: int = 512 upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) use_spectral_norm_discriminator: bool = False upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) pretrained_model_path: Optional[str] = None @dataclass class DelightfulTtsAudioConfig(Coqpit): sample_rate: int = 22050 hop_length: int = 256 win_length: int = 1024 fft_size: int = 1024 mel_fmin: float = 0.0 mel_fmax: float = 8000 num_mels: int = 100 pitch_fmax: float = 640.0 pitch_fmin: float = 1.0 resample: bool = False preemphasis: float = 0.0 ref_level_db: int = 20 do_sound_norm: bool = False log_func: str = "np.log10" do_trim_silence: bool = True trim_db: int = 45 do_rms_norm: bool = False db_level: float = None power: float = 1.5 griffin_lim_iters: int = 60 spec_gain: int = 20 do_amp_to_db_linear: bool = True do_amp_to_db_mel: bool = True min_level_db: int = -100 max_norm: float = 4.0 @dataclass class DelightfulTtsArgs(Coqpit): num_chars: int = 100 spec_segment_size: int = 32 n_hidden_conformer_encoder: int = 512 n_layers_conformer_encoder: int = 6 n_heads_conformer_encoder: int = 8 dropout_conformer_encoder: float = 0.1 kernel_size_conv_mod_conformer_encoder: int = 7 kernel_size_depthwise_conformer_encoder: int = 7 lrelu_slope: float = 0.3 n_hidden_conformer_decoder: int = 512 n_layers_conformer_decoder: int = 6 n_heads_conformer_decoder: int = 8 dropout_conformer_decoder: float = 0.1 kernel_size_conv_mod_conformer_decoder: int = 11 kernel_size_depthwise_conformer_decoder: int = 11 bottleneck_size_p_reference_encoder: int = 4 bottleneck_size_u_reference_encoder: int = 512 ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] ref_enc_size_reference_encoder: int = 3 ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] ref_enc_pad_reference_encoder = [1, 1] ref_enc_gru_size_reference_encoder: int = 32 ref_attention_dropout_reference_encoder: float = 0.2 token_num_reference_encoder: int = 32 predictor_kernel_size_reference_encoder: int = 5 n_hidden_variance_adaptor: int = 512 kernel_size_variance_adaptor: int = 5 dropout_variance_adaptor: float = 0.5 n_bins_variance_adaptor: int = 256 emb_kernel_size_variance_adaptor: int = 3 use_speaker_embedding: bool = False num_speakers: int = 0 speakers_file: str = None d_vector_file: str = None speaker_embedding_channels: int = 384 use_d_vector_file: bool = False d_vector_dim: int = 0 freeze_vocoder: bool = False freeze_text_encoder: bool = False freeze_duration_predictor: bool = False freeze_pitch_predictor: bool = False freeze_energy_predictor: bool = False freeze_basis_vectors_predictor: bool = False freeze_decoder: bool = False length_scale: float = 1.0 ############################## # MODEL DEFINITION ############################## class DelightfulTTS(BaseTTSE2E): """ Paper:: https://arxiv.org/pdf/2110.12612.pdf Paper Abstract:: This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test and 4.35 in SMOS test, which indicates the effectiveness of our proposed system Model training:: text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg spec --------^ Examples: >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig >>> config = ForwardTTSE2eConfig() >>> model = ForwardTTSE2e(config) """ # pylint: disable=dangerous-default-value def __init__( self, config: Coqpit, ap, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) self.ap = ap self._set_model_args(config) self.init_multispeaker(config) self.binary_loss_weight = None self.args.out_channels = self.config.audio.num_mels self.args.num_mels = self.config.audio.num_mels self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) self.waveform_decoder = HifiganGenerator( self.config.audio.num_mels, 1, self.config.vocoder.resblock_type_decoder, self.config.vocoder.resblock_dilation_sizes_decoder, self.config.vocoder.resblock_kernel_sizes_decoder, self.config.vocoder.upsample_kernel_sizes_decoder, self.config.vocoder.upsample_initial_channel_decoder, self.config.vocoder.upsample_rates_decoder, inference_padding=0, # cond_channels=self.embedded_speaker_dim, conv_pre_weight_norm=False, conv_post_weight_norm=False, conv_post_bias=False, ) if self.config.init_discriminator: self.disc = VitsDiscriminator( use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, periods=self.config.vocoder.periods_discriminator, ) @property def device(self): return next(self.parameters()).device @property def energy_scaler(self): return self.acoustic_model.energy_scaler @property def length_scale(self): return self.acoustic_model.length_scale @length_scale.setter def length_scale(self, value): self.acoustic_model.length_scale = value @property def pitch_mean(self): return self.acoustic_model.pitch_mean @pitch_mean.setter def pitch_mean(self, value): self.acoustic_model.pitch_mean = value @property def pitch_std(self): return self.acoustic_model.pitch_std @pitch_std.setter def pitch_std(self, value): self.acoustic_model.pitch_std = value @property def mel_basis(self): return build_mel_basis( sample_rate=self.ap.sample_rate, fft_size=self.ap.fft_size, num_mels=self.ap.num_mels, mel_fmax=self.ap.mel_fmax, mel_fmin=self.ap.mel_fmin, ) # pylint: disable=function-redefined def init_for_training(self) -> None: self.train_disc = ( # pylint: disable=attribute-defined-outside-init self.config.steps_to_start_discriminator <= 0 ) # pylint: disable=attribute-defined-outside-init self.update_energy_scaler = True # pylint: disable=attribute-defined-outside-init def init_multispeaker(self, config: Coqpit): """Init for multi-speaker training. Args: config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 self.num_speakers = self.args.num_speakers self.audio_transform = None if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers self.args.num_speakers = self.speaker_manager.num_speakers if self.args.use_speaker_embedding: self._init_speaker_embedding() if self.args.use_d_vector_file: self._init_d_vector() def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: print(" > initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.args.embedded_speaker_dim = self.args.speaker_embedding_channels def _init_d_vector(self): # pylint: disable=attribute-defined-outside-init if hasattr(self, "emb_g"): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") self.embedded_speaker_dim = self.args.d_vector_dim self.args.embedded_speaker_dim = self.args.d_vector_dim def _freeze_layers(self): if self.args.freeze_vocoder: for param in self.vocoder.paramseters(): param.requires_grad = False if self.args.freeze_text_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False if self.args.freeze_duration_predictor: for param in self.durarion_predictor.parameters(): param.requires_grad = False if self.args.freeze_pitch_predictor: for param in self.pitch_predictor.parameters(): param.requires_grad = False if self.args.freeze_energy_predictor: for param in self.energy_predictor.parameters(): param.requires_grad = False if self.args.freeze_decoder: for param in self.decoder.parameters(): param.requires_grad = False def forward( self, x: torch.LongTensor, x_lengths: torch.LongTensor, spec_lengths: torch.LongTensor, spec: torch.FloatTensor, waveform: torch.FloatTensor, pitch: torch.FloatTensor = None, energy: torch.FloatTensor = None, attn_priors: torch.FloatTensor = None, d_vectors: torch.FloatTensor = None, speaker_idx: torch.LongTensor = None, ) -> Dict: """Model's forward pass. Args: x (torch.LongTensor): Input character sequences. x_lengths (torch.LongTensor): Input sequence lengths. spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. waveform (torch.FloatTensor): Waveform. Defaults to None. pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - spec_lengths: :math:`[B]` - spec: :math:`[B, T_max2, C_spec]` - waveform: :math:`[B, 1, T_max2 * hop_length]` - g: :math:`[B, C]` - pitch: :math:`[B, 1, T_max2]` - energy: :math:`[B, 1, T_max2]` """ encoder_outputs = self.acoustic_model( tokens=x, src_lens=x_lengths, mel_lens=spec_lengths, mels=spec, pitches=pitch, energies=energy, attn_priors=attn_priors, d_vectors=d_vectors, speaker_idx=speaker_idx, ) # use mel-spec from the decoder vocoder_input = encoder_outputs["model_outputs"] # [B, T_max2, C_mel] vocoder_input_slices, slice_ids = rand_segments( x=vocoder_input.transpose(1, 2), x_lengths=spec_lengths, segment_size=self.args.spec_segment_size, let_short_samples=True, pad_short=True, ) if encoder_outputs["spk_emb"] is not None: g = encoder_outputs["spk_emb"].unsqueeze(-1) else: g = None vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) wav_seg = segment( waveform, slice_ids * self.ap.hop_length, self.args.spec_segment_size * self.ap.hop_length, pad_short=True, ) model_outputs = {**encoder_outputs} model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] model_outputs["model_outputs"] = vocoder_output model_outputs["waveform_seg"] = wav_seg model_outputs["slice_ids"] = slice_ids return model_outputs @torch.no_grad() def inference( self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None ): encoder_outputs = self.acoustic_model.inference( tokens=x, d_vectors=aux_input["d_vectors"], speaker_idx=aux_input["speaker_ids"], pitch_transform=pitch_transform, energy_transform=energy_transform, p_control=None, d_control=None, ) vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) # [B, T_max2, C_mel] -> [B, C_mel, T_max2] if encoder_outputs["spk_emb"] is not None: g = encoder_outputs["spk_emb"].unsqueeze(-1) else: g = None vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) model_outputs = {**encoder_outputs} model_outputs["model_outputs"] = vocoder_output return model_outputs @torch.no_grad() def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): encoder_outputs = self.acoustic_model.inference( tokens=x, d_vectors=aux_input["d_vectors"], speaker_idx=aux_input["speaker_ids"], ) model_outputs = {**encoder_outputs} return model_outputs def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): if optimizer_idx == 0: tokens = batch["text_input"] token_lenghts = batch["text_lengths"] mel = batch["mel_input"] mel_lens = batch["mel_lengths"] waveform = batch["waveform"] # [B, T, C] -> [B, C, T] pitch = batch["pitch"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] attn_priors = batch["attn_priors"] energy = batch["energy"] # generator pass outputs = self.forward( x=tokens, x_lengths=token_lenghts, spec_lengths=mel_lens, spec=mel, waveform=waveform, pitch=pitch, energy=energy, attn_priors=attn_priors, d_vectors=d_vectors, speaker_idx=speaker_ids, ) # cache tensors for the generator pass self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init if self.train_disc: # compute scores and features scores_d_fake, _, scores_d_real, _ = self.disc( outputs["model_outputs"].detach(), outputs["waveform_seg"] ) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( scores_disc_fake=scores_d_fake, scores_disc_real=scores_d_real, ) return outputs, loss_dict return None, None if optimizer_idx == 1: mel = batch["mel_input"] # compute melspec segment with autocast(enabled=False): mel_slice = segment( mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True ) mel_slice_hat = wav_to_mel( y=self.model_outputs_cache["model_outputs"].float(), n_fft=self.ap.fft_size, sample_rate=self.ap.sample_rate, num_mels=self.ap.num_mels, hop_length=self.ap.hop_length, win_length=self.ap.win_length, fmin=self.ap.mel_fmin, fmax=self.ap.mel_fmax, center=False, ) scores_d_fake = None feats_d_fake = None feats_d_real = None if self.train_disc: # compute discriminator scores and features scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] ) # compute losses with autocast(enabled=True): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), mel_target=batch["mel_input"], mel_lens=batch["mel_lengths"], dur_output=self.model_outputs_cache["dr_log_pred"], dur_target=self.model_outputs_cache["dr_log_target"].detach(), pitch_output=self.model_outputs_cache["pitch_pred"], pitch_target=self.model_outputs_cache["pitch_target"], energy_output=self.model_outputs_cache["energy_pred"], energy_target=self.model_outputs_cache["energy_target"], src_lens=batch["text_lengths"], waveform=self.model_outputs_cache["waveform_seg"], waveform_hat=self.model_outputs_cache["model_outputs"], p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], aligner_logprob=self.model_outputs_cache["aligner_logprob"], aligner_hard=self.model_outputs_cache["aligner_mas"], aligner_soft=self.model_outputs_cache["aligner_soft"], binary_loss_weight=self.binary_loss_weight, feats_fake=feats_d_fake, feats_real=feats_d_real, scores_fake=scores_d_fake, spec_slice=mel_slice, spec_slice_hat=mel_slice_hat, skip_disc=not self.train_disc, ) loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() loss_dict["avg_text_batch_occupancy"] = ( batch["text_lengths"].float() / batch["text_lengths"].float().max() ).mean() loss_dict["avg_mel_batch_occupancy"] = ( batch["mel_lengths"].float() / batch["mel_lengths"].float().max() ).mean() return self.model_outputs_cache, loss_dict raise ValueError(" [!] Unexpected `optimizer_idx`.") def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def _log(self, batch, outputs, name_prefix="train"): figures, audios = {}, {} # encoder outputs model_outputs = outputs[1]["acoustic_model_outputs"] alignments = outputs[1]["alignments"] mel_input = batch["mel_input"] pred_spec = model_outputs[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(pred_spec, None, output_fig=False), "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), } # plot pitch figures pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) pitch_figures = { "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), } figures.update(pitch_figures) # plot energy figures energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) energy_figures = { "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), } figures.update(energy_figures) # plot the attention mask computed from the predicted durations alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) # Sample audio encoder_audio = mel_to_wav_numpy( mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio ) audios[f"{name_prefix}/encoder_audio"] = encoder_audio # vocoder outputs y_hat = outputs[1]["model_outputs"] y = outputs[1]["waveform_seg"] vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) figures.update(vocoder_figures) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios[f"{name_prefix}/vocoder_audio"] = sample_voice return figures, audios def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ): # pylint: disable=no-self-use, unused-argument """Create visualizations and waveform examples. For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to be projected onto Tensorboard. Args: batch (Dict): Model inputs used at the previous training step. outputs (Dict): Model outputs generated at the previous training step. Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") logger.train_figures(steps, figures) logger.train_audios(steps, audios, self.ap.sample_rate) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) def get_aux_input_from_test_sentences(self, sentence_info): if hasattr(self.config, "model_args"): config = self.config.model_args else: config = self.config # extract speaker and language info text, speaker_name, style_wav = None, None, None if isinstance(sentence_info, list): if len(sentence_info) == 1: text = sentence_info[0] elif len(sentence_info) == 2: text, speaker_name = sentence_info elif len(sentence_info) == 3: text, speaker_name, style_wav = sentence_info else: text = sentence_info # get speaker id/d_vector speaker_id, d_vector = None, None if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() else: speaker_id = self.speaker_manager.name_to_id[speaker_name] return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} def plot_outputs(self, text, wav, alignment, outputs): figures = {} pitch_avg_pred = outputs["pitch"].cpu() energy_avg_pred = outputs["energy"].cpu() spec = wav_to_mel( y=torch.from_numpy(wav[None, :]), n_fft=self.ap.fft_size, sample_rate=self.ap.sample_rate, num_mels=self.ap.num_mels, hop_length=self.ap.hop_length, win_length=self.ap.win_length, fmin=self.ap.mel_fmin, fmax=self.ap.mel_fmax, center=False, )[0].transpose(0, 1) pitch = compute_f0( x=wav[0], sample_rate=self.ap.sample_rate, hop_length=self.ap.hop_length, pitch_fmax=self.ap.pitch_fmax, ) input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) input_text = input_text.replace("", "_") durations = outputs["durations"] pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) # [1, 1, n_frames] pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) figures["spectrogram"] = plot_spectrogram(spec) figures["pitch_from_wav"] = plot_pitch(pitch, spec) figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) return figures def synthesize( self, text: str, speaker_id: str = None, d_vector: torch.tensor = None, pitch_transform=None, **kwargs, ): # pylint: disable=unused-argument # TODO: add cloning support with ref_waveform is_cuda = next(self.parameters()).is_cuda # convert text to sequence of token IDs text_inputs = np.asarray( self.tokenizer.text_to_ids(text, language=None), dtype=np.int32, ) # set speaker inputs _speaker_id = None if speaker_id is not None and self.args.use_speaker_embedding: if isinstance(speaker_id, str) and self.args.use_speaker_embedding: # get the speaker id for the speaker embedding layer _speaker_id = self.speaker_manager.name_to_id[speaker_id] _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) if speaker_id is not None and self.args.use_d_vector_file: # get the average d_vector for the speaker d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) d_vector = embedding_to_torch(d_vector, cuda=is_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) text_inputs = text_inputs.unsqueeze(0) # synthesize voice outputs = self.inference( text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, pitch_transform=pitch_transform, # energy_transform=energy_transform ) # collect outputs wav = outputs["model_outputs"][0].data.cpu().numpy() alignments = outputs["alignments"] return_dict = { "wav": wav, "alignments": alignments, "text_inputs": text_inputs, "outputs": outputs, } return return_dict def synthesize_with_gl(self, text: str, speaker_id, d_vector): is_cuda = next(self.parameters()).is_cuda # convert text to sequence of token IDs text_inputs = np.asarray( self.tokenizer.text_to_ids(text, language=None), dtype=np.int32, ) # pass tensors to backend if speaker_id is not None: speaker_id = id_to_torch(speaker_id, cuda=is_cuda) if d_vector is not None: d_vector = embedding_to_torch(d_vector, cuda=is_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) text_inputs = text_inputs.unsqueeze(0) # synthesize voice outputs = self.inference_spec_decoder( x=text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, ) # collect outputs S = outputs["model_outputs"].cpu().numpy()[0].T S = db_to_amp_numpy(x=S, gain=1, base=None) wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) alignments = outputs["alignments"] return_dict = { "wav": wav[None, :], "alignments": alignments, "text_inputs": text_inputs, "outputs": outputs, } return return_dict @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): aux_inputs = self.get_aux_input_from_test_sentences(s_info) outputs = self.synthesize( aux_inputs["text"], config=self.config, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], ) outputs_gl = self.synthesize_with_gl( aux_inputs["text"], speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], ) # speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]] test_audios["{}-audio".format(idx)] = outputs["wav"].T test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return {"figures": test_figures, "audios": test_audios} def test_log( self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument ) -> None: logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) logger.test_figures(steps, outputs["figures"]) def format_batch(self, batch: Dict) -> Dict: """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" speaker_ids = None d_vectors = None # get numerical speaker ids from speaker names if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) batch["speaker_ids"] = speaker_ids # get d_vectors from audio file names if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: d_vector_mapping = self.speaker_manager.embeddings d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] d_vectors = torch.FloatTensor(d_vectors) batch["d_vectors"] = d_vectors batch["speaker_ids"] = speaker_ids return batch def format_batch_on_device(self, batch): """Compute spectrograms on the device.""" ac = self.ap # compute spectrograms batch["mel_input"] = wav_to_mel( batch["waveform"], hop_length=ac.hop_length, win_length=ac.win_length, n_fft=ac.fft_size, num_mels=ac.num_mels, sample_rate=ac.sample_rate, fmin=ac.mel_fmin, fmax=ac.mel_fmax, center=False, ) # TODO: Align pitch properly # assert ( # batch["pitch"].shape[2] == batch["mel_input"].shape[2] # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() # zero the padding frames batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) # format attn priors as we now the max mel length # TODO: fix 1 diff b/w mel_lengths and attn_priors if self.config.use_attn_priors: attn_priors_np = batch["attn_priors"] batch["attn_priors"] = torch.zeros( batch["mel_input"].shape[0], batch["mel_lengths"].max(), batch["text_lengths"].max(), device=batch["mel_input"].device, ) for i in range(batch["mel_input"].shape[0]): batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( attn_priors_np[i] ) batch["energy"] = None batch["energy"] = wav_to_energy( # [B, 1, T_max2] batch["waveform"], hop_length=ac.hop_length, win_length=ac.win_length, n_fft=ac.fft_size, center=False, ) batch["energy"] = self.energy_scaler(batch["energy"]) return batch def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): weights = None data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) print(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) else: sampler = None # sampler for DDP if sampler is None: sampler = DistributedSampler(dataset) if num_gpus > 1 else None else: # If a sampler is already defined use this sampler and DDP sampler together sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler return sampler def get_data_loader( self, config: Coqpit, assets: Dict, is_eval: bool, samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None else: # init dataloader dataset = ForwardTTSE2eDataset( samples=samples, ap=self.ap, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len, max_text_len=config.max_text_len, min_audio_len=config.min_audio_len, max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, compute_f0=config.compute_f0, f0_cache_path=config.f0_cache_path, attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) # wait all the DDP process to be ready if num_gpus > 1: dist.barrier() # sort input sequences ascendingly by length dataset.preprocess_samples() # get samplers sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, # shuffle is done in the dataset. drop_last=False, # setting this False might cause issues in AMP training. sampler=sampler, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=True, ) # get pitch mean and std self.pitch_mean = dataset.f0_dataset.mean self.pitch_std = dataset.f0_dataset.std return loader def get_criterion(self): return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. Returns: List: optimizers. """ optimizer_disc = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc ) gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) optimizer_gen = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) return [optimizer_disc, optimizer_gen] def get_lr(self) -> List: """Set the initial learning rates for each optimizer. Returns: List: learning rates for each optimizer. """ return [self.config.lr_disc, self.config.lr_gen] def get_scheduler(self, optimizer) -> List: """Set the schedulers for each optimizer. Args: optimizer (List[`torch.optim.Optimizer`]): List of optimizers. Returns: List: Schedulers, one for each optimizer. """ scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) return [scheduler_D, scheduler_G] def on_epoch_end(self, trainer): # pylint: disable=unused-argument # stop updating mean and var # TODO: do the same for F0 self.energy_scaler.eval() @staticmethod def init_from_config( config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False ): # pylint: disable=unused-argument """Initiate model from config Args: config (ForwardTTSE2eConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. """ tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) ap = AudioProcessor.init_from_config(config=config) return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) def load_checkpoint(self, config, checkpoint_path, eval=False): """Load model from a checkpoint created by the 👟""" # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training def get_state_dict(self): """Custom state dict of the model with all the necessary components for inference.""" save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} if hasattr(self, "emb_g"): save_state["speaker_ids"] = self.speaker_manager.speaker_names if self.args.use_d_vector_file: # TODO: implement saving of d_vectors ... return save_state def save(self, config, checkpoint_path): """Save model to a file.""" save_state = self.get_state_dict(config, checkpoint_path) # pylint: disable=too-many-function-args save_state["pitch_mean"] = self.pitch_mean save_state["pitch_std"] = self.pitch_std torch.save(save_state, checkpoint_path) def on_train_step_start(self, trainer) -> None: """Enable the discriminator training based on `steps_to_start_discriminator` Args: trainer (Trainer): Trainer object. """ self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 self.train_disc = ( # pylint: disable=attribute-defined-outside-init trainer.total_steps_done >= self.config.steps_to_start_discriminator ) class DelightfulTTSLoss(nn.Module): def __init__(self, config): super().__init__() self.mse_loss = nn.MSELoss() self.mae_loss = nn.L1Loss() self.forward_sum_loss = ForwardSumLoss() self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) self.mel_loss_alpha = config.mel_loss_alpha self.aligner_loss_alpha = config.aligner_loss_alpha self.pitch_loss_alpha = config.pitch_loss_alpha self.energy_loss_alpha = config.energy_loss_alpha self.u_prosody_loss_alpha = config.u_prosody_loss_alpha self.p_prosody_loss_alpha = config.p_prosody_loss_alpha self.dur_loss_alpha = config.dur_loss_alpha self.char_dur_loss_alpha = config.char_dur_loss_alpha self.binary_alignment_loss_alpha = config.binary_align_loss_alpha self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha self.feat_loss_alpha = config.feat_loss_alpha self.gen_loss_alpha = config.gen_loss_alpha self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha @staticmethod def _binary_alignment_loss(alignment_hard, alignment_soft): """Binary loss that forces soft alignments to match the hard alignments as explained in `https://arxiv.org/pdf/2108.10447.pdf`. """ log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() return -log_sum / alignment_hard.sum() @staticmethod def feature_loss(feats_real, feats_generated): loss = 0 for dr, dg in zip(feats_real, feats_generated): for rl, gl in zip(dr, dg): rl = rl.float().detach() gl = gl.float() loss += torch.mean(torch.abs(rl - gl)) return loss * 2 @staticmethod def generator_loss(scores_fake): loss = 0 gen_losses = [] for dg in scores_fake: dg = dg.float() l = torch.mean((1 - dg) ** 2) gen_losses.append(l) loss += l return loss, gen_losses def forward( self, mel_output, mel_target, mel_lens, dur_output, dur_target, pitch_output, pitch_target, energy_output, energy_target, src_lens, waveform, waveform_hat, p_prosody_ref, p_prosody_pred, u_prosody_ref, u_prosody_pred, aligner_logprob, aligner_hard, aligner_soft, binary_loss_weight=None, feats_fake=None, feats_real=None, scores_fake=None, spec_slice=None, spec_slice_hat=None, skip_disc=False, ): """ Shapes: - mel_output: :math:`(B, C_mel, T_mel)` - mel_target: :math:`(B, C_mel, T_mel)` - mel_lens: :math:`(B)` - dur_output: :math:`(B, T_src)` - dur_target: :math:`(B, T_src)` - pitch_output: :math:`(B, 1, T_src)` - pitch_target: :math:`(B, 1, T_src)` - energy_output: :math:`(B, 1, T_src)` - energy_target: :math:`(B, 1, T_src)` - src_lens: :math:`(B)` - waveform: :math:`(B, 1, T_wav)` - waveform_hat: :math:`(B, 1, T_wav)` - p_prosody_ref: :math:`(B, T_src, 4)` - p_prosody_pred: :math:`(B, T_src, 4)` - u_prosody_ref: :math:`(B, 1, 256) - u_prosody_pred: :math:`(B, 1, 256) - aligner_logprob: :math:`(B, 1, T_mel, T_src)` - aligner_hard: :math:`(B, T_mel, T_src)` - aligner_soft: :math:`(B, T_mel, T_src)` - spec_slice: :math:`(B, C_mel, T_mel)` - spec_slice_hat: :math:`(B, C_mel, T_mel)` """ loss_dict = {} src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) dur_target.requires_grad = False mel_target.requires_grad = False pitch_target.requires_grad = False masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) mel_targets = mel_target.masked_select(mel_mask[:, None]) mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) p_prosody_ref = p_prosody_ref.detach() p_prosody_loss = 0.5 * self.mae_loss( p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), ) u_prosody_ref = u_prosody_ref.detach() u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) duration_loss = self.mse_loss(dur_output, dur_target) pitch_output = pitch_output.masked_select(src_mask[:, None]) pitch_target = pitch_target.masked_select(src_mask[:, None]) pitch_loss = self.mse_loss(pitch_output, pitch_target) energy_output = energy_output.masked_select(src_mask[:, None]) energy_target = energy_target.masked_select(src_mask[:, None]) energy_loss = self.mse_loss(energy_output, energy_target) forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) total_loss = ( (mel_loss * self.mel_loss_alpha) + (duration_loss * self.dur_loss_alpha) + (u_prosody_loss * self.u_prosody_loss_alpha) + (p_prosody_loss * self.p_prosody_loss_alpha) + (pitch_loss * self.pitch_loss_alpha) + (energy_loss * self.energy_loss_alpha) + (forward_sum_loss * self.aligner_loss_alpha) ) if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight if binary_loss_weight: loss_dict["loss_binary_alignment"] = ( self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight ) else: loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss loss_dict["loss"] = total_loss # vocoder losses if not skip_disc: loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha loss_dict["vocoder_loss_feat"] = loss_feat loss_dict["vocoder_loss_gen"] = loss_gen loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha loss_dict["vocoder_loss_mel"] = loss_mel loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg return loss_dict