Ada312 commited on
Commit
56f35be
·
verified ·
1 Parent(s): 10d809f

Upload 2 files

Browse files
Files changed (2) hide show
  1. complexnn.py +431 -0
  2. conv_stft.py +164 -0
complexnn.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ def get_casual_padding1d():
8
+ pass
9
+
10
+
11
+ def get_casual_padding2d():
12
+ pass
13
+
14
+
15
+ class cPReLU(nn.Module):
16
+
17
+ def __init__(self, complex_axis=1):
18
+ super(cPReLU, self).__init__()
19
+ self.r_prelu = nn.PReLU()
20
+ self.i_prelu = nn.PReLU()
21
+ self.complex_axis = complex_axis
22
+
23
+ def forward(self, inputs):
24
+ real, imag = torch.chunk(inputs, 2, self.complex_axis)
25
+ real = self.r_prelu(real)
26
+ imag = self.i_prelu(imag)
27
+ return torch.cat([real, imag], self.complex_axis)
28
+
29
+
30
+ class NavieComplexLSTM(nn.Module):
31
+ def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
32
+ super(NavieComplexLSTM, self).__init__()
33
+
34
+ self.input_dim = input_size // 2
35
+ self.rnn_units = hidden_size // 2
36
+ self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
37
+ batch_first=False)
38
+ self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
39
+ batch_first=False)
40
+ if bidirectional:
41
+ bidirectional = 2
42
+ else:
43
+ bidirectional = 1
44
+ if projection_dim is not None:
45
+ self.projection_dim = projection_dim // 2
46
+ self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
47
+ self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
48
+ else:
49
+ self.projection_dim = None
50
+
51
+ def forward(self, inputs):
52
+ if isinstance(inputs, list):
53
+ real, imag = inputs
54
+ elif isinstance(inputs, torch.Tensor):
55
+ real, imag = torch.chunk(inputs, -1)
56
+ r2r_out = self.real_lstm(real)[0]
57
+ r2i_out = self.imag_lstm(real)[0]
58
+ i2r_out = self.real_lstm(imag)[0]
59
+ i2i_out = self.imag_lstm(imag)[0]
60
+ real_out = r2r_out - i2i_out
61
+ imag_out = i2r_out + r2i_out
62
+ if self.projection_dim is not None:
63
+ real_out = self.r_trans(real_out)
64
+ imag_out = self.i_trans(imag_out)
65
+ # print(real_out.shape,imag_out.shape)
66
+ return [real_out, imag_out]
67
+
68
+ def flatten_parameters(self):
69
+ self.imag_lstm.flatten_parameters()
70
+ self.real_lstm.flatten_parameters()
71
+
72
+
73
+ def complex_cat(inputs, axis):
74
+ real, imag = [], []
75
+ for idx, data in enumerate(inputs):
76
+ r, i = torch.chunk(data, 2, axis)
77
+ real.append(r)
78
+ imag.append(i)
79
+ real = torch.cat(real, axis)
80
+ imag = torch.cat(imag, axis)
81
+ outputs = torch.cat([real, imag], axis)
82
+ return outputs
83
+
84
+
85
+ class ComplexConv2d(nn.Module):
86
+
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels,
91
+ kernel_size=(1, 1),
92
+ stride=(1, 1),
93
+ padding=(0, 0),
94
+ dilation=1,
95
+ groups=1,
96
+ causal=True,
97
+ complex_axis=1,
98
+ ):
99
+ '''
100
+ in_channels: real+imag
101
+ out_channels: real+imag
102
+ kernel_size : input [B,C,D,T] kernel size in [D,T]
103
+ padding : input [B,C,D,T] padding in [D,T]
104
+ causal: if causal, will padding time dimension's left side,
105
+ otherwise both
106
+
107
+ '''
108
+ super(ComplexConv2d, self).__init__()
109
+ self.in_channels = in_channels // 2
110
+ self.out_channels = out_channels // 2
111
+ self.kernel_size = kernel_size
112
+ self.stride = stride
113
+ self.padding = padding
114
+ self.causal = causal
115
+ self.groups = groups
116
+ self.dilation = dilation
117
+ self.complex_axis = complex_axis
118
+ self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
119
+ padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
120
+ self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
121
+ padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
122
+
123
+ nn.init.normal_(self.real_conv.weight.data, std=0.05)
124
+ nn.init.normal_(self.imag_conv.weight.data, std=0.05)
125
+ nn.init.constant_(self.real_conv.bias, 0.)
126
+ nn.init.constant_(self.imag_conv.bias, 0.)
127
+
128
+ def forward(self, inputs):
129
+ if self.padding[1] != 0 and self.causal:
130
+ inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])
131
+ else:
132
+ inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
133
+
134
+ if self.complex_axis == 0:
135
+ real = self.real_conv(inputs)
136
+ imag = self.imag_conv(inputs)
137
+ real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
138
+ real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
139
+
140
+ else:
141
+ if isinstance(inputs, torch.Tensor):
142
+ real, imag = torch.chunk(inputs, 2, self.complex_axis)
143
+
144
+ real2real = self.real_conv(real, )
145
+ imag2imag = self.imag_conv(imag, )
146
+
147
+ real2imag = self.imag_conv(real)
148
+ imag2real = self.real_conv(imag)
149
+
150
+ real = real2real - imag2imag
151
+ imag = real2imag + imag2real
152
+ out = torch.cat([real, imag], self.complex_axis)
153
+
154
+ return out
155
+
156
+
157
+ class ComplexConvTranspose2d(nn.Module):
158
+
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ kernel_size=(1, 1),
164
+ stride=(1, 1),
165
+ padding=(0, 0),
166
+ output_padding=(0, 0),
167
+ causal=False,
168
+ complex_axis=1,
169
+ groups=1
170
+ ):
171
+ '''
172
+ in_channels: real+imag
173
+ out_channels: real+imag
174
+ '''
175
+ super(ComplexConvTranspose2d, self).__init__()
176
+ self.in_channels = in_channels // 2
177
+ self.out_channels = out_channels // 2
178
+ self.kernel_size = kernel_size
179
+ self.stride = stride
180
+ self.padding = padding
181
+ self.output_padding = output_padding
182
+ self.groups = groups
183
+
184
+ self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
185
+ padding=self.padding, output_padding=output_padding, groups=self.groups)
186
+ self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
187
+ padding=self.padding, output_padding=output_padding, groups=self.groups)
188
+ self.complex_axis = complex_axis
189
+
190
+ nn.init.normal_(self.real_conv.weight, std=0.05)
191
+ nn.init.normal_(self.imag_conv.weight, std=0.05)
192
+ nn.init.constant_(self.real_conv.bias, 0.)
193
+ nn.init.constant_(self.imag_conv.bias, 0.)
194
+
195
+ def forward(self, inputs):
196
+
197
+ if isinstance(inputs, torch.Tensor):
198
+ real, imag = torch.chunk(inputs, 2, self.complex_axis)
199
+ elif isinstance(inputs, tuple) or isinstance(inputs, list):
200
+ real = inputs[0]
201
+ imag = inputs[1]
202
+ if self.complex_axis == 0:
203
+ real = self.real_conv(inputs)
204
+ imag = self.imag_conv(inputs)
205
+ real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
206
+ real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
207
+
208
+ else:
209
+ if isinstance(inputs, torch.Tensor):
210
+ real, imag = torch.chunk(inputs, 2, self.complex_axis)
211
+
212
+ real2real = self.real_conv(real, )
213
+ imag2imag = self.imag_conv(imag, )
214
+
215
+ real2imag = self.imag_conv(real)
216
+ imag2real = self.real_conv(imag)
217
+
218
+ real = real2real - imag2imag
219
+ imag = real2imag + imag2real
220
+ out = torch.cat([real, imag], self.complex_axis)
221
+
222
+ return out
223
+
224
+
225
+ # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
226
+ # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
227
+
228
+ class ComplexBatchNorm(torch.nn.Module):
229
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
230
+ track_running_stats=True, complex_axis=1):
231
+ super(ComplexBatchNorm, self).__init__()
232
+ self.num_features = num_features // 2
233
+ self.eps = eps
234
+ self.momentum = momentum
235
+ self.affine = affine
236
+ self.track_running_stats = track_running_stats
237
+
238
+ self.complex_axis = complex_axis
239
+
240
+ if self.affine:
241
+ self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
242
+ self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
243
+ self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
244
+ self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
245
+ self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
246
+ else:
247
+ self.register_parameter('Wrr', None)
248
+ self.register_parameter('Wri', None)
249
+ self.register_parameter('Wii', None)
250
+ self.register_parameter('Br', None)
251
+ self.register_parameter('Bi', None)
252
+
253
+ if self.track_running_stats:
254
+ self.register_buffer('RMr', torch.zeros(self.num_features))
255
+ self.register_buffer('RMi', torch.zeros(self.num_features))
256
+ self.register_buffer('RVrr', torch.ones(self.num_features))
257
+ self.register_buffer('RVri', torch.zeros(self.num_features))
258
+ self.register_buffer('RVii', torch.ones(self.num_features))
259
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
260
+ else:
261
+ self.register_parameter('RMr', None)
262
+ self.register_parameter('RMi', None)
263
+ self.register_parameter('RVrr', None)
264
+ self.register_parameter('RVri', None)
265
+ self.register_parameter('RVii', None)
266
+ self.register_parameter('num_batches_tracked', None)
267
+ self.reset_parameters()
268
+
269
+ def reset_running_stats(self):
270
+ if self.track_running_stats:
271
+ self.RMr.zero_()
272
+ self.RMi.zero_()
273
+ self.RVrr.fill_(1)
274
+ self.RVri.zero_()
275
+ self.RVii.fill_(1)
276
+ self.num_batches_tracked.zero_()
277
+
278
+ def reset_parameters(self):
279
+ self.reset_running_stats()
280
+ if self.affine:
281
+ self.Br.data.zero_()
282
+ self.Bi.data.zero_()
283
+ self.Wrr.data.fill_(1)
284
+ self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
285
+ self.Wii.data.fill_(1)
286
+
287
+ def _check_input_dim(self, xr, xi):
288
+ assert (xr.shape == xi.shape)
289
+ assert (xr.size(1) == self.num_features)
290
+
291
+ def forward(self, inputs):
292
+ # self._check_input_dim(xr, xi)
293
+
294
+ xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
295
+ exponential_average_factor = 0.0
296
+
297
+ if self.training and self.track_running_stats:
298
+ self.num_batches_tracked += 1
299
+ if self.momentum is None: # use cumulative moving average
300
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
301
+ else: # use exponential moving average
302
+ exponential_average_factor = self.momentum
303
+
304
+ #
305
+ # NOTE: The precise meaning of the "training flag" is:
306
+ # True: Normalize using batch statistics, update running statistics
307
+ # if they are being collected.
308
+ # False: Normalize using running statistics, ignore batch statistics.
309
+ #
310
+ training = self.training or not self.track_running_stats
311
+ redux = [i for i in reversed(range(xr.dim())) if i != 1]
312
+ vdim = [1] * xr.dim()
313
+ vdim[1] = xr.size(1)
314
+
315
+ #
316
+ # Mean M Computation and Centering
317
+ #
318
+ # Includes running mean update if training and running.
319
+ #
320
+ if training:
321
+ Mr, Mi = xr, xi
322
+ for d in redux:
323
+ Mr = Mr.mean(d, keepdim=True)
324
+ Mi = Mi.mean(d, keepdim=True)
325
+ if self.track_running_stats:
326
+ self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
327
+ self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
328
+ else:
329
+ Mr = self.RMr.view(vdim)
330
+ Mi = self.RMi.view(vdim)
331
+ xr, xi = xr - Mr, xi - Mi
332
+
333
+ #
334
+ # Variance Matrix V Computation
335
+ #
336
+ # Includes epsilon numerical stabilizer/Tikhonov regularizer.
337
+ # Includes running variance update if training and running.
338
+ #
339
+ if training:
340
+ Vrr = xr * xr
341
+ Vri = xr * xi
342
+ Vii = xi * xi
343
+ for d in redux:
344
+ Vrr = Vrr.mean(d, keepdim=True)
345
+ Vri = Vri.mean(d, keepdim=True)
346
+ Vii = Vii.mean(d, keepdim=True)
347
+ if self.track_running_stats:
348
+ self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
349
+ self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
350
+ self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
351
+ else:
352
+ Vrr = self.RVrr.view(vdim)
353
+ Vri = self.RVri.view(vdim)
354
+ Vii = self.RVii.view(vdim)
355
+ Vrr = Vrr + self.eps
356
+ Vri = Vri
357
+ Vii = Vii + self.eps
358
+
359
+ #
360
+ # Matrix Inverse Square Root U = V^-0.5
361
+ #
362
+ # sqrt of a 2x2 matrix,
363
+ # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
364
+ tau = Vrr + Vii
365
+ delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
366
+ s = delta.sqrt()
367
+ t = (tau + 2 * s).sqrt()
368
+
369
+ # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
370
+ rst = (s * t).reciprocal()
371
+ Urr = (s + Vii) * rst
372
+ Uii = (s + Vrr) * rst
373
+ Uri = (- Vri) * rst
374
+
375
+ #
376
+ # Optionally left-multiply U by affine weights W to produce combined
377
+ # weights Z, left-multiply the inputs by Z, then optionally bias them.
378
+ #
379
+ # y = Zx + B
380
+ # y = WUx + B
381
+ # y = [Wrr Wri][Urr Uri] [xr] + [Br]
382
+ # [Wir Wii][Uir Uii] [xi] [Bi]
383
+ #
384
+ if self.affine:
385
+ Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
386
+ Zrr = (Wrr * Urr) + (Wri * Uri)
387
+ Zri = (Wrr * Uri) + (Wri * Uii)
388
+ Zir = (Wri * Urr) + (Wii * Uri)
389
+ Zii = (Wri * Uri) + (Wii * Uii)
390
+ else:
391
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
392
+
393
+ yr = (Zrr * xr) + (Zri * xi)
394
+ yi = (Zir * xr) + (Zii * xi)
395
+
396
+ if self.affine:
397
+ yr = yr + self.Br.view(vdim)
398
+ yi = yi + self.Bi.view(vdim)
399
+
400
+ outputs = torch.cat([yr, yi], self.complex_axis)
401
+ return outputs
402
+
403
+ def extra_repr(self):
404
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
405
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
406
+
407
+
408
+ def complex_cat(inputs, axis):
409
+ real, imag = [], []
410
+ for idx, data in enumerate(inputs):
411
+ r, i = torch.chunk(data, 2, axis)
412
+ real.append(r)
413
+ imag.append(i)
414
+ real = torch.cat(real, axis)
415
+ imag = torch.cat(imag, axis)
416
+ outputs = torch.cat([real, imag], axis)
417
+ return outputs
418
+
419
+
420
+ if __name__ == '__main__':
421
+ import dc_crn7
422
+
423
+ torch.manual_seed(20)
424
+ onet1 = dc_crn7.ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
425
+ onet2 = dc_crn7.ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
426
+ inputs = torch.randn([1, 12, 12, 10])
427
+ # print(onet1.real_kernel[0,0,0,0])
428
+ nnet1 = ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1), causal=True)
429
+ # print(nnet1.real_conv.weight[0,0,0,0])
430
+ nnet2 = ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
431
+ print(torch.mean(nnet1(inputs) - onet1(inputs)))
conv_stft.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from scipy.signal import get_window
6
+
7
+
8
+ def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
9
+ if win_type == 'None' or win_type is None:
10
+ window = np.ones(win_len)
11
+ else:
12
+ window = get_window(win_type, win_len, fftbins=True) # **0.5
13
+
14
+ N = fft_len
15
+ fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
16
+ real_kernel = np.real(fourier_basis)
17
+ imag_kernel = np.imag(fourier_basis)
18
+ kernel = np.concatenate([real_kernel, imag_kernel], 1).T
19
+
20
+ if invers:
21
+ kernel = np.linalg.pinv(kernel).T
22
+
23
+ kernel = kernel * window
24
+ kernel = kernel[:, None, :]
25
+ return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
26
+
27
+
28
+ class ConvSTFT(nn.Module):
29
+
30
+ def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
31
+ super(ConvSTFT, self).__init__()
32
+
33
+ if fft_len == None:
34
+ self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
35
+ else:
36
+ self.fft_len = fft_len
37
+
38
+ kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
39
+ # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
40
+ self.register_buffer('weight', kernel)
41
+ self.feature_type = feature_type
42
+ self.stride = win_inc
43
+ self.win_len = win_len
44
+ self.dim = self.fft_len
45
+
46
+ def forward(self, inputs):
47
+ if inputs.dim() == 2:
48
+ inputs = torch.unsqueeze(inputs, 1)
49
+ inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
50
+ outputs = F.conv1d(inputs, self.weight, stride=self.stride)
51
+
52
+ if self.feature_type == 'complex':
53
+ return outputs
54
+ else:
55
+ dim = self.dim // 2 + 1
56
+ real = outputs[:, :dim, :]
57
+ imag = outputs[:, dim:, :]
58
+ mags = torch.sqrt(real ** 2 + imag ** 2)
59
+ phase = torch.atan2(imag, real)
60
+ return mags, phase
61
+
62
+
63
+ class ConviSTFT(nn.Module):
64
+
65
+ def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
66
+ super(ConviSTFT, self).__init__()
67
+ if fft_len == None:
68
+ self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
69
+ else:
70
+ self.fft_len = fft_len
71
+ kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
72
+ # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
73
+ self.register_buffer('weight', kernel)
74
+ self.feature_type = feature_type
75
+ self.win_type = win_type
76
+ self.win_len = win_len
77
+ self.stride = win_inc
78
+ self.stride = win_inc
79
+ self.dim = self.fft_len
80
+ self.register_buffer('window', window)
81
+ self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
82
+
83
+ def forward(self, inputs, phase=None):
84
+ """
85
+ inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
86
+ phase: [B, N//2+1, T] (if not none)
87
+ """
88
+
89
+ if phase is not None:
90
+ real = inputs * torch.cos(phase)
91
+ imag = inputs * torch.sin(phase)
92
+ inputs = torch.cat([real, imag], 1)
93
+ outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
94
+
95
+ # this is from torch-stft: https://github.com/pseeth/torch-stft
96
+ t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
97
+ coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
98
+ outputs = outputs / (coff + 1e-8)
99
+ # outputs = torch.where(coff == 0, outputs, outputs/coff)
100
+ outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
101
+
102
+ return outputs
103
+
104
+
105
+ def test_fft():
106
+ torch.manual_seed(20)
107
+ win_len = 320
108
+ win_inc = 160
109
+ fft_len = 512
110
+ inputs = torch.randn([1, 1, 16000 * 4])
111
+ fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real')
112
+ import librosa
113
+
114
+ outputs1 = fft(inputs)[0]
115
+ outputs1 = outputs1.numpy()[0]
116
+ np_inputs = inputs.numpy().reshape([-1])
117
+ librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False)
118
+ print(np.mean((outputs1 - np.abs(librosa_stft)) ** 2))
119
+
120
+
121
+ def test_ifft1():
122
+ import soundfile as sf
123
+ N = 400
124
+ inc = 100
125
+ fft_len = 512
126
+ torch.manual_seed(N)
127
+ data = np.random.randn(16000 * 8)[None, None, :]
128
+ # data = sf.read('../ori.wav')[0]
129
+ inputs = data.reshape([1, 1, -1])
130
+ fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
131
+ ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
132
+ inputs = torch.from_numpy(inputs.astype(np.float32))
133
+ outputs1 = fft(inputs)
134
+ print(outputs1.shape)
135
+ outputs2 = ifft(outputs1)
136
+ sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000)
137
+ print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2) ** 2))
138
+
139
+
140
+ def test_ifft2():
141
+ N = 400
142
+ inc = 100
143
+ fft_len = 512
144
+ np.random.seed(20)
145
+ torch.manual_seed(20)
146
+ t = np.random.randn(16000 * 4) * 0.001
147
+ t = np.clip(t, -1, 1)
148
+ # input = torch.randn([1,16000*4])
149
+ input = torch.from_numpy(t[None, None, :].astype(np.float32))
150
+
151
+ fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
152
+ ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
153
+
154
+ out1 = fft(input)
155
+ output = ifft(out1)
156
+ print('random MSE', torch.mean(torch.abs(input - output) ** 2))
157
+ import soundfile as sf
158
+ sf.write('zero.wav', output[0, 0].numpy(), 16000)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ # test_fft()
163
+ test_ifft1()
164
+ # test_ifft2()