TheComputerMan commited on
Commit
daa42a1
1 Parent(s): 8aa300f

Upload PositionwiseFeedForward.py

Browse files
Files changed (1) hide show
  1. PositionwiseFeedForward.py +26 -0
PositionwiseFeedForward.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class PositionwiseFeedForward(torch.nn.Module):
10
+ """
11
+ Args:
12
+ idim (int): Input dimenstion.
13
+ hidden_units (int): The number of hidden units.
14
+ dropout_rate (float): Dropout rate.
15
+
16
+ """
17
+
18
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
19
+ super(PositionwiseFeedForward, self).__init__()
20
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
21
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
22
+ self.dropout = torch.nn.Dropout(dropout_rate)
23
+ self.activation = activation
24
+
25
+ def forward(self, x):
26
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))