TheComputerMan commited on
Commit
af68200
1 Parent(s): f6fe944

Upload ResidualStack.py

Browse files
Files changed (1) hide show
  1. ResidualStack.py +51 -0
ResidualStack.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class ResidualStack(torch.nn.Module):
10
+
11
+ def __init__(self, kernel_size=3, channels=32, dilation=1, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2},
12
+ pad="ReflectionPad1d", pad_params={}, ):
13
+ """
14
+ Initialize ResidualStack module.
15
+
16
+ Args:
17
+ kernel_size (int): Kernel size of dilation convolution layer.
18
+ channels (int): Number of channels of convolution layers.
19
+ dilation (int): Dilation factor.
20
+ bias (bool): Whether to add bias parameter in convolution layers.
21
+ nonlinear_activation (str): Activation function module name.
22
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
23
+ pad (str): Padding function module name before dilated convolution layer.
24
+ pad_params (dict): Hyperparameters for padding function.
25
+
26
+ """
27
+ super(ResidualStack, self).__init__()
28
+
29
+ # defile residual stack part
30
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
31
+ self.stack = torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
32
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
33
+ torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
34
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
35
+ torch.nn.Conv1d(channels, channels, 1, bias=bias), )
36
+
37
+ # defile extra layer for skip connection
38
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
39
+
40
+ def forward(self, c):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ c (Tensor): Input tensor (B, channels, T).
46
+
47
+ Returns:
48
+ Tensor: Output tensor (B, chennels, T).
49
+
50
+ """
51
+ return self.stack(c) + self.skip_layer(c)