TheComputerMan commited on
Commit
9adcb78
1 Parent(s): 1e3e10b

Upload EncoderLayer.py

Browse files
Files changed (1) hide show
  1. EncoderLayer.py +144 -0
EncoderLayer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from Layers.LayerNorm import LayerNorm
11
+
12
+
13
+ class EncoderLayer(nn.Module):
14
+ """
15
+ Encoder layer module.
16
+
17
+ Args:
18
+ size (int): Input dimension.
19
+ self_attn (torch.nn.Module): Self-attention module instance.
20
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
21
+ can be used as the argument.
22
+ feed_forward (torch.nn.Module): Feed-forward module instance.
23
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
24
+ can be used as the argument.
25
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
26
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
27
+ can be used as the argument.
28
+ conv_module (torch.nn.Module): Convolution module instance.
29
+ `ConvlutionModule` instance can be used as the argument.
30
+ dropout_rate (float): Dropout rate.
31
+ normalize_before (bool): Whether to use layer_norm before the first block.
32
+ concat_after (bool): Whether to concat attention layer's input and output.
33
+ if True, additional linear will be applied.
34
+ i.e. x -> x + linear(concat(x, att(x)))
35
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
36
+
37
+ """
38
+
39
+ def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ):
40
+ super(EncoderLayer, self).__init__()
41
+ self.self_attn = self_attn
42
+ self.feed_forward = feed_forward
43
+ self.feed_forward_macaron = feed_forward_macaron
44
+ self.conv_module = conv_module
45
+ self.norm_ff = LayerNorm(size) # for the FNN module
46
+ self.norm_mha = LayerNorm(size) # for the MHA module
47
+ if feed_forward_macaron is not None:
48
+ self.norm_ff_macaron = LayerNorm(size)
49
+ self.ff_scale = 0.5
50
+ else:
51
+ self.ff_scale = 1.0
52
+ if self.conv_module is not None:
53
+ self.norm_conv = LayerNorm(size) # for the CNN module
54
+ self.norm_final = LayerNorm(size) # for the final output of the block
55
+ self.dropout = nn.Dropout(dropout_rate)
56
+ self.size = size
57
+ self.normalize_before = normalize_before
58
+ self.concat_after = concat_after
59
+ if self.concat_after:
60
+ self.concat_linear = nn.Linear(size + size, size)
61
+
62
+ def forward(self, x_input, mask, cache=None):
63
+ """
64
+ Compute encoded features.
65
+
66
+ Args:
67
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
68
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
69
+ - w/o pos emb: Tensor (#batch, time, size).
70
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
71
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
72
+
73
+ Returns:
74
+ torch.Tensor: Output tensor (#batch, time, size).
75
+ torch.Tensor: Mask tensor (#batch, time).
76
+
77
+ """
78
+ if isinstance(x_input, tuple):
79
+ x, pos_emb = x_input[0], x_input[1]
80
+ else:
81
+ x, pos_emb = x_input, None
82
+
83
+ # whether to use macaron style
84
+ if self.feed_forward_macaron is not None:
85
+ residual = x
86
+ if self.normalize_before:
87
+ x = self.norm_ff_macaron(x)
88
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
89
+ if not self.normalize_before:
90
+ x = self.norm_ff_macaron(x)
91
+
92
+ # multi-headed self-attention module
93
+ residual = x
94
+ if self.normalize_before:
95
+ x = self.norm_mha(x)
96
+
97
+ if cache is None:
98
+ x_q = x
99
+ else:
100
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
101
+ x_q = x[:, -1:, :]
102
+ residual = residual[:, -1:, :]
103
+ mask = None if mask is None else mask[:, -1:, :]
104
+
105
+ if pos_emb is not None:
106
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
107
+ else:
108
+ x_att = self.self_attn(x_q, x, x, mask)
109
+
110
+ if self.concat_after:
111
+ x_concat = torch.cat((x, x_att), dim=-1)
112
+ x = residual + self.concat_linear(x_concat)
113
+ else:
114
+ x = residual + self.dropout(x_att)
115
+ if not self.normalize_before:
116
+ x = self.norm_mha(x)
117
+
118
+ # convolution module
119
+ if self.conv_module is not None:
120
+ residual = x
121
+ if self.normalize_before:
122
+ x = self.norm_conv(x)
123
+ x = residual + self.dropout(self.conv_module(x))
124
+ if not self.normalize_before:
125
+ x = self.norm_conv(x)
126
+
127
+ # feed forward module
128
+ residual = x
129
+ if self.normalize_before:
130
+ x = self.norm_ff(x)
131
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
132
+ if not self.normalize_before:
133
+ x = self.norm_ff(x)
134
+
135
+ if self.conv_module is not None:
136
+ x = self.norm_final(x)
137
+
138
+ if cache is not None:
139
+ x = torch.cat([cache, x], dim=1)
140
+
141
+ if pos_emb is not None:
142
+ return (x, pos_emb), mask
143
+
144
+ return x, mask