yangwang825 commited on
Commit
b7b7a53
1 Parent(s): efdbdac

Create helpers_ecapa.py

Browse files
Files changed (1) hide show
  1. helpers_ecapa.py +719 -0
helpers_ecapa.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Deltas(torch.nn.Module):
7
+ """Computes delta coefficients (time derivatives).
8
+ Arguments
9
+ ---------
10
+ win_length : int
11
+ Length of the window used to compute the time derivatives.
12
+ Example
13
+ -------
14
+ >>> inputs = torch.randn([10, 101, 20])
15
+ >>> compute_deltas = Deltas(input_size=inputs.size(-1))
16
+ >>> features = compute_deltas(inputs)
17
+ >>> features.shape
18
+ torch.Size([10, 101, 20])
19
+ """
20
+
21
+ def __init__(
22
+ self, input_size, window_length=5,
23
+ ):
24
+ super().__init__()
25
+ self.n = (window_length - 1) // 2
26
+ self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
27
+
28
+ self.register_buffer(
29
+ "kernel",
30
+ torch.arange(-self.n, self.n + 1, dtype=torch.float32,).repeat(
31
+ input_size, 1, 1
32
+ ),
33
+ )
34
+
35
+ def forward(self, x):
36
+ """Returns the delta coefficients.
37
+ Arguments
38
+ ---------
39
+ x : tensor
40
+ A batch of tensors.
41
+ """
42
+ # Managing multi-channel deltas reshape tensor (batch*channel,time)
43
+ x = x.transpose(1, 2).transpose(2, -1)
44
+ or_shape = x.shape
45
+ if len(or_shape) == 4:
46
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
47
+
48
+ # Padding for time borders
49
+ x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
50
+
51
+ # Derivative estimation (with a fixed convolutional kernel)
52
+ delta_coeff = (
53
+ torch.nn.functional.conv1d(
54
+ x, self.kernel.to(x.device), groups=x.shape[1]
55
+ )
56
+ / self.denom
57
+ )
58
+
59
+ # Retrieving the original dimensionality (for multi-channel case)
60
+ if len(or_shape) == 4:
61
+ delta_coeff = delta_coeff.reshape(
62
+ or_shape[0], or_shape[1], or_shape[2], or_shape[3],
63
+ )
64
+ delta_coeff = delta_coeff.transpose(1, -1).transpose(2, -1)
65
+
66
+ return delta_coeff
67
+
68
+
69
+ class Filterbank(torch.nn.Module):
70
+ """computes filter bank (FBANK) features given spectral magnitudes.
71
+ Arguments
72
+ ---------
73
+ n_mels : float
74
+ Number of Mel filters used to average the spectrogram.
75
+ log_mel : bool
76
+ If True, it computes the log of the FBANKs.
77
+ filter_shape : str
78
+ Shape of the filters ('triangular', 'rectangular', 'gaussian').
79
+ f_min : int
80
+ Lowest frequency for the Mel filters.
81
+ f_max : int
82
+ Highest frequency for the Mel filters.
83
+ n_fft : int
84
+ Number of fft points of the STFT. It defines the frequency resolution
85
+ (n_fft should be<= than win_len).
86
+ sample_rate : int
87
+ Sample rate of the input audio signal (e.g, 16000)
88
+ power_spectrogram : float
89
+ Exponent used for spectrogram computation.
90
+ amin : float
91
+ Minimum amplitude (used for numerical stability).
92
+ ref_value : float
93
+ Reference value used for the dB scale.
94
+ top_db : float
95
+ Minimum negative cut-off in decibels.
96
+ freeze : bool
97
+ If False, it the central frequency and the band of each filter are
98
+ added into nn.parameters. If True, the standard frozen features
99
+ are computed.
100
+ param_change_factor: bool
101
+ If freeze=False, this parameter affects the speed at which the filter
102
+ parameters (i.e., central_freqs and bands) can be changed. When high
103
+ (e.g., param_change_factor=1) the filters change a lot during training.
104
+ When low (e.g. param_change_factor=0.1) the filter parameters are more
105
+ stable during training
106
+ param_rand_factor: float
107
+ This parameter can be used to randomly change the filter parameters
108
+ (i.e, central frequencies and bands) during training. It is thus a
109
+ sort of regularization. param_rand_factor=0 does not affect, while
110
+ param_rand_factor=0.15 allows random variations within +-15% of the
111
+ standard values of the filter parameters (e.g., if the central freq
112
+ is 100 Hz, we can randomly change it from 85 Hz to 115 Hz).
113
+ Example
114
+ -------
115
+ >>> import torch
116
+ >>> compute_fbanks = Filterbank()
117
+ >>> inputs = torch.randn([10, 101, 201])
118
+ >>> features = compute_fbanks(inputs)
119
+ >>> features.shape
120
+ torch.Size([10, 101, 40])
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ n_mels=40,
126
+ log_mel=True,
127
+ filter_shape="triangular",
128
+ f_min=0,
129
+ f_max=8000,
130
+ n_fft=400,
131
+ sample_rate=16000,
132
+ power_spectrogram=2,
133
+ amin=1e-10,
134
+ ref_value=1.0,
135
+ top_db=80.0,
136
+ param_change_factor=1.0,
137
+ param_rand_factor=0.0,
138
+ freeze=True,
139
+ ):
140
+ super().__init__()
141
+ self.n_mels = n_mels
142
+ self.log_mel = log_mel
143
+ self.filter_shape = filter_shape
144
+ self.f_min = f_min
145
+ self.f_max = f_max
146
+ self.n_fft = n_fft
147
+ self.sample_rate = sample_rate
148
+ self.power_spectrogram = power_spectrogram
149
+ self.amin = amin
150
+ self.ref_value = ref_value
151
+ self.top_db = top_db
152
+ self.freeze = freeze
153
+ self.n_stft = self.n_fft // 2 + 1
154
+ self.db_multiplier = math.log10(max(self.amin, self.ref_value))
155
+ self.device_inp = torch.device("cpu")
156
+ self.param_change_factor = param_change_factor
157
+ self.param_rand_factor = param_rand_factor
158
+
159
+ if self.power_spectrogram == 2:
160
+ self.multiplier = 10
161
+ else:
162
+ self.multiplier = 20
163
+
164
+ # Make sure f_min < f_max
165
+ if self.f_min >= self.f_max:
166
+ err_msg = "Require f_min: %f < f_max: %f" % (
167
+ self.f_min,
168
+ self.f_max,
169
+ )
170
+ print(err_msg)
171
+
172
+ # Filter definition
173
+ mel = torch.linspace(
174
+ self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2
175
+ )
176
+ hz = self._to_hz(mel)
177
+
178
+ # Computation of the filter bands
179
+ band = hz[1:] - hz[:-1]
180
+ self.band = band[:-1]
181
+ self.f_central = hz[1:-1]
182
+
183
+ # Adding the central frequency and the band to the list of nn param
184
+ if not self.freeze:
185
+ self.f_central = torch.nn.Parameter(
186
+ self.f_central / (self.sample_rate * self.param_change_factor)
187
+ )
188
+ self.band = torch.nn.Parameter(
189
+ self.band / (self.sample_rate * self.param_change_factor)
190
+ )
191
+
192
+ # Frequency axis
193
+ all_freqs = torch.linspace(0, self.sample_rate // 2, self.n_stft)
194
+
195
+ # Replicating for all the filters
196
+ self.all_freqs_mat = all_freqs.repeat(self.f_central.shape[0], 1)
197
+
198
+ def forward(self, spectrogram):
199
+ """Returns the FBANks.
200
+ Arguments
201
+ ---------
202
+ x : tensor
203
+ A batch of spectrogram tensors.
204
+ """
205
+ # Computing central frequency and bandwidth of each filter
206
+ f_central_mat = self.f_central.repeat(
207
+ self.all_freqs_mat.shape[1], 1
208
+ ).transpose(0, 1)
209
+ band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(
210
+ 0, 1
211
+ )
212
+
213
+ # Uncomment to print filter parameters
214
+ # print(self.f_central*self.sample_rate * self.param_change_factor)
215
+ # print(self.band*self.sample_rate* self.param_change_factor)
216
+
217
+ # Creation of the multiplication matrix. It is used to create
218
+ # the filters that average the computed spectrogram.
219
+ if not self.freeze:
220
+ f_central_mat = f_central_mat * (
221
+ self.sample_rate
222
+ * self.param_change_factor
223
+ * self.param_change_factor
224
+ )
225
+ band_mat = band_mat * (
226
+ self.sample_rate
227
+ * self.param_change_factor
228
+ * self.param_change_factor
229
+ )
230
+
231
+ # Regularization with random changes of filter central frequency and band
232
+ elif self.param_rand_factor != 0 and self.training:
233
+ rand_change = (
234
+ 1.0
235
+ + torch.rand(2) * 2 * self.param_rand_factor
236
+ - self.param_rand_factor
237
+ )
238
+ f_central_mat = f_central_mat * rand_change[0]
239
+ band_mat = band_mat * rand_change[1]
240
+
241
+ fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(
242
+ spectrogram.device
243
+ )
244
+
245
+ sp_shape = spectrogram.shape
246
+
247
+ # Managing multi-channels case (batch, time, channels)
248
+ if len(sp_shape) == 4:
249
+ spectrogram = spectrogram.permute(0, 3, 1, 2)
250
+ spectrogram = spectrogram.reshape(
251
+ sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2]
252
+ )
253
+
254
+ # FBANK computation
255
+ fbanks = torch.matmul(spectrogram, fbank_matrix)
256
+ if self.log_mel:
257
+ fbanks = self._amplitude_to_DB(fbanks)
258
+
259
+ # Reshaping in the case of multi-channel inputs
260
+ if len(sp_shape) == 4:
261
+ fb_shape = fbanks.shape
262
+ fbanks = fbanks.reshape(
263
+ sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]
264
+ )
265
+ fbanks = fbanks.permute(0, 2, 3, 1)
266
+
267
+ return fbanks
268
+
269
+ @staticmethod
270
+ def _to_mel(hz):
271
+ """Returns mel-frequency value corresponding to the input
272
+ frequency value in Hz.
273
+ Arguments
274
+ ---------
275
+ x : float
276
+ The frequency point in Hz.
277
+ """
278
+ return 2595 * math.log10(1 + hz / 700)
279
+
280
+ @staticmethod
281
+ def _to_hz(mel):
282
+ """Returns hz-frequency value corresponding to the input
283
+ mel-frequency value.
284
+ Arguments
285
+ ---------
286
+ x : float
287
+ The frequency point in the mel-scale.
288
+ """
289
+ return 700 * (10 ** (mel / 2595) - 1)
290
+
291
+ def _triangular_filters(self, all_freqs, f_central, band):
292
+ """Returns fbank matrix using triangular filters.
293
+ Arguments
294
+ ---------
295
+ all_freqs : Tensor
296
+ Tensor gathering all the frequency points.
297
+ f_central : Tensor
298
+ Tensor gathering central frequencies of each filter.
299
+ band : Tensor
300
+ Tensor gathering the bands of each filter.
301
+ """
302
+
303
+ # Computing the slops of the filters
304
+ slope = (all_freqs - f_central) / band
305
+ left_side = slope + 1.0
306
+ right_side = -slope + 1.0
307
+
308
+ # Adding zeros for negative values
309
+ zero = torch.zeros(1, device=self.device_inp)
310
+ fbank_matrix = torch.max(
311
+ zero, torch.min(left_side, right_side)
312
+ ).transpose(0, 1)
313
+
314
+ return fbank_matrix
315
+
316
+ def _rectangular_filters(self, all_freqs, f_central, band):
317
+ """Returns fbank matrix using rectangular filters.
318
+ Arguments
319
+ ---------
320
+ all_freqs : Tensor
321
+ Tensor gathering all the frequency points.
322
+ f_central : Tensor
323
+ Tensor gathering central frequencies of each filter.
324
+ band : Tensor
325
+ Tensor gathering the bands of each filter.
326
+ """
327
+
328
+ # cut-off frequencies of the filters
329
+ low_hz = f_central - band
330
+ high_hz = f_central + band
331
+
332
+ # Left/right parts of the filter
333
+ left_side = right_size = all_freqs.ge(low_hz)
334
+ right_size = all_freqs.le(high_hz)
335
+
336
+ fbank_matrix = (left_side * right_size).float().transpose(0, 1)
337
+
338
+ return fbank_matrix
339
+
340
+ def _gaussian_filters(
341
+ self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)
342
+ ):
343
+ """Returns fbank matrix using gaussian filters.
344
+ Arguments
345
+ ---------
346
+ all_freqs : Tensor
347
+ Tensor gathering all the frequency points.
348
+ f_central : Tensor
349
+ Tensor gathering central frequencies of each filter.
350
+ band : Tensor
351
+ Tensor gathering the bands of each filter.
352
+ smooth_factor: Tensor
353
+ Smoothing factor of the gaussian filter. It can be used to employ
354
+ sharper or flatter filters.
355
+ """
356
+ fbank_matrix = torch.exp(
357
+ -0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2
358
+ ).transpose(0, 1)
359
+
360
+ return fbank_matrix
361
+
362
+ def _create_fbank_matrix(self, f_central_mat, band_mat):
363
+ """Returns fbank matrix to use for averaging the spectrum with
364
+ the set of filter-banks.
365
+ Arguments
366
+ ---------
367
+ f_central : Tensor
368
+ Tensor gathering central frequencies of each filter.
369
+ band : Tensor
370
+ Tensor gathering the bands of each filter.
371
+ smooth_factor: Tensor
372
+ Smoothing factor of the gaussian filter. It can be used to employ
373
+ sharper or flatter filters.
374
+ """
375
+ if self.filter_shape == "triangular":
376
+ fbank_matrix = self._triangular_filters(
377
+ self.all_freqs_mat, f_central_mat, band_mat
378
+ )
379
+
380
+ elif self.filter_shape == "rectangular":
381
+ fbank_matrix = self._rectangular_filters(
382
+ self.all_freqs_mat, f_central_mat, band_mat
383
+ )
384
+
385
+ else:
386
+ fbank_matrix = self._gaussian_filters(
387
+ self.all_freqs_mat, f_central_mat, band_mat
388
+ )
389
+
390
+ return fbank_matrix
391
+
392
+ def _amplitude_to_DB(self, x):
393
+ """Converts linear-FBANKs to log-FBANKs.
394
+ Arguments
395
+ ---------
396
+ x : Tensor
397
+ A batch of linear FBANK tensors.
398
+ """
399
+
400
+ x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
401
+ x_db -= self.multiplier * self.db_multiplier
402
+
403
+ # Setting up dB max. It is the max over time and frequency,
404
+ # Hence, of a whole sequence (sequence-dependent)
405
+ new_x_db_max = x_db.amax(dim=(-2, -1)) - self.top_db
406
+
407
+ # Clipping to dB max. The view is necessary as only a scalar is obtained
408
+ # per sequence.
409
+ x_db = torch.max(x_db, new_x_db_max.view(x_db.shape[0], 1, 1))
410
+
411
+ return x_db
412
+
413
+
414
+ class STFT(torch.nn.Module):
415
+ """computes the Short-Term Fourier Transform (STFT).
416
+ This class computes the Short-Term Fourier Transform of an audio signal.
417
+ It supports multi-channel audio inputs (batch, time, channels).
418
+ Arguments
419
+ ---------
420
+ sample_rate : int
421
+ Sample rate of the input audio signal (e.g 16000).
422
+ win_length : float
423
+ Length (in ms) of the sliding window used to compute the STFT.
424
+ hop_length : float
425
+ Length (in ms) of the hope of the sliding window used to compute
426
+ the STFT.
427
+ n_fft : int
428
+ Number of fft point of the STFT. It defines the frequency resolution
429
+ (n_fft should be <= than win_len).
430
+ window_fn : function
431
+ A function that takes an integer (number of samples) and outputs a
432
+ tensor to be multiplied with each window before fft.
433
+ normalized_stft : bool
434
+ If True, the function returns the normalized STFT results,
435
+ i.e., multiplied by win_length^-0.5 (default is False).
436
+ center : bool
437
+ If True (default), the input will be padded on both sides so that the
438
+ t-th frame is centered at time t×hop_length. Otherwise, the t-th frame
439
+ begins at time t×hop_length.
440
+ pad_mode : str
441
+ It can be 'constant','reflect','replicate', 'circular', 'reflect'
442
+ (default). 'constant' pads the input tensor boundaries with a
443
+ constant value. 'reflect' pads the input tensor using the reflection
444
+ of the input boundary. 'replicate' pads the input tensor using
445
+ replication of the input boundary. 'circular' pads using circular
446
+ replication.
447
+ onesided : True
448
+ If True (default) only returns nfft/2 values. Note that the other
449
+ samples are redundant due to the Fourier transform conjugate symmetry.
450
+ Example
451
+ -------
452
+ >>> import torch
453
+ >>> compute_STFT = STFT(
454
+ ... sample_rate=16000, win_length=25, hop_length=10, n_fft=400
455
+ ... )
456
+ >>> inputs = torch.randn([10, 16000])
457
+ >>> features = compute_STFT(inputs)
458
+ >>> features.shape
459
+ torch.Size([10, 101, 201, 2])
460
+ """
461
+
462
+ def __init__(
463
+ self,
464
+ sample_rate,
465
+ win_length=25,
466
+ hop_length=10,
467
+ n_fft=400,
468
+ window_fn=torch.hamming_window,
469
+ normalized_stft=False,
470
+ center=True,
471
+ pad_mode="constant",
472
+ onesided=True,
473
+ ):
474
+ super().__init__()
475
+ self.sample_rate = sample_rate
476
+ self.win_length = win_length
477
+ self.hop_length = hop_length
478
+ self.n_fft = n_fft
479
+ self.normalized_stft = normalized_stft
480
+ self.center = center
481
+ self.pad_mode = pad_mode
482
+ self.onesided = onesided
483
+
484
+ # Convert win_length and hop_length from ms to samples
485
+ self.win_length = int(
486
+ round((self.sample_rate / 1000.0) * self.win_length)
487
+ )
488
+ self.hop_length = int(
489
+ round((self.sample_rate / 1000.0) * self.hop_length)
490
+ )
491
+
492
+ self.window = window_fn(self.win_length)
493
+
494
+ def forward(self, x):
495
+ """Returns the STFT generated from the input waveforms.
496
+ Arguments
497
+ ---------
498
+ x : tensor
499
+ A batch of audio signals to transform.
500
+ """
501
+
502
+ # Managing multi-channel stft
503
+ or_shape = x.shape
504
+ if len(or_shape) == 3:
505
+ x = x.transpose(1, 2)
506
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1])
507
+
508
+ stft = torch.stft(
509
+ x,
510
+ self.n_fft,
511
+ self.hop_length,
512
+ self.win_length,
513
+ self.window.to(x.device),
514
+ self.center,
515
+ self.pad_mode,
516
+ self.normalized_stft,
517
+ self.onesided,
518
+ return_complex=True,
519
+ )
520
+
521
+ stft = torch.view_as_real(stft)
522
+
523
+ # Retrieving the original dimensionality (batch,time, channels)
524
+ if len(or_shape) == 3:
525
+ stft = stft.reshape(
526
+ or_shape[0],
527
+ or_shape[2],
528
+ stft.shape[1],
529
+ stft.shape[2],
530
+ stft.shape[3],
531
+ )
532
+ stft = stft.permute(0, 3, 2, 4, 1)
533
+ else:
534
+ # (batch, time, channels)
535
+ stft = stft.transpose(2, 1)
536
+
537
+ return stft
538
+
539
+
540
+ def spectral_magnitude(
541
+ stft, power: int = 1, log: bool = False, eps: float = 1e-14
542
+ ):
543
+ """Returns the magnitude of a complex spectrogram.
544
+ Arguments
545
+ ---------
546
+ stft : torch.Tensor
547
+ A tensor, output from the stft function.
548
+ power : int
549
+ What power to use in computing the magnitude.
550
+ Use power=1 for the power spectrogram.
551
+ Use power=0.5 for the magnitude spectrogram.
552
+ log : bool
553
+ Whether to apply log to the spectral features.
554
+ Example
555
+ -------
556
+ >>> a = torch.Tensor([[3, 4]])
557
+ >>> spectral_magnitude(a, power=0.5)
558
+ tensor([5.])
559
+ """
560
+ spectr = stft.pow(2).sum(-1)
561
+
562
+ # Add eps avoids NaN when spectr is zero
563
+ if power < 1:
564
+ spectr = spectr + eps
565
+ spectr = spectr.pow(power)
566
+
567
+ if log:
568
+ return torch.log(spectr + eps)
569
+ return spectr
570
+
571
+
572
+ class ContextWindow(torch.nn.Module):
573
+ """Computes the context window.
574
+ This class applies a context window by gathering multiple time steps
575
+ in a single feature vector. The operation is performed with a
576
+ convolutional layer based on a fixed kernel designed for that.
577
+ Arguments
578
+ ---------
579
+ left_frames : int
580
+ Number of left frames (i.e, past frames) to collect.
581
+ right_frames : int
582
+ Number of right frames (i.e, future frames) to collect.
583
+ Example
584
+ -------
585
+ >>> import torch
586
+ >>> compute_cw = ContextWindow(left_frames=5, right_frames=5)
587
+ >>> inputs = torch.randn([10, 101, 20])
588
+ >>> features = compute_cw(inputs)
589
+ >>> features.shape
590
+ torch.Size([10, 101, 220])
591
+ """
592
+
593
+ def __init__(
594
+ self, left_frames=0, right_frames=0,
595
+ ):
596
+ super().__init__()
597
+ self.left_frames = left_frames
598
+ self.right_frames = right_frames
599
+ self.context_len = self.left_frames + self.right_frames + 1
600
+ self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
601
+
602
+ # Kernel definition
603
+ self.kernel = torch.eye(self.context_len, self.kernel_len)
604
+
605
+ if self.right_frames > self.left_frames:
606
+ lag = self.right_frames - self.left_frames
607
+ self.kernel = torch.roll(self.kernel, lag, 1)
608
+
609
+ self.first_call = True
610
+
611
+ def forward(self, x):
612
+ """Returns the tensor with the surrounding context.
613
+ Arguments
614
+ ---------
615
+ x : tensor
616
+ A batch of tensors.
617
+ """
618
+
619
+ x = x.transpose(1, 2)
620
+
621
+ if self.first_call is True:
622
+ self.first_call = False
623
+ self.kernel = (
624
+ self.kernel.repeat(x.shape[1], 1, 1)
625
+ .view(x.shape[1] * self.context_len, self.kernel_len,)
626
+ .unsqueeze(1)
627
+ )
628
+
629
+ # Managing multi-channel case
630
+ or_shape = x.shape
631
+ if len(or_shape) == 4:
632
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
633
+
634
+ # Compute context (using the estimated convolutional kernel)
635
+ cw_x = torch.nn.functional.conv1d(
636
+ x,
637
+ self.kernel.to(x.device),
638
+ groups=x.shape[1],
639
+ padding=max(self.left_frames, self.right_frames),
640
+ )
641
+
642
+ # Retrieving the original dimensionality (for multi-channel case)
643
+ if len(or_shape) == 4:
644
+ cw_x = cw_x.reshape(
645
+ or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1]
646
+ )
647
+
648
+ cw_x = cw_x.transpose(1, 2)
649
+
650
+ return cw_x
651
+
652
+
653
+ class Fbank(torch.nn.Module):
654
+
655
+ def __init__(
656
+ self,
657
+ deltas=False,
658
+ context=False,
659
+ requires_grad=False,
660
+ sample_rate=16000,
661
+ f_min=0,
662
+ f_max=None,
663
+ n_fft=400,
664
+ n_mels=40,
665
+ filter_shape="triangular",
666
+ param_change_factor=1.0,
667
+ param_rand_factor=0.0,
668
+ left_frames=5,
669
+ right_frames=5,
670
+ win_length=25,
671
+ hop_length=10,
672
+ ):
673
+ super().__init__()
674
+ self.deltas = deltas
675
+ self.context = context
676
+ self.requires_grad = requires_grad
677
+
678
+ if f_max is None:
679
+ f_max = sample_rate / 2
680
+
681
+ self.compute_STFT = STFT(
682
+ sample_rate=sample_rate,
683
+ n_fft=n_fft,
684
+ win_length=win_length,
685
+ hop_length=hop_length,
686
+ )
687
+ self.compute_fbanks = Filterbank(
688
+ sample_rate=sample_rate,
689
+ n_fft=n_fft,
690
+ n_mels=n_mels,
691
+ f_min=f_min,
692
+ f_max=f_max,
693
+ freeze=not requires_grad,
694
+ filter_shape=filter_shape,
695
+ param_change_factor=param_change_factor,
696
+ param_rand_factor=param_rand_factor,
697
+ )
698
+ self.compute_deltas = Deltas(input_size=n_mels)
699
+ self.context_window = ContextWindow(
700
+ left_frames=left_frames, right_frames=right_frames,
701
+ )
702
+
703
+ def forward(self, wav):
704
+ """Returns a set of features generated from the input waveforms.
705
+ Arguments
706
+ ---------
707
+ wav : tensor
708
+ A batch of audio signals to transform to features.
709
+ """
710
+ STFT = self.compute_STFT(wav)
711
+ mag = spectral_magnitude(STFT)
712
+ fbanks = self.compute_fbanks(mag)
713
+ if self.deltas:
714
+ delta1 = self.compute_deltas(fbanks)
715
+ delta2 = self.compute_deltas(delta1)
716
+ fbanks = torch.cat([fbanks, delta1, delta2], dim=2)
717
+ if self.context:
718
+ fbanks = self.context_window(fbanks)
719
+ return fbanks