mattricesound commited on
Commit
ca6b6f7
1 Parent(s): 79a7f1b

Remove previous DPTNet/DCUNet implementations

Browse files
cfg/model/dptnet.yaml CHANGED
@@ -10,12 +10,13 @@ model:
10
  network:
11
  _target_: remfx.models.DPTNetModel
12
  n_src: 1
13
- bn_chan: 128
14
- hid_size: 128
15
  chunk_size: 100
16
- n_repeats: 6
17
  fb_name: "free"
18
  kernel_size: 16
19
- n_filters: 1
 
20
  sample_rate: ${sample_rate}
21
  num_bins: 1025
 
10
  network:
11
  _target_: remfx.models.DPTNetModel
12
  n_src: 1
13
+ in_chan: 64
14
+ out_chan: 64
15
  chunk_size: 100
16
+ n_repeats: 2
17
  fb_name: "free"
18
  kernel_size: 16
19
+ n_filters: 64
20
+ stride: 8
21
  sample_rate: ${sample_rate}
22
  num_bins: 1025
cfg/model/tcn.yaml CHANGED
@@ -13,7 +13,7 @@ model:
13
  noutputs: 1
14
  nblocks: 20
15
  channel_growth: 0
16
- channel_width: 32
17
  kernel_size: 7
18
  stack_size: 10
19
  dilation_growth: 2
 
13
  noutputs: 1
14
  nblocks: 20
15
  channel_growth: 0
16
+ channel_width: 64
17
  kernel_size: 7
18
  stack_size: 10
19
  dilation_growth: 2
remfx/datasets.py CHANGED
@@ -295,7 +295,7 @@ class EffectDataset(Dataset):
295
 
296
  # Up to max_kept_effects
297
  if self.max_kept_effects != -1:
298
- num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
299
  else:
300
  num_kept_effects = len(self.effects_to_keep)
301
  effect_indices = effect_indices[:num_kept_effects]
 
295
 
296
  # Up to max_kept_effects
297
  if self.max_kept_effects != -1:
298
+ num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects))
299
  else:
300
  num_kept_effects = len(self.effects_to_keep)
301
  effect_indices = effect_indices[:num_kept_effects]
