Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Copyright 2020 MINH ANH (@dathudeptrai) | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Tensorflow MelGAN modules complatible with pytorch.""" | |
import tensorflow as tf | |
import numpy as np | |
from parallel_wavegan.layers.tf_layers import TFConvTranspose1d | |
from parallel_wavegan.layers.tf_layers import TFReflectionPad1d | |
from parallel_wavegan.layers.tf_layers import TFResidualStack | |
class TFMelGANGenerator(tf.keras.layers.Layer): | |
"""Tensorflow MelGAN generator module.""" | |
def __init__( | |
self, | |
in_channels=80, | |
out_channels=1, | |
kernel_size=7, | |
channels=512, | |
bias=True, | |
upsample_scales=[8, 8, 2, 2], | |
stack_kernel_size=3, | |
stacks=3, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"alpha": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_final_nonlinear_activation=True, | |
use_weight_norm=True, | |
use_causal_conv=False, | |
): | |
"""Initialize TFMelGANGenerator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int): Kernel size of initial and final conv layer. | |
channels (int): Initial number of channels for conv layer. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
upsample_scales (list): List of upsampling scales. | |
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. | |
stacks (int): Number of stacks in a single residual stack. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer. | |
use_weight_norm (bool): No effect but keep it as is to be the same as pytorch version. | |
use_causal_conv (bool): Whether to use causal convolution. | |
""" | |
super(TFMelGANGenerator, self).__init__() | |
# check hyper parameters is valid | |
assert not use_causal_conv, "Not supported yet." | |
assert channels >= np.prod(upsample_scales) | |
assert channels % (2 ** len(upsample_scales)) == 0 | |
assert pad == "ReflectionPad1d", f"Not supported (pad={pad})." | |
# add initial layer | |
layers = [] | |
layers += [ | |
TFReflectionPad1d((kernel_size - 1) // 2), | |
tf.keras.layers.Conv2D( | |
filters=channels, | |
kernel_size=(kernel_size, 1), | |
padding="valid", | |
use_bias=bias, | |
), | |
] | |
for i, upsample_scale in enumerate(upsample_scales): | |
# add upsampling layer | |
layers += [ | |
getattr(tf.keras.layers, nonlinear_activation)( | |
**nonlinear_activation_params | |
), | |
TFConvTranspose1d( | |
channels=channels // (2 ** (i + 1)), | |
kernel_size=upsample_scale * 2, | |
stride=upsample_scale, | |
padding="same", | |
), | |
] | |
# add residual stack | |
for j in range(stacks): | |
layers += [ | |
TFResidualStack( | |
kernel_size=stack_kernel_size, | |
channels=channels // (2 ** (i + 1)), | |
dilation=stack_kernel_size ** j, | |
bias=bias, | |
nonlinear_activation=nonlinear_activation, | |
nonlinear_activation_params=nonlinear_activation_params, | |
padding="same", | |
) | |
] | |
# add final layer | |
layers += [ | |
getattr(tf.keras.layers, nonlinear_activation)( | |
**nonlinear_activation_params | |
), | |
TFReflectionPad1d((kernel_size - 1) // 2), | |
tf.keras.layers.Conv2D( | |
filters=out_channels, kernel_size=(kernel_size, 1), use_bias=bias | |
), | |
] | |
if use_final_nonlinear_activation: | |
layers += [tf.keras.layers.Activation("tanh")] | |
self.melgan = tf.keras.models.Sequential(layers) | |
# TODO(kan-bayashi): Fix hard coded dimension | |
def call(self, c): | |
"""Calculate forward propagation. | |
Args: | |
c (Tensor): Input tensor (B, T, in_channels). | |
Returns: | |
Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels). | |
""" | |
c = tf.expand_dims(c, 2) | |
c = self.melgan(c) | |
return c[:, :, 0, :] | |