victan commited on
Commit
02db026
1 Parent(s): 5c71b2a

Upload seamless_communication/models/aligner/model.py with huggingface_hub

Browse files
seamless_communication/models/aligner/model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, List, Tuple, Union
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from fairseq2.data import CString
15
+ from fairseq2.nn.embedding import StandardEmbedding
16
+ from fairseq2.nn.padding import to_padding_mask
17
+ from fairseq2.typing import DataType
18
+ from torch import Tensor
19
+ from torch.nn import Module
20
+
21
+ from seamless_communication.models.unity.char_tokenizer import CharTokenizer
22
+ from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
23
+
24
+
25
+ class UnitY2AlignmentFrontend(Module):
26
+ def __init__(
27
+ self,
28
+ embed_text: StandardEmbedding,
29
+ embed_unit: StandardEmbedding,
30
+ text_tokenizer: CharTokenizer,
31
+ unit_tokenizer: UnitTokenizer,
32
+ ):
33
+ super().__init__()
34
+ self.embed_text = embed_text
35
+ self.embed_unit = embed_unit
36
+ self.text_tokenizer = text_tokenizer
37
+ self.unit_tokenizer = unit_tokenizer
38
+ unit_tokenizer.is_nar_decoder = True
39
+
40
+ self.encode_text = self.text_tokenizer.create_raw_encoder()
41
+ # text decoder can be used to map aligned characters to words
42
+ self.decode_text = self.text_tokenizer.create_decoder()
43
+ self.encode_unit = self.unit_tokenizer.create_encoder(lang="eng")
44
+
45
+ def tokenize_text(
46
+ self, text: str, return_tokens: bool = False, add_trailing_silence: bool = False
47
+ ) -> Tensor:
48
+ tokenized = self.encode_text(text)
49
+ if add_trailing_silence:
50
+ tokenized = torch.cat([tokenized, tokenized[0:1]])
51
+
52
+ return tokenized
53
+
54
+ def tokenize_text_to_tokens(
55
+ self, text: str, add_trailing_silence: bool = False
56
+ ) -> List[Union[CString, str]]:
57
+ tokenized = self.encode_text.encode_as_tokens(text)
58
+ if add_trailing_silence:
59
+ tokenized = tokenized + [tokenized[0]]
60
+
61
+ return tokenized
62
+
63
+ def tokenize_unit(self, units: Union[str, Tensor]) -> Tensor:
64
+ if isinstance(units, str):
65
+ units = torch.tensor([int(u) for u in units.split(" ")])
66
+ return self.encode_unit(units)
67
+
68
+ def forward(self, text: Tensor, unit: Tensor) -> Tuple[Any, Any]:
69
+ embs_unit = self.embed_unit(unit)
70
+ embs_text = self.embed_text(text)
71
+ return embs_text, embs_unit
72
+
73
+
74
+ class Permute12(nn.Module):
75
+ def forward(self, x: Tensor) -> Tensor:
76
+ return x.transpose(1, 2)
77
+
78
+
79
+ class UnitY2AlignmentEncoder(Module):
80
+ """
81
+ UnitY2 Aligner component
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ embed_dim: int,
87
+ feat_dim: int,
88
+ text_layers: int,
89
+ feat_layers: int,
90
+ dropout: float,
91
+ temperature: float,
92
+ reduction_factor: int,
93
+ dtype: DataType,
94
+ ):
95
+ super().__init__()
96
+ self.temperature = temperature
97
+ self.reduction_factor = reduction_factor # for unit
98
+
99
+ layers: List[Module] = [Permute12()]
100
+ for i in range(text_layers):
101
+ if i < text_layers - 1:
102
+ layers.append(
103
+ nn.Conv1d(
104
+ embed_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
105
+ )
106
+ )
107
+ layers.append(nn.ReLU())
108
+ layers.append(nn.Dropout(p=dropout))
109
+ else:
110
+ layers.append(
111
+ nn.Conv1d(
112
+ embed_dim, embed_dim, kernel_size=1, padding=0, dtype=dtype
113
+ )
114
+ )
115
+ layers.append(nn.Dropout(p=dropout))
116
+ layers.append(Permute12())
117
+ self.t_conv = nn.Sequential(*layers)
118
+
119
+ layers = [Permute12()]
120
+ input_dim = feat_dim
121
+ for i in range(feat_layers):
122
+ if i < feat_layers - 1:
123
+ layers.append(
124
+ nn.Conv1d(
125
+ input_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
126
+ )
127
+ )
128
+ layers.append(nn.ReLU())
129
+ layers.append(nn.Dropout(p=dropout))
130
+ else:
131
+ layers.append(
132
+ nn.Conv1d(
133
+ input_dim,
134
+ embed_dim,
135
+ kernel_size=1,
136
+ padding=0,
137
+ stride=reduction_factor,
138
+ dtype=dtype,
139
+ )
140
+ )
141
+ layers.append(nn.Dropout(p=dropout))
142
+ layers.append(Permute12())
143
+ input_dim = embed_dim
144
+ self.f_conv = nn.Sequential(*layers)
145
+
146
+ def forward(
147
+ self,
148
+ text_emb: Tensor,
149
+ feat_emb: Tensor,
150
+ text_lengths: Tensor,
151
+ feat_lengths: Tensor,
152
+ ) -> Tuple[Tensor, Tensor]:
153
+ """Compute alignment between sequence of text and feature embeddings
154
+
155
+ Args:
156
+ text_emb (Tensor): Batched text embedding (B, T_text, C).
157
+ feat_emb (Tensor): Batched acoustic feature (B, T_feat, feat_dim).
158
+ text_lengths (Tensor): Source text length (B,).
159
+ feat_lengths (Tensor): Target feature length (B,).
160
+
161
+ Returns:
162
+ Tensor: Log probability of attention matrix (B, T_feat, T_text)
163
+ Tensor: Unit durations of every text token (B, T_text)
164
+
165
+ """
166
+ _feat_lengths = feat_lengths.clone()
167
+ if self.reduction_factor > 1:
168
+ feat_lengths = torch.ceil(feat_lengths / self.reduction_factor).long()
169
+
170
+ text_emb = self.t_conv(text_emb)
171
+ feat_emb = self.f_conv(feat_emb)
172
+
173
+ dist = feat_emb.unsqueeze(2) - text_emb.unsqueeze(1)
174
+ dist = torch.norm(dist, p=2, dim=3)
175
+ score = -self.temperature * dist
176
+
177
+ padding_mask = ~(to_padding_mask(text_lengths, max(text_lengths)))
178
+ padding_mask = padding_mask.unsqueeze(-2)
179
+ score = score.masked_fill(padding_mask, -np.inf)
180
+
181
+ attn_lprob = F.log_softmax(score, dim=-1)
182
+
183
+ attn_hard_dur = viterbi_decode(attn_lprob, text_lengths, feat_lengths)
184
+
185
+ if self.reduction_factor > 1:
186
+ attn_hard_dur = self.postprocess_alignment(
187
+ attn_hard_dur, text_lengths, _feat_lengths
188
+ )
189
+
190
+ return attn_lprob, attn_hard_dur
191
+
192
+ def postprocess_alignment(
193
+ self, attn_hard_dur: Tensor, text_lengths: Tensor, feat_lengths: Tensor
194
+ ) -> Tensor:
195
+ attn_hard_dur = attn_hard_dur * self.reduction_factor
196
+ B, T = attn_hard_dur.size() # B x T_text
197
+ dur_cumsum = torch.cumsum(attn_hard_dur, dim=1)
198
+ for b in range(B):
199
+ for t in range(text_lengths[b]):
200
+ # truncate the right frames
201
+ if dur_cumsum[b, t] >= feat_lengths[b]:
202
+ if t == 0:
203
+ attn_hard_dur[b, t] = feat_lengths[b]
204
+ else:
205
+ attn_hard_dur[b, t] = feat_lengths[b] - dur_cumsum[b, t - 1]
206
+ if t < text_lengths[b] - 1:
207
+ attn_hard_dur[b, t + 1 :] = 0
208
+ break
209
+ return attn_hard_dur
210
+
211
+
212
+ def _monotonic_alignment_search(
213
+ attn_lprob: npt.NDArray[np.float64],
214
+ ) -> npt.NDArray[np.float64]:
215
+ # https://arxiv.org/abs/2005.11129
216
+ T_feat = attn_lprob.shape[0]
217
+ T_text = attn_lprob.shape[1]
218
+ Q = np.full((T_text, T_feat), fill_value=-np.inf)
219
+
220
+ log_prob = attn_lprob.transpose(1, 0) # -> (T_text, T_feat)
221
+ # 1. Q <- init first row for all j
222
+ for j in range(T_feat):
223
+ Q[0, j] = log_prob[0, : j + 1].sum()
224
+
225
+ # 2.
226
+ for j in range(1, T_feat):
227
+ for i in range(1, min(j + 1, T_text)):
228
+ Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j]
229
+
230
+ # 3.
231
+ A = np.full((T_feat,), fill_value=T_text - 1)
232
+ for j in range(T_feat - 2, -1, -1): # T_feat-2, ..., 0
233
+ # 'i' in {A[j+1]-1, A[j+1]}
234
+ i_a = A[j + 1] - 1
235
+ i_b = A[j + 1]
236
+ if i_b == 0:
237
+ argmax_i = 0
238
+ elif Q[i_a, j] >= Q[i_b, j]:
239
+ argmax_i = i_a
240
+ else:
241
+ argmax_i = i_b
242
+ A[j] = argmax_i
243
+ return A
244
+
245
+
246
+ def viterbi_decode(
247
+ attn_lprob: Tensor, text_lengths: Tensor, feat_lengths: Tensor
248
+ ) -> Tensor:
249
+ """Extract duration from an attention probability matrix
250
+
251
+ Args:
252
+ attn_lprob (Tensor): Batched log probability of attention
253
+ matrix (B, T_feat, T_text).
254
+ text_lengths (Tensor): Text length tensor (B,).
255
+ feat_lengths (Tensor): Feature length tensor (B,).
256
+
257
+ Returns:
258
+ Tensor: Batched token duration extracted from `attn_lprob` (B, T_text).
259
+ Tensor: Binarization loss tensor ().
260
+
261
+ """
262
+ B = attn_lprob.size(0)
263
+ T_text = attn_lprob.size(2)
264
+ device = attn_lprob.device
265
+
266
+ durations = torch.zeros((B, T_text), device=device, dtype=torch.long)
267
+ for b in range(B):
268
+ assert feat_lengths[b] > 0
269
+ assert text_lengths[b] > 0
270
+ cur_log_p_attn = attn_lprob[b, : feat_lengths[b], : text_lengths[b]]
271
+ viterbi = _monotonic_alignment_search(
272
+ cur_log_p_attn.float().detach().cpu().numpy()
273
+ )
274
+ _durations = np.bincount(viterbi)
275
+ durations[b, : len(_durations)] = torch.from_numpy(_durations).to(device)
276
+
277
+ return durations
278
+
279
+
280
+ class UnitY2AlignmentModel(Module):
281
+ alignment_encoder: UnitY2AlignmentEncoder
282
+ alignment_frontend: UnitY2AlignmentFrontend
283
+
284
+ def __init__(
285
+ self,
286
+ alignment_frontend: UnitY2AlignmentFrontend,
287
+ alignment_encoder: UnitY2AlignmentEncoder,
288
+ ):
289
+ super().__init__()
290
+ self.alignment_frontend = alignment_frontend
291
+ self.alignment_encoder = alignment_encoder
292
+
293
+ def forward(self, input_text: Tensor, input_unit: Tensor) -> Tuple[Tensor, Tensor]:
294
+ assert input_text.ndim == 2
295
+ assert input_unit.ndim == 2
296
+ embs_text, embs_unit = self.alignment_frontend(input_text, input_unit)
297
+ attn_lprob, attn_hard_dur = self.alignment_encoder(
298
+ embs_text,
299
+ embs_unit,
300
+ torch.tensor([embs_text.size(1)]).to(embs_text).int(),
301
+ torch.tensor([embs_unit.size(1)]).to(embs_unit).int(),
302
+ )
303
+
304
+ return attn_lprob, attn_hard_dur