Spaces:
Running
Running
File size: 3,853 Bytes
26925fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# -*- coding: utf-8 -*-
# Copyright 2020 MINH ANH (@dathudeptrai)
# MIT License (https://opensource.org/licenses/MIT)
"""Tensorflow Layer modules complatible with pytorch."""
import tensorflow as tf
class TFReflectionPad1d(tf.keras.layers.Layer):
"""Tensorflow ReflectionPad1d module."""
def __init__(self, padding_size):
"""Initialize TFReflectionPad1d module.
Args:
padding_size (int): Padding size.
"""
super(TFReflectionPad1d, self).__init__()
self.padding_size = padding_size
@tf.function
def call(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, T, 1, C).
Returns:
Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
"""
return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
class TFConvTranspose1d(tf.keras.layers.Layer):
"""Tensorflow ConvTranspose1d module."""
def __init__(self, channels, kernel_size, stride, padding):
"""Initialize TFConvTranspose1d( module.
Args:
channels (int): Number of channels.
kernel_size (int): kernel size.
strides (int): Stride width.
padding (str): Padding type ("same" or "valid").
"""
super(TFConvTranspose1d, self).__init__()
self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
filters=channels,
kernel_size=(kernel_size, 1),
strides=(stride, 1),
padding=padding,
)
@tf.function
def call(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, T, 1, C).
Returns:
Tensors: Output tensor (B, T', 1, C').
"""
x = self.conv1d_transpose(x)
return x
class TFResidualStack(tf.keras.layers.Layer):
"""Tensorflow ResidualStack module."""
def __init__(self,
kernel_size,
channels,
dilation,
bias,
nonlinear_activation,
nonlinear_activation_params,
padding,
):
"""Initialize TFResidualStack module.
Args:
kernel_size (int): Kernel size.
channles (int): Number of channels.
dilation (int): Dilation ine.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
padding (str): Padding type ("same" or "valid").
"""
super(TFResidualStack, self).__init__()
self.block = [
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
TFReflectionPad1d(dilation),
tf.keras.layers.Conv2D(
filters=channels,
kernel_size=(kernel_size, 1),
dilation_rate=(dilation, 1),
use_bias=bias,
padding="valid",
),
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
]
self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
@tf.function
def call(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, T, 1, C).
Returns:
Tensor: Output tensor (B, T, 1, C).
"""
_x = tf.identity(x)
for i, layer in enumerate(self.block):
_x = layer(_x)
shortcut = self.shortcut(x)
return shortcut + _x
|