victan commited on
Commit
d6ab6ec
1 Parent(s): 391c5a0

Upload seamless_communication/models/monotonic_decoder/builder.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/builder.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ from fairseq2.data import VocabularyInfo
11
+ from fairseq2.models.transformer import (
12
+ TransformerEmbeddingFrontend,
13
+ TransformerFrontend,
14
+ )
15
+ from fairseq2.models.utils.arch_registry import ArchitectureRegistry
16
+ from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
17
+ from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
18
+ from fairseq2.nn.projection import TiedProjection
19
+ from fairseq2.nn.transformer import (
20
+ FeedForwardNetwork,
21
+ MultiheadAttention,
22
+ StandardFeedForwardNetwork,
23
+ StandardMultiheadAttention,
24
+ TransformerNormOrder,
25
+ create_default_sdpa,
26
+ )
27
+ from fairseq2.typing import DataType, Device
28
+
29
+ from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel
30
+ from seamless_communication.models.monotonic_decoder.monotonic_decoder import (
31
+ MonotonicTransformerDecoder,
32
+ )
33
+ from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
34
+ MonotonicTransformerDecoderLayer,
35
+ )
36
+ from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer
37
+
38
+
39
+ @dataclass
40
+ class MonotonicDecoderConfig:
41
+ """Holds the configuration of an Monotonic Decoder model."""
42
+
43
+ model_dim: int
44
+ """The dimensionality of the model."""
45
+
46
+ max_seq_len: int
47
+ """The expected maximum sequence length."""
48
+
49
+ vocab_info: VocabularyInfo
50
+ """The vocabulary information."""
51
+
52
+ num_decoder_layers: int
53
+ """The number of Transformer decoder layers."""
54
+
55
+ num_decoder_attn_heads: int
56
+ """The number of attention heads in Transformer decoder layers."""
57
+
58
+ ffn_inner_dim: int
59
+ """The inner dimensionality of Transformer feed-forward networks."""
60
+
61
+ dropout_p: float
62
+ """The dropout probability in Transformer layers."""
63
+
64
+ energy_bias_value: float
65
+ """The value of the energy bias parameter to be added to the
66
+ monotonic energy in the PChooseLayer."""
67
+
68
+ monotonic_temperature: float
69
+ """The parameter with which to divide the monotonic energy
70
+ to compute p_choose."""
71
+
72
+ num_monotonic_energy_layers: int
73
+ """The number of layers in the EnergyProjection module."""
74
+
75
+ pre_decision_ratio: int
76
+ """The kernel size and stride of the average pooling
77
+ in the PChooseLayer."""
78
+
79
+
80
+ monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig](
81
+ "monotonic_decoder"
82
+ )
83
+
84
+ monotonic_decoder_arch = monotonic_decoder_archs.decorator
85
+
86
+
87
+ @monotonic_decoder_arch("dense_1b")
88
+ def _dense_1b() -> MonotonicDecoderConfig:
89
+ return MonotonicDecoderConfig(
90
+ model_dim=1024,
91
+ max_seq_len=4096,
92
+ vocab_info=VocabularyInfo(
93
+ size=256102, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
94
+ ),
95
+ num_decoder_layers=24,
96
+ num_decoder_attn_heads=16,
97
+ ffn_inner_dim=1024 * 8,
98
+ dropout_p=0.1,
99
+ energy_bias_value=-0.5,
100
+ monotonic_temperature=0.2,
101
+ num_monotonic_energy_layers=4,
102
+ pre_decision_ratio=2,
103
+ )
104
+
105
+
106
+ class MonotonicDecoderBuilder:
107
+ """Builds modules of a Monotonic Decoder.
108
+
109
+ To tweak the architecture, you can derive from this class and override the
110
+ corresponding methods.
111
+ """
112
+
113
+ config: MonotonicDecoderConfig
114
+ device: Optional[Device]
115
+ dtype: Optional[DataType]
116
+
117
+ def __init__(
118
+ self,
119
+ config: MonotonicDecoderConfig,
120
+ *,
121
+ device: Optional[Device] = None,
122
+ dtype: Optional[DataType] = None,
123
+ ) -> None:
124
+ """
125
+ :param config:
126
+ The configuration to use.
127
+ :param device:
128
+ The device on which to initialize modules.
129
+ :param dtype:
130
+ The data type of module parameters and buffers.
131
+ """
132
+ self.config = config
133
+
134
+ self.device, self.dtype = device, dtype
135
+
136
+ def build_model(self) -> MonotonicDecoderModel:
137
+ text_embed = self.build_embedding()
138
+
139
+ text_decoder_frontend = self.build_frontend(text_embed)
140
+
141
+ text_decoder = self.build_decoder()
142
+
143
+ final_proj = TiedProjection(text_embed.weight, bias=None)
144
+
145
+ return MonotonicDecoderModel(
146
+ text_decoder_frontend,
147
+ text_decoder,
148
+ final_proj,
149
+ )
150
+
151
+ def build_embedding(self) -> StandardEmbedding:
152
+ """Build an embedding table."""
153
+ return StandardEmbedding(
154
+ num_embeddings=self.config.vocab_info.size,
155
+ embedding_dim=self.config.model_dim,
156
+ pad_idx=self.config.vocab_info.pad_idx,
157
+ init_fn=init_scaled_embedding,
158
+ device=self.device,
159
+ dtype=self.dtype,
160
+ )
161
+
162
+ def build_frontend(self, embed: Embedding) -> TransformerFrontend:
163
+ """Build a Transformer decoder front-end."""
164
+ pos_encoder = SinusoidalPositionEncoder(
165
+ self.config.model_dim,
166
+ self.config.max_seq_len,
167
+ _legacy_pad_idx=1,
168
+ device=self.device,
169
+ )
170
+
171
+ return TransformerEmbeddingFrontend(
172
+ embed,
173
+ pos_encoder,
174
+ dropout_p=self.config.dropout_p,
175
+ device=self.device,
176
+ dtype=self.dtype,
177
+ )
178
+
179
+ def build_decoder(self) -> MonotonicTransformerDecoder:
180
+ """Build a Transformer decoder."""
181
+ num_layers = self.config.num_decoder_layers
182
+
183
+ layers = [self.build_decoder_layer() for _ in range(num_layers)]
184
+
185
+ return MonotonicTransformerDecoder(
186
+ layers,
187
+ device=self.device,
188
+ dtype=self.dtype,
189
+ )
190
+
191
+ def build_decoder_layer(self) -> MonotonicTransformerDecoderLayer:
192
+ """Build a Transformer decoder layer."""
193
+ self_attn = self.build_attention(self.config.num_decoder_attn_heads)
194
+
195
+ encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)
196
+
197
+ p_choose_layer = self.build_p_choose_layer(self.config.num_decoder_attn_heads)
198
+
199
+ ffn = self.build_ffn()
200
+
201
+ return MonotonicTransformerDecoderLayer(
202
+ self_attn,
203
+ encoder_decoder_attn,
204
+ p_choose_layer,
205
+ ffn,
206
+ dropout_p=self.config.dropout_p,
207
+ device=self.device,
208
+ dtype=self.dtype,
209
+ )
210
+
211
+ def build_attention(self, num_heads: int) -> MultiheadAttention:
212
+ """Build a Transformer multi-head attention layer."""
213
+ sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
214
+
215
+ return StandardMultiheadAttention(
216
+ self.config.model_dim,
217
+ num_heads,
218
+ sdpa=sdpa,
219
+ device=self.device,
220
+ dtype=self.dtype,
221
+ )
222
+
223
+ def build_p_choose_layer(self, num_heads: int) -> PChooseLayer:
224
+ """Build a PChoose layer."""
225
+ return PChooseLayer(
226
+ self.config.model_dim,
227
+ num_heads,
228
+ self.config.energy_bias_value,
229
+ self.config.monotonic_temperature,
230
+ self.config.num_monotonic_energy_layers,
231
+ self.config.pre_decision_ratio,
232
+ device=self.device,
233
+ dtype=self.dtype,
234
+ )
235
+
236
+ def build_ffn(self) -> FeedForwardNetwork:
237
+ """Build a Transformer feed-forward network."""
238
+ return StandardFeedForwardNetwork(
239
+ self.config.model_dim,
240
+ self.config.ffn_inner_dim,
241
+ bias=True,
242
+ norm_order=TransformerNormOrder.PRE,
243
+ device=self.device,
244
+ dtype=self.dtype,
245
+ )
246
+
247
+
248
+ def create_monotonic_decoder_model(
249
+ config: MonotonicDecoderConfig,
250
+ *,
251
+ device: Optional[Device] = None,
252
+ dtype: Optional[DataType] = None,
253
+ ) -> MonotonicDecoderModel:
254
+ """Create an Monotonic Decoder model.
255
+
256
+ :param config:
257
+ The configuration to use.
258
+ :param device:
259
+ The device on which to initialize modules.
260
+ :param dtype:
261
+ The data type of module parameters and buffers.
262
+ """
263
+ return MonotonicDecoderBuilder(config, device=device, dtype=dtype).build_model()