akhaliq3
spaces demo
2b7bf83
raw
history blame
No virus
5 kB
# -*- coding: utf-8 -*-
# Copyright 2020 MINH ANH (@dathudeptrai)
# MIT License (https://opensource.org/licenses/MIT)
"""Tensorflow MelGAN modules complatible with pytorch."""
import tensorflow as tf
import numpy as np
from parallel_wavegan.layers.tf_layers import TFConvTranspose1d
from parallel_wavegan.layers.tf_layers import TFReflectionPad1d
from parallel_wavegan.layers.tf_layers import TFResidualStack
class TFMelGANGenerator(tf.keras.layers.Layer):
"""Tensorflow MelGAN generator module."""
def __init__(
self,
in_channels=80,
out_channels=1,
kernel_size=7,
channels=512,
bias=True,
upsample_scales=[8, 8, 2, 2],
stack_kernel_size=3,
stacks=3,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"alpha": 0.2},
pad="ReflectionPad1d",
pad_params={},
use_final_nonlinear_activation=True,
use_weight_norm=True,
use_causal_conv=False,
):
"""Initialize TFMelGANGenerator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of initial and final conv layer.
channels (int): Initial number of channels for conv layer.
bias (bool): Whether to add bias parameter in convolution layers.
upsample_scales (list): List of upsampling scales.
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
stacks (int): Number of stacks in a single residual stack.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
pad (str): Padding function module name before dilated convolution layer.
pad_params (dict): Hyperparameters for padding function.
use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
use_weight_norm (bool): No effect but keep it as is to be the same as pytorch version.
use_causal_conv (bool): Whether to use causal convolution.
"""
super(TFMelGANGenerator, self).__init__()
# check hyper parameters is valid
assert not use_causal_conv, "Not supported yet."
assert channels >= np.prod(upsample_scales)
assert channels % (2 ** len(upsample_scales)) == 0
assert pad == "ReflectionPad1d", f"Not supported (pad={pad})."
# add initial layer
layers = []
layers += [
TFReflectionPad1d((kernel_size - 1) // 2),
tf.keras.layers.Conv2D(
filters=channels,
kernel_size=(kernel_size, 1),
padding="valid",
use_bias=bias,
),
]
for i, upsample_scale in enumerate(upsample_scales):
# add upsampling layer
layers += [
getattr(tf.keras.layers, nonlinear_activation)(
**nonlinear_activation_params
),
TFConvTranspose1d(
channels=channels // (2 ** (i + 1)),
kernel_size=upsample_scale * 2,
stride=upsample_scale,
padding="same",
),
]
# add residual stack
for j in range(stacks):
layers += [
TFResidualStack(
kernel_size=stack_kernel_size,
channels=channels // (2 ** (i + 1)),
dilation=stack_kernel_size ** j,
bias=bias,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
padding="same",
)
]
# add final layer
layers += [
getattr(tf.keras.layers, nonlinear_activation)(
**nonlinear_activation_params
),
TFReflectionPad1d((kernel_size - 1) // 2),
tf.keras.layers.Conv2D(
filters=out_channels, kernel_size=(kernel_size, 1), use_bias=bias
),
]
if use_final_nonlinear_activation:
layers += [tf.keras.layers.Activation("tanh")]
self.melgan = tf.keras.models.Sequential(layers)
# TODO(kan-bayashi): Fix hard coded dimension
@tf.function(
input_signature=[tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32)]
)
def call(self, c):
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, T, in_channels).
Returns:
Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels).
"""
c = tf.expand_dims(c, 2)
c = self.melgan(c)
return c[:, :, 0, :]