victan commited on
Commit
abacfaa
1 Parent(s): c9852e4

Upload seamless_communication/models/monotonic_decoder/p_choose.py with huggingface_hub

Browse files
seamless_communication/models/monotonic_decoder/p_choose.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, final
8
+
9
+ import torch
10
+ from fairseq2.nn.projection import Linear
11
+ from fairseq2.typing import DataType, Device, finaloverride
12
+ from torch import Tensor
13
+ from torch.nn import AvgPool1d, Module, ModuleList, ReLU
14
+ from torch.nn.parameter import Parameter
15
+
16
+
17
+ class EnergyProjection(Module):
18
+ def __init__(
19
+ self,
20
+ model_dim: int,
21
+ num_layers: int,
22
+ bias: bool = True,
23
+ device: Optional[Device] = None,
24
+ dtype: Optional[DataType] = None,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ if num_layers < 1:
29
+ raise ValueError(
30
+ f"Invalid `num_layers`: {num_layers} for EnergyProjectionLayer."
31
+ )
32
+
33
+ self.layers = ModuleList()
34
+
35
+ for _ in range(num_layers):
36
+ self.layers.append(
37
+ Linear(model_dim, model_dim, bias, device=device, dtype=dtype)
38
+ )
39
+ self.layers.append(ReLU())
40
+
41
+ def forward(self, seqs: Tensor) -> Tensor:
42
+ for layer in self.layers:
43
+ seqs = layer(seqs)
44
+ return seqs
45
+
46
+
47
+ @final
48
+ class PChooseLayer(Module):
49
+ """Represents a PChoose layer."""
50
+
51
+ model_dim: int
52
+ num_heads: int
53
+ energy_bias: Parameter
54
+ monotonic_temperature: float
55
+ q_energy_proj: EnergyProjection
56
+ k_energy_proj: EnergyProjection
57
+ keys_pooling: AvgPool1d
58
+
59
+ def __init__(
60
+ self,
61
+ model_dim: int,
62
+ num_heads: int,
63
+ energy_bias_value: float,
64
+ monotonic_temperature: float,
65
+ num_monotonic_energy_layers: int,
66
+ pre_decision_ratio: int,
67
+ *,
68
+ bias: bool = True,
69
+ device: Optional[Device] = None,
70
+ dtype: Optional[DataType] = None,
71
+ ) -> None:
72
+ """
73
+ :param model_dim:
74
+ The dimensionality of the model.
75
+ :param num_heads:
76
+ The number of attention heads.
77
+ :param bias:
78
+ If ``True``, query, key energy projection layers learn an
79
+ additive bias.
80
+ """
81
+ super().__init__()
82
+
83
+ self.model_dim = model_dim
84
+ self.num_heads = num_heads
85
+
86
+ if energy_bias_value != 0.0:
87
+ self.energy_bias = Parameter(
88
+ torch.full([1], energy_bias_value, device=device, dtype=dtype)
89
+ )
90
+ else:
91
+ self.register_module("energy_bias", None)
92
+
93
+ self.monotonic_temperature = monotonic_temperature
94
+
95
+ if num_monotonic_energy_layers <= 0:
96
+ raise ValueError("Number of monotonic energy layers must be > 0.")
97
+
98
+ self.q_energy_proj = EnergyProjection(
99
+ self.model_dim,
100
+ num_monotonic_energy_layers,
101
+ bias,
102
+ device=device,
103
+ dtype=dtype,
104
+ )
105
+ self.k_energy_proj = EnergyProjection(
106
+ self.model_dim,
107
+ num_monotonic_energy_layers,
108
+ bias,
109
+ device=device,
110
+ dtype=dtype,
111
+ )
112
+
113
+ self.keys_pooling = AvgPool1d(
114
+ kernel_size=pre_decision_ratio,
115
+ stride=pre_decision_ratio,
116
+ ceil_mode=True,
117
+ )
118
+
119
+ @finaloverride
120
+ def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
121
+ q = self.q_energy_proj(seqs)
122
+
123
+ # (N, S, M) -> (N, H, S, K)
124
+ q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
125
+
126
+ # (N, S_kv, M) -> (N, M, S_kv) -> (N, M, S_p)
127
+ pooled_keys = self.keys_pooling(keys.transpose(1, 2))
128
+
129
+ # (N, M, S_p) -> (N, S_p, M)
130
+ pooled_keys = pooled_keys.transpose(1, 2)
131
+
132
+ k = self.k_energy_proj(pooled_keys)
133
+
134
+ # (N, S_p, M) -> (N, H, S_p, K)
135
+ k = k.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
136
+
137
+ # (N, H, S, K) @ (N, H, K, S_p) = (N, H, S, S_p)
138
+ monotonic_energy = torch.matmul(q, k.transpose(-1, -2))
139
+
140
+ monotonic_energy = monotonic_energy * (q.size(-1) ** -0.5)
141
+
142
+ if self.energy_bias is not None:
143
+ monotonic_energy += self.energy_bias
144
+
145
+ # p_choose: (N, H, S, S_p)
146
+ p_choose = torch.sigmoid(monotonic_energy / self.monotonic_temperature)
147
+
148
+ return p_choose