TheComputerMan commited on
Commit
520a0ed
1 Parent(s): 3edffb9

Upload MultiLayeredConv1d.py

Browse files
Files changed (1) hide show
  1. MultiLayeredConv1d.py +87 -0
MultiLayeredConv1d.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ """
6
+ Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class MultiLayeredConv1d(torch.nn.Module):
13
+ """
14
+ Multi-layered conv1d for Transformer block.
15
+
16
+ This is a module of multi-layered conv1d designed
17
+ to replace positionwise feed-forward network
18
+ in Transformer block, which is introduced in
19
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
20
+
21
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
22
+ https://arxiv.org/pdf/1905.09263.pdf
23
+ """
24
+
25
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
26
+ """
27
+ Initialize MultiLayeredConv1d module.
28
+
29
+ Args:
30
+ in_chans (int): Number of input channels.
31
+ hidden_chans (int): Number of hidden channels.
32
+ kernel_size (int): Kernel size of conv1d.
33
+ dropout_rate (float): Dropout rate.
34
+ """
35
+ super(MultiLayeredConv1d, self).__init__()
36
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
37
+ self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
38
+ self.dropout = torch.nn.Dropout(dropout_rate)
39
+
40
+ def forward(self, x):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
46
+
47
+ Returns:
48
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
49
+ """
50
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
51
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
52
+
53
+
54
+ class Conv1dLinear(torch.nn.Module):
55
+ """
56
+ Conv1D + Linear for Transformer block.
57
+
58
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
59
+ """
60
+
61
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
62
+ """
63
+ Initialize Conv1dLinear module.
64
+
65
+ Args:
66
+ in_chans (int): Number of input channels.
67
+ hidden_chans (int): Number of hidden channels.
68
+ kernel_size (int): Kernel size of conv1d.
69
+ dropout_rate (float): Dropout rate.
70
+ """
71
+ super(Conv1dLinear, self).__init__()
72
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
73
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
74
+ self.dropout = torch.nn.Dropout(dropout_rate)
75
+
76
+ def forward(self, x):
77
+ """
78
+ Calculate forward propagation.
79
+
80
+ Args:
81
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
82
+
83
+ Returns:
84
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
85
+ """
86
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
87
+ return self.w_2(self.dropout(x))