victan commited on
Commit
563d86f
1 Parent(s): f4c6ff7

Upload seamless_communication/models/pretssel/ecapa_tdnn.py with huggingface_hub

Browse files
seamless_communication/models/pretssel/ecapa_tdnn.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq2.nn.padding import PaddingMask, to_padding_mask
12
+ from torch import Tensor
13
+ from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
14
+
15
+
16
+ class ECAPA_TDNN(Module):
17
+ """
18
+ Represents the ECAPA-TDNN model described in paper:
19
+ :cite:t`https://doi.org/10.48550/arxiv.2005.07143`.
20
+
21
+ Arguments
22
+ ---------
23
+ :param channels:
24
+ Output channels for TDNN/SERes2Net layer.
25
+ :param kernel_sizes:
26
+ List of kernel sizes for each layer.
27
+ :param dilations:
28
+ List of dilations for kernels in each layer.
29
+ :param groups:
30
+ List of groups for kernels in each layer.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels: List[int],
36
+ kernel_sizes: List[int],
37
+ dilations: List[int],
38
+ attention_channels: int,
39
+ res2net_scale: int,
40
+ se_channels: int,
41
+ global_context: bool,
42
+ groups: List[int],
43
+ embed_dim: int,
44
+ input_dim: int,
45
+ ):
46
+ super().__init__()
47
+ assert len(channels) == len(kernel_sizes) == len(dilations)
48
+ self.channels = channels
49
+ self.embed_dim = embed_dim
50
+ self.blocks = ModuleList()
51
+
52
+ self.blocks.append(
53
+ TDNNBlock(
54
+ input_dim,
55
+ channels[0],
56
+ kernel_sizes[0],
57
+ dilations[0],
58
+ groups[0],
59
+ )
60
+ )
61
+
62
+ # SE-Res2Net layers
63
+ for i in range(1, len(channels) - 1):
64
+ self.blocks.append(
65
+ SERes2NetBlock(
66
+ channels[i - 1],
67
+ channels[i],
68
+ res2net_scale=res2net_scale,
69
+ se_channels=se_channels,
70
+ kernel_size=kernel_sizes[i],
71
+ dilation=dilations[i],
72
+ groups=groups[i],
73
+ )
74
+ )
75
+
76
+ # Multi-layer feature aggregation
77
+ self.mfa = TDNNBlock(
78
+ channels[-1],
79
+ channels[-1],
80
+ kernel_sizes[-1],
81
+ dilations[-1],
82
+ groups=groups[-1],
83
+ )
84
+
85
+ # Attentive Statistical Pooling
86
+ self.asp = AttentiveStatisticsPooling(
87
+ channels[-1],
88
+ attention_channels=attention_channels,
89
+ global_context=global_context,
90
+ )
91
+ self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
92
+
93
+ # Final linear transformation
94
+ self.fc = Conv1d(
95
+ in_channels=channels[-1] * 2,
96
+ out_channels=embed_dim,
97
+ kernel_size=1,
98
+ )
99
+
100
+ self.reset_parameters()
101
+
102
+ def reset_parameters(self) -> None:
103
+ """Reset the parameters and buffers of the module."""
104
+
105
+ def encoder_init(m: Module) -> None:
106
+ if isinstance(m, Conv1d):
107
+ init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
108
+
109
+ self.apply(encoder_init)
110
+
111
+ def forward(
112
+ self,
113
+ x: Tensor,
114
+ padding_mask: Optional[PaddingMask] = None,
115
+ ) -> Tensor:
116
+ """Returns the embedding vector.
117
+
118
+ Arguments
119
+ ---------
120
+ x : torch.Tensor
121
+ Tensor of shape (batch, time, channel).
122
+ """
123
+ # Minimize transpose for efficiency
124
+ x = x.transpose(1, 2)
125
+
126
+ xl = []
127
+ for layer in self.blocks:
128
+ x = layer(x, padding_mask=padding_mask)
129
+ xl.append(x)
130
+
131
+ # Multi-layer feature aggregation
132
+ x = torch.cat(xl[1:], dim=1)
133
+ x = self.mfa(x)
134
+
135
+ # Attentive Statistical Pooling
136
+ x = self.asp(x, padding_mask=padding_mask)
137
+ x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
138
+
139
+ # Final linear transformation
140
+ x = self.fc(x)
141
+
142
+ x = x.transpose(1, 2).squeeze(1) # B x C
143
+ return F.normalize(x, dim=-1)
144
+
145
+
146
+ class TDNNBlock(Module):
147
+ """An implementation of TDNN.
148
+
149
+ Arguments
150
+ ----------
151
+ :param in_channels : int
152
+ Number of input channels.
153
+ :param out_channels : int
154
+ The number of output channels.
155
+ :param kernel_size : int
156
+ The kernel size of the TDNN blocks.
157
+ :param dilation : int
158
+ The dilation of the TDNN block.
159
+ :param groups: int
160
+ The groups size of the TDNN blocks.
161
+
162
+ Example
163
+ -------
164
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
165
+ >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
166
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
167
+ >>> out_tensor.shape
168
+ torch.Size([8, 120, 64])
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ in_channels: int,
174
+ out_channels: int,
175
+ kernel_size: int,
176
+ dilation: int,
177
+ groups: int = 1,
178
+ ):
179
+ super().__init__()
180
+ self.conv = Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=kernel_size,
184
+ dilation=dilation,
185
+ padding=dilation * (kernel_size - 1) // 2,
186
+ groups=groups,
187
+ )
188
+ self.activation = ReLU()
189
+ self.norm = LayerNorm(out_channels, eps=1e-12)
190
+
191
+ def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
192
+ """Processes the input tensor x and returns an output tensor."""
193
+ x = self.activation(self.conv(x))
194
+
195
+ return self.norm(x.transpose(1, 2)).transpose(1, 2) # type: ignore[no-any-return]
196
+
197
+
198
+ class Res2NetBlock(Module):
199
+ """An implementation of Res2NetBlock w/ dilation.
200
+
201
+ Arguments
202
+ ---------
203
+ :param in_channels : int
204
+ The number of channels expected in the input.
205
+ :param out_channels : int
206
+ The number of output channels.
207
+ :param scale : int
208
+ The scale of the Res2Net block.
209
+ :param kernel_size: int
210
+ The kernel size of the Res2Net block.
211
+ :param dilation : int
212
+ The dilation of the Res2Net block.
213
+
214
+ Example
215
+ -------
216
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
217
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
218
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
219
+ >>> out_tensor.shape
220
+ torch.Size([8, 120, 64])
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ in_channels: int,
226
+ out_channels: int,
227
+ scale: int = 8,
228
+ kernel_size: int = 3,
229
+ dilation: int = 1,
230
+ ):
231
+ super().__init__()
232
+ assert in_channels % scale == 0
233
+ assert out_channels % scale == 0
234
+
235
+ in_channel = in_channels // scale
236
+ hidden_channel = out_channels // scale
237
+ self.blocks = ModuleList(
238
+ [
239
+ TDNNBlock(
240
+ in_channel,
241
+ hidden_channel,
242
+ kernel_size=kernel_size,
243
+ dilation=dilation,
244
+ )
245
+ for i in range(scale - 1)
246
+ ]
247
+ )
248
+ self.scale = scale
249
+
250
+ def forward(self, x: Tensor) -> Tensor:
251
+ """Processes the input tensor x and returns an output tensor."""
252
+ y = []
253
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
254
+ if i == 0:
255
+ y_i = x_i
256
+ elif i == 1:
257
+ y_i = self.blocks[i - 1](x_i)
258
+ else:
259
+ y_i = self.blocks[i - 1](x_i + y_i)
260
+ y.append(y_i)
261
+
262
+ y_tensor = torch.cat(y, dim=1)
263
+ return y_tensor
264
+
265
+
266
+ class SEBlock(Module):
267
+ """An implementation of squeeze-and-excitation block.
268
+
269
+ Arguments
270
+ ---------
271
+ in_channels : int
272
+ The number of input channels.
273
+ se_channels : int
274
+ The number of output channels after squeeze.
275
+ out_channels : int
276
+ The number of output channels.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ in_channels: int,
282
+ se_channels: int,
283
+ out_channels: int,
284
+ ):
285
+ super().__init__()
286
+
287
+ self.conv1 = Conv1d(
288
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
289
+ )
290
+ self.relu = ReLU(inplace=True)
291
+ self.conv2 = Conv1d(
292
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
293
+ )
294
+ self.sigmoid = Sigmoid()
295
+
296
+ def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
297
+ """Processes the input tensor x and returns an output tensor."""
298
+ if padding_mask is not None:
299
+ mask = padding_mask.materialize().unsqueeze(1)
300
+ s = (x * mask).sum(dim=2, keepdim=True) / padding_mask.seq_lens[
301
+ :, None, None
302
+ ]
303
+ else:
304
+ s = x.mean(dim=2, keepdim=True)
305
+
306
+ s = self.relu(self.conv1(s))
307
+ s = self.sigmoid(self.conv2(s))
308
+
309
+ return s * x
310
+
311
+
312
+ class AttentiveStatisticsPooling(Module):
313
+ """This class implements an attentive statistic pooling layer for each channel.
314
+ It returns the concatenated mean and std of the input tensor.
315
+
316
+ Arguments
317
+ ---------
318
+ channels: int
319
+ The number of input channels.
320
+ attention_channels: int
321
+ The number of attention channels.
322
+ """
323
+
324
+ def __init__(
325
+ self, channels: int, attention_channels: int = 128, global_context: bool = True
326
+ ):
327
+ super().__init__()
328
+
329
+ self.eps = 1e-12
330
+ self.global_context = global_context
331
+ if global_context:
332
+ self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
333
+ else:
334
+ self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
335
+
336
+ self.tanh = Tanh()
337
+ self.conv = Conv1d(
338
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
339
+ )
340
+
341
+ def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
342
+ """Calculates mean and std for a batch (input tensor).
343
+
344
+ Arguments
345
+ ---------
346
+ x : torch.Tensor
347
+ Tensor of shape [N, C, L].
348
+ """
349
+ L = x.shape[-1]
350
+
351
+ def _compute_statistics(
352
+ x: Tensor, m: Tensor, dim: int = 2, eps: float = self.eps
353
+ ) -> Tuple[Tensor, Tensor]:
354
+ mean = (m * x).sum(dim)
355
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
356
+ return mean, std
357
+
358
+ # if lengths is None:
359
+ # lengths = [x.shape[0]]
360
+
361
+ # Make binary mask of shape [N, 1, L]
362
+ # mask = to_padding_mask(lengths, max(lengths))
363
+ if padding_mask is not None:
364
+ mask = padding_mask.materialize()
365
+ else:
366
+ mask = to_padding_mask(torch.IntTensor([L]), L).repeat(x.shape[0], 1).to(x)
367
+ mask = mask.unsqueeze(1)
368
+
369
+ # Expand the temporal context of the pooling layer by allowing the
370
+ # self-attention to look at global properties of the utterance.
371
+ if self.global_context:
372
+ # torch.std is unstable for backward computation
373
+ # https://github.com/pytorch/pytorch/issues/4320
374
+ total = mask.sum(dim=2, keepdim=True).to(x)
375
+ mean, std = _compute_statistics(x, mask / total)
376
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
377
+ std = std.unsqueeze(2).repeat(1, 1, L)
378
+ attn = torch.cat([x, mean, std], dim=1)
379
+ else:
380
+ attn = x
381
+
382
+ # Apply layers
383
+ attn = self.conv(self.tanh(self.tdnn(attn)))
384
+
385
+ # Filter out zero-paddings
386
+ attn = attn.masked_fill(mask == 0, float("-inf"))
387
+
388
+ attn = F.softmax(attn, dim=2)
389
+ mean, std = _compute_statistics(x, attn)
390
+ # Append mean and std of the batch
391
+ pooled_stats = torch.cat((mean, std), dim=1)
392
+ pooled_stats = pooled_stats.unsqueeze(2)
393
+
394
+ return pooled_stats
395
+
396
+
397
+ class SERes2NetBlock(Module):
398
+ """An implementation of building block in ECAPA-TDNN, i.e.,
399
+ TDNN-Res2Net-TDNN-SEBlock.
400
+
401
+ Arguments
402
+ ----------
403
+ out_channels: int
404
+ The number of output channels.
405
+ res2net_scale: int
406
+ The scale of the Res2Net block.
407
+ kernel_size: int
408
+ The kernel size of the TDNN blocks.
409
+ dilation: int
410
+ The dilation of the Res2Net block.
411
+ groups: int
412
+ Number of blocked connections from input channels to output channels.
413
+
414
+ Example
415
+ -------
416
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
417
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
418
+ >>> out = conv(x).transpose(1, 2)
419
+ >>> out.shape
420
+ torch.Size([8, 120, 64])
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ in_channels: int,
426
+ out_channels: int,
427
+ res2net_scale: int = 8,
428
+ se_channels: int = 128,
429
+ kernel_size: int = 1,
430
+ dilation: int = 1,
431
+ groups: int = 1,
432
+ ):
433
+ super().__init__()
434
+ self.out_channels = out_channels
435
+ self.tdnn1 = TDNNBlock(
436
+ in_channels,
437
+ out_channels,
438
+ kernel_size=1,
439
+ dilation=1,
440
+ groups=groups,
441
+ )
442
+ self.res2net_block = Res2NetBlock(
443
+ out_channels,
444
+ out_channels,
445
+ res2net_scale,
446
+ kernel_size,
447
+ dilation,
448
+ )
449
+ self.tdnn2 = TDNNBlock(
450
+ out_channels,
451
+ out_channels,
452
+ kernel_size=1,
453
+ dilation=1,
454
+ groups=groups,
455
+ )
456
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
457
+
458
+ self.shortcut = None
459
+ if in_channels != out_channels:
460
+ self.shortcut = Conv1d(
461
+ in_channels=in_channels,
462
+ out_channels=out_channels,
463
+ kernel_size=1,
464
+ )
465
+
466
+ def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
467
+ """Processes the input tensor x and returns an output tensor."""
468
+ residual = x
469
+ if self.shortcut:
470
+ residual = self.shortcut(x)
471
+
472
+ x = self.tdnn1(x)
473
+ x = self.res2net_block(x)
474
+ x = self.tdnn2(x)
475
+ x = self.se_block(x, padding_mask=padding_mask)
476
+
477
+ return x + residual