Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from TTS.utils.io import load_fsspec | |
| from TTS.vocoder.layers.melgan import ResidualStack | |
| class MelganGenerator(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=80, | |
| out_channels=1, | |
| proj_kernel=7, | |
| base_channels=512, | |
| upsample_factors=(8, 8, 2, 2), | |
| res_kernel=3, | |
| num_res_blocks=3, | |
| ): | |
| super().__init__() | |
| # assert model parameters | |
| assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." | |
| # setup additional model parameters | |
| base_padding = (proj_kernel - 1) // 2 | |
| act_slope = 0.2 | |
| self.inference_padding = 2 | |
| # initial layer | |
| layers = [] | |
| layers += [ | |
| nn.ReflectionPad1d(base_padding), | |
| weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), | |
| ] | |
| # upsampling layers and residual stacks | |
| for idx, upsample_factor in enumerate(upsample_factors): | |
| layer_in_channels = base_channels // (2**idx) | |
| layer_out_channels = base_channels // (2 ** (idx + 1)) | |
| layer_filter_size = upsample_factor * 2 | |
| layer_stride = upsample_factor | |
| layer_output_padding = upsample_factor % 2 | |
| layer_padding = upsample_factor // 2 + layer_output_padding | |
| layers += [ | |
| nn.LeakyReLU(act_slope), | |
| weight_norm( | |
| nn.ConvTranspose1d( | |
| layer_in_channels, | |
| layer_out_channels, | |
| layer_filter_size, | |
| stride=layer_stride, | |
| padding=layer_padding, | |
| output_padding=layer_output_padding, | |
| bias=True, | |
| ) | |
| ), | |
| ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), | |
| ] | |
| layers += [nn.LeakyReLU(act_slope)] | |
| # final layer | |
| layers += [ | |
| nn.ReflectionPad1d(base_padding), | |
| weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), | |
| nn.Tanh(), | |
| ] | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, c): | |
| return self.layers(c) | |
| def inference(self, c): | |
| c = c.to(self.layers[1].weight.device) | |
| c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") | |
| return self.layers(c) | |
| def remove_weight_norm(self): | |
| for _, layer in enumerate(self.layers): | |
| if len(layer.state_dict()) != 0: | |
| try: | |
| nn.utils.parametrize.remove_parametrizations(layer, "weight") | |
| except ValueError: | |
| layer.remove_weight_norm() | |
| def load_checkpoint( | |
| self, config, checkpoint_path, eval=False, cache=False | |
| ): # pylint: disable=unused-argument, redefined-builtin | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
| self.load_state_dict(state["model"]) | |
| if eval: | |
| self.eval() | |
| assert not self.training | |
| self.remove_weight_norm() | |