# Copyright 2021 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """StyleMelGAN Modules.""" import copy import logging import math import numpy as np import torch import torch.nn.functional as F from parallel_wavegan.layers import PQMF from parallel_wavegan.layers import TADEResBlock from parallel_wavegan.models import MelGANDiscriminator as BaseDiscriminator from parallel_wavegan.utils import read_hdf5 class StyleMelGANGenerator(torch.nn.Module): """Style MelGAN generator module.""" def __init__( self, in_channels=128, aux_channels=80, channels=64, out_channels=1, kernel_size=9, dilation=2, bias=True, noise_upsample_scales=[11, 2, 2, 2], noise_upsample_activation="LeakyReLU", noise_upsample_activation_params={"negative_slope": 0.2}, upsample_scales=[2, 2, 2, 2, 2, 2, 2, 2, 1], upsample_mode="nearest", gated_function="softmax", use_weight_norm=True, ): """Initilize Style MelGAN generator. Args: in_channels (int): Number of input noise channels. aux_channels (int): Number of auxiliary input channels. channels (int): Number of channels for conv layer. out_channels (int): Number of output channels. kernel_size (int): Kernel size of conv layers. dilation (int): Dilation factor for conv layers. bias (bool): Whether to add bias parameter in convolution layers. noise_upsample_scales (list): List of noise upsampling scales. noise_upsample_activation (str): Activation function module name for noise upsampling. noise_upsample_activation_params (dict): Hyperparameters for the above activation function. upsample_scales (list): List of upsampling scales. upsample_mode (str): Upsampling mode in TADE layer. gated_function (str): Gated function in TADEResBlock ("softmax" or "sigmoid"). use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super().__init__() self.in_channels = in_channels noise_upsample = [] in_chs = in_channels for noise_upsample_scale in noise_upsample_scales: # NOTE(kan-bayashi): How should we design noise upsampling part? noise_upsample += [ torch.nn.ConvTranspose1d( in_chs, channels, noise_upsample_scale * 2, stride=noise_upsample_scale, padding=noise_upsample_scale // 2 + noise_upsample_scale % 2, output_padding=noise_upsample_scale % 2, bias=bias, ) ] noise_upsample += [ getattr(torch.nn, noise_upsample_activation)( **noise_upsample_activation_params ) ] in_chs = channels self.noise_upsample = torch.nn.Sequential(*noise_upsample) self.noise_upsample_factor = np.prod(noise_upsample_scales) self.blocks = torch.nn.ModuleList() aux_chs = aux_channels for upsample_scale in upsample_scales: self.blocks += [ TADEResBlock( in_channels=channels, aux_channels=aux_chs, kernel_size=kernel_size, dilation=dilation, bias=bias, upsample_factor=upsample_scale, upsample_mode=upsample_mode, gated_function=gated_function, ), ] aux_chs = channels self.upsample_factor = np.prod(upsample_scales) self.output_conv = torch.nn.Sequential( torch.nn.Conv1d( channels, out_channels, kernel_size, 1, bias=bias, padding=(kernel_size - 1) // 2, ), torch.nn.Tanh(), ) # apply weight norm if use_weight_norm: self.apply_weight_norm() # reset parameters self.reset_parameters() def forward(self, c, z=None): """Calculate forward propagation. Args: c (Tensor): Auxiliary input tensor (B, channels, T). z (Tensor): Input noise tensor (B, in_channels, 1). Returns: Tensor: Output tensor (B, out_channels, T ** prod(upsample_scales)). """ if z is None: z = torch.randn(c.size(0), self.in_channels, 1).to( device=c.device, dtype=c.dtype, ) x = self.noise_upsample(z) for block in self.blocks: x, c = block(x, c) x = self.output_conv(x) return x def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: logging.debug(f"Weight norm is removed from {m}.") torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance( m, torch.nn.ConvTranspose1d ): torch.nn.utils.weight_norm(m) logging.debug(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) def reset_parameters(self): """Reset parameters.""" def _reset_parameters(m): if isinstance(m, torch.nn.Conv1d) or isinstance( m, torch.nn.ConvTranspose1d ): m.weight.data.normal_(0.0, 0.02) logging.debug(f"Reset parameters in {m}.") self.apply(_reset_parameters) def register_stats(self, stats): """Register stats for de-normalization as buffer. Args: stats (str): Path of statistics file (".npy" or ".h5"). """ assert stats.endswith(".h5") or stats.endswith(".npy") if stats.endswith(".h5"): mean = read_hdf5(stats, "mean").reshape(-1) scale = read_hdf5(stats, "scale").reshape(-1) else: mean = np.load(stats)[0].reshape(-1) scale = np.load(stats)[1].reshape(-1) self.register_buffer("mean", torch.from_numpy(mean).float()) self.register_buffer("scale", torch.from_numpy(scale).float()) logging.info("Successfully registered stats as buffer.") def inference(self, c, normalize_before=False): """Perform inference. Args: c (Union[Tensor, ndarray]): Input tensor (T, in_channels). normalize_before (bool): Whether to perform normalization. Returns: Tensor: Output tensor (T ** prod(upsample_scales), out_channels). """ if not isinstance(c, torch.Tensor): c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) if normalize_before: c = (c - self.mean) / self.scale c = c.transpose(1, 0).unsqueeze(0) # prepare noise input noise_size = ( 1, self.in_channels, math.ceil(c.size(2) / self.noise_upsample_factor), ) noise = torch.randn(*noise_size, dtype=torch.float).to( next(self.parameters()).device ) x = self.noise_upsample(noise) # NOTE(kan-bayashi): To remove pop noise at the end of audio, perform padding # for feature sequence and after generation cut the generated audio. This # requires additional computation but it can prevent pop noise. total_length = c.size(2) * self.upsample_factor c = F.pad(c, (0, x.size(2) - c.size(2)), "replicate") # This version causes pop noise. # x = x[:, :, :c.size(2)] for block in self.blocks: x, c = block(x, c) x = self.output_conv(x)[..., :total_length] return x.squeeze(0).transpose(1, 0) class StyleMelGANDiscriminator(torch.nn.Module): """Style MelGAN disciminator module.""" def __init__( self, repeats=2, window_sizes=[512, 1024, 2048, 4096], pqmf_params=[ [1, None, None, None], [2, 62, 0.26700, 9.0], [4, 62, 0.14200, 9.0], [8, 62, 0.07949, 9.0], ], discriminator_params={ "out_channels": 1, "kernel_sizes": [5, 3], "channels": 16, "max_downsample_channels": 512, "bias": True, "downsample_scales": [4, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.2}, "pad": "ReflectionPad1d", "pad_params": {}, }, use_weight_norm=True, ): """Initilize Style MelGAN discriminator. Args: repeats (int): Number of repititons to apply RWD. window_sizes (list): List of random window sizes. pqmf_params (list): List of list of Parameters for PQMF modules discriminator_params (dict): Parameters for base discriminator module. use_weight_nom (bool): Whether to apply weight normalization. """ super().__init__() # window size check assert len(window_sizes) == len(pqmf_params) sizes = [ws // p[0] for ws, p in zip(window_sizes, pqmf_params)] assert len(window_sizes) == sum([sizes[0] == size for size in sizes]) self.repeats = repeats self.window_sizes = window_sizes self.pqmfs = torch.nn.ModuleList() self.discriminators = torch.nn.ModuleList() for pqmf_param in pqmf_params: d_params = copy.deepcopy(discriminator_params) d_params["in_channels"] = pqmf_param[0] if pqmf_param[0] == 1: self.pqmfs += [torch.nn.Identity()] else: self.pqmfs += [PQMF(*pqmf_param)] self.discriminators += [BaseDiscriminator(**d_params)] # apply weight norm if use_weight_norm: self.apply_weight_norm() # reset parameters self.reset_parameters() def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input tensor (B, 1, T). Returns: List: List of discriminator outputs, #items in the list will be equal to repeats * #discriminators. """ outs = [] for _ in range(self.repeats): outs += self._forward(x) return outs def _forward(self, x): outs = [] for idx, (ws, pqmf, disc) in enumerate( zip(self.window_sizes, self.pqmfs, self.discriminators) ): # NOTE(kan-bayashi): Is it ok to apply different window for real and fake samples? start_idx = np.random.randint(x.size(-1) - ws) x_ = x[:, :, start_idx : start_idx + ws] if idx == 0: x_ = pqmf(x_) else: x_ = pqmf.analysis(x_) outs += [disc(x_)] return outs def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance( m, torch.nn.ConvTranspose1d ): torch.nn.utils.weight_norm(m) logging.debug(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) def reset_parameters(self): """Reset parameters.""" def _reset_parameters(m): if isinstance(m, torch.nn.Conv1d) or isinstance( m, torch.nn.ConvTranspose1d ): m.weight.data.normal_(0.0, 0.02) logging.debug(f"Reset parameters in {m}.") self.apply(_reset_parameters)