yangwang825 commited on
Commit
92a6cad
1 Parent(s): b3c7a44

Create modeling_xvector.py

Browse files
Files changed (1) hide show
  1. modeling_xvector.py +557 -0
modeling_xvector.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import typing as tp
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.utils import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ from .helpers_xvector import Fbank
11
+ from .configuration_xvector import XvectorConfig
12
+
13
+
14
+ class InputNormalization(nn.Module):
15
+
16
+ spk_dict_mean: tp.Dict[int, torch.Tensor]
17
+ spk_dict_std: tp.Dict[int, torch.Tensor]
18
+ spk_dict_count: tp.Dict[int, int]
19
+
20
+ def __init__(
21
+ self,
22
+ mean_norm=True,
23
+ std_norm=True,
24
+ norm_type="global",
25
+ avg_factor=None,
26
+ requires_grad=False,
27
+ update_until_epoch=3,
28
+ ):
29
+ super().__init__()
30
+ self.mean_norm = mean_norm
31
+ self.std_norm = std_norm
32
+ self.norm_type = norm_type
33
+ self.avg_factor = avg_factor
34
+ self.requires_grad = requires_grad
35
+ self.glob_mean = torch.tensor([0])
36
+ self.glob_std = torch.tensor([0])
37
+ self.spk_dict_mean = {}
38
+ self.spk_dict_std = {}
39
+ self.spk_dict_count = {}
40
+ self.weight = 1.0
41
+ self.count = 0
42
+ self.eps = 1e-10
43
+ self.update_until_epoch = update_until_epoch
44
+
45
+ def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
46
+ """Returns the tensor with the surrounding context.
47
+
48
+ Arguments
49
+ ---------
50
+ x : tensor
51
+ A batch of tensors.
52
+ lengths : tensor
53
+ A batch of tensors containing the relative length of each
54
+ sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
55
+ computing stats on zero-padded steps.
56
+ spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
57
+ It is used to perform per-speaker normalization when
58
+ norm_type='speaker'.
59
+ """
60
+ x = input_values
61
+ N_batches = x.shape[0]
62
+
63
+ current_means = []
64
+ current_stds = []
65
+
66
+ for snt_id in range(N_batches):
67
+ # Avoiding padded time steps
68
+ # lengths = torch.sum(attention_mask, dim=1)
69
+ # relative_lengths = lengths / torch.max(lengths)
70
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
71
+ actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
72
+
73
+ # computing statistics
74
+ current_mean, current_std = self._compute_current_stats(
75
+ x[snt_id, 0:actual_size, ...]
76
+ )
77
+
78
+ current_means.append(current_mean)
79
+ current_stds.append(current_std)
80
+
81
+ if self.norm_type == "sentence":
82
+ x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
83
+
84
+ if self.norm_type == "speaker":
85
+ spk_id = int(spk_ids[snt_id][0])
86
+
87
+ if self.training:
88
+ if spk_id not in self.spk_dict_mean:
89
+ # Initialization of the dictionary
90
+ self.spk_dict_mean[spk_id] = current_mean
91
+ self.spk_dict_std[spk_id] = current_std
92
+ self.spk_dict_count[spk_id] = 1
93
+
94
+ else:
95
+ self.spk_dict_count[spk_id] = (
96
+ self.spk_dict_count[spk_id] + 1
97
+ )
98
+
99
+ if self.avg_factor is None:
100
+ self.weight = 1 / self.spk_dict_count[spk_id]
101
+ else:
102
+ self.weight = self.avg_factor
103
+
104
+ self.spk_dict_mean[spk_id] = (
105
+ (1 - self.weight) * self.spk_dict_mean[spk_id]
106
+ + self.weight * current_mean
107
+ )
108
+ self.spk_dict_std[spk_id] = (
109
+ (1 - self.weight) * self.spk_dict_std[spk_id]
110
+ + self.weight * current_std
111
+ )
112
+
113
+ self.spk_dict_mean[spk_id].detach()
114
+ self.spk_dict_std[spk_id].detach()
115
+
116
+ speaker_mean = self.spk_dict_mean[spk_id].data
117
+ speaker_std = self.spk_dict_std[spk_id].data
118
+ else:
119
+ if spk_id in self.spk_dict_mean:
120
+ speaker_mean = self.spk_dict_mean[spk_id].data
121
+ speaker_std = self.spk_dict_std[spk_id].data
122
+ else:
123
+ speaker_mean = current_mean.data
124
+ speaker_std = current_std.data
125
+
126
+ x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
127
+
128
+ if self.norm_type == "batch" or self.norm_type == "global":
129
+ current_mean = torch.mean(torch.stack(current_means), dim=0)
130
+ current_std = torch.mean(torch.stack(current_stds), dim=0)
131
+
132
+ if self.norm_type == "batch":
133
+ x = (x - current_mean.data) / (current_std.data)
134
+
135
+ if self.norm_type == "global":
136
+ if self.training:
137
+ if self.count == 0:
138
+ self.glob_mean = current_mean
139
+ self.glob_std = current_std
140
+
141
+ elif epoch < self.update_until_epoch:
142
+ if self.avg_factor is None:
143
+ self.weight = 1 / (self.count + 1)
144
+ else:
145
+ self.weight = self.avg_factor
146
+
147
+ self.glob_mean = (
148
+ 1 - self.weight
149
+ ) * self.glob_mean + self.weight * current_mean
150
+
151
+ self.glob_std = (
152
+ 1 - self.weight
153
+ ) * self.glob_std + self.weight * current_std
154
+
155
+ self.glob_mean.detach()
156
+ self.glob_std.detach()
157
+
158
+ self.count = self.count + 1
159
+
160
+ x = (x - self.glob_mean.data) / (self.glob_std.data)
161
+
162
+ return x
163
+
164
+ def _compute_current_stats(self, x):
165
+ """Returns the tensor with the surrounding context.
166
+
167
+ Arguments
168
+ ---------
169
+ x : tensor
170
+ A batch of tensors.
171
+ """
172
+ # Compute current mean
173
+ if self.mean_norm:
174
+ current_mean = torch.mean(x, dim=0).detach().data
175
+ else:
176
+ current_mean = torch.tensor([0.0], device=x.device)
177
+
178
+ # Compute current std
179
+ if self.std_norm:
180
+ current_std = torch.std(x, dim=0).detach().data
181
+ else:
182
+ current_std = torch.tensor([1.0], device=x.device)
183
+
184
+ # Improving numerical stability of std
185
+ current_std = torch.max(
186
+ current_std, self.eps * torch.ones_like(current_std)
187
+ )
188
+
189
+ return current_mean, current_std
190
+
191
+ def _statistics_dict(self):
192
+ """Fills the dictionary containing the normalization statistics."""
193
+ state = {}
194
+ state["count"] = self.count
195
+ state["glob_mean"] = self.glob_mean
196
+ state["glob_std"] = self.glob_std
197
+ state["spk_dict_mean"] = self.spk_dict_mean
198
+ state["spk_dict_std"] = self.spk_dict_std
199
+ state["spk_dict_count"] = self.spk_dict_count
200
+
201
+ return state
202
+
203
+ def _load_statistics_dict(self, state):
204
+ """Loads the dictionary containing the statistics.
205
+
206
+ Arguments
207
+ ---------
208
+ state : dict
209
+ A dictionary containing the normalization statistics.
210
+ """
211
+ self.count = state["count"]
212
+ if isinstance(state["glob_mean"], int):
213
+ self.glob_mean = state["glob_mean"]
214
+ self.glob_std = state["glob_std"]
215
+ else:
216
+ self.glob_mean = state["glob_mean"] # .to(self.device_inp)
217
+ self.glob_std = state["glob_std"] # .to(self.device_inp)
218
+
219
+ # Loading the spk_dict_mean in the right device
220
+ self.spk_dict_mean = {}
221
+ for spk in state["spk_dict_mean"]:
222
+ self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
223
+ self.device_inp
224
+ )
225
+
226
+ # Loading the spk_dict_std in the right device
227
+ self.spk_dict_std = {}
228
+ for spk in state["spk_dict_std"]:
229
+ self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
230
+ self.device_inp
231
+ )
232
+
233
+ self.spk_dict_count = state["spk_dict_count"]
234
+
235
+ return state
236
+
237
+ def to(self, device):
238
+ """Puts the needed tensors in the right device."""
239
+ self = super(InputNormalization, self).to(device)
240
+ self.glob_mean = self.glob_mean.to(device)
241
+ self.glob_std = self.glob_std.to(device)
242
+ for spk in self.spk_dict_mean:
243
+ self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
244
+ self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
245
+ return self
246
+
247
+
248
+ class TdnnLayer(nn.Module):
249
+
250
+ def __init__(
251
+ self,
252
+ in_channels,
253
+ out_channels,
254
+ kernel_size,
255
+ dilation=1,
256
+ stride=1,
257
+ padding=0,
258
+ padding_mode="reflect",
259
+ activation=torch.nn.LeakyReLU,
260
+ ):
261
+ super(TdnnLayer, self).__init__()
262
+ self.in_channels = in_channels
263
+ self.out_channels = out_channels
264
+ self.kernel_size = kernel_size
265
+ self.dilation = dilation
266
+ self.stride = stride
267
+ self.padding = padding
268
+ self.padding_mode = padding_mode
269
+ self.activation = activation
270
+
271
+ self.conv = nn.Conv1d(
272
+ self.in_channels,
273
+ self.out_channels,
274
+ self.kernel_size,
275
+ dilation=self.dilation,
276
+ padding=self.padding
277
+ )
278
+
279
+ # Set Affine=false to be compatible with the original kaldi version
280
+ # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
281
+ self.norm = nn.BatchNorm1d(out_channels, affine=False)
282
+
283
+ def forward(self, x):
284
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
285
+ out = self.conv(x)
286
+ out = self.activation()(out)
287
+ out = self.norm(out)
288
+ return out
289
+
290
+ def _manage_padding(
291
+ self, x, kernel_size: int, dilation: int, stride: int,
292
+ ):
293
+ # Detecting input shape
294
+ L_in = self.in_channels
295
+
296
+ # Time padding
297
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
298
+
299
+ # Applying padding
300
+ x = F.pad(x, padding, mode=self.padding_mode)
301
+
302
+ return x
303
+
304
+
305
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
306
+ """This function computes the number of elements to add for zero-padding.
307
+
308
+ Arguments
309
+ ---------
310
+ L_in : int
311
+ stride: int
312
+ kernel_size : int
313
+ dilation : int
314
+ """
315
+ if stride > 1:
316
+ padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
317
+
318
+ else:
319
+ L_out = (
320
+ math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
321
+ )
322
+ padding = [
323
+ math.floor((L_in - L_out) / 2),
324
+ math.floor((L_in - L_out) / 2),
325
+ ]
326
+ return padding
327
+
328
+
329
+ class StatisticsPooling(nn.Module):
330
+
331
+ def __init__(self, return_mean=True, return_std=True):
332
+ super().__init__()
333
+
334
+ # Small value for GaussNoise
335
+ self.eps = 1e-5
336
+ self.return_mean = return_mean
337
+ self.return_std = return_std
338
+ if not (self.return_mean or self.return_std):
339
+ raise ValueError(
340
+ "both of statistics are equal to False \n"
341
+ "consider enabling mean and/or std statistic pooling"
342
+ )
343
+
344
+ def forward(self, input_values, lengths=None):
345
+ """Calculates mean and std for a batch (input tensor).
346
+
347
+ Arguments
348
+ ---------
349
+ x : torch.Tensor
350
+ It represents a tensor for a mini-batch.
351
+ """
352
+ x = input_values
353
+ if lengths is None:
354
+ if self.return_mean:
355
+ mean = x.mean(dim=1)
356
+ if self.return_std:
357
+ std = x.std(dim=1)
358
+ else:
359
+ mean = []
360
+ std = []
361
+ for snt_id in range(x.shape[0]):
362
+ # Avoiding padded time steps
363
+ # lengths = torch.sum(attention_mask, dim=1)
364
+ # relative_lengths = lengths / torch.max(lengths)
365
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
366
+ actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
367
+
368
+ # computing statistics
369
+ if self.return_mean:
370
+ mean.append(
371
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
372
+ )
373
+ if self.return_std:
374
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
375
+ if self.return_mean:
376
+ mean = torch.stack(mean)
377
+ if self.return_std:
378
+ std = torch.stack(std)
379
+
380
+ if self.return_mean:
381
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
382
+ gnoise = gnoise
383
+ mean += gnoise
384
+ if self.return_std:
385
+ std = std + self.eps
386
+
387
+ # Append mean and std of the batch
388
+ if self.return_mean and self.return_std:
389
+ pooled_stats = torch.cat((mean, std), dim=1)
390
+ pooled_stats = pooled_stats.unsqueeze(1)
391
+ elif self.return_mean:
392
+ pooled_stats = mean.unsqueeze(1)
393
+ elif self.return_std:
394
+ pooled_stats = std.unsqueeze(1)
395
+
396
+ return pooled_stats
397
+
398
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
399
+ """Returns a tensor of epsilon Gaussian noise.
400
+
401
+ Arguments
402
+ ---------
403
+ shape_of_tensor : tensor
404
+ It represents the size of tensor for generating Gaussian noise.
405
+ """
406
+ gnoise = torch.randn(shape_of_tensor, device=device)
407
+ gnoise -= torch.min(gnoise)
408
+ gnoise /= torch.max(gnoise)
409
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
410
+
411
+ return gnoise
412
+
413
+
414
+ class XvectorEmbedder(nn.Module):
415
+
416
+ def __init__(
417
+ self,
418
+ in_channels=40,
419
+ activation=torch.nn.LeakyReLU,
420
+ tdnn_blocks=5,
421
+ tdnn_channels=[512, 512, 512, 512, 1500],
422
+ tdnn_kernel_sizes=[5, 3, 3, 1, 1],
423
+ tdnn_dilations=[1, 2, 3, 1, 1],
424
+ hidden_size=512,
425
+ ) -> None:
426
+ super(XvectorEmbedder, self).__init__()
427
+ self.activation = activation
428
+ self.blocks = nn.ModuleList()
429
+ for block_index in range(tdnn_blocks):
430
+ out_channels = tdnn_channels[block_index]
431
+ tdnn = TdnnLayer(
432
+ in_channels,
433
+ out_channels,
434
+ kernel_size=tdnn_kernel_sizes[block_index],
435
+ dilation=tdnn_dilations[block_index],
436
+ activation=activation,
437
+ )
438
+ self.blocks.append(tdnn)
439
+ in_channels = tdnn_channels[block_index]
440
+ self.pooler = StatisticsPooling()
441
+ self.fc = nn.Linear(2 * out_channels, hidden_size)
442
+
443
+ def forward(self, input_values, lengths=None):
444
+ x = input_values
445
+ x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T)
446
+ for block in self.blocks:
447
+ x = block(x)
448
+ last_hidden_state = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
449
+ pooler_output = self.pooler(last_hidden_state, lengths)
450
+ pooler_output = self.fc(pooler_output.squeeze(1))
451
+ return ModelOutput(
452
+ last_hidden_state=last_hidden_state,
453
+ pooler_output=pooler_output
454
+ )
455
+
456
+
457
+ class CosineSimilarityHead(torch.nn.Module):
458
+ """
459
+ This class implements the cosine similarity on the top of features.
460
+ """
461
+ def __init__(
462
+ self,
463
+ in_channels,
464
+ lin_blocks=0,
465
+ hidden_size=192,
466
+ num_classes=1211,
467
+ ):
468
+ super().__init__()
469
+ self.blocks = nn.ModuleList()
470
+
471
+ for block_index in range(lin_blocks):
472
+ self.blocks.extend(
473
+ [
474
+ nn.BatchNorm1d(num_features=in_channels),
475
+ nn.Linear(in_features=in_channels, out_features=hidden_size),
476
+ ]
477
+ )
478
+ in_channels = hidden_size
479
+
480
+ # Final Layer
481
+ self.weight = nn.Parameter(
482
+ torch.FloatTensor(num_classes, in_channels)
483
+ )
484
+ nn.init.xavier_uniform_(self.weight)
485
+
486
+ def forward(self, x):
487
+ """Returns the output probabilities over speakers.
488
+
489
+ Arguments
490
+ ---------
491
+ x : torch.Tensor
492
+ Torch tensor.
493
+ """
494
+ for layer in self.blocks:
495
+ x = layer(x)
496
+
497
+ # Need to be normalized
498
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
499
+ return x
500
+
501
+
502
+ class XvectorPreTrainedModel(PreTrainedModel):
503
+
504
+ config_class = XvectorConfig
505
+ base_model_prefix = "xvector"
506
+ main_input_name = "input_values"
507
+ supports_gradient_checkpointing = True
508
+
509
+ def _init_weights(self, module):
510
+ """Initialize the weights"""
511
+ if isinstance(module, nn.Linear):
512
+ # Slightly different from the TF version which uses truncated_normal for initialization
513
+ # cf https://github.com/pytorch/pytorch/pull/5617
514
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
515
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
516
+ module.bias.data.zero_()
517
+ module.weight.data.fill_(1.0)
518
+ elif isinstance(module, nn.Conv1d):
519
+ nn.init.kaiming_normal_(module.weight.data)
520
+
521
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
522
+ module.bias.data.zero_()
523
+
524
+
525
+ class XvectorModel(XvectorPreTrainedModel):
526
+
527
+ def __init__(self, config):
528
+ super().__init__(config)
529
+ self.compute_features = Fbank(
530
+ n_mels=config.n_mels,
531
+ sample_rate=config.sample_rate,
532
+ win_length=config.win_length,
533
+ hop_length=config.hop_length,
534
+ )
535
+ self.mean_var_norm = InputNormalization(
536
+ mean_norm=config.mean_norm,
537
+ std_norm=config.std_norm,
538
+ norm_type=config.norm_type
539
+ )
540
+ self.embedding_model = XvectorEmbedder(
541
+ in_channels=config.n_mels,
542
+ activation=nn.LeakyReLU,
543
+ tdnn_blocks=config.tdnn_blocks,
544
+ tdnn_channels=config.tdnn_channels,
545
+ tdnn_kernel_sizes=config.tdnn_kernel_sizes,
546
+ tdnn_dilations=config.tdnn_dilations,
547
+ hidden_size=config.hidden_size,
548
+ )
549
+
550
+ def forward(self, input_values, lengths=None):
551
+ x = input_values
552
+ # if attention_mask is None:
553
+ # attention_mask = torch.ones_like(input_values, device=x.device)
554
+ x = self.compute_features(x)
555
+ x = self.mean_var_norm(x, lengths)
556
+ output = self.embedding_model(x, lengths)
557
+ return output