yangwang825 commited on
Commit
7f6d802
1 Parent(s): 8d61bc5

Create modeling_xvector.py

Browse files
Files changed (1) hide show
  1. modeling_xvector.py +550 -0
modeling_xvector.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import typing as tp
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.utils import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ from .helpers_xvector import Fbank
11
+ from .configuration_xvector import XvectorConfig
12
+
13
+
14
+ class InputNormalization(nn.Module):
15
+
16
+ spk_dict_mean: tp.Dict[int, torch.Tensor]
17
+ spk_dict_std: tp.Dict[int, torch.Tensor]
18
+ spk_dict_count: tp.Dict[int, int]
19
+
20
+ def __init__(
21
+ self,
22
+ mean_norm=True,
23
+ std_norm=True,
24
+ norm_type="global",
25
+ avg_factor=None,
26
+ requires_grad=False,
27
+ update_until_epoch=3,
28
+ ):
29
+ super().__init__()
30
+ self.mean_norm = mean_norm
31
+ self.std_norm = std_norm
32
+ self.norm_type = norm_type
33
+ self.avg_factor = avg_factor
34
+ self.requires_grad = requires_grad
35
+ self.glob_mean = torch.tensor([0])
36
+ self.glob_std = torch.tensor([0])
37
+ self.spk_dict_mean = {}
38
+ self.spk_dict_std = {}
39
+ self.spk_dict_count = {}
40
+ self.weight = 1.0
41
+ self.count = 0
42
+ self.eps = 1e-10
43
+ self.update_until_epoch = update_until_epoch
44
+
45
+ def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
46
+ """Returns the tensor with the surrounding context.
47
+ Arguments
48
+ ---------
49
+ x : tensor
50
+ A batch of tensors.
51
+ lengths : tensor
52
+ A batch of tensors containing the relative length of each
53
+ sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
54
+ computing stats on zero-padded steps.
55
+ spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
56
+ It is used to perform per-speaker normalization when
57
+ norm_type='speaker'.
58
+ """
59
+ x = input_values
60
+ N_batches = x.shape[0]
61
+
62
+ current_means = []
63
+ current_stds = []
64
+
65
+ for snt_id in range(N_batches):
66
+ # Avoiding padded time steps
67
+ # lengths = torch.sum(attention_mask, dim=1)
68
+ # relative_lengths = lengths / torch.max(lengths)
69
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
70
+ actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
71
+
72
+ # computing statistics
73
+ current_mean, current_std = self._compute_current_stats(
74
+ x[snt_id, 0:actual_size, ...]
75
+ )
76
+
77
+ current_means.append(current_mean)
78
+ current_stds.append(current_std)
79
+
80
+ if self.norm_type == "sentence":
81
+ x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
82
+
83
+ if self.norm_type == "speaker":
84
+ spk_id = int(spk_ids[snt_id][0])
85
+
86
+ if self.training:
87
+ if spk_id not in self.spk_dict_mean:
88
+ # Initialization of the dictionary
89
+ self.spk_dict_mean[spk_id] = current_mean
90
+ self.spk_dict_std[spk_id] = current_std
91
+ self.spk_dict_count[spk_id] = 1
92
+
93
+ else:
94
+ self.spk_dict_count[spk_id] = (
95
+ self.spk_dict_count[spk_id] + 1
96
+ )
97
+
98
+ if self.avg_factor is None:
99
+ self.weight = 1 / self.spk_dict_count[spk_id]
100
+ else:
101
+ self.weight = self.avg_factor
102
+
103
+ self.spk_dict_mean[spk_id] = (
104
+ (1 - self.weight) * self.spk_dict_mean[spk_id]
105
+ + self.weight * current_mean
106
+ )
107
+ self.spk_dict_std[spk_id] = (
108
+ (1 - self.weight) * self.spk_dict_std[spk_id]
109
+ + self.weight * current_std
110
+ )
111
+
112
+ self.spk_dict_mean[spk_id].detach()
113
+ self.spk_dict_std[spk_id].detach()
114
+
115
+ speaker_mean = self.spk_dict_mean[spk_id].data
116
+ speaker_std = self.spk_dict_std[spk_id].data
117
+ else:
118
+ if spk_id in self.spk_dict_mean:
119
+ speaker_mean = self.spk_dict_mean[spk_id].data
120
+ speaker_std = self.spk_dict_std[spk_id].data
121
+ else:
122
+ speaker_mean = current_mean.data
123
+ speaker_std = current_std.data
124
+
125
+ x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
126
+
127
+ if self.norm_type == "batch" or self.norm_type == "global":
128
+ current_mean = torch.mean(torch.stack(current_means), dim=0)
129
+ current_std = torch.mean(torch.stack(current_stds), dim=0)
130
+
131
+ if self.norm_type == "batch":
132
+ x = (x - current_mean.data) / (current_std.data)
133
+
134
+ if self.norm_type == "global":
135
+ if self.training:
136
+ if self.count == 0:
137
+ self.glob_mean = current_mean
138
+ self.glob_std = current_std
139
+
140
+ elif epoch < self.update_until_epoch:
141
+ if self.avg_factor is None:
142
+ self.weight = 1 / (self.count + 1)
143
+ else:
144
+ self.weight = self.avg_factor
145
+
146
+ self.glob_mean = (
147
+ 1 - self.weight
148
+ ) * self.glob_mean + self.weight * current_mean
149
+
150
+ self.glob_std = (
151
+ 1 - self.weight
152
+ ) * self.glob_std + self.weight * current_std
153
+
154
+ self.glob_mean.detach()
155
+ self.glob_std.detach()
156
+
157
+ self.count = self.count + 1
158
+
159
+ x = (x - self.glob_mean.data) / (self.glob_std.data)
160
+
161
+ return x
162
+
163
+ def _compute_current_stats(self, x):
164
+ """Returns the tensor with the surrounding context.
165
+ Arguments
166
+ ---------
167
+ x : tensor
168
+ A batch of tensors.
169
+ """
170
+ # Compute current mean
171
+ if self.mean_norm:
172
+ current_mean = torch.mean(x, dim=0).detach().data
173
+ else:
174
+ current_mean = torch.tensor([0.0], device=x.device)
175
+
176
+ # Compute current std
177
+ if self.std_norm:
178
+ current_std = torch.std(x, dim=0).detach().data
179
+ else:
180
+ current_std = torch.tensor([1.0], device=x.device)
181
+
182
+ # Improving numerical stability of std
183
+ current_std = torch.max(
184
+ current_std, self.eps * torch.ones_like(current_std)
185
+ )
186
+
187
+ return current_mean, current_std
188
+
189
+ def _statistics_dict(self):
190
+ """Fills the dictionary containing the normalization statistics."""
191
+ state = {}
192
+ state["count"] = self.count
193
+ state["glob_mean"] = self.glob_mean
194
+ state["glob_std"] = self.glob_std
195
+ state["spk_dict_mean"] = self.spk_dict_mean
196
+ state["spk_dict_std"] = self.spk_dict_std
197
+ state["spk_dict_count"] = self.spk_dict_count
198
+
199
+ return state
200
+
201
+ def _load_statistics_dict(self, state):
202
+ """Loads the dictionary containing the statistics.
203
+ Arguments
204
+ ---------
205
+ state : dict
206
+ A dictionary containing the normalization statistics.
207
+ """
208
+ self.count = state["count"]
209
+ if isinstance(state["glob_mean"], int):
210
+ self.glob_mean = state["glob_mean"]
211
+ self.glob_std = state["glob_std"]
212
+ else:
213
+ self.glob_mean = state["glob_mean"] # .to(self.device_inp)
214
+ self.glob_std = state["glob_std"] # .to(self.device_inp)
215
+
216
+ # Loading the spk_dict_mean in the right device
217
+ self.spk_dict_mean = {}
218
+ for spk in state["spk_dict_mean"]:
219
+ self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
220
+ self.device_inp
221
+ )
222
+
223
+ # Loading the spk_dict_std in the right device
224
+ self.spk_dict_std = {}
225
+ for spk in state["spk_dict_std"]:
226
+ self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
227
+ self.device_inp
228
+ )
229
+
230
+ self.spk_dict_count = state["spk_dict_count"]
231
+
232
+ return state
233
+
234
+ def to(self, device):
235
+ """Puts the needed tensors in the right device."""
236
+ self = super(InputNormalization, self).to(device)
237
+ self.glob_mean = self.glob_mean.to(device)
238
+ self.glob_std = self.glob_std.to(device)
239
+ for spk in self.spk_dict_mean:
240
+ self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
241
+ self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
242
+ return self
243
+
244
+
245
+ class TdnnLayer(nn.Module):
246
+
247
+ def __init__(
248
+ self,
249
+ in_channels,
250
+ out_channels,
251
+ kernel_size,
252
+ dilation=1,
253
+ stride=1,
254
+ padding=0,
255
+ padding_mode="reflect",
256
+ activation=torch.nn.LeakyReLU,
257
+ ):
258
+ super(TdnnLayer, self).__init__()
259
+ self.in_channels = in_channels
260
+ self.out_channels = out_channels
261
+ self.kernel_size = kernel_size
262
+ self.dilation = dilation
263
+ self.stride = stride
264
+ self.padding = padding
265
+ self.padding_mode = padding_mode
266
+ self.activation = activation
267
+
268
+ self.conv = nn.Conv1d(
269
+ self.in_channels,
270
+ self.out_channels,
271
+ self.kernel_size,
272
+ dilation=self.dilation,
273
+ padding=self.padding
274
+ )
275
+
276
+ # Set Affine=false to be compatible with the original kaldi version
277
+ # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
278
+ self.norm = nn.BatchNorm1d(out_channels, affine=False)
279
+
280
+ def forward(self, x):
281
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
282
+ out = self.conv(x)
283
+ out = self.activation()(out)
284
+ out = self.norm(out)
285
+ return out
286
+
287
+ def _manage_padding(
288
+ self, x, kernel_size: int, dilation: int, stride: int,
289
+ ):
290
+ # Detecting input shape
291
+ L_in = self.in_channels
292
+
293
+ # Time padding
294
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
295
+
296
+ # Applying padding
297
+ x = F.pad(x, padding, mode=self.padding_mode)
298
+
299
+ return x
300
+
301
+
302
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
303
+ """This function computes the number of elements to add for zero-padding.
304
+ Arguments
305
+ ---------
306
+ L_in : int
307
+ stride: int
308
+ kernel_size : int
309
+ dilation : int
310
+ """
311
+ if stride > 1:
312
+ padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
313
+
314
+ else:
315
+ L_out = (
316
+ math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
317
+ )
318
+ padding = [
319
+ math.floor((L_in - L_out) / 2),
320
+ math.floor((L_in - L_out) / 2),
321
+ ]
322
+ return padding
323
+
324
+
325
+ class StatisticsPooling(nn.Module):
326
+
327
+ def __init__(self, return_mean=True, return_std=True):
328
+ super().__init__()
329
+
330
+ # Small value for GaussNoise
331
+ self.eps = 1e-5
332
+ self.return_mean = return_mean
333
+ self.return_std = return_std
334
+ if not (self.return_mean or self.return_std):
335
+ raise ValueError(
336
+ "both of statistics are equal to False \n"
337
+ "consider enabling mean and/or std statistic pooling"
338
+ )
339
+
340
+ def forward(self, input_values, lengths=None):
341
+ """Calculates mean and std for a batch (input tensor).
342
+ Arguments
343
+ ---------
344
+ x : torch.Tensor
345
+ It represents a tensor for a mini-batch.
346
+ """
347
+ x = input_values
348
+ if lengths is None:
349
+ if self.return_mean:
350
+ mean = x.mean(dim=1)
351
+ if self.return_std:
352
+ std = x.std(dim=1)
353
+ else:
354
+ mean = []
355
+ std = []
356
+ for snt_id in range(x.shape[0]):
357
+ # Avoiding padded time steps
358
+ # lengths = torch.sum(attention_mask, dim=1)
359
+ # relative_lengths = lengths / torch.max(lengths)
360
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
361
+ actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
362
+
363
+ # computing statistics
364
+ if self.return_mean:
365
+ mean.append(
366
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
367
+ )
368
+ if self.return_std:
369
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
370
+ if self.return_mean:
371
+ mean = torch.stack(mean)
372
+ if self.return_std:
373
+ std = torch.stack(std)
374
+
375
+ if self.return_mean:
376
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
377
+ gnoise = gnoise
378
+ mean += gnoise
379
+ if self.return_std:
380
+ std = std + self.eps
381
+
382
+ # Append mean and std of the batch
383
+ if self.return_mean and self.return_std:
384
+ pooled_stats = torch.cat((mean, std), dim=1)
385
+ pooled_stats = pooled_stats.unsqueeze(1)
386
+ elif self.return_mean:
387
+ pooled_stats = mean.unsqueeze(1)
388
+ elif self.return_std:
389
+ pooled_stats = std.unsqueeze(1)
390
+
391
+ return pooled_stats
392
+
393
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
394
+ """Returns a tensor of epsilon Gaussian noise.
395
+ Arguments
396
+ ---------
397
+ shape_of_tensor : tensor
398
+ It represents the size of tensor for generating Gaussian noise.
399
+ """
400
+ gnoise = torch.randn(shape_of_tensor, device=device)
401
+ gnoise -= torch.min(gnoise)
402
+ gnoise /= torch.max(gnoise)
403
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
404
+
405
+ return gnoise
406
+
407
+
408
+ class XvectorEmbedder(nn.Module):
409
+
410
+ def __init__(
411
+ self,
412
+ in_channels=40,
413
+ activation=torch.nn.LeakyReLU,
414
+ tdnn_blocks=5,
415
+ tdnn_channels=[512, 512, 512, 512, 1500],
416
+ tdnn_kernel_sizes=[5, 3, 3, 1, 1],
417
+ tdnn_dilations=[1, 2, 3, 1, 1],
418
+ hidden_size=512,
419
+ ) -> None:
420
+ super(XvectorEmbedder, self).__init__()
421
+ self.activation = activation
422
+ self.blocks = nn.ModuleList()
423
+ for block_index in range(tdnn_blocks):
424
+ out_channels = tdnn_channels[block_index]
425
+ tdnn = TdnnLayer(
426
+ in_channels,
427
+ out_channels,
428
+ kernel_size=tdnn_kernel_sizes[block_index],
429
+ dilation=tdnn_dilations[block_index],
430
+ activation=activation,
431
+ )
432
+ self.blocks.append(tdnn)
433
+ in_channels = tdnn_channels[block_index]
434
+ self.pooler = StatisticsPooling()
435
+ self.fc = nn.Linear(2 * out_channels, hidden_size)
436
+
437
+ def forward(self, input_values, lengths=None):
438
+ x = input_values
439
+ x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T)
440
+ for block in self.blocks:
441
+ x = block(x)
442
+ last_hidden_state = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
443
+ pooler_output = self.pooler(last_hidden_state, lengths)
444
+ pooler_output = self.fc(pooler_output.squeeze(1))
445
+ return ModelOutput(
446
+ last_hidden_state=last_hidden_state,
447
+ pooler_output=pooler_output
448
+ )
449
+
450
+
451
+ class CosineSimilarityHead(torch.nn.Module):
452
+ """
453
+ This class implements the cosine similarity on the top of features.
454
+ """
455
+ def __init__(
456
+ self,
457
+ in_channels,
458
+ lin_blocks=0,
459
+ hidden_size=192,
460
+ num_classes=1211,
461
+ ):
462
+ super().__init__()
463
+ self.blocks = nn.ModuleList()
464
+
465
+ for block_index in range(lin_blocks):
466
+ self.blocks.extend(
467
+ [
468
+ nn.BatchNorm1d(num_features=in_channels),
469
+ nn.Linear(in_features=in_channels, out_features=hidden_size),
470
+ ]
471
+ )
472
+ in_channels = hidden_size
473
+
474
+ # Final Layer
475
+ self.weight = nn.Parameter(
476
+ torch.FloatTensor(num_classes, in_channels)
477
+ )
478
+ nn.init.xavier_uniform_(self.weight)
479
+
480
+ def forward(self, x):
481
+ """Returns the output probabilities over speakers.
482
+ Arguments
483
+ ---------
484
+ x : torch.Tensor
485
+ Torch tensor.
486
+ """
487
+ for layer in self.blocks:
488
+ x = layer(x)
489
+
490
+ # Need to be normalized
491
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
492
+ return x
493
+
494
+
495
+ class XvectorPreTrainedModel(PreTrainedModel):
496
+
497
+ config_class = XvectorConfig
498
+ base_model_prefix = "xvector"
499
+ main_input_name = "input_values"
500
+ supports_gradient_checkpointing = True
501
+
502
+ def _init_weights(self, module):
503
+ """Initialize the weights"""
504
+ if isinstance(module, nn.Linear):
505
+ # Slightly different from the TF version which uses truncated_normal for initialization
506
+ # cf https://github.com/pytorch/pytorch/pull/5617
507
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
508
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
509
+ module.bias.data.zero_()
510
+ module.weight.data.fill_(1.0)
511
+ elif isinstance(module, nn.Conv1d):
512
+ nn.init.kaiming_normal_(module.weight.data)
513
+
514
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
515
+ module.bias.data.zero_()
516
+
517
+
518
+ class XvectorModel(XvectorPreTrainedModel):
519
+
520
+ def __init__(self, config):
521
+ super().__init__(config)
522
+ self.compute_features = Fbank(
523
+ n_mels=config.n_mels,
524
+ sample_rate=config.sample_rate,
525
+ win_length=config.win_length,
526
+ hop_length=config.hop_length,
527
+ )
528
+ self.mean_var_norm = InputNormalization(
529
+ mean_norm=config.mean_norm,
530
+ std_norm=config.std_norm,
531
+ norm_type=config.norm_type
532
+ )
533
+ self.embedding_model = XvectorEmbedder(
534
+ in_channels=config.n_mels,
535
+ activation=nn.LeakyReLU,
536
+ tdnn_blocks=config.tdnn_blocks,
537
+ tdnn_channels=config.tdnn_channels,
538
+ tdnn_kernel_sizes=config.tdnn_kernel_sizes,
539
+ tdnn_dilations=config.tdnn_dilations,
540
+ hidden_size=config.hidden_size,
541
+ )
542
+
543
+ def forward(self, input_values, lengths=None):
544
+ x = input_values
545
+ # if attention_mask is None:
546
+ # attention_mask = torch.ones_like(input_values, device=x.device)
547
+ x = self.compute_features(x)
548
+ x = self.mean_var_norm(x, lengths)
549
+ output = self.embedding_model(x, lengths)
550
+ return output