remfx/dcunet.py DELETED
@@ -1,649 +0,0 @@
1
- # Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation
2
-
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from torch.nn.init import calculate_gain
9
- from typing import Tuple
10
- from scipy.signal import get_window
11
- from librosa.util import pad_center
12
- from remfx.utils import single, concat_complex
13
-
14
-
15
- class ComplexConvBlock(nn.Module):
16
- """
17
- Convolution block
18
- """
19
-
20
- def __init__(
21
- self,
22
- in_channels: int,
23
- out_channels: int,
24
- kernel_size: int,
25
- padding: int = 0,
26
- layers: int = 4,
27
- bn_func=nn.BatchNorm1d,
28
- act_func=nn.LeakyReLU,
29
- skip_res: bool = False,
30
- ):
31
- super().__init__()
32
- # modules
33
- self.blocks = nn.ModuleList()
34
- self.skip_res = skip_res
35
-
36
- for idx in range(layers):
37
- in_ = in_channels if idx == 0 else out_channels
38
- self.blocks.append(
39
- nn.Sequential(
40
- *[
41
- bn_func(in_),
42
- act_func(),
43
- ComplexConv1d(in_, out_channels, kernel_size, padding=padding),
44
- ]
45
- )
46
- )
47
-
48
- def forward(self, x: torch.tensor) -> torch.tensor:
49
- temp = x
50
- for idx, block in enumerate(self.blocks):
51
- x = block(x)
52
-
53
- if temp.size() != x.size() or self.skip_res:
54
- return x
55
- else:
56
- return x + temp
57
-
58
-
59
- class SpectrogramUnet(nn.Module):
60
- def __init__(
61
- self,
62
- spec_dim: int,
63
- hidden_dim: int,
64
- filter_len: int,
65
- hop_len: int,
66
- layers: int = 3,
67
- block_layers: int = 3,
68
- kernel_size: int = 5,
69
- is_mask: bool = False,
70
- norm: str = "bn",
71
- act: str = "tanh",
72
- ):
73
- super().__init__()
74
- self.layers = layers
75
- self.is_mask = is_mask
76
-
77
- # stft modules
78
- self.stft = STFT(filter_len, hop_len)
79
-
80
- if norm == "bn":
81
- self.bn_func = nn.BatchNorm1d
82
- elif norm == "ins":
83
- self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
84
- else:
85
- raise NotImplementedError("{} is not implemented !".format(norm))
86
-
87
- if act == "tanh":
88
- self.act_func = nn.Tanh
89
- self.act_out = nn.Tanh
90
- elif act == "comp":
91
- self.act_func = ComplexActLayer
92
- self.act_out = lambda: ComplexActLayer(is_out=True)
93
- else:
94
- raise NotImplementedError("{} is not implemented !".format(act))
95
-
96
- # prev conv
97
- self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)
98
-
99
- # down
100
- self.down = nn.ModuleList()
101
- self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
102
- for idx in range(self.layers):
103
- block = ComplexConvBlock(
104
- hidden_dim,
105
- hidden_dim,
106
- kernel_size=kernel_size,
107
- padding=kernel_size // 2,
108
- bn_func=self.bn_func,
109
- act_func=self.act_func,
110
- layers=block_layers,
111
- )
112
- self.down.append(block)
113
-
114
- # up
115
- self.up = nn.ModuleList()
116
- for idx in range(self.layers):
117
- in_c = hidden_dim if idx == 0 else hidden_dim * 2
118
- self.up.append(
119
- nn.Sequential(
120
- ComplexConvBlock(
121
- in_c,
122
- hidden_dim,
123
- kernel_size=kernel_size,
124
- padding=kernel_size // 2,
125
- bn_func=self.bn_func,
126
- act_func=self.act_func,
127
- layers=block_layers,
128
- ),
129
- self.bn_func(hidden_dim),
130
- self.act_func(),
131
- ComplexTransposedConv1d(
132
- hidden_dim, hidden_dim, kernel_size=2, stride=2
133
- ),
134
- )
135
- )
136
-
137
- # out_conv
138
- self.out_conv = nn.Sequential(
139
- ComplexConvBlock(
140
- hidden_dim * 2,
141
- spec_dim * 2,
142
- kernel_size=kernel_size,
143
- padding=kernel_size // 2,
144
- bn_func=self.bn_func,
145
- act_func=self.act_func,
146
- ),
147
- self.bn_func(spec_dim * 2),
148
- self.act_func(),
149
- )
150
-
151
- # refine conv
152
- self.refine_conv = nn.Sequential(
153
- ComplexConvBlock(
154
- spec_dim * 4,
155
- spec_dim * 2,
156
- kernel_size=kernel_size,
157
- padding=kernel_size // 2,
158
- bn_func=self.bn_func,
159
- act_func=self.act_func,
160
- ),
161
- self.bn_func(spec_dim * 2),
162
- self.act_func(),
163
- )
164
-
165
- def log_stft(self, wav):
166
- # stft
167
- mag, phase = self.stft.transform(wav)
168
- return torch.log(mag + 1), phase
169
-
170
- def exp_istft(self, log_mag, phase):
171
- # exp
172
- mag = np.e**log_mag - 1
173
- # istft
174
- wav = self.stft.inverse(mag, phase)
175
- return wav
176
-
177
- def adjust_diff(self, x, target):
178
- size_diff = target.size()[-1] - x.size()[-1]
179
- assert size_diff >= 0
180
- if size_diff > 0:
181
- x = F.pad(
182
- x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect"
183
- ).squeeze(1)
184
- return x
185
-
186
- def masking(self, mag, phase, origin_mag, origin_phase):
187
- abs_mag = torch.abs(mag)
188
- mag_mask = torch.tanh(abs_mag)
189
- phase_mask = mag / abs_mag
190
-
191
- # masking
192
- mag = mag_mask * origin_mag
193
- phase = phase_mask * (origin_phase + phase)
194
- return mag, phase
195
-
196
- def forward(self, wav):
197
- # stft
198
- origin_mag, origin_phase = self.log_stft(wav)
199
- origin_x = torch.cat([origin_mag, origin_phase], dim=1)
200
-
201
- # prev
202
- x = self.prev_conv(origin_x)
203
-
204
- # body
205
- # down
206
- down_cache = []
207
- for idx, block in enumerate(self.down):
208
- x = block(x)
209
- down_cache.append(x)
210
- x = self.down_pool(x)
211
-
212
- # up
213
- for idx, block in enumerate(self.up):
214
- x = block(x)
215
- res = F.interpolate(
216
- down_cache[self.layers - (idx + 1)],
217
- size=[x.size()[2]],
218
- mode="linear",
219
- align_corners=False,
220
- )
221
- x = concat_complex(x, res, dim=1)
222
-
223
- # match spec dimension
224
- x = self.out_conv(x)
225
- if origin_mag.size(2) != x.size(2):
226
- x = F.interpolate(
227
- x, size=[origin_mag.size(2)], mode="linear", align_corners=False
228
- )
229
-
230
- # refine
231
- x = self.refine_conv(concat_complex(x, origin_x))
232
-
233
- def to_wav(stft):
234
- mag, phase = stft.chunk(2, 1)
235
- if self.is_mask:
236
- mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
237
- out = self.exp_istft(mag, phase)
238
- out = self.adjust_diff(out, wav)
239
- return out
240
-
241
- refine_wav = to_wav(x)
242
-
243
- return refine_wav
244
-
245
-
246
- class RefineSpectrogramUnet(SpectrogramUnet):
247
- def __init__(
248
- self,
249
- spec_dim: int,
250
- hidden_dim: int,
251
- filter_len: int,
252
- hop_len: int,
253
- layers: int = 4,
254
- block_layers: int = 4,
255
- kernel_size: int = 3,
256
- is_mask: bool = True,
257
- norm: str = "ins",
258
- act: str = "comp",
259
- refine_layers: int = 1,
260
- add_spec_results: bool = False,
261
- ):
262
- super().__init__(
263
- spec_dim,
264
- hidden_dim,
265
- filter_len,
266
- hop_len,
267
- layers,
268
- block_layers,
269
- kernel_size,
270
- is_mask,
271
- norm,
272
- act,
273
- )
274
- self.add_spec_results = add_spec_results
275
- # refine conv
276
- self.refine_conv = nn.ModuleList(
277
- [
278
- nn.Sequential(
279
- ComplexConvBlock(
280
- spec_dim * 2,
281
- spec_dim * 2,
282
- kernel_size=kernel_size,
283
- padding=kernel_size // 2,
284
- bn_func=self.bn_func,
285
- act_func=self.act_func,
286
- ),
287
- self.bn_func(spec_dim * 2),
288
- self.act_func(),
289
- )
290
- ]
291
- * refine_layers
292
- )
293
-
294
- def forward(self, wav):
295
- # stft
296
- origin_mag, origin_phase = self.log_stft(wav)
297
- origin_x = torch.cat([origin_mag, origin_phase], dim=1)
298
-
299
- # prev
300
- x = self.prev_conv(origin_x)
301
-
302
- # body
303
- # down
304
- down_cache = []
305
- for idx, block in enumerate(self.down):
306
- x = block(x)
307
- down_cache.append(x)
308
- x = self.down_pool(x)
309
-
310
- # up
311
- for idx, block in enumerate(self.up):
312
- x = block(x)
313
- res = F.interpolate(
314
- down_cache[self.layers - (idx + 1)],
315
- size=[x.size()[2]],
316
- mode="linear",
317
- align_corners=False,
318
- )
319
- x = concat_complex(x, res, dim=1)
320
-
321
- # match spec dimension
322
- x = self.out_conv(x)
323
- if origin_mag.size(2) != x.size(2):
324
- x = F.interpolate(
325
- x, size=[origin_mag.size(2)], mode="linear", align_corners=False
326
- )
327
-
328
- # refine
329
- for idx, refine_module in enumerate(self.refine_conv):
330
- x = refine_module(x)
331
- mag, phase = x.chunk(2, 1)
332
- mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
333
- if idx < len(self.refine_conv) - 1:
334
- x = torch.cat([mag, phase], dim=1)
335
-
336
- # clamp phase
337
- phase = phase.clamp(-np.pi, np.pi)
338
-
339
- out = self.exp_istft(mag, phase)
340
- out = self.adjust_diff(out, wav)
341
-
342
- if self.add_spec_results:
343
- out = (out, mag, phase)
344
-
345
- return out
346
-
347
-
348
- class _ComplexConvNd(nn.Module):
349
- """
350
- Implement Complex Convolution
351
- A: real weight
352
- B: img weight
353
- """
354
-
355
- def __init__(
356
- self,
357
- in_channels,
358
- out_channels,
359
- kernel_size,
360
- stride,
361
- padding,
362
- dilation,
363
- transposed,
364
- output_padding,
365
- ):
366
- super().__init__()
367
- self.in_channels = in_channels
368
- self.out_channels = out_channels
369
- self.kernel_size = kernel_size
370
- self.stride = stride
371
- self.padding = padding
372
- self.dilation = dilation
373
- self.output_padding = output_padding
374
- self.transposed = transposed
375
-
376
- self.A = self.make_weight(in_channels, out_channels, kernel_size)
377
- self.B = self.make_weight(in_channels, out_channels, kernel_size)
378
-
379
- self.reset_parameters()
380
-
381
- def make_weight(self, in_ch, out_ch, kernel_size):
382
- if self.transposed:
383
- tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size))
384
- else:
385
- tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size))
386
- return tensor
387
-
388
- def reset_parameters(self):
389
- # init real weight
390
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A)
391
-
392
- # init A
393
- gain = calculate_gain("leaky_relu", 0)
394
- std = gain / np.sqrt(fan_in)
395
- bound = np.sqrt(3.0) * std
396
-
397
- with torch.no_grad():
398
- # TODO: find more stable initial values
399
- self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2)))
400
- #
401
- # B is initialized by pi
402
- # -pi and pi is too big, so it is powed by -1
403
- self.B.uniform_(-1 / np.pi, 1 / np.pi)
404
-
405
-
406
- class ComplexConv1d(_ComplexConvNd):
407
- """
408
- Complex Convolution 1d
409
- """
410
-
411
- def __init__(
412
- self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
413
- ):
414
- kernel_size = single(kernel_size)
415
- stride = single(stride)
416
- # edit padding
417
- padding = padding
418
- dilation = single(dilation)
419
- super(ComplexConv1d, self).__init__(
420
- in_channels,
421
- out_channels,
422
- kernel_size,
423
- stride,
424
- padding,
425
- dilation,
426
- False,
427
- single(0),
428
- )
429
-
430
- def forward(self, x):
431
- """
432
- Implemented complex convolution using combining 'grouped convolution' and
433
- 'real / img weight'
434
- :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
435
- :return: complex conved result
436
- """
437
- # adopt reflect padding
438
- if self.padding:
439
- x = F.pad(x, (self.padding, self.padding), "reflect")
440
-
441
- # forward real
442
- real_part = F.conv1d(
443
- x,
444
- self.A,
445
- None,
446
- stride=self.stride,
447
- padding=0,
448
- dilation=self.dilation,
449
- groups=2,
450
- )
451
-
452
- # forward idea
453
- spl = self.in_channels // 2
454
- weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data])
455
- idea_part = F.conv1d(
456
- x,
457
- weight_B,
458
- None,
459
- stride=self.stride,
460
- padding=0,
461
- dilation=self.dilation,
462
- groups=2,
463
- )
464
-
465
- return real_part + idea_part
466
-
467
-
468
- class ComplexTransposedConv1d(_ComplexConvNd):
469
- """
470
- Complex Transposed Convolution 1d
471
- """
472
-
473
- def __init__(
474
- self,
475
- in_channels,
476
- out_channels,
477
- kernel_size,
478
- stride=1,
479
- padding=0,
480
- output_padding=0,
481
- dilation=1,
482
- ):
483
- kernel_size = single(kernel_size)
484
- stride = single(stride)
485
- padding = padding
486
- dilation = single(dilation)
487
- super().__init__(
488
- in_channels,
489
- out_channels,
490
- kernel_size,
491
- stride,
492
- padding,
493
- dilation,
494
- True,
495
- output_padding,
496
- )
497
-
498
- def forward(self, x, output_size=None):
499
- """
500
- Implemented complex transposed convolution using combining 'grouped convolution'
501
- and 'real / img weight'
502
- :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
503
- :return: complex transposed convolution result
504
- """
505
- # forward real
506
- if self.padding:
507
- x = F.pad(x, (self.padding, self.padding), "reflect")
508
-
509
- real_part = F.conv_transpose1d(
510
- x,
511
- self.A,
512
- None,
513
- stride=self.stride,
514
- padding=0,
515
- dilation=self.dilation,
516
- groups=2,
517
- )
518
-
519
- # forward idea
520
- spl = self.out_channels // 2
521
- weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]])
522
- idea_part = F.conv_transpose1d(
523
- x,
524
- weight_B,
525
- None,
526
- stride=self.stride,
527
- padding=0,
528
- dilation=self.dilation,
529
- groups=2,
530
- )
531
-
532
- if self.output_padding:
533
- real_part = F.pad(
534
- real_part, (self.output_padding, self.output_padding), "reflect"
535
- )
536
- idea_part = F.pad(
537
- idea_part, (self.output_padding, self.output_padding), "reflect"
538
- )
539
-
540
- return real_part + idea_part
541
-
542
-
543
- class ComplexActLayer(nn.Module):
544
- """
545
- Activation differently 'real' part and 'img' part
546
- In implemented DCUnet on this repository, Real part is activated to log space.
547
- And Phase(img) part, it is distributed in [-pi, pi]...
548
- """
549
-
550
- def forward(self, x):
551
- real, img = x.chunk(2, 1)
552
- return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1)
553
-
554
-
555
- class STFT(nn.Module):
556
- """
557
- Re-construct stft for calculating backward operation
558
- refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py
559
- """
560
-
561
- def __init__(
562
- self,
563
- filter_length: int = 1024,
564
- hop_length: int = 512,
565
- win_length: int = None,
566
- window: str = "hann",
567
- ):
568
- super().__init__()
569
- self.filter_length = filter_length
570
- self.hop_length = hop_length
571
- self.win_length = win_length if win_length else filter_length
572
- self.window = window
573
- self.pad_amount = self.filter_length // 2
574
-
575
- # make fft window
576
- assert filter_length >= self.win_length
577
- # get window and zero center pad it to filter_length
578
- fft_window = get_window(window, self.win_length, fftbins=True)
579
- fft_window = pad_center(fft_window, filter_length)
580
- fft_window = torch.from_numpy(fft_window).float()
581
-
582
- # calculate fourer_basis
583
- cut_off = int((self.filter_length / 2 + 1))
584
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
585
- fourier_basis = np.vstack(
586
- [np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])]
587
- )
588
-
589
- # make forward & inverse basis
590
- self.register_buffer("square_window", fft_window**2)
591
-
592
- forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window
593
- inverse_basis = (
594
- torch.FloatTensor(
595
- np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[
596
- :, np.newaxis, :
597
- ]
598
- )
599
- * fft_window
600
- )
601
- # torch.pinverse has a bug, so at this time, it is separated into two parts..
602
- self.register_buffer("forward_basis", forward_basis)
603
- self.register_buffer("inverse_basis", inverse_basis)
604
-
605
- def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
606
- # reflect padding
607
- wav = wav.unsqueeze(1).unsqueeze(1)
608
- wav = F.pad(
609
- wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect"
610
- ).squeeze(1)
611
-
612
- # conv
613
- forward_trans = F.conv1d(
614
- wav, self.forward_basis, stride=self.hop_length, padding=0
615
- )
616
- real_part, imag_part = forward_trans.chunk(2, 1)
617
-
618
- return torch.sqrt(real_part**2 + imag_part**2), torch.atan2(
619
- imag_part.data, real_part.data
620
- )
621
-
622
- def inverse(
623
- self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9
624
- ) -> torch.Tensor:
625
- comp = torch.cat(
626
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
627
- )
628
- inverse_transform = F.conv_transpose1d(
629
- comp, self.inverse_basis, stride=self.hop_length, padding=0
630
- )
631
-
632
- # remove window effect
633
- n_frames = comp.size(-1)
634
- inverse_size = inverse_transform.size(-1)
635
-
636
- window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)
637
-
638
- weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0)
639
- window_filter = F.conv_transpose1d(
640
- window_filter, weight, stride=self.hop_length, padding=0
641
- )
642
- window_filter = window_filter.squeeze()[:inverse_size] + eps
643
-
644
- inverse_transform /= window_filter
645
-
646
- # scale by hop ratio
647
- inverse_transform *= self.filter_length / self.hop_length
648
-
649
- return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/dptnet.py DELETED
@@ -1,459 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.nn.modules.container import ModuleList
5
- from torch.nn.modules.activation import MultiheadAttention
6
- from torch.nn.modules.dropout import Dropout
7
- from torch.nn.modules.linear import Linear
8
- from torch.nn.modules.rnn import LSTM
9
- from torch.nn.modules.normalization import LayerNorm
10
- from torch.autograd import Variable
11
- import copy
12
- import math
13
-
14
-
15
- # adapted from https://github.com/ujscjj/DPTNet
16
-
17
-
18
- class DPTNet_base(nn.Module):
19
- def __init__(
20
- self,
21
- enc_dim,
22
- feature_dim,
23
- hidden_dim,
24
- layer,
25
- segment_size=250,
26
- nspk=2,
27
- win_len=2,
28
- ):
29
- super().__init__()
30
- # parameters
31
- self.window = win_len
32
- self.stride = self.window // 2
33
-
34
- self.enc_dim = enc_dim
35
- self.feature_dim = feature_dim
36
- self.hidden_dim = hidden_dim
37
- self.segment_size = segment_size
38
-
39
- self.layer = layer
40
- self.num_spk = nspk
41
- self.eps = 1e-8
42
-
43
- self.dpt_encoder = DPTEncoder(
44
- n_filters=enc_dim,
45
- window_size=win_len,
46
- )
47
- self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8)
48
- self.dpt_separation = DPTSeparation(
49
- self.enc_dim,
50
- self.feature_dim,
51
- self.hidden_dim,
52
- self.num_spk,
53
- self.layer,
54
- self.segment_size,
55
- )
56
-
57
- self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
- self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
-
60
- def forward(self, mix):
61
- """
62
- mix: shape (batch, T)
63
- """
64
- batch_size = mix.shape[0]
65
- mix = self.dpt_encoder(mix) # (B, E, L)
66
-
67
- score_ = self.enc_LN(mix) # B, E, L
68
- score_ = self.dpt_separation(score_) # B, nspk, T, N
69
- score_ = (
70
- score_.view(batch_size * self.num_spk, -1, self.feature_dim)
71
- .transpose(1, 2)
72
- .contiguous()
73
- ) # B*nspk, N, T
74
- score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L]
75
- score = score.view(
76
- batch_size, self.num_spk, self.enc_dim, -1
77
- ) # [B*nspk, E, L] -> [B, nspk, E, L]
78
- est_mask = F.relu(score)
79
-
80
- est_source = self.decoder(
81
- mix, est_mask
82
- ) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T]
83
-
84
- return est_source
85
-
86
-
87
- class DPTEncoder(nn.Module):
88
- def __init__(self, n_filters: int = 64, window_size: int = 2):
89
- super().__init__()
90
- self.conv = nn.Conv1d(
91
- 1, n_filters, kernel_size=window_size, stride=window_size // 2, bias=False
92
- )
93
-
94
- def forward(self, x):
95
- x = x.unsqueeze(1)
96
- x = F.relu(self.conv(x))
97
- return x
98
-
99
-
100
- class TransformerEncoderLayer(torch.nn.Module):
101
- def __init__(
102
- self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"
103
- ):
104
- super(TransformerEncoderLayer, self).__init__()
105
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
106
-
107
- # Implementation of improved part
108
- self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
109
- self.dropout = Dropout(dropout)
110
- self.linear = Linear(hidden_size * 2, d_model)
111
-
112
- self.norm1 = LayerNorm(d_model)
113
- self.norm2 = LayerNorm(d_model)
114
- self.dropout1 = Dropout(dropout)
115
- self.dropout2 = Dropout(dropout)
116
-
117
- self.activation = _get_activation_fn(activation)
118
-
119
- def __setstate__(self, state):
120
- if "activation" not in state:
121
- state["activation"] = F.relu
122
- super(TransformerEncoderLayer, self).__setstate__(state)
123
-
124
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
125
- r"""Pass the input through the encoder layer.
126
- Args:
127
- src: the sequnce to the encoder layer (required).
128
- src_mask: the mask for the src sequence (optional).
129
- src_key_padding_mask: the mask for the src keys per batch (optional).
130
- Shape:
131
- see the docs in Transformer class.
132
- """
133
- src2 = self.self_attn(
134
- src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
135
- )[0]
136
- src = src + self.dropout1(src2)
137
- src = self.norm1(src)
138
- src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
139
- src = src + self.dropout2(src2)
140
- src = self.norm2(src)
141
- return src
142
-
143
-
144
- def _get_clones(module, N):
145
- return ModuleList([copy.deepcopy(module) for i in range(N)])
146
-
147
-
148
- def _get_activation_fn(activation):
149
- if activation == "relu":
150
- return F.relu
151
- elif activation == "gelu":
152
- return F.gelu
153
-
154
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
155
-
156
-
157
- class SingleTransformer(nn.Module):
158
- """
159
- Container module for a single Transformer layer.
160
- args: input_size: int, dimension of the input feature.
161
- The input should have shape (batch, seq_len, input_size).
162
- """
163
-
164
- def __init__(self, input_size, hidden_size, dropout):
165
- super(SingleTransformer, self).__init__()
166
- self.transformer = TransformerEncoderLayer(
167
- d_model=input_size,
168
- nhead=4,
169
- hidden_size=hidden_size,
170
- dim_feedforward=hidden_size * 2,
171
- dropout=dropout,
172
- )
173
-
174
- def forward(self, input):
175
- # input shape: batch, seq, dim
176
- output = input
177
- transformer_output = (
178
- self.transformer(output.permute(1, 0, 2).contiguous())
179
- .permute(1, 0, 2)
180
- .contiguous()
181
- )
182
- return transformer_output
183
-
184
-
185
- # dual-path transformer
186
- class DPT(nn.Module):
187
- """
188
- Deep dual-path transformer.
189
- args:
190
- input_size: int, dimension of the input feature. The input should have shape
191
- (batch, seq_len, input_size).
192
- hidden_size: int, dimension of the hidden state.
193
- output_size: int, dimension of the output size.
194
- num_layers: int, number of stacked Transformer layers. Default is 1.
195
- dropout: float, dropout ratio. Default is 0.
196
- """
197
-
198
- def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
199
- super(DPT, self).__init__()
200
-
201
- self.input_size = input_size
202
- self.output_size = output_size
203
- self.hidden_size = hidden_size
204
-
205
- # dual-path transformer
206
- self.row_transformer = nn.ModuleList([])
207
- self.col_transformer = nn.ModuleList([])
208
- for i in range(num_layers):
209
- self.row_transformer.append(
210
- SingleTransformer(input_size, hidden_size, dropout)
211
- )
212
- self.col_transformer.append(
213
- SingleTransformer(input_size, hidden_size, dropout)
214
- )
215
-
216
- # output layer
217
- self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
218
-
219
- def forward(self, input):
220
- # input shape: batch, N, dim1, dim2
221
- # apply transformer on dim1 first and then dim2
222
- # output shape: B, output_size, dim1, dim2
223
- # input = input.to(device)
224
- batch_size, _, dim1, dim2 = input.shape
225
- output = input
226
- for i in range(len(self.row_transformer)):
227
- row_input = (
228
- output.permute(0, 3, 2, 1)
229
- .contiguous()
230
- .view(batch_size * dim2, dim1, -1)
231
- ) # B*dim2, dim1, N
232
- row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
233
- row_output = (
234
- row_output.view(batch_size, dim2, dim1, -1)
235
- .permute(0, 3, 2, 1)
236
- .contiguous()
237
- ) # B, N, dim1, dim2
238
- output = row_output
239
-
240
- col_input = (
241
- output.permute(0, 2, 3, 1)
242
- .contiguous()
243
- .view(batch_size * dim1, dim2, -1)
244
- ) # B*dim1, dim2, N
245
- col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
246
- col_output = (
247
- col_output.view(batch_size, dim1, dim2, -1)
248
- .permute(0, 3, 1, 2)
249
- .contiguous()
250
- ) # B, N, dim1, dim2
251
- output = col_output
252
-
253
- output = self.output(output) # B, output_size, dim1, dim2
254
-
255
- return output
256
-
257
-
258
- # base module for deep DPT
259
- class DPT_base(nn.Module):
260
- def __init__(
261
- self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250
262
- ):
263
- super(DPT_base, self).__init__()
264
-
265
- self.input_dim = input_dim
266
- self.feature_dim = feature_dim
267
- self.hidden_dim = hidden_dim
268
-
269
- self.layer = layer
270
- self.segment_size = segment_size
271
- self.num_spk = num_spk
272
-
273
- self.eps = 1e-8
274
-
275
- # bottleneck
276
- self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
277
-
278
- # DPT model
279
- self.DPT = DPT(
280
- self.feature_dim,
281
- self.hidden_dim,
282
- self.feature_dim * self.num_spk,
283
- num_layers=layer,
284
- )
285
-
286
- def pad_segment(self, input, segment_size):
287
- # input is the features: (B, N, T)
288
- batch_size, dim, seq_len = input.shape
289
- segment_stride = segment_size // 2
290
-
291
- rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
292
- if rest > 0:
293
- pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
294
- input = torch.cat([input, pad], 2)
295
-
296
- pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(
297
- input.type()
298
- )
299
- input = torch.cat([pad_aux, input, pad_aux], 2)
300
-
301
- return input, rest
302
-
303
- def split_feature(self, input, segment_size):
304
- # split the feature into chunks of segment size
305
- # input is the features: (B, N, T)
306
-
307
- input, rest = self.pad_segment(input, segment_size)
308
- batch_size, dim, seq_len = input.shape
309
- segment_stride = segment_size // 2
310
-
311
- segments1 = (
312
- input[:, :, :-segment_stride]
313
- .contiguous()
314
- .view(batch_size, dim, -1, segment_size)
315
- )
316
- segments2 = (
317
- input[:, :, segment_stride:]
318
- .contiguous()
319
- .view(batch_size, dim, -1, segment_size)
320
- )
321
- segments = (
322
- torch.cat([segments1, segments2], 3)
323
- .view(batch_size, dim, -1, segment_size)
324
- .transpose(2, 3)
325
- )
326
-
327
- return segments.contiguous(), rest
328
-
329
- def merge_feature(self, input, rest):
330
- # merge the splitted features into full utterance
331
- # input is the features: (B, N, L, K)
332
-
333
- batch_size, dim, segment_size, _ = input.shape
334
- segment_stride = segment_size // 2
335
- input = (
336
- input.transpose(2, 3)
337
- .contiguous()
338
- .view(batch_size, dim, -1, segment_size * 2)
339
- ) # B, N, K, L
340
-
341
- input1 = (
342
- input[:, :, :, :segment_size]
343
- .contiguous()
344
- .view(batch_size, dim, -1)[:, :, segment_stride:]
345
- )
346
- input2 = (
347
- input[:, :, :, segment_size:]
348
- .contiguous()
349
- .view(batch_size, dim, -1)[:, :, :-segment_stride]
350
- )
351
-
352
- output = input1 + input2
353
- if rest > 0:
354
- output = output[:, :, :-rest]
355
-
356
- return output.contiguous() # B, N, T
357
-
358
- def forward(self, input):
359
- pass
360
-
361
-
362
- class DPTSeparation(DPT_base):
363
- def __init__(self, *args, **kwargs):
364
- super(DPTSeparation, self).__init__(*args, **kwargs)
365
-
366
- # gated output layer
367
- self.output = nn.Sequential(
368
- nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Tanh()
369
- )
370
- self.output_gate = nn.Sequential(
371
- nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Sigmoid()
372
- )
373
-
374
- def forward(self, input):
375
- # input = input.to(device)
376
- # input: (B, E, T)
377
- batch_size, E, seq_length = input.shape
378
-
379
- enc_feature = self.BN(input) # (B, E, L)-->(B, N, L)
380
- # split the encoder output into overlapped, longer segments
381
- enc_segments, enc_rest = self.split_feature(
382
- enc_feature, self.segment_size
383
- ) # B, N, L, K: L is the segment_size
384
- # print('enc_segments.shape {}'.format(enc_segments.shape))
385
- # pass to DPT
386
- output = self.DPT(enc_segments).view(
387
- batch_size * self.num_spk, self.feature_dim, self.segment_size, -1
388
- ) # B*nspk, N, L, K
389
-
390
- # overlap-and-add of the outputs
391
- output = self.merge_feature(output, enc_rest) # B*nspk, N, T
392
-
393
- # gated output layer for filter generation
394
- bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
395
- bf_filter = (
396
- bf_filter.transpose(1, 2)
397
- .contiguous()
398
- .view(batch_size, self.num_spk, -1, self.feature_dim)
399
- ) # B, nspk, T, N
400
-
401
- return bf_filter
402
-
403
-
404
- class DPTDecoder(nn.Module):
405
- def __init__(self, n_filters: int = 64, window_size: int = 2):
406
- super().__init__()
407
- self.W = window_size
408
- self.basis_signals = nn.Linear(n_filters, window_size, bias=False)
409
-
410
- def forward(self, mixture, mask):
411
- """
412
- mixture: (batch, n_filters, L)
413
- mask: (batch, sources, n_filters, L)
414
- """
415
- source_w = torch.unsqueeze(mixture, 1) * mask # [B, C, E, L]
416
- source_w = torch.transpose(source_w, 2, 3) # [B, C, L, E]
417
- # S = DV
418
- est_source = self.basis_signals(source_w) # [B, C, L, W]
419
- est_source = overlap_and_add(est_source, self.W // 2) # B x C x T
420
- return est_source
421
-
422
-
423
- def overlap_and_add(signal, frame_step):
424
- """Reconstructs a signal from a framed representation.
425
- Adds potentially overlapping frames of a signal with shape
426
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
427
- The resulting tensor has shape `[..., output_size]` where
428
- output_size = (frames - 1) * frame_step + frame_length
429
- Args:
430
- signal: A [..., frames, frame_length] Tensor.
431
- All dimensions may be unknown, and rank must be at least 2.
432
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
433
- Returns:
434
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's
435
- inner-most two dimensions.
436
- output_size = (frames - 1) * frame_step + frame_length
437
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
438
- """
439
- outer_dimensions = signal.size()[:-2]
440
- frames, frame_length = signal.size()[-2:]
441
-
442
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
443
- subframe_step = frame_step // subframe_length
444
- subframes_per_frame = frame_length // subframe_length
445
- output_size = frame_step * (frames - 1) + frame_length
446
- output_subframes = output_size // subframe_length
447
-
448
- subframe_signal = signal.reshape(*outer_dimensions, -1, subframe_length)
449
-
450
- frame = torch.arange(0, output_subframes).unfold(
451
- 0, subframes_per_frame, subframe_step
452
- )
453
- frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
454
- frame = frame.contiguous().view(-1)
455
-
456
- result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
457
- result.index_add_(-2, frame, subframe_signal)
458
- result = result.view(*outer_dimensions, -1)
459
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/models.py CHANGED
@@ -226,7 +226,7 @@ class DCUNetModel(nn.Module):
226
 
227
  def forward(self, batch):
228
  x, target = batch
229
- output = self.model(x.squeeze(1)) # B x 1 x T
230
  # Crop target to match output
231
  if output.shape[-1] < target.shape[-1]:
232
  target = causal_crop(target, output.shape[-1])
@@ -234,7 +234,7 @@ class DCUNetModel(nn.Module):
234
  return loss, output
235
 
236
  def sample(self, x: Tensor) -> Tensor:
237
- output = self.model(x.squeeze(1)) # B x 1 x T
238
  return output
239
 
240
 
 
226
 
227
  def forward(self, batch):
228
  x, target = batch
229
+ output = self.model(x.squeeze(1)) # B x T
230
  # Crop target to match output
231
  if output.shape[-1] < target.shape[-1]:
232
  target = causal_crop(target, output.shape[-1])
 
234
  return loss, output
235
 
236
  def sample(self, x: Tensor) -> Tensor:
237
+ output = self.model(x.squeeze(1)) # B x T
238
  return output
239
 
240