victan commited on
Commit
e659968
1 Parent(s): 5ec225a

Upload seamless_communication/models/monotonic_decoder/monotonic_decoder.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/monotonic_decoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Iterable, List, Optional, Tuple, final
8
+
9
+ import torch
10
+ from fairseq2.nn.incremental_state import IncrementalStateBag
11
+ from fairseq2.nn.module_list import ModuleList
12
+ from fairseq2.nn.normalization import LayerNorm
13
+ from fairseq2.nn.padding import PaddingMask
14
+ from fairseq2.nn.transformer import (
15
+ AttentionMaskFactory,
16
+ CausalAttentionMaskFactory,
17
+ create_standard_layer_norm,
18
+ )
19
+ from fairseq2.typing import DataType, Device, finaloverride
20
+ from torch import Tensor
21
+ from torch.nn import Module
22
+
23
+ from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import (
24
+ MonotonicTransformerDecoderLayer,
25
+ )
26
+
27
+
28
+ @final
29
+ class MonotonicTransformerDecoder(Module):
30
+ """Represents a Monotonic Transformer decoder."""
31
+
32
+ model_dim: int
33
+ self_attn_mask_factory: AttentionMaskFactory
34
+ layers: ModuleList
35
+ layer_norm: LayerNorm
36
+
37
+ def __init__(
38
+ self,
39
+ layers: Iterable[MonotonicTransformerDecoderLayer],
40
+ *,
41
+ device: Optional[Device] = None,
42
+ dtype: Optional[DataType] = None,
43
+ ) -> None:
44
+ """
45
+ :param layers:
46
+ The decoder layers.
47
+ """
48
+ super().__init__()
49
+
50
+ layer_list = ModuleList(layers)
51
+
52
+ if not layer_list:
53
+ raise ValueError("`layers` must be non-empty.")
54
+
55
+ self.model_dim = layer_list[0].model_dim
56
+
57
+ self.self_attn_mask_factory = CausalAttentionMaskFactory()
58
+
59
+ self.layers = layer_list
60
+
61
+ self.layer_norm = create_standard_layer_norm(
62
+ self.model_dim, device=device, dtype=dtype
63
+ )
64
+
65
+ @finaloverride
66
+ def forward(
67
+ self,
68
+ seqs: Tensor,
69
+ padding_mask: Optional[PaddingMask],
70
+ encoder_output: Optional[Tensor] = None,
71
+ encoder_padding_mask: Optional[PaddingMask] = None,
72
+ *,
73
+ state_bag: Optional[IncrementalStateBag] = None,
74
+ ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
75
+ self_attn_mask = self.self_attn_mask_factory(
76
+ seqs, keys=seqs, training=self.training, state_bag=state_bag
77
+ )
78
+
79
+ p_choose_list: List[Tensor] = []
80
+
81
+ for layer in self.layers.drop_iter():
82
+ seqs, padding_mask, p_choose = layer(
83
+ seqs,
84
+ padding_mask,
85
+ self_attn_mask,
86
+ encoder_output,
87
+ encoder_padding_mask,
88
+ state_bag=state_bag,
89
+ )
90
+ p_choose_list.append(p_choose)
91
+
92
+ seqs = self.layer_norm(seqs)
93
+
94
+ p_choose = torch.cat(p_choose_list, dim=0)
95
+
96
+ p_choose = p_choose.flatten(0, 1)
97
+
98
+ return seqs, padding_mask, p_choose