File size: 4,998 Bytes
2b7bf83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-

# Copyright 2020 MINH ANH (@dathudeptrai)
#  MIT License (https://opensource.org/licenses/MIT)

"""Tensorflow MelGAN modules complatible with pytorch."""

import tensorflow as tf

import numpy as np

from parallel_wavegan.layers.tf_layers import TFConvTranspose1d
from parallel_wavegan.layers.tf_layers import TFReflectionPad1d
from parallel_wavegan.layers.tf_layers import TFResidualStack


class TFMelGANGenerator(tf.keras.layers.Layer):
    """Tensorflow MelGAN generator module."""

    def __init__(
        self,
        in_channels=80,
        out_channels=1,
        kernel_size=7,
        channels=512,
        bias=True,
        upsample_scales=[8, 8, 2, 2],
        stack_kernel_size=3,
        stacks=3,
        nonlinear_activation="LeakyReLU",
        nonlinear_activation_params={"alpha": 0.2},
        pad="ReflectionPad1d",
        pad_params={},
        use_final_nonlinear_activation=True,
        use_weight_norm=True,
        use_causal_conv=False,
    ):
        """Initialize TFMelGANGenerator module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Kernel size of initial and final conv layer.
            channels (int): Initial number of channels for conv layer.
            bias (bool): Whether to add bias parameter in convolution layers.
            upsample_scales (list): List of upsampling scales.
            stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
            stacks (int): Number of stacks in a single residual stack.
            nonlinear_activation (str): Activation function module name.
            nonlinear_activation_params (dict): Hyperparameters for activation function.
            pad (str): Padding function module name before dilated convolution layer.
            pad_params (dict): Hyperparameters for padding function.
            use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
            use_weight_norm (bool): No effect but keep it as is to be the same as pytorch version.
            use_causal_conv (bool): Whether to use causal convolution.

        """
        super(TFMelGANGenerator, self).__init__()

        # check hyper parameters is valid
        assert not use_causal_conv, "Not supported yet."
        assert channels >= np.prod(upsample_scales)
        assert channels % (2 ** len(upsample_scales)) == 0
        assert pad == "ReflectionPad1d", f"Not supported (pad={pad})."

        # add initial layer
        layers = []
        layers += [
            TFReflectionPad1d((kernel_size - 1) // 2),
            tf.keras.layers.Conv2D(
                filters=channels,
                kernel_size=(kernel_size, 1),
                padding="valid",
                use_bias=bias,
            ),
        ]

        for i, upsample_scale in enumerate(upsample_scales):
            # add upsampling layer
            layers += [
                getattr(tf.keras.layers, nonlinear_activation)(
                    **nonlinear_activation_params
                ),
                TFConvTranspose1d(
                    channels=channels // (2 ** (i + 1)),
                    kernel_size=upsample_scale * 2,
                    stride=upsample_scale,
                    padding="same",
                ),
            ]

            # add residual stack
            for j in range(stacks):
                layers += [
                    TFResidualStack(
                        kernel_size=stack_kernel_size,
                        channels=channels // (2 ** (i + 1)),
                        dilation=stack_kernel_size ** j,
                        bias=bias,
                        nonlinear_activation=nonlinear_activation,
                        nonlinear_activation_params=nonlinear_activation_params,
                        padding="same",
                    )
                ]

        # add final layer
        layers += [
            getattr(tf.keras.layers, nonlinear_activation)(
                **nonlinear_activation_params
            ),
            TFReflectionPad1d((kernel_size - 1) // 2),
            tf.keras.layers.Conv2D(
                filters=out_channels, kernel_size=(kernel_size, 1), use_bias=bias
            ),
        ]
        if use_final_nonlinear_activation:
            layers += [tf.keras.layers.Activation("tanh")]

        self.melgan = tf.keras.models.Sequential(layers)

    # TODO(kan-bayashi): Fix hard coded dimension
    @tf.function(
        input_signature=[tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32)]
    )
    def call(self, c):
        """Calculate forward propagation.

        Args:
            c (Tensor): Input tensor (B, T, in_channels).

        Returns:
            Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels).

        """
        c = tf.expand_dims(c, 2)
        c = self.melgan(c)
        return c[:, :, 0, :]