yangwang825 commited on
Commit
5920d4c
1 Parent(s): b7b7a53

Create modeling_ecapa.py

Browse files
Files changed (1) hide show
  1. modeling_ecapa.py +858 -0
modeling_ecapa.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import typing as tp
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.utils import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ from .helpers_ecapa import Fbank
11
+ from .configuration_ecapa import EcapaConfig
12
+
13
+
14
+ class InputNormalization(nn.Module):
15
+
16
+ spk_dict_mean: tp.Dict[int, torch.Tensor]
17
+ spk_dict_std: tp.Dict[int, torch.Tensor]
18
+ spk_dict_count: tp.Dict[int, int]
19
+
20
+ def __init__(
21
+ self,
22
+ mean_norm=True,
23
+ std_norm=True,
24
+ norm_type="global",
25
+ avg_factor=None,
26
+ requires_grad=False,
27
+ update_until_epoch=3,
28
+ ):
29
+ super().__init__()
30
+ self.mean_norm = mean_norm
31
+ self.std_norm = std_norm
32
+ self.norm_type = norm_type
33
+ self.avg_factor = avg_factor
34
+ self.requires_grad = requires_grad
35
+ self.glob_mean = torch.tensor([0])
36
+ self.glob_std = torch.tensor([0])
37
+ self.spk_dict_mean = {}
38
+ self.spk_dict_std = {}
39
+ self.spk_dict_count = {}
40
+ self.weight = 1.0
41
+ self.count = 0
42
+ self.eps = 1e-10
43
+ self.update_until_epoch = update_until_epoch
44
+
45
+ def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
46
+ """Returns the tensor with the surrounding context.
47
+ Arguments
48
+ ---------
49
+ x : tensor
50
+ A batch of tensors.
51
+ lengths : tensor
52
+ A batch of tensors containing the relative length of each
53
+ sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
54
+ computing stats on zero-padded steps.
55
+ spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
56
+ It is used to perform per-speaker normalization when
57
+ norm_type='speaker'.
58
+ """
59
+ x = input_values
60
+ N_batches = x.shape[0]
61
+
62
+ current_means = []
63
+ current_stds = []
64
+
65
+ for snt_id in range(N_batches):
66
+ # Avoiding padded time steps
67
+ # lengths = torch.sum(attention_mask, dim=1)
68
+ # relative_lengths = lengths / torch.max(lengths)
69
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
70
+ actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
71
+
72
+ # computing statistics
73
+ current_mean, current_std = self._compute_current_stats(
74
+ x[snt_id, 0:actual_size, ...]
75
+ )
76
+
77
+ current_means.append(current_mean)
78
+ current_stds.append(current_std)
79
+
80
+ if self.norm_type == "sentence":
81
+ x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
82
+
83
+ if self.norm_type == "speaker":
84
+ spk_id = int(spk_ids[snt_id][0])
85
+
86
+ if self.training:
87
+ if spk_id not in self.spk_dict_mean:
88
+ # Initialization of the dictionary
89
+ self.spk_dict_mean[spk_id] = current_mean
90
+ self.spk_dict_std[spk_id] = current_std
91
+ self.spk_dict_count[spk_id] = 1
92
+
93
+ else:
94
+ self.spk_dict_count[spk_id] = (
95
+ self.spk_dict_count[spk_id] + 1
96
+ )
97
+
98
+ if self.avg_factor is None:
99
+ self.weight = 1 / self.spk_dict_count[spk_id]
100
+ else:
101
+ self.weight = self.avg_factor
102
+
103
+ self.spk_dict_mean[spk_id] = (
104
+ (1 - self.weight) * self.spk_dict_mean[spk_id]
105
+ + self.weight * current_mean
106
+ )
107
+ self.spk_dict_std[spk_id] = (
108
+ (1 - self.weight) * self.spk_dict_std[spk_id]
109
+ + self.weight * current_std
110
+ )
111
+
112
+ self.spk_dict_mean[spk_id].detach()
113
+ self.spk_dict_std[spk_id].detach()
114
+
115
+ speaker_mean = self.spk_dict_mean[spk_id].data
116
+ speaker_std = self.spk_dict_std[spk_id].data
117
+ else:
118
+ if spk_id in self.spk_dict_mean:
119
+ speaker_mean = self.spk_dict_mean[spk_id].data
120
+ speaker_std = self.spk_dict_std[spk_id].data
121
+ else:
122
+ speaker_mean = current_mean.data
123
+ speaker_std = current_std.data
124
+
125
+ x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
126
+
127
+ if self.norm_type == "batch" or self.norm_type == "global":
128
+ current_mean = torch.mean(torch.stack(current_means), dim=0)
129
+ current_std = torch.mean(torch.stack(current_stds), dim=0)
130
+
131
+ if self.norm_type == "batch":
132
+ x = (x - current_mean.data) / (current_std.data)
133
+
134
+ if self.norm_type == "global":
135
+ if self.training:
136
+ if self.count == 0:
137
+ self.glob_mean = current_mean
138
+ self.glob_std = current_std
139
+
140
+ elif epoch < self.update_until_epoch:
141
+ if self.avg_factor is None:
142
+ self.weight = 1 / (self.count + 1)
143
+ else:
144
+ self.weight = self.avg_factor
145
+
146
+ self.glob_mean = (
147
+ 1 - self.weight
148
+ ) * self.glob_mean + self.weight * current_mean
149
+
150
+ self.glob_std = (
151
+ 1 - self.weight
152
+ ) * self.glob_std + self.weight * current_std
153
+
154
+ self.glob_mean.detach()
155
+ self.glob_std.detach()
156
+
157
+ self.count = self.count + 1
158
+
159
+ x = (x - self.glob_mean.data) / (self.glob_std.data)
160
+
161
+ return x
162
+
163
+ def _compute_current_stats(self, x):
164
+ """Returns the tensor with the surrounding context.
165
+ Arguments
166
+ ---------
167
+ x : tensor
168
+ A batch of tensors.
169
+ """
170
+ # Compute current mean
171
+ if self.mean_norm:
172
+ current_mean = torch.mean(x, dim=0).detach().data
173
+ else:
174
+ current_mean = torch.tensor([0.0], device=x.device)
175
+
176
+ # Compute current std
177
+ if self.std_norm:
178
+ current_std = torch.std(x, dim=0).detach().data
179
+ else:
180
+ current_std = torch.tensor([1.0], device=x.device)
181
+
182
+ # Improving numerical stability of std
183
+ current_std = torch.max(
184
+ current_std, self.eps * torch.ones_like(current_std)
185
+ )
186
+
187
+ return current_mean, current_std
188
+
189
+ def _statistics_dict(self):
190
+ """Fills the dictionary containing the normalization statistics."""
191
+ state = {}
192
+ state["count"] = self.count
193
+ state["glob_mean"] = self.glob_mean
194
+ state["glob_std"] = self.glob_std
195
+ state["spk_dict_mean"] = self.spk_dict_mean
196
+ state["spk_dict_std"] = self.spk_dict_std
197
+ state["spk_dict_count"] = self.spk_dict_count
198
+
199
+ return state
200
+
201
+ def _load_statistics_dict(self, state):
202
+ """Loads the dictionary containing the statistics.
203
+ Arguments
204
+ ---------
205
+ state : dict
206
+ A dictionary containing the normalization statistics.
207
+ """
208
+ self.count = state["count"]
209
+ if isinstance(state["glob_mean"], int):
210
+ self.glob_mean = state["glob_mean"]
211
+ self.glob_std = state["glob_std"]
212
+ else:
213
+ self.glob_mean = state["glob_mean"] # .to(self.device_inp)
214
+ self.glob_std = state["glob_std"] # .to(self.device_inp)
215
+
216
+ # Loading the spk_dict_mean in the right device
217
+ self.spk_dict_mean = {}
218
+ for spk in state["spk_dict_mean"]:
219
+ self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
220
+ self.device_inp
221
+ )
222
+
223
+ # Loading the spk_dict_std in the right device
224
+ self.spk_dict_std = {}
225
+ for spk in state["spk_dict_std"]:
226
+ self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
227
+ self.device_inp
228
+ )
229
+
230
+ self.spk_dict_count = state["spk_dict_count"]
231
+
232
+ return state
233
+
234
+ def to(self, device):
235
+ """Puts the needed tensors in the right device."""
236
+ self = super(InputNormalization, self).to(device)
237
+ self.glob_mean = self.glob_mean.to(device)
238
+ self.glob_std = self.glob_std.to(device)
239
+ for spk in self.spk_dict_mean:
240
+ self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
241
+ self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
242
+ return self
243
+
244
+
245
+ class TdnnLayer(nn.Module):
246
+
247
+ def __init__(
248
+ self,
249
+ in_channels,
250
+ out_channels,
251
+ kernel_size,
252
+ dilation=1,
253
+ stride=1,
254
+ groups=1,
255
+ padding=0,
256
+ padding_mode="reflect",
257
+ activation=torch.nn.LeakyReLU,
258
+ ):
259
+ super(TdnnLayer, self).__init__()
260
+ self.in_channels = in_channels
261
+ self.out_channels = out_channels
262
+ self.kernel_size = kernel_size
263
+ self.dilation = dilation
264
+ self.stride = stride
265
+ self.groups = groups
266
+ self.padding = padding
267
+ self.padding_mode = padding_mode
268
+ self.activation = activation()
269
+
270
+ self.conv = nn.Conv1d(
271
+ self.in_channels,
272
+ self.out_channels,
273
+ self.kernel_size,
274
+ dilation=self.dilation,
275
+ padding=self.padding,
276
+ groups=self.groups
277
+ )
278
+
279
+ # Set Affine=false to be compatible with the original kaldi version
280
+ # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
281
+ self.norm = nn.BatchNorm1d(out_channels, affine=False)
282
+
283
+ def forward(self, x):
284
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
285
+ out = self.conv(x)
286
+ out = self.activation(out)
287
+ out = self.norm(out)
288
+ return out
289
+
290
+ def _manage_padding(
291
+ self, x, kernel_size: int, dilation: int, stride: int,
292
+ ):
293
+ # Detecting input shape
294
+ L_in = self.in_channels
295
+
296
+ # Time padding
297
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
298
+
299
+ # Applying padding
300
+ x = F.pad(x, padding, mode=self.padding_mode)
301
+
302
+ return x
303
+
304
+
305
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
306
+ """This function computes the number of elements to add for zero-padding.
307
+ Arguments
308
+ ---------
309
+ L_in : int
310
+ stride: int
311
+ kernel_size : int
312
+ dilation : int
313
+ """
314
+ if stride > 1:
315
+ padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
316
+
317
+ else:
318
+ L_out = (
319
+ math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
320
+ )
321
+ padding = [
322
+ math.floor((L_in - L_out) / 2),
323
+ math.floor((L_in - L_out) / 2),
324
+ ]
325
+ return padding
326
+
327
+
328
+ class Res2NetBlock(torch.nn.Module):
329
+ """An implementation of Res2NetBlock w/ dilation.
330
+ Arguments
331
+ ---------
332
+ in_channels : int
333
+ The number of channels expected in the input.
334
+ out_channels : int
335
+ The number of output channels.
336
+ scale : int
337
+ The scale of the Res2Net block.
338
+ kernel_size: int
339
+ The kernel size of the Res2Net block.
340
+ dilation : int
341
+ The dilation of the Res2Net block.
342
+ Example
343
+ -------
344
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
345
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
346
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
347
+ >>> out_tensor.shape
348
+ torch.Size([8, 120, 64])
349
+ """
350
+
351
+ def __init__(
352
+ self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
353
+ ):
354
+ super(Res2NetBlock, self).__init__()
355
+ assert in_channels % scale == 0
356
+ assert out_channels % scale == 0
357
+
358
+ in_channel = in_channels // scale
359
+ hidden_channel = out_channels // scale
360
+
361
+ self.blocks = nn.ModuleList(
362
+ [
363
+ TdnnLayer(
364
+ in_channel,
365
+ hidden_channel,
366
+ kernel_size=kernel_size,
367
+ dilation=dilation,
368
+ )
369
+ for _ in range(scale - 1)
370
+ ]
371
+ )
372
+ self.scale = scale
373
+
374
+ def forward(self, x):
375
+ """Processes the input tensor x and returns an output tensor."""
376
+ y = []
377
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
378
+ if i == 0:
379
+ y_i = x_i
380
+ elif i == 1:
381
+ y_i = self.blocks[i - 1](x_i)
382
+ else:
383
+ y_i = self.blocks[i - 1](x_i + y_i)
384
+ y.append(y_i)
385
+ y = torch.cat(y, dim=1)
386
+ return y
387
+
388
+
389
+ class SEBlock(nn.Module):
390
+ """An implementation of squeeze-and-excitation block.
391
+ Arguments
392
+ ---------
393
+ in_channels : int
394
+ The number of input channels.
395
+ se_channels : int
396
+ The number of output channels after squeeze.
397
+ out_channels : int
398
+ The number of output channels.
399
+ Example
400
+ -------
401
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
402
+ >>> se_layer = SEBlock(64, 16, 64)
403
+ >>> lengths = torch.rand((8,))
404
+ >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
405
+ >>> out_tensor.shape
406
+ torch.Size([8, 120, 64])
407
+ """
408
+
409
+ def __init__(self, in_channels, se_channels, out_channels):
410
+ super(SEBlock, self).__init__()
411
+
412
+ self.conv1 = nn.Conv1d(
413
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
414
+ )
415
+ self.relu = torch.nn.ReLU(inplace=True)
416
+ self.conv2 = nn.Conv1d(
417
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
418
+ )
419
+ self.sigmoid = torch.nn.Sigmoid()
420
+
421
+ def forward(self, x, lengths=None):
422
+ """Processes the input tensor x and returns an output tensor."""
423
+ L = x.shape[-1]
424
+ if lengths is not None:
425
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
426
+ mask = mask.unsqueeze(1)
427
+ total = mask.sum(dim=2, keepdim=True)
428
+ s = (x * mask).sum(dim=2, keepdim=True) / total
429
+ else:
430
+ s = x.mean(dim=2, keepdim=True)
431
+
432
+ s = self.relu(self.conv1(s))
433
+ s = self.sigmoid(self.conv2(s))
434
+
435
+ return s * x
436
+
437
+
438
+ def length_to_mask(length, max_len=None, dtype=None, device=None):
439
+ """Creates a binary mask for each sequence.
440
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
441
+ Arguments
442
+ ---------
443
+ length : torch.LongTensor
444
+ Containing the length of each sequence in the batch. Must be 1D.
445
+ max_len : int
446
+ Max length for the mask, also the size of the second dimension.
447
+ dtype : torch.dtype, default: None
448
+ The dtype of the generated mask.
449
+ device: torch.device, default: None
450
+ The device to put the mask variable.
451
+ Returns
452
+ -------
453
+ mask : tensor
454
+ The binary mask.
455
+ Example
456
+ -------
457
+ >>> length=torch.Tensor([1,2,3])
458
+ >>> mask=length_to_mask(length)
459
+ >>> mask
460
+ tensor([[1., 0., 0.],
461
+ [1., 1., 0.],
462
+ [1., 1., 1.]])
463
+ """
464
+ assert len(length.shape) == 1
465
+
466
+ if max_len is None:
467
+ max_len = length.max().long().item() # using arange to generate mask
468
+ mask = torch.arange(
469
+ max_len, device=length.device, dtype=length.dtype
470
+ ).expand(len(length), max_len) < length.unsqueeze(1)
471
+
472
+ if dtype is None:
473
+ dtype = length.dtype
474
+
475
+ if device is None:
476
+ device = length.device
477
+
478
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
479
+ return mask
480
+
481
+
482
+ class AttentiveStatisticsPooling(nn.Module):
483
+ """This class implements an attentive statistic pooling layer for each channel.
484
+ It returns the concatenated mean and std of the input tensor.
485
+ Arguments
486
+ ---------
487
+ channels: int
488
+ The number of input channels.
489
+ attention_channels: int
490
+ The number of attention channels.
491
+ Example
492
+ -------
493
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
494
+ >>> asp_layer = AttentiveStatisticsPooling(64)
495
+ >>> lengths = torch.rand((8,))
496
+ >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
497
+ >>> out_tensor.shape
498
+ torch.Size([8, 1, 128])
499
+ """
500
+
501
+ def __init__(self, channels, attention_channels=128, global_context=True):
502
+ super().__init__()
503
+
504
+ self.eps = 1e-12
505
+ self.global_context = global_context
506
+ if global_context:
507
+ self.tdnn = TdnnLayer(channels * 3, attention_channels, 1, 1)
508
+ else:
509
+ self.tdnn = TdnnLayer(channels, attention_channels, 1, 1)
510
+ self.tanh = nn.Tanh()
511
+ self.conv = nn.Conv1d(
512
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
513
+ )
514
+
515
+ def forward(self, x, lengths=None):
516
+ """Calculates mean and std for a batch (input tensor).
517
+ Arguments
518
+ ---------
519
+ x : torch.Tensor
520
+ Tensor of shape [N, C, L].
521
+ """
522
+ L = x.shape[-1]
523
+
524
+ def _compute_statistics(x, m, dim=2, eps=self.eps):
525
+ mean = (m * x).sum(dim)
526
+ std = torch.sqrt(
527
+ (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
528
+ )
529
+ return mean, std
530
+
531
+ if lengths is None:
532
+ lengths = torch.ones(x.shape[0], device=x.device)
533
+
534
+ # Make binary mask of shape [N, 1, L]
535
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
536
+ mask = mask.unsqueeze(1)
537
+
538
+ # Expand the temporal context of the pooling layer by allowing the
539
+ # self-attention to look at global properties of the utterance.
540
+ if self.global_context:
541
+ # torch.std is unstable for backward computation
542
+ # https://github.com/pytorch/pytorch/issues/4320
543
+ total = mask.sum(dim=2, keepdim=True).float()
544
+ mean, std = _compute_statistics(x, mask / total)
545
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
546
+ std = std.unsqueeze(2).repeat(1, 1, L)
547
+ attn = torch.cat([x, mean, std], dim=1)
548
+ else:
549
+ attn = x
550
+
551
+ # Apply layers
552
+ attn = self.conv(self.tanh(self.tdnn(attn)))
553
+
554
+ # Filter out zero-paddings
555
+ attn = attn.masked_fill(mask == 0, float("-inf"))
556
+
557
+ attn = F.softmax(attn, dim=2)
558
+ mean, std = _compute_statistics(x, attn)
559
+ # Append mean and std of the batch
560
+ pooled_stats = torch.cat((mean, std), dim=1)
561
+ pooled_stats = pooled_stats.unsqueeze(2)
562
+
563
+ return pooled_stats
564
+
565
+
566
+
567
+ class SERes2NetBlock(nn.Module):
568
+ """An implementation of building block in ECAPA-TDNN, i.e.,
569
+ TDNN-Res2Net-TDNN-SEBlock.
570
+ Arguments
571
+ ----------
572
+ out_channels: int
573
+ The number of output channels.
574
+ res2net_scale: int
575
+ The scale of the Res2Net block.
576
+ kernel_size: int
577
+ The kernel size of the TDNN blocks.
578
+ dilation: int
579
+ The dilation of the Res2Net block.
580
+ activation : torch class
581
+ A class for constructing the activation layers.
582
+ groups: int
583
+ Number of blocked connections from input channels to output channels.
584
+ Example
585
+ -------
586
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
587
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
588
+ >>> out = conv(x).transpose(1, 2)
589
+ >>> out.shape
590
+ torch.Size([8, 120, 64])
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ in_channels,
596
+ out_channels,
597
+ res2net_scale=8,
598
+ se_channels=128,
599
+ kernel_size=1,
600
+ dilation=1,
601
+ activation=torch.nn.ReLU,
602
+ groups=1,
603
+ ):
604
+ super().__init__()
605
+ self.out_channels = out_channels
606
+ self.tdnn1 = TdnnLayer(
607
+ in_channels,
608
+ out_channels,
609
+ kernel_size=1,
610
+ dilation=1,
611
+ activation=activation,
612
+ groups=groups,
613
+ )
614
+ self.res2net_block = Res2NetBlock(
615
+ out_channels, out_channels, res2net_scale, kernel_size, dilation
616
+ )
617
+ self.tdnn2 = TdnnLayer(
618
+ out_channels,
619
+ out_channels,
620
+ kernel_size=1,
621
+ dilation=1,
622
+ activation=activation,
623
+ groups=groups,
624
+ )
625
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
626
+
627
+ self.shortcut = None
628
+ if in_channels != out_channels:
629
+ self.shortcut = nn.Conv1d(
630
+ in_channels=in_channels,
631
+ out_channels=out_channels,
632
+ kernel_size=1,
633
+ )
634
+
635
+ def forward(self, x, lengths=None):
636
+ """Processes the input tensor x and returns an output tensor."""
637
+ residual = x
638
+ if self.shortcut:
639
+ residual = self.shortcut(x)
640
+
641
+ x = self.tdnn1(x)
642
+ x = self.res2net_block(x)
643
+ x = self.tdnn2(x)
644
+ x = self.se_block(x, lengths)
645
+
646
+ return x + residual
647
+
648
+
649
+ class EcapaEmbedder(nn.Module):
650
+
651
+ def __init__(
652
+ self,
653
+ in_channels=80,
654
+ hidden_size=192,
655
+ activation=torch.nn.ReLU,
656
+ channels=[512, 512, 512, 512, 1536],
657
+ kernel_sizes=[5, 3, 3, 3, 1],
658
+ dilations=[1, 2, 3, 4, 1],
659
+ attention_channels=128,
660
+ res2net_scale=8,
661
+ se_channels=128,
662
+ global_context=True,
663
+ groups=[1, 1, 1, 1, 1],
664
+ ) -> None:
665
+ super(EcapaEmbedder, self).__init__()
666
+ self.channels = channels
667
+ self.blocks = nn.ModuleList()
668
+
669
+ # The initial TDNN layer
670
+ self.blocks.append(
671
+ TdnnLayer(
672
+ in_channels,
673
+ channels[0],
674
+ kernel_sizes[0],
675
+ dilations[0],
676
+ activation=activation,
677
+ groups=groups[0],
678
+ )
679
+ )
680
+
681
+ # SE-Res2Net layers
682
+ for i in range(1, len(channels) - 1):
683
+ self.blocks.append(
684
+ SERes2NetBlock(
685
+ channels[i - 1],
686
+ channels[i],
687
+ res2net_scale=res2net_scale,
688
+ se_channels=se_channels,
689
+ kernel_size=kernel_sizes[i],
690
+ dilation=dilations[i],
691
+ activation=activation,
692
+ groups=groups[i],
693
+ )
694
+ )
695
+
696
+ # Multi-layer feature aggregation
697
+ self.mfa = TdnnLayer(
698
+ channels[-2] * (len(channels) - 2),
699
+ channels[-1],
700
+ kernel_sizes[-1],
701
+ dilations[-1],
702
+ activation=activation,
703
+ groups=groups[-1],
704
+ )
705
+
706
+ # Attentive Statistical Pooling
707
+ self.asp = AttentiveStatisticsPooling(
708
+ channels[-1],
709
+ attention_channels=attention_channels,
710
+ global_context=global_context,
711
+ )
712
+ self.asp_bn = nn.BatchNorm1d(channels[-1] * 2)
713
+
714
+ # Final linear transformation
715
+ self.fc = nn.Conv1d(
716
+ in_channels=channels[-1] * 2,
717
+ out_channels=hidden_size,
718
+ kernel_size=1,
719
+ )
720
+
721
+ def forward(self, input_values, lengths=None):
722
+ # Minimize transpose for efficiency
723
+ x = input_values.transpose(1, 2)
724
+ # lengths = torch.sum(attention_mask, dim=1)
725
+ # lengths = lengths / torch.max(lengths)
726
+
727
+ xl = []
728
+ for layer in self.blocks:
729
+ try:
730
+ x = layer(x, lengths)
731
+ except TypeError:
732
+ x = layer(x)
733
+ xl.append(x)
734
+
735
+ # Multi-layer feature aggregation
736
+ x = torch.cat(xl[1:], dim=1)
737
+ x = self.mfa(x)
738
+
739
+ # Attentive Statistical Pooling
740
+ x = self.asp(x, lengths)
741
+ x = self.asp_bn(x)
742
+
743
+ # Final linear transformation
744
+ x = self.fc(x)
745
+
746
+ pooler_output = x.transpose(1, 2)
747
+ pooler_output = pooler_output.squeeze(1)
748
+ return ModelOutput(
749
+ # last_hidden_state=last_hidden_state,
750
+ pooler_output=pooler_output
751
+ )
752
+
753
+
754
+ class CosineSimilarityHead(torch.nn.Module):
755
+ """
756
+ This class implements the cosine similarity on the top of features.
757
+ """
758
+ def __init__(
759
+ self,
760
+ in_channels,
761
+ lin_blocks=0,
762
+ hidden_size=192,
763
+ num_classes=1211,
764
+ ):
765
+ super().__init__()
766
+ self.blocks = nn.ModuleList()
767
+
768
+ for block_index in range(lin_blocks):
769
+ self.blocks.extend(
770
+ [
771
+ nn.BatchNorm1d(num_features=in_channels),
772
+ nn.Linear(in_features=in_channels, out_features=hidden_size),
773
+ ]
774
+ )
775
+ in_channels = hidden_size
776
+
777
+ # Final Layer
778
+ self.weight = nn.Parameter(
779
+ torch.FloatTensor(num_classes, in_channels)
780
+ )
781
+ nn.init.xavier_uniform_(self.weight)
782
+
783
+ def forward(self, x):
784
+ """Returns the output probabilities over speakers.
785
+ Arguments
786
+ ---------
787
+ x : torch.Tensor
788
+ Torch tensor.
789
+ """
790
+ for layer in self.blocks:
791
+ x = layer(x)
792
+
793
+ # Need to be normalized
794
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
795
+ return x
796
+
797
+
798
+ class EcapaPreTrainedModel(PreTrainedModel):
799
+
800
+ config_class = EcapaConfig
801
+ base_model_prefix = "ecapa"
802
+ main_input_name = "input_values"
803
+ supports_gradient_checkpointing = True
804
+
805
+ def _init_weights(self, module):
806
+ """Initialize the weights"""
807
+ if isinstance(module, nn.Linear):
808
+ # Slightly different from the TF version which uses truncated_normal for initialization
809
+ # cf https://github.com/pytorch/pytorch/pull/5617
810
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
811
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
812
+ module.bias.data.zero_()
813
+ module.weight.data.fill_(1.0)
814
+ elif isinstance(module, nn.Conv1d):
815
+ nn.init.kaiming_normal_(module.weight.data)
816
+
817
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
818
+ module.bias.data.zero_()
819
+
820
+
821
+ class EcapaModel(EcapaPreTrainedModel):
822
+
823
+ def __init__(self, config):
824
+ super().__init__(config)
825
+ self.compute_features = Fbank(
826
+ n_mels=config.n_mels,
827
+ sample_rate=config.sample_rate,
828
+ win_length=config.win_length,
829
+ hop_length=config.hop_length,
830
+ )
831
+ self.mean_var_norm = InputNormalization(
832
+ mean_norm=config.mean_norm,
833
+ std_norm=config.std_norm,
834
+ norm_type=config.norm_type
835
+ )
836
+ self.embedding_model = EcapaEmbedder(
837
+ in_channels=config.n_mels,
838
+ channels=config.channels,
839
+ kernel_sizes=config.kernel_sizes,
840
+ dilations=config.dilations,
841
+ attention_channels=config.attention_channels,
842
+ res2net_scale=config.res2net_scale,
843
+ se_channels=config.se_channels,
844
+ global_context=config.global_context,
845
+ groups=config.groups,
846
+ hidden_size=config.hidden_size
847
+ )
848
+
849
+ def forward(self, input_values, lengths=None):
850
+ x = input_values
851
+ # if attention_mask is None:
852
+ # attention_mask = torch.ones_like(input_values, device=x.device)
853
+ x = self.compute_features(x)
854
+ x = self.mean_var_norm(x, lengths)
855
+ output = self.embedding_model(x, lengths)
856
+ return ModelOutput(
857
+ pooler_output=output.pooler_output,
858
+ )