Spaces:
Running
Running
| import logging | |
| import torch | |
| from trainer.io import load_fsspec | |
| from TTS.encoder.models.resnet import ResNetSpeakerEncoder | |
| from TTS.vocoder.models.hifigan_generator import HifiganGenerator | |
| logger = logging.getLogger(__name__) | |
| class HifiDecoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| input_sample_rate=22050, | |
| output_sample_rate=24000, | |
| output_hop_length=256, | |
| ar_mel_length_compression=1024, | |
| decoder_input_dim=1024, | |
| resblock_type_decoder="1", | |
| resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| resblock_kernel_sizes_decoder=[3, 7, 11], | |
| upsample_rates_decoder=[8, 8, 2, 2], | |
| upsample_initial_channel_decoder=512, | |
| upsample_kernel_sizes_decoder=[16, 16, 4, 4], | |
| d_vector_dim=512, | |
| cond_d_vector_in_each_upsampling_layer=True, | |
| speaker_encoder_audio_config={ | |
| "fft_size": 512, | |
| "win_length": 400, | |
| "hop_length": 160, | |
| "sample_rate": 16000, | |
| "preemphasis": 0.97, | |
| "num_mels": 64, | |
| }, | |
| ): | |
| super().__init__() | |
| self.input_sample_rate = input_sample_rate | |
| self.output_sample_rate = output_sample_rate | |
| self.output_hop_length = output_hop_length | |
| self.ar_mel_length_compression = ar_mel_length_compression | |
| self.speaker_encoder_audio_config = speaker_encoder_audio_config | |
| self.waveform_decoder = HifiganGenerator( | |
| decoder_input_dim, | |
| 1, | |
| resblock_type_decoder, | |
| resblock_dilation_sizes_decoder, | |
| resblock_kernel_sizes_decoder, | |
| upsample_kernel_sizes_decoder, | |
| upsample_initial_channel_decoder, | |
| upsample_rates_decoder, | |
| inference_padding=0, | |
| cond_channels=d_vector_dim, | |
| conv_pre_weight_norm=False, | |
| conv_post_weight_norm=False, | |
| conv_post_bias=False, | |
| cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, | |
| ) | |
| self.speaker_encoder = ResNetSpeakerEncoder( | |
| input_dim=64, | |
| proj_dim=512, | |
| log_input=True, | |
| use_torch_spec=True, | |
| audio_config=speaker_encoder_audio_config, | |
| ) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, latents, g=None): | |
| """ | |
| Args: | |
| x (Tensor): feature input tensor (GPT latent). | |
| g (Tensor): global conditioning input tensor. | |
| Returns: | |
| Tensor: output waveform. | |
| Shapes: | |
| x: [B, C, T] | |
| Tensor: [B, 1, T] | |
| """ | |
| z = torch.nn.functional.interpolate( | |
| latents.transpose(1, 2), | |
| scale_factor=[self.ar_mel_length_compression / self.output_hop_length], | |
| mode="linear", | |
| ).squeeze(1) | |
| # upsample to the right sr | |
| if self.output_sample_rate != self.input_sample_rate: | |
| z = torch.nn.functional.interpolate( | |
| z, | |
| scale_factor=[self.output_sample_rate / self.input_sample_rate], | |
| mode="linear", | |
| ).squeeze(0) | |
| o = self.waveform_decoder(z, g=g) | |
| return o | |
| def inference(self, c, g): | |
| """ | |
| Args: | |
| x (Tensor): feature input tensor (GPT latent). | |
| g (Tensor): global conditioning input tensor. | |
| Returns: | |
| Tensor: output waveform. | |
| Shapes: | |
| x: [B, C, T] | |
| Tensor: [B, 1, T] | |
| """ | |
| return self.forward(c, g=g) | |
| def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) | |
| # remove unused keys | |
| state = state["model"] | |
| states_keys = list(state.keys()) | |
| for key in states_keys: | |
| if "waveform_decoder." not in key and "speaker_encoder." not in key: | |
| del state[key] | |
| self.load_state_dict(state) | |
| if eval: | |
| self.eval() | |
| assert not self.training | |
| self.waveform_decoder.remove_weight_norm() | |