victan commited on
Commit
c9852e4
1 Parent(s): e659968

Upload seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/monotonic_decoder_layer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, final
8
+
9
+ from fairseq2.nn.incremental_state import IncrementalStateBag
10
+ from fairseq2.nn.normalization import LayerNorm
11
+ from fairseq2.nn.padding import PaddingMask
12
+ from fairseq2.nn.transformer import (
13
+ AttentionMask,
14
+ FeedForwardNetwork,
15
+ MultiheadAttention,
16
+ create_standard_layer_norm,
17
+ )
18
+ from fairseq2.typing import DataType, Device, finaloverride
19
+ from torch import Tensor
20
+ from torch.nn import Dropout, Module
21
+
22
+ from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer
23
+
24
+
25
+ @final
26
+ class MonotonicTransformerDecoderLayer(Module):
27
+ """Represents a Monotonic Transformer decoder layer."""
28
+
29
+ self_attn: MultiheadAttention
30
+ self_attn_dropout: Optional[Dropout]
31
+ self_attn_layer_norm: LayerNorm
32
+ encoder_decoder_attn: MultiheadAttention
33
+ encoder_decoder_attn_dropout: Optional[Dropout]
34
+ encoder_decoder_attn_layer_norm: LayerNorm
35
+ p_choose_layer: PChooseLayer
36
+ ffn: FeedForwardNetwork
37
+ ffn_dropout: Optional[Dropout]
38
+ ffn_layer_norm: LayerNorm
39
+
40
+ def __init__(
41
+ self,
42
+ self_attn: MultiheadAttention,
43
+ encoder_decoder_attn: MultiheadAttention,
44
+ p_choose_layer: PChooseLayer,
45
+ ffn: FeedForwardNetwork,
46
+ *,
47
+ dropout_p: float = 0.1,
48
+ device: Optional[Device] = None,
49
+ dtype: Optional[DataType] = None,
50
+ ) -> None:
51
+ """
52
+ :param self_attn:
53
+ The self attention layer.
54
+ :param encoder_decoder_attn:
55
+ The encoder-decoder attention layer.
56
+ :param ffn:
57
+ The feed-forward network.
58
+ :param dropout_p:
59
+ The dropout probability on outputs of the attention layers and the
60
+ feed-forward network.
61
+ """
62
+ super().__init__()
63
+
64
+ self.model_dim = self_attn.model_dim
65
+
66
+ self_attn_layer_norm = create_standard_layer_norm(
67
+ self.model_dim, device=device, dtype=dtype
68
+ )
69
+
70
+ self.self_attn_layer_norm = self_attn_layer_norm
71
+
72
+ self.self_attn = self_attn
73
+
74
+ if dropout_p > 0.0:
75
+ self.self_attn_dropout = Dropout(dropout_p)
76
+ else:
77
+ self.register_module("self_attn_dropout", None)
78
+
79
+ encoder_decoder_attn_layer_norm = create_standard_layer_norm(
80
+ self.model_dim, device=device, dtype=dtype
81
+ )
82
+
83
+ self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm
84
+
85
+ self.encoder_decoder_attn = encoder_decoder_attn
86
+
87
+ if dropout_p > 0.0:
88
+ self.encoder_decoder_attn_dropout = Dropout(dropout_p)
89
+ else:
90
+ self.register_module("encoder_decoder_attn_dropout", None)
91
+
92
+ self.p_choose_layer = p_choose_layer
93
+
94
+ ffn_layer_norm = create_standard_layer_norm(
95
+ self.model_dim, device=device, dtype=dtype
96
+ )
97
+
98
+ self.ffn_layer_norm = ffn_layer_norm
99
+
100
+ self.ffn = ffn
101
+
102
+ if dropout_p > 0.0:
103
+ self.ffn_dropout = Dropout(dropout_p)
104
+ else:
105
+ self.register_module("ffn_dropout", None)
106
+
107
+ @finaloverride
108
+ def forward(
109
+ self,
110
+ seqs: Tensor,
111
+ padding_mask: Optional[PaddingMask],
112
+ self_attn_mask: Optional[AttentionMask] = None,
113
+ encoder_output: Optional[Tensor] = None,
114
+ encoder_padding_mask: Optional[PaddingMask] = None,
115
+ *,
116
+ state_bag: Optional[IncrementalStateBag] = None,
117
+ ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
118
+ seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag)
119
+
120
+ seqs, p_choose = self._forward_encoder_decoder_attn(
121
+ seqs, padding_mask, encoder_output, encoder_padding_mask
122
+ )
123
+
124
+ seqs = self._forward_ffn(seqs)
125
+
126
+ return seqs, padding_mask, p_choose
127
+
128
+ def _forward_self_attn(
129
+ self,
130
+ seqs: Tensor,
131
+ padding_mask: Optional[PaddingMask],
132
+ self_attn_mask: Optional[AttentionMask],
133
+ state_bag: Optional[IncrementalStateBag],
134
+ ) -> Tensor:
135
+ residual = seqs
136
+
137
+ seqs = self.self_attn_layer_norm(seqs)
138
+
139
+ seqs = self.self_attn(
140
+ seqs,
141
+ padding_mask,
142
+ keys=seqs,
143
+ key_padding_mask=padding_mask,
144
+ values=seqs,
145
+ attn_mask=self_attn_mask,
146
+ state_bag=state_bag,
147
+ )
148
+
149
+ if self.self_attn_dropout is not None:
150
+ seqs = self.self_attn_dropout(seqs)
151
+
152
+ seqs = seqs + residual
153
+
154
+ return seqs
155
+
156
+ def _forward_encoder_decoder_attn(
157
+ self,
158
+ seqs: Tensor,
159
+ padding_mask: Optional[PaddingMask],
160
+ encoder_output: Optional[Tensor],
161
+ encoder_padding_mask: Optional[PaddingMask],
162
+ ) -> Tuple[Tensor, Tensor]:
163
+ if encoder_output is None:
164
+ raise ValueError(
165
+ "`encoder_output` must not be `None` for encoder-decoder attention."
166
+ )
167
+
168
+ residual = seqs
169
+
170
+ seqs = self.encoder_decoder_attn_layer_norm(seqs)
171
+
172
+ p_choose = self.p_choose_layer(seqs, encoder_output)
173
+
174
+ seqs = self.encoder_decoder_attn(
175
+ seqs,
176
+ padding_mask,
177
+ encoder_output,
178
+ encoder_padding_mask,
179
+ encoder_output,
180
+ )
181
+
182
+ if self.encoder_decoder_attn_dropout is not None:
183
+ seqs = self.encoder_decoder_attn_dropout(seqs)
184
+
185
+ seqs = seqs + residual
186
+
187
+ return seqs, p_choose
188
+
189
+ def _forward_ffn(self, seqs: Tensor) -> Tensor:
190
+ residual = seqs
191
+
192
+ seqs = self.ffn_layer_norm(seqs)
193
+
194
+ seqs = self.ffn(seqs)
195
+
196
+ if self.ffn_dropout is not None:
197
+ seqs = self.ffn_dropout(seqs)
198
+
199
+ seqs = seqs + residual
200
+
201
+ return seqs