groadabike commited on
Commit
de2fbc6
1 Parent(s): 3046f2a

Upload tasnet.py

Browse files
Files changed (1) hide show
  1. tasnet.py +532 -0
tasnet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from huggingface_hub import PyTorchModelHubMixin
39
+
40
+ EPS = 1e-8
41
+
42
+
43
+ def overlap_and_add(signal, frame_step):
44
+ outer_dimensions = signal.size()[:-2]
45
+ frames, frame_length = signal.size()[-2:]
46
+
47
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
48
+ subframe_step = frame_step // subframe_length
49
+ subframes_per_frame = frame_length // subframe_length
50
+ output_size = frame_step * (frames - 1) + frame_length
51
+ output_subframes = output_size // subframe_length
52
+
53
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
54
+
55
+ frame = torch.arange(0, output_subframes, device=signal.device).unfold(
56
+ 0, subframes_per_frame, subframe_step
57
+ )
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNetStereo(nn.Module, PyTorchModelHubMixin):
68
+ def __init__(
69
+ self,
70
+ N=256,
71
+ L=20,
72
+ B=256,
73
+ H=512,
74
+ P=3,
75
+ X=8,
76
+ R=4,
77
+ C=2,
78
+ audio_channels=2,
79
+ samplerate=44100,
80
+ norm_type="gLN",
81
+ causal=False,
82
+ mask_nonlinear="relu",
83
+ ):
84
+ """
85
+ Args:
86
+ N: Number of filters in autoencoder
87
+ L: Length of the filters (in samples)
88
+ B: Number of channels in bottleneck 1 × 1-conv block
89
+ H: Number of channels in convolutional blocks
90
+ P: Kernel size in convolutional blocks
91
+ X: Number of convolutional blocks in each repeat
92
+ R: Number of repeats
93
+ C: Number of speakers
94
+ norm_type: BN, gLN, cLN
95
+ causal: causal or non-causal
96
+ mask_nonlinear: use which non-linear function to generate mask
97
+ """
98
+ super(ConvTasNetStereo, self).__init__()
99
+ # Hyper-parameter
100
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = (
101
+ N,
102
+ L,
103
+ B,
104
+ H,
105
+ P,
106
+ X,
107
+ R,
108
+ C,
109
+ )
110
+ self.norm_type = norm_type
111
+ self.causal = causal
112
+ self.mask_nonlinear = mask_nonlinear
113
+ self.audio_channels = audio_channels
114
+ self.samplerate = samplerate
115
+ # Components
116
+ self.encoder = Encoder(L, N, audio_channels)
117
+ self.separator = TemporalConvNet(
118
+ N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear
119
+ )
120
+ self.decoder = Decoder(N, L, audio_channels)
121
+ # init
122
+ for p in self.parameters():
123
+ if p.dim() > 1:
124
+ nn.init.xavier_normal_(p)
125
+
126
+ def valid_length(self, length):
127
+ return length
128
+
129
+ def forward(self, mixture):
130
+ """
131
+ Args:
132
+ mixture: [M, T], M is batch size, T is #samples
133
+ Returns:
134
+ est_source: [M, C, T]
135
+ """
136
+ mixture_w = self.encoder(mixture)
137
+ est_mask = self.separator(mixture_w)
138
+ est_source = self.decoder(mixture_w, est_mask)
139
+
140
+ # T changed after conv1d in encoder, fix it here
141
+ T_origin = mixture.size(-1)
142
+ T_conv = est_source.size(-1)
143
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
144
+ return est_source
145
+
146
+ def serialize(self):
147
+ """Serialize model and output dictionary.
148
+
149
+ Returns:
150
+ dict, serialized model with keys `model_args` and `state_dict`.
151
+ """
152
+ import pytorch_lightning as pl # Not used in torch.hub
153
+
154
+ model_conf = dict(
155
+ model_name=self.__class__.__name__,
156
+ state_dict=self.get_state_dict(),
157
+ # model_args=self.get_model_args(),
158
+ )
159
+ # Additional infos
160
+ infos = dict()
161
+ infos["software_versions"] = dict(
162
+ torch_version=torch.__version__,
163
+ pytorch_lightning_version=pl.__version__,
164
+ asteroid_version="0.7.0",
165
+ )
166
+ model_conf["infos"] = infos
167
+ return model_conf
168
+
169
+ def get_state_dict(self):
170
+ """In case the state dict needs to be modified before sharing the model."""
171
+ return self.state_dict()
172
+
173
+ def get_model_args(self):
174
+ """Arguments needed to re-instantiate the model."""
175
+ fb_config = self.encoder.filterbank.get_config()
176
+ masknet_config = self.masker.get_config()
177
+ # Assert both dict are disjoint
178
+ if not all(k not in fb_config for k in masknet_config):
179
+ raise AssertionError(
180
+ "Filterbank and Mask network config share common keys. Merging them is not safe."
181
+ )
182
+ # Merge all args under model_args.
183
+ model_args = {
184
+ **fb_config,
185
+ **masknet_config,
186
+ "encoder_activation": self.encoder_activation,
187
+ }
188
+ return model_args
189
+
190
+
191
+ class Encoder(nn.Module):
192
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer."""
193
+
194
+ def __init__(self, L, N, audio_channels):
195
+ super(Encoder, self).__init__()
196
+ # Hyper-parameter
197
+ self.L, self.N = L, N
198
+ # Components
199
+ # 50% overlap
200
+ self.conv1d_U = nn.Conv1d(
201
+ audio_channels, N, kernel_size=L, stride=L // 2, bias=False
202
+ )
203
+
204
+ def forward(self, mixture):
205
+ """
206
+ Args:
207
+ mixture: [M, T], M is batch size, T is #samples
208
+ Returns:
209
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
210
+ """
211
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
212
+ return mixture_w
213
+
214
+
215
+ class Decoder(nn.Module):
216
+ def __init__(self, N, L, audio_channels):
217
+ super(Decoder, self).__init__()
218
+ # Hyper-parameter
219
+ self.N, self.L = N, L
220
+ self.audio_channels = audio_channels
221
+ # Components
222
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
223
+
224
+ def forward(self, mixture_w, est_mask):
225
+ """
226
+ Args:
227
+ mixture_w: [M, N, K]
228
+ est_mask: [M, C, N, K]
229
+ Returns:
230
+ est_source: [M, C, T]
231
+ """
232
+ # D = W * M
233
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
234
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
235
+ # S = DV
236
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
237
+ m, c, k, _ = est_source.size()
238
+ est_source = (
239
+ est_source.view(m, c, k, self.audio_channels, -1)
240
+ .transpose(2, 3)
241
+ .contiguous()
242
+ )
243
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
244
+ return est_source
245
+
246
+
247
+ class TemporalConvNet(nn.Module):
248
+ def __init__(
249
+ self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"
250
+ ):
251
+ """
252
+ Args:
253
+ N: Number of filters in autoencoder
254
+ B: Number of channels in bottleneck 1 × 1-conv block
255
+ H: Number of channels in convolutional blocks
256
+ P: Kernel size in convolutional blocks
257
+ X: Number of convolutional blocks in each repeat
258
+ R: Number of repeats
259
+ C: Number of speakers
260
+ norm_type: BN, gLN, cLN
261
+ causal: causal or non-causal
262
+ mask_nonlinear: use which non-linear function to generate mask
263
+ """
264
+ super(TemporalConvNet, self).__init__()
265
+ # Hyper-parameter
266
+ self.C = C
267
+ self.mask_nonlinear = mask_nonlinear
268
+ # Components
269
+ # [M, N, K] -> [M, N, K]
270
+ layer_norm = ChannelwiseLayerNorm(N)
271
+ # [M, N, K] -> [M, B, K]
272
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
273
+ # [M, B, K] -> [M, B, K]
274
+ repeats = []
275
+ for r in range(R):
276
+ blocks = []
277
+ for x in range(X):
278
+ dilation = 2**x
279
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
280
+ blocks += [
281
+ TemporalBlock(
282
+ B,
283
+ H,
284
+ P,
285
+ stride=1,
286
+ padding=padding,
287
+ dilation=dilation,
288
+ norm_type=norm_type,
289
+ causal=causal,
290
+ )
291
+ ]
292
+ repeats += [nn.Sequential(*blocks)]
293
+ temporal_conv_net = nn.Sequential(*repeats)
294
+ # [M, B, K] -> [M, C*N, K]
295
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
296
+ # Put together
297
+ self.network = nn.Sequential(
298
+ layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1
299
+ )
300
+
301
+ def forward(self, mixture_w):
302
+ """
303
+ Keep this API same with TasNet
304
+ Args:
305
+ mixture_w: [M, N, K], M is batch size
306
+ returns:
307
+ est_mask: [M, C, N, K]
308
+ """
309
+ M, N, K = mixture_w.size()
310
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
311
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
312
+ if self.mask_nonlinear == "softmax":
313
+ est_mask = F.softmax(score, dim=1)
314
+ elif self.mask_nonlinear == "relu":
315
+ est_mask = F.relu(score)
316
+ else:
317
+ raise ValueError("Unsupported mask non-linear function")
318
+ return est_mask
319
+
320
+
321
+ class TemporalBlock(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_channels,
325
+ out_channels,
326
+ kernel_size,
327
+ stride,
328
+ padding,
329
+ dilation,
330
+ norm_type="gLN",
331
+ causal=False,
332
+ ):
333
+ super(TemporalBlock, self).__init__()
334
+ # [M, B, K] -> [M, H, K]
335
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
336
+ prelu = nn.PReLU()
337
+ norm = chose_norm(norm_type, out_channels)
338
+ # [M, H, K] -> [M, B, K]
339
+ dsconv = DepthwiseSeparableConv(
340
+ out_channels,
341
+ in_channels,
342
+ kernel_size,
343
+ stride,
344
+ padding,
345
+ dilation,
346
+ norm_type,
347
+ causal,
348
+ )
349
+ # Put together
350
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
351
+
352
+ def forward(self, x):
353
+ """
354
+ Args:
355
+ x: [M, B, K]
356
+ Returns:
357
+ [M, B, K]
358
+ """
359
+ residual = x
360
+ out = self.net(x)
361
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
362
+ return out + residual # look like w/o F.relu is better than w/ F.relu
363
+ # return F.relu(out + residual)
364
+
365
+
366
+ class DepthwiseSeparableConv(nn.Module):
367
+ def __init__(
368
+ self,
369
+ in_channels,
370
+ out_channels,
371
+ kernel_size,
372
+ stride,
373
+ padding,
374
+ dilation,
375
+ norm_type="gLN",
376
+ causal=False,
377
+ ):
378
+ super(DepthwiseSeparableConv, self).__init__()
379
+ # Use `groups` option to implement depthwise convolution
380
+ # [M, H, K] -> [M, H, K]
381
+ depthwise_conv = nn.Conv1d(
382
+ in_channels,
383
+ in_channels,
384
+ kernel_size,
385
+ stride=stride,
386
+ padding=padding,
387
+ dilation=dilation,
388
+ groups=in_channels,
389
+ bias=False,
390
+ )
391
+ if causal:
392
+ chomp = Chomp1d(padding)
393
+ prelu = nn.PReLU()
394
+ norm = chose_norm(norm_type, in_channels)
395
+ # [M, H, K] -> [M, B, K]
396
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
397
+ # Put together
398
+ if causal:
399
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
400
+ else:
401
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
402
+
403
+ def forward(self, x):
404
+ """
405
+ Args:
406
+ x: [M, H, K]
407
+ Returns:
408
+ result: [M, B, K]
409
+ """
410
+ return self.net(x)
411
+
412
+
413
+ class Chomp1d(nn.Module):
414
+ """To ensure the output length is the same as the input."""
415
+
416
+ def __init__(self, chomp_size):
417
+ super(Chomp1d, self).__init__()
418
+ self.chomp_size = chomp_size
419
+
420
+ def forward(self, x):
421
+ """
422
+ Args:
423
+ x: [M, H, Kpad]
424
+ Returns:
425
+ [M, H, K]
426
+ """
427
+ return x[:, :, : -self.chomp_size].contiguous()
428
+
429
+
430
+ def chose_norm(norm_type, channel_size):
431
+ """The input of normlization will be (M, C, K), where M is batch size,
432
+ C is channel size and K is sequence length.
433
+ """
434
+ if norm_type == "gLN":
435
+ return GlobalLayerNorm(channel_size)
436
+ elif norm_type == "cLN":
437
+ return ChannelwiseLayerNorm(channel_size)
438
+ elif norm_type == "id":
439
+ return nn.Identity()
440
+ else: # norm_type == "BN":
441
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
442
+ # along M and K, so this BN usage is right.
443
+ return nn.BatchNorm1d(channel_size)
444
+
445
+
446
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
447
+ class ChannelwiseLayerNorm(nn.Module):
448
+ """Channel-wise Layer Normalization (cLN)"""
449
+
450
+ def __init__(self, channel_size):
451
+ super(ChannelwiseLayerNorm, self).__init__()
452
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
453
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
454
+ self.reset_parameters()
455
+
456
+ def reset_parameters(self):
457
+ self.gamma.data.fill_(1)
458
+ self.beta.data.zero_()
459
+
460
+ def forward(self, y):
461
+ """
462
+ Args:
463
+ y: [M, N, K], M is batch size, N is channel size, K is length
464
+ Returns:
465
+ cLN_y: [M, N, K]
466
+ """
467
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
468
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
469
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
470
+ return cLN_y
471
+
472
+
473
+ class GlobalLayerNorm(nn.Module):
474
+ """Global Layer Normalization (gLN)"""
475
+
476
+ def __init__(self, channel_size):
477
+ super(GlobalLayerNorm, self).__init__()
478
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
479
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
480
+ self.reset_parameters()
481
+
482
+ def reset_parameters(self):
483
+ self.gamma.data.fill_(1)
484
+ self.beta.data.zero_()
485
+
486
+ def forward(self, y):
487
+ """
488
+ Args:
489
+ y: [M, N, K], M is batch size, N is channel size, K is length
490
+ Returns:
491
+ gLN_y: [M, N, K]
492
+ """
493
+ # TODO: in torch 1.0, torch.mean() support dim list
494
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
495
+ var = (
496
+ (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
497
+ )
498
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
499
+ return gLN_y
500
+
501
+
502
+ if __name__ == "__main__":
503
+ torch.manual_seed(123)
504
+ M, N, L, T = 2, 3, 4, 12
505
+ K = 2 * T // L - 1
506
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
507
+ mixture = torch.randint(3, (M, T))
508
+ # test Encoder
509
+ encoder = Encoder(L, N)
510
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
511
+ mixture_w = encoder(mixture)
512
+ print("mixture", mixture)
513
+ print("U", encoder.conv1d_U.weight)
514
+ print("mixture_w", mixture_w)
515
+ print("mixture_w size", mixture_w.size())
516
+
517
+ # test TemporalConvNet
518
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
519
+ est_mask = separator(mixture_w)
520
+ print("est_mask", est_mask)
521
+
522
+ # test Decoder
523
+ decoder = Decoder(N, L)
524
+ est_mask = torch.randint(2, (B, K, C, N))
525
+ est_source = decoder(mixture_w, est_mask)
526
+ print("est_source", est_source)
527
+
528
+ # test Conv-TasNet
529
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
530
+ est_source = conv_tasnet(mixture)
531
+ print("est_source", est_source)
532
+ print("est_source size", est_source.size())