yangwang825 commited on
Commit
a09ce64
1 Parent(s): ca3e491

Create modeling_svector.py

Browse files
Files changed (1) hide show
  1. modeling_svector.py +548 -0
modeling_svector.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import typing as tp
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.utils import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ from .helpers_svector import Fbank
11
+ from .configuration_svector import SvectorConfig
12
+
13
+
14
+ class InputNormalization(nn.Module):
15
+
16
+ spk_dict_mean: tp.Dict[int, torch.Tensor]
17
+ spk_dict_std: tp.Dict[int, torch.Tensor]
18
+ spk_dict_count: tp.Dict[int, int]
19
+
20
+ def __init__(
21
+ self,
22
+ mean_norm=True,
23
+ std_norm=True,
24
+ norm_type="global",
25
+ avg_factor=None,
26
+ requires_grad=False,
27
+ update_until_epoch=3,
28
+ ):
29
+ super().__init__()
30
+ self.mean_norm = mean_norm
31
+ self.std_norm = std_norm
32
+ self.norm_type = norm_type
33
+ self.avg_factor = avg_factor
34
+ self.requires_grad = requires_grad
35
+ self.glob_mean = torch.tensor([0])
36
+ self.glob_std = torch.tensor([0])
37
+ self.spk_dict_mean = {}
38
+ self.spk_dict_std = {}
39
+ self.spk_dict_count = {}
40
+ self.weight = 1.0
41
+ self.count = 0
42
+ self.eps = 1e-10
43
+ self.update_until_epoch = update_until_epoch
44
+
45
+ def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
46
+ """Returns the tensor with the surrounding context.
47
+
48
+ Arguments
49
+ ---------
50
+ x : tensor
51
+ A batch of tensors.
52
+ lengths : tensor
53
+ A batch of tensors containing the relative length of each
54
+ sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
55
+ computing stats on zero-padded steps.
56
+ spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
57
+ It is used to perform per-speaker normalization when
58
+ norm_type='speaker'.
59
+ """
60
+ x = input_values
61
+ N_batches = x.shape[0]
62
+
63
+ current_means = []
64
+ current_stds = []
65
+
66
+ for snt_id in range(N_batches):
67
+ # Avoiding padded time steps
68
+ # lengths = torch.sum(attention_mask, dim=1)
69
+ # relative_lengths = lengths / torch.max(lengths)
70
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
71
+ actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
72
+
73
+ # computing statistics
74
+ current_mean, current_std = self._compute_current_stats(
75
+ x[snt_id, 0:actual_size, ...]
76
+ )
77
+
78
+ current_means.append(current_mean)
79
+ current_stds.append(current_std)
80
+
81
+ if self.norm_type == "sentence":
82
+ x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
83
+
84
+ if self.norm_type == "speaker":
85
+ spk_id = int(spk_ids[snt_id][0])
86
+
87
+ if self.training:
88
+ if spk_id not in self.spk_dict_mean:
89
+ # Initialization of the dictionary
90
+ self.spk_dict_mean[spk_id] = current_mean
91
+ self.spk_dict_std[spk_id] = current_std
92
+ self.spk_dict_count[spk_id] = 1
93
+
94
+ else:
95
+ self.spk_dict_count[spk_id] = (
96
+ self.spk_dict_count[spk_id] + 1
97
+ )
98
+
99
+ if self.avg_factor is None:
100
+ self.weight = 1 / self.spk_dict_count[spk_id]
101
+ else:
102
+ self.weight = self.avg_factor
103
+
104
+ self.spk_dict_mean[spk_id] = (
105
+ (1 - self.weight) * self.spk_dict_mean[spk_id]
106
+ + self.weight * current_mean
107
+ )
108
+ self.spk_dict_std[spk_id] = (
109
+ (1 - self.weight) * self.spk_dict_std[spk_id]
110
+ + self.weight * current_std
111
+ )
112
+
113
+ self.spk_dict_mean[spk_id].detach()
114
+ self.spk_dict_std[spk_id].detach()
115
+
116
+ speaker_mean = self.spk_dict_mean[spk_id].data
117
+ speaker_std = self.spk_dict_std[spk_id].data
118
+ else:
119
+ if spk_id in self.spk_dict_mean:
120
+ speaker_mean = self.spk_dict_mean[spk_id].data
121
+ speaker_std = self.spk_dict_std[spk_id].data
122
+ else:
123
+ speaker_mean = current_mean.data
124
+ speaker_std = current_std.data
125
+
126
+ x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
127
+
128
+ if self.norm_type == "batch" or self.norm_type == "global":
129
+ current_mean = torch.mean(torch.stack(current_means), dim=0)
130
+ current_std = torch.mean(torch.stack(current_stds), dim=0)
131
+
132
+ if self.norm_type == "batch":
133
+ x = (x - current_mean.data) / (current_std.data)
134
+
135
+ if self.norm_type == "global":
136
+ if self.training:
137
+ if self.count == 0:
138
+ self.glob_mean = current_mean
139
+ self.glob_std = current_std
140
+
141
+ elif epoch < self.update_until_epoch:
142
+ if self.avg_factor is None:
143
+ self.weight = 1 / (self.count + 1)
144
+ else:
145
+ self.weight = self.avg_factor
146
+
147
+ self.glob_mean = (
148
+ 1 - self.weight
149
+ ) * self.glob_mean + self.weight * current_mean
150
+
151
+ self.glob_std = (
152
+ 1 - self.weight
153
+ ) * self.glob_std + self.weight * current_std
154
+
155
+ self.glob_mean.detach()
156
+ self.glob_std.detach()
157
+
158
+ self.count = self.count + 1
159
+
160
+ x = (x - self.glob_mean.data) / (self.glob_std.data)
161
+
162
+ return x
163
+
164
+ def _compute_current_stats(self, x):
165
+ """Returns the tensor with the surrounding context.
166
+
167
+ Arguments
168
+ ---------
169
+ x : tensor
170
+ A batch of tensors.
171
+ """
172
+ # Compute current mean
173
+ if self.mean_norm:
174
+ current_mean = torch.mean(x, dim=0).detach().data
175
+ else:
176
+ current_mean = torch.tensor([0.0], device=x.device)
177
+
178
+ # Compute current std
179
+ if self.std_norm:
180
+ current_std = torch.std(x, dim=0).detach().data
181
+ else:
182
+ current_std = torch.tensor([1.0], device=x.device)
183
+
184
+ # Improving numerical stability of std
185
+ current_std = torch.max(
186
+ current_std, self.eps * torch.ones_like(current_std)
187
+ )
188
+
189
+ return current_mean, current_std
190
+
191
+ def _statistics_dict(self):
192
+ """Fills the dictionary containing the normalization statistics."""
193
+ state = {}
194
+ state["count"] = self.count
195
+ state["glob_mean"] = self.glob_mean
196
+ state["glob_std"] = self.glob_std
197
+ state["spk_dict_mean"] = self.spk_dict_mean
198
+ state["spk_dict_std"] = self.spk_dict_std
199
+ state["spk_dict_count"] = self.spk_dict_count
200
+
201
+ return state
202
+
203
+ def _load_statistics_dict(self, state):
204
+ """Loads the dictionary containing the statistics.
205
+
206
+ Arguments
207
+ ---------
208
+ state : dict
209
+ A dictionary containing the normalization statistics.
210
+ """
211
+ self.count = state["count"]
212
+ if isinstance(state["glob_mean"], int):
213
+ self.glob_mean = state["glob_mean"]
214
+ self.glob_std = state["glob_std"]
215
+ else:
216
+ self.glob_mean = state["glob_mean"] # .to(self.device_inp)
217
+ self.glob_std = state["glob_std"] # .to(self.device_inp)
218
+
219
+ # Loading the spk_dict_mean in the right device
220
+ self.spk_dict_mean = {}
221
+ for spk in state["spk_dict_mean"]:
222
+ self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
223
+ self.device_inp
224
+ )
225
+
226
+ # Loading the spk_dict_std in the right device
227
+ self.spk_dict_std = {}
228
+ for spk in state["spk_dict_std"]:
229
+ self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
230
+ self.device_inp
231
+ )
232
+
233
+ self.spk_dict_count = state["spk_dict_count"]
234
+
235
+ return state
236
+
237
+ def to(self, device):
238
+ """Puts the needed tensors in the right device."""
239
+ self = super(InputNormalization, self).to(device)
240
+ self.glob_mean = self.glob_mean.to(device)
241
+ self.glob_std = self.glob_std.to(device)
242
+ for spk in self.spk_dict_mean:
243
+ self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
244
+ self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
245
+ return self
246
+
247
+
248
+ class TdnnLayer(nn.Module):
249
+
250
+ def __init__(
251
+ self,
252
+ in_channels,
253
+ out_channels,
254
+ kernel_size,
255
+ dilation=1,
256
+ stride=1,
257
+ padding=0,
258
+ padding_mode="reflect",
259
+ activation=torch.nn.LeakyReLU,
260
+ ):
261
+ super(TdnnLayer, self).__init__()
262
+ self.in_channels = in_channels
263
+ self.out_channels = out_channels
264
+ self.kernel_size = kernel_size
265
+ self.dilation = dilation
266
+ self.stride = stride
267
+ self.padding = padding
268
+ self.padding_mode = padding_mode
269
+ self.activation = activation
270
+
271
+ self.conv = nn.Conv1d(
272
+ self.in_channels,
273
+ self.out_channels,
274
+ self.kernel_size,
275
+ dilation=self.dilation,
276
+ padding=self.padding
277
+ )
278
+
279
+ # Set Affine=false to be compatible with the original kaldi version
280
+ # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
281
+ self.norm = nn.BatchNorm1d(out_channels, affine=False)
282
+
283
+ def forward(self, x):
284
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
285
+ out = self.conv(x)
286
+ out = self.activation()(out)
287
+ out = self.norm(out)
288
+ return out
289
+
290
+ def _manage_padding(
291
+ self, x, kernel_size: int, dilation: int, stride: int,
292
+ ):
293
+ # Detecting input shape
294
+ L_in = self.in_channels
295
+
296
+ # Time padding
297
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
298
+
299
+ # Applying padding
300
+ x = F.pad(x, padding, mode=self.padding_mode)
301
+
302
+ return x
303
+
304
+
305
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
306
+ """This function computes the number of elements to add for zero-padding.
307
+
308
+ Arguments
309
+ ---------
310
+ L_in : int
311
+ stride: int
312
+ kernel_size : int
313
+ dilation : int
314
+ """
315
+ if stride > 1:
316
+ padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
317
+
318
+ else:
319
+ L_out = (
320
+ math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
321
+ )
322
+ padding = [
323
+ math.floor((L_in - L_out) / 2),
324
+ math.floor((L_in - L_out) / 2),
325
+ ]
326
+ return padding
327
+
328
+
329
+ class StatisticsPooling(nn.Module):
330
+
331
+ def __init__(self, return_mean=True, return_std=True):
332
+ super().__init__()
333
+
334
+ # Small value for GaussNoise
335
+ self.eps = 1e-5
336
+ self.return_mean = return_mean
337
+ self.return_std = return_std
338
+ if not (self.return_mean or self.return_std):
339
+ raise ValueError(
340
+ "both of statistics are equal to False \n"
341
+ "consider enabling mean and/or std statistic pooling"
342
+ )
343
+
344
+ def forward(self, input_values, lengths=None):
345
+ """Calculates mean and std for a batch (input tensor).
346
+
347
+ Arguments
348
+ ---------
349
+ x : torch.Tensor
350
+ It represents a tensor for a mini-batch.
351
+ """
352
+ x = input_values
353
+ if lengths is None:
354
+ if self.return_mean:
355
+ mean = x.mean(dim=1)
356
+ if self.return_std:
357
+ std = x.std(dim=1)
358
+ else:
359
+ mean = []
360
+ std = []
361
+ for snt_id in range(x.shape[0]):
362
+ # Avoiding padded time steps
363
+ # lengths = torch.sum(attention_mask, dim=1)
364
+ # relative_lengths = lengths / torch.max(lengths)
365
+ # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
366
+ actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
367
+
368
+ # computing statistics
369
+ if self.return_mean:
370
+ mean.append(
371
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
372
+ )
373
+ if self.return_std:
374
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
375
+ if self.return_mean:
376
+ mean = torch.stack(mean)
377
+ if self.return_std:
378
+ std = torch.stack(std)
379
+
380
+ if self.return_mean:
381
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
382
+ gnoise = gnoise
383
+ mean += gnoise
384
+ if self.return_std:
385
+ std = std + self.eps
386
+
387
+ # Append mean and std of the batch
388
+ if self.return_mean and self.return_std:
389
+ pooled_stats = torch.cat((mean, std), dim=1)
390
+ pooled_stats = pooled_stats.unsqueeze(1)
391
+ elif self.return_mean:
392
+ pooled_stats = mean.unsqueeze(1)
393
+ elif self.return_std:
394
+ pooled_stats = std.unsqueeze(1)
395
+
396
+ return pooled_stats
397
+
398
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
399
+ """Returns a tensor of epsilon Gaussian noise.
400
+
401
+ Arguments
402
+ ---------
403
+ shape_of_tensor : tensor
404
+ It represents the size of tensor for generating Gaussian noise.
405
+ """
406
+ gnoise = torch.randn(shape_of_tensor, device=device)
407
+ gnoise -= torch.min(gnoise)
408
+ gnoise /= torch.max(gnoise)
409
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
410
+
411
+ return gnoise
412
+
413
+
414
+ class SvectorEmbedder(nn.Module):
415
+
416
+ def __init__(
417
+ self,
418
+ in_channels=40,
419
+ num_heads=8,
420
+ num_layers=5,
421
+ activation=torch.nn.LeakyReLU,
422
+ hidden_size=512,
423
+ ) -> None:
424
+ super(SvectorEmbedder, self).__init__()
425
+ self.tdnn = TdnnLayer(
426
+ in_channels=in_channels,
427
+ out_channels=hidden_size,
428
+ kernel_size=1,
429
+ dilation=1,
430
+ activation=activation,
431
+ )
432
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads)
433
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
434
+ self.pooler = StatisticsPooling()
435
+ self.fc = nn.Linear(2 * hidden_size, hidden_size)
436
+
437
+ def forward(self, input_values, lengths=None):
438
+ """
439
+ x: [B, T, F]
440
+ """
441
+ x = input_values
442
+ x = self.tdnn(x.transpose(1, 2))
443
+ last_hidden_state = self.transformer_encoder(x.transpose(1, 2))
444
+ pooler_output = self.pooler(last_hidden_state, lengths)
445
+ pooler_output = self.fc(pooler_output.squeeze(1))
446
+ return ModelOutput(
447
+ last_hidden_state=last_hidden_state,
448
+ pooler_output=pooler_output
449
+ )
450
+
451
+
452
+ class CosineSimilarityHead(torch.nn.Module):
453
+ """
454
+ This class implements the cosine similarity on the top of features.
455
+ """
456
+ def __init__(
457
+ self,
458
+ in_channels,
459
+ lin_blocks=0,
460
+ hidden_size=192,
461
+ num_classes=1211,
462
+ ):
463
+ super().__init__()
464
+ self.blocks = nn.ModuleList()
465
+
466
+ for block_index in range(lin_blocks):
467
+ self.blocks.extend(
468
+ [
469
+ nn.BatchNorm1d(num_features=in_channels),
470
+ nn.Linear(in_features=in_channels, out_features=hidden_size),
471
+ ]
472
+ )
473
+ in_channels = hidden_size
474
+
475
+ # Final Layer
476
+ self.weight = nn.Parameter(
477
+ torch.FloatTensor(num_classes, in_channels)
478
+ )
479
+ nn.init.xavier_uniform_(self.weight)
480
+
481
+ def forward(self, x):
482
+ """Returns the output probabilities over speakers.
483
+
484
+ Arguments
485
+ ---------
486
+ x : torch.Tensor
487
+ Torch tensor.
488
+ """
489
+ for layer in self.blocks:
490
+ x = layer(x)
491
+
492
+ # Need to be normalized
493
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
494
+ return x
495
+
496
+
497
+ class SvectorPreTrainedModel(PreTrainedModel):
498
+
499
+ config_class = SvectorConfig
500
+ base_model_prefix = "svector"
501
+ main_input_name = "input_values"
502
+ supports_gradient_checkpointing = True
503
+
504
+ def _init_weights(self, module):
505
+ """Initialize the weights"""
506
+ if isinstance(module, nn.Linear):
507
+ # Slightly different from the TF version which uses truncated_normal for initialization
508
+ # cf https://github.com/pytorch/pytorch/pull/5617
509
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
510
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
511
+ module.bias.data.zero_()
512
+ module.weight.data.fill_(1.0)
513
+ elif isinstance(module, nn.Conv1d):
514
+ nn.init.kaiming_normal_(module.weight.data)
515
+
516
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
517
+ module.bias.data.zero_()
518
+
519
+
520
+ class SvectorModel(SvectorPreTrainedModel):
521
+
522
+ def __init__(self, config):
523
+ super().__init__(config)
524
+ self.compute_features = Fbank(
525
+ n_mels=config.n_mels,
526
+ sample_rate=config.sample_rate,
527
+ win_length=config.win_length,
528
+ hop_length=config.hop_length,
529
+ )
530
+ self.mean_var_norm = InputNormalization(
531
+ mean_norm=config.mean_norm,
532
+ std_norm=config.std_norm,
533
+ norm_type=config.norm_type
534
+ )
535
+ self.embedding_model = SvectorEmbedder(
536
+ in_channels=config.n_mels,
537
+ activation=nn.LeakyReLU,
538
+ num_heads=config.num_heads,
539
+ num_layers=config.num_layers,
540
+ hidden_size=config.hidden_size,
541
+ )
542
+
543
+ def forward(self, input_values, lengths=None):
544
+ x = input_values
545
+ x = self.compute_features(x)
546
+ x = self.mean_var_norm(x, lengths)
547
+ output = self.embedding_model(x, lengths)
548
+ return output