Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2020 MINH ANH (@dathudeptrai) | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Tensorflow Layer modules complatible with pytorch.""" | |
import tensorflow as tf | |
class TFReflectionPad1d(tf.keras.layers.Layer): | |
"""Tensorflow ReflectionPad1d module.""" | |
def __init__(self, padding_size): | |
"""Initialize TFReflectionPad1d module. | |
Args: | |
padding_size (int): Padding size. | |
""" | |
super(TFReflectionPad1d, self).__init__() | |
self.padding_size = padding_size | |
def call(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input tensor (B, T, 1, C). | |
Returns: | |
Tensor: Padded tensor (B, T + 2 * padding_size, 1, C). | |
""" | |
return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT") | |
class TFConvTranspose1d(tf.keras.layers.Layer): | |
"""Tensorflow ConvTranspose1d module.""" | |
def __init__(self, channels, kernel_size, stride, padding): | |
"""Initialize TFConvTranspose1d( module. | |
Args: | |
channels (int): Number of channels. | |
kernel_size (int): kernel size. | |
strides (int): Stride width. | |
padding (str): Padding type ("same" or "valid"). | |
""" | |
super(TFConvTranspose1d, self).__init__() | |
self.conv1d_transpose = tf.keras.layers.Conv2DTranspose( | |
filters=channels, | |
kernel_size=(kernel_size, 1), | |
strides=(stride, 1), | |
padding=padding, | |
) | |
def call(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input tensor (B, T, 1, C). | |
Returns: | |
Tensors: Output tensor (B, T', 1, C'). | |
""" | |
x = self.conv1d_transpose(x) | |
return x | |
class TFResidualStack(tf.keras.layers.Layer): | |
"""Tensorflow ResidualStack module.""" | |
def __init__(self, | |
kernel_size, | |
channels, | |
dilation, | |
bias, | |
nonlinear_activation, | |
nonlinear_activation_params, | |
padding, | |
): | |
"""Initialize TFResidualStack module. | |
Args: | |
kernel_size (int): Kernel size. | |
channles (int): Number of channels. | |
dilation (int): Dilation ine. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
padding (str): Padding type ("same" or "valid"). | |
""" | |
super(TFResidualStack, self).__init__() | |
self.block = [ | |
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), | |
TFReflectionPad1d(dilation), | |
tf.keras.layers.Conv2D( | |
filters=channels, | |
kernel_size=(kernel_size, 1), | |
dilation_rate=(dilation, 1), | |
use_bias=bias, | |
padding="valid", | |
), | |
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), | |
tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) | |
] | |
self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) | |
def call(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input tensor (B, T, 1, C). | |
Returns: | |
Tensor: Output tensor (B, T, 1, C). | |
""" | |
_x = tf.identity(x) | |
for i, layer in enumerate(self.block): | |
_x = layer(_x) | |
shortcut = self.shortcut(x) | |
return shortcut + _x | |