poiqazwsx commited on
Commit
a45061e
1 Parent(s): 9db67e9

Upload 7 files

Browse files
bsroformer/bs_roformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.bs_roformer.bs_roformer import BSRoformer
2
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
bsroformer/bs_roformer/attend.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, reduce
10
+
11
+ # constants
12
+
13
+ FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14
+
15
+ # helpers
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def default(v, d):
21
+ return v if exists(v) else d
22
+
23
+ def once(fn):
24
+ called = False
25
+ @wraps(fn)
26
+ def inner(x):
27
+ nonlocal called
28
+ if called:
29
+ return
30
+ called = True
31
+ return fn(x)
32
+ return inner
33
+
34
+ print_once = once(print)
35
+
36
+ # main class
37
+
38
+ class Attend(nn.Module):
39
+ def __init__(
40
+ self,
41
+ dropout = 0.,
42
+ flash = False,
43
+ scale = None
44
+ ):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52
+
53
+ # determine efficient attention configs for cuda and cpu
54
+
55
+ self.cpu_config = FlashAttentionConfig(True, True, True)
56
+ self.cuda_config = None
57
+
58
+ if not torch.cuda.is_available() or not flash:
59
+ return
60
+
61
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62
+
63
+ if device_properties.major == 8 and device_properties.minor == 0:
64
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65
+ self.cuda_config = FlashAttentionConfig(True, False, False)
66
+ else:
67
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68
+ self.cuda_config = FlashAttentionConfig(False, True, True)
69
+
70
+ def flash_attn(self, q, k, v):
71
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72
+
73
+ if exists(self.scale):
74
+ default_scale = q.shape[-1] ** -0.5
75
+ q = q * (self.scale / default_scale)
76
+
77
+ # Check if there is a compatible device for flash attention
78
+
79
+ config = self.cuda_config if is_cuda else self.cpu_config
80
+
81
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
82
+
83
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
84
+ out = F.scaled_dot_product_attention(
85
+ q, k, v,
86
+ dropout_p = self.dropout if self.training else 0.
87
+ )
88
+
89
+ return out
90
+
91
+ def forward(self, q, k, v):
92
+ """
93
+ einstein notation
94
+ b - batch
95
+ h - heads
96
+ n, i, j - sequence length (base sequence length, source, target)
97
+ d - feature dimension
98
+ """
99
+
100
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
101
+
102
+ scale = default(self.scale, q.shape[-1] ** -0.5)
103
+
104
+ if self.flash:
105
+ return self.flash_attn(q, k, v)
106
+
107
+ # similarity
108
+
109
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
110
+
111
+ # attention
112
+
113
+ attn = sim.softmax(dim=-1)
114
+ attn = self.attn_dropout(attn)
115
+
116
+ # aggregate values
117
+
118
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119
+
120
+ return out
bsroformer/bs_roformer/bs_roformer.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack
16
+ from einops.layers.torch import Rearrange
17
+
18
+ # helper functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+
28
+ def pack_one(t, pattern):
29
+ return pack([t], pattern)
30
+
31
+
32
+ def unpack_one(t, ps, pattern):
33
+ return unpack(t, ps, pattern)[0]
34
+
35
+
36
+ # norm
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim = -1, p = 2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim ** 0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ # attention
53
+
54
+ class FeedForward(Module):
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ mult=4,
59
+ dropout=0.
60
+ ):
61
+ super().__init__()
62
+ dim_inner = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ RMSNorm(dim),
65
+ nn.Linear(dim, dim_inner),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(dim_inner, dim),
69
+ nn.Dropout(dropout)
70
+ )
71
+
72
+ def forward(self, x):
73
+ return self.net(x)
74
+
75
+
76
+ class Attention(Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.,
83
+ rotary_embed=None,
84
+ flash=True
85
+ ):
86
+ super().__init__()
87
+ self.heads = heads
88
+ self.scale = dim_head ** -0.5
89
+ dim_inner = heads * dim_head
90
+
91
+ self.rotary_embed = rotary_embed
92
+
93
+ self.attend = Attend(flash=flash, dropout=dropout)
94
+
95
+ self.norm = RMSNorm(dim)
96
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
+
98
+ self.to_gates = nn.Linear(dim, heads)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(dim_inner, dim, bias=False),
102
+ nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
109
+
110
+ if exists(self.rotary_embed):
111
+ q = self.rotary_embed.rotate_queries_or_keys(q)
112
+ k = self.rotary_embed.rotate_queries_or_keys(k)
113
+
114
+ out = self.attend(q, k, v)
115
+
116
+ gates = self.to_gates(x)
117
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
118
+
119
+ out = rearrange(out, 'b h n d -> b n (h d)')
120
+ return self.to_out(out)
121
+
122
+
123
+ class LinearAttention(Module):
124
+ """
125
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
126
+ """
127
+
128
+ @beartype
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim,
133
+ dim_head=32,
134
+ heads=8,
135
+ scale=8,
136
+ flash=False,
137
+ dropout=0.
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
155
+
156
+ self.to_out = nn.Sequential(
157
+ Rearrange('b h d n -> b n (h d)'),
158
+ nn.Linear(dim_inner, dim, bias=False)
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ x
164
+ ):
165
+ x = self.norm(x)
166
+
167
+ q, k, v = self.to_qkv(x)
168
+
169
+ q, k = map(l2norm, (q, k))
170
+ q = q * self.temperature.exp()
171
+
172
+ out = self.attend(q, k, v)
173
+
174
+ return self.to_out(out)
175
+
176
+
177
+ class Transformer(Module):
178
+ def __init__(
179
+ self,
180
+ *,
181
+ dim,
182
+ depth,
183
+ dim_head=64,
184
+ heads=8,
185
+ attn_dropout=0.,
186
+ ff_dropout=0.,
187
+ ff_mult=4,
188
+ norm_output=True,
189
+ rotary_embed=None,
190
+ flash_attn=True,
191
+ linear_attn=False
192
+ ):
193
+ super().__init__()
194
+ self.layers = ModuleList([])
195
+
196
+ for _ in range(depth):
197
+ if linear_attn:
198
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
199
+ else:
200
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
201
+ rotary_embed=rotary_embed, flash=flash_attn)
202
+
203
+ self.layers.append(ModuleList([
204
+ attn,
205
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
206
+ ]))
207
+
208
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
209
+
210
+ def forward(self, x):
211
+
212
+ for attn, ff in self.layers:
213
+ x = attn(x) + x
214
+ x = ff(x) + x
215
+
216
+ return self.norm(x)
217
+
218
+
219
+ # bandsplit module
220
+
221
+ class BandSplit(Module):
222
+ @beartype
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ dim_inputs: Tuple[int, ...]
227
+ ):
228
+ super().__init__()
229
+ self.dim_inputs = dim_inputs
230
+ self.to_features = ModuleList([])
231
+
232
+ for dim_in in dim_inputs:
233
+ net = nn.Sequential(
234
+ RMSNorm(dim_in),
235
+ nn.Linear(dim_in, dim)
236
+ )
237
+
238
+ self.to_features.append(net)
239
+
240
+ def forward(self, x):
241
+ x = x.split(self.dim_inputs, dim=-1)
242
+
243
+ outs = []
244
+ for split_input, to_feature in zip(x, self.to_features):
245
+ split_output = to_feature(split_input)
246
+ outs.append(split_output)
247
+
248
+ return torch.stack(outs, dim=-2)
249
+
250
+
251
+ def MLP(
252
+ dim_in,
253
+ dim_out,
254
+ dim_hidden=None,
255
+ depth=1,
256
+ activation=nn.Tanh
257
+ ):
258
+ dim_hidden = default(dim_hidden, dim_in)
259
+
260
+ net = []
261
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
262
+
263
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
264
+ is_last = ind == (len(dims) - 2)
265
+
266
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
267
+
268
+ if is_last:
269
+ continue
270
+
271
+ net.append(activation())
272
+
273
+ return nn.Sequential(*net)
274
+
275
+
276
+ class MaskEstimator(Module):
277
+ @beartype
278
+ def __init__(
279
+ self,
280
+ dim,
281
+ dim_inputs: Tuple[int, ...],
282
+ depth,
283
+ mlp_expansion_factor=4
284
+ ):
285
+ super().__init__()
286
+ self.dim_inputs = dim_inputs
287
+ self.to_freqs = ModuleList([])
288
+ dim_hidden = dim * mlp_expansion_factor
289
+
290
+ for dim_in in dim_inputs:
291
+ net = []
292
+
293
+ mlp = nn.Sequential(
294
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
295
+ nn.GLU(dim=-1)
296
+ )
297
+
298
+ self.to_freqs.append(mlp)
299
+
300
+ def forward(self, x):
301
+ x = x.unbind(dim=-2)
302
+
303
+ outs = []
304
+
305
+ for band_features, mlp in zip(x, self.to_freqs):
306
+ freq_out = mlp(band_features)
307
+ outs.append(freq_out)
308
+
309
+ return torch.cat(outs, dim=-1)
310
+
311
+
312
+ # main class
313
+
314
+ DEFAULT_FREQS_PER_BANDS = (
315
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
316
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
317
+ 2, 2, 2, 2,
318
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
319
+ 12, 12, 12, 12, 12, 12, 12, 12,
320
+ 24, 24, 24, 24, 24, 24, 24, 24,
321
+ 48, 48, 48, 48, 48, 48, 48, 48,
322
+ 128, 129,
323
+ )
324
+
325
+
326
+ class BSRoformer(Module):
327
+
328
+ @beartype
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ *,
333
+ depth,
334
+ stereo=False,
335
+ num_stems=1,
336
+ time_transformer_depth=2,
337
+ freq_transformer_depth=2,
338
+ linear_transformer_depth=0,
339
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
340
+ # in the paper, they divide into ~60 bands, test with 1 for starters
341
+ dim_head=64,
342
+ heads=8,
343
+ attn_dropout=0.,
344
+ ff_dropout=0.,
345
+ flash_attn=True,
346
+ dim_freqs_in=1025,
347
+ stft_n_fft=2048,
348
+ stft_hop_length=512,
349
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
350
+ stft_win_length=2048,
351
+ stft_normalized=False,
352
+ stft_window_fn: Optional[Callable] = None,
353
+ mask_estimator_depth=2,
354
+ multi_stft_resolution_loss_weight=1.,
355
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
356
+ multi_stft_hop_size=147,
357
+ multi_stft_normalized=False,
358
+ multi_stft_window_fn: Callable = torch.hann_window
359
+ ):
360
+ super().__init__()
361
+
362
+ self.stereo = stereo
363
+ self.audio_channels = 2 if stereo else 1
364
+ self.num_stems = num_stems
365
+
366
+ self.layers = ModuleList([])
367
+
368
+ transformer_kwargs = dict(
369
+ dim=dim,
370
+ heads=heads,
371
+ dim_head=dim_head,
372
+ attn_dropout=attn_dropout,
373
+ ff_dropout=ff_dropout,
374
+ flash_attn=flash_attn,
375
+ norm_output=False
376
+ )
377
+
378
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
379
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
380
+
381
+ for _ in range(depth):
382
+ tran_modules = []
383
+ if linear_transformer_depth > 0:
384
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
385
+ tran_modules.append(
386
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
387
+ )
388
+ tran_modules.append(
389
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
390
+ )
391
+ self.layers.append(nn.ModuleList(tran_modules))
392
+
393
+ self.final_norm = RMSNorm(dim)
394
+
395
+ self.stft_kwargs = dict(
396
+ n_fft=stft_n_fft,
397
+ hop_length=stft_hop_length,
398
+ win_length=stft_win_length,
399
+ normalized=stft_normalized
400
+ )
401
+
402
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
403
+
404
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
405
+
406
+ assert len(freqs_per_bands) > 1
407
+ assert sum(
408
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
409
+
410
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
411
+
412
+ self.band_split = BandSplit(
413
+ dim=dim,
414
+ dim_inputs=freqs_per_bands_with_complex
415
+ )
416
+
417
+ self.mask_estimators = nn.ModuleList([])
418
+
419
+ for _ in range(num_stems):
420
+ mask_estimator = MaskEstimator(
421
+ dim=dim,
422
+ dim_inputs=freqs_per_bands_with_complex,
423
+ depth=mask_estimator_depth
424
+ )
425
+
426
+ self.mask_estimators.append(mask_estimator)
427
+
428
+ # for the multi-resolution stft loss
429
+
430
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
431
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
432
+ self.multi_stft_n_fft = stft_n_fft
433
+ self.multi_stft_window_fn = multi_stft_window_fn
434
+
435
+ self.multi_stft_kwargs = dict(
436
+ hop_length=multi_stft_hop_size,
437
+ normalized=multi_stft_normalized
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ raw_audio,
443
+ target=None,
444
+ return_loss_breakdown=False
445
+ ):
446
+ """
447
+ einops
448
+
449
+ b - batch
450
+ f - freq
451
+ t - time
452
+ s - audio channel (1 for mono, 2 for stereo)
453
+ n - number of 'stems'
454
+ c - complex (2)
455
+ d - feature dimension
456
+ """
457
+
458
+ device = raw_audio.device
459
+
460
+ if raw_audio.ndim == 2:
461
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
462
+
463
+ channels = raw_audio.shape[1]
464
+ assert (not self.stereo and channels == 1) or (
465
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
466
+
467
+ # to stft
468
+
469
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
470
+
471
+ stft_window = self.stft_window_fn(device=device)
472
+
473
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
474
+ stft_repr = torch.view_as_real(stft_repr)
475
+
476
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
477
+ stft_repr = rearrange(stft_repr,
478
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
479
+
480
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
481
+
482
+ x = self.band_split(x)
483
+
484
+ # axial / hierarchical attention
485
+
486
+ for transformer_block in self.layers:
487
+
488
+ if len(transformer_block) == 3:
489
+ linear_transformer, time_transformer, freq_transformer = transformer_block
490
+
491
+ x, ft_ps = pack([x], 'b * d')
492
+ x = linear_transformer(x)
493
+ x, = unpack(x, ft_ps, 'b * d')
494
+ else:
495
+ time_transformer, freq_transformer = transformer_block
496
+
497
+ x = rearrange(x, 'b t f d -> b f t d')
498
+ x, ps = pack([x], '* t d')
499
+
500
+ x = time_transformer(x)
501
+
502
+ x, = unpack(x, ps, '* t d')
503
+ x = rearrange(x, 'b f t d -> b t f d')
504
+ x, ps = pack([x], '* f d')
505
+
506
+ x = freq_transformer(x)
507
+
508
+ x, = unpack(x, ps, '* f d')
509
+
510
+ x = self.final_norm(x)
511
+
512
+ num_stems = len(self.mask_estimators)
513
+
514
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
515
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
516
+
517
+ # modulate frequency representation
518
+
519
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
520
+
521
+ # complex number multiplication
522
+
523
+ stft_repr = torch.view_as_complex(stft_repr)
524
+ mask = torch.view_as_complex(mask)
525
+
526
+ stft_repr = stft_repr * mask
527
+
528
+ # istft
529
+
530
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
531
+
532
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
533
+
534
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
535
+
536
+ if num_stems == 1:
537
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
538
+
539
+ # if a target is passed in, calculate loss for learning
540
+
541
+ if not exists(target):
542
+ return recon_audio
543
+
544
+ if self.num_stems > 1:
545
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
546
+
547
+ if target.ndim == 2:
548
+ target = rearrange(target, '... t -> ... 1 t')
549
+
550
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
551
+
552
+ loss = F.l1_loss(recon_audio, target)
553
+
554
+ multi_stft_resolution_loss = 0.
555
+
556
+ for window_size in self.multi_stft_resolutions_window_sizes:
557
+ res_stft_kwargs = dict(
558
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
559
+ win_length=window_size,
560
+ return_complex=True,
561
+ window=self.multi_stft_window_fn(window_size, device=device),
562
+ **self.multi_stft_kwargs,
563
+ )
564
+
565
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
566
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
567
+
568
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
569
+
570
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
571
+
572
+ total_loss = loss + weighted_multi_resolution_loss
573
+
574
+ if not return_loss_breakdown:
575
+ return total_loss
576
+
577
+ return total_loss, (loss, multi_stft_resolution_loss)
bsroformer/bs_roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from librosa import filters
19
+
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+
31
+ def pack_one(t, pattern):
32
+ return pack([t], pattern)
33
+
34
+
35
+ def unpack_one(t, ps, pattern):
36
+ return unpack(t, ps, pattern)[0]
37
+
38
+
39
+ def pad_at_dim(t, pad, dim=-1, value=0.):
40
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
+ zeros = ((0, 0) * dims_from_right)
42
+ return F.pad(t, (*zeros, *pad), value=value)
43
+
44
+
45
+ def l2norm(t):
46
+ return F.normalize(t, dim=-1, p=2)
47
+
48
+
49
+ # norm
50
+
51
+ class RMSNorm(Module):
52
+ def __init__(self, dim):
53
+ super().__init__()
54
+ self.scale = dim ** 0.5
55
+ self.gamma = nn.Parameter(torch.ones(dim))
56
+
57
+ def forward(self, x):
58
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
59
+
60
+
61
+ # attention
62
+
63
+ class FeedForward(Module):
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ mult=4,
68
+ dropout=0.
69
+ ):
70
+ super().__init__()
71
+ dim_inner = int(dim * mult)
72
+ self.net = nn.Sequential(
73
+ RMSNorm(dim),
74
+ nn.Linear(dim, dim_inner),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout),
77
+ nn.Linear(dim_inner, dim),
78
+ nn.Dropout(dropout)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+
85
+ class Attention(Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ heads=8,
90
+ dim_head=64,
91
+ dropout=0.,
92
+ rotary_embed=None,
93
+ flash=True
94
+ ):
95
+ super().__init__()
96
+ self.heads = heads
97
+ self.scale = dim_head ** -0.5
98
+ dim_inner = heads * dim_head
99
+
100
+ self.rotary_embed = rotary_embed
101
+
102
+ self.attend = Attend(flash=flash, dropout=dropout)
103
+
104
+ self.norm = RMSNorm(dim)
105
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
+
107
+ self.to_gates = nn.Linear(dim, heads)
108
+
109
+ self.to_out = nn.Sequential(
110
+ nn.Linear(dim_inner, dim, bias=False),
111
+ nn.Dropout(dropout)
112
+ )
113
+
114
+ def forward(self, x):
115
+ x = self.norm(x)
116
+
117
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
+
119
+ if exists(self.rotary_embed):
120
+ q = self.rotary_embed.rotate_queries_or_keys(q)
121
+ k = self.rotary_embed.rotate_queries_or_keys(k)
122
+
123
+ out = self.attend(q, k, v)
124
+
125
+ gates = self.to_gates(x)
126
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
+
128
+ out = rearrange(out, 'b h n d -> b n (h d)')
129
+ return self.to_out(out)
130
+
131
+
132
+ class LinearAttention(Module):
133
+ """
134
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
+ """
136
+
137
+ @beartype
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim,
142
+ dim_head=32,
143
+ heads=8,
144
+ scale=8,
145
+ flash=False,
146
+ dropout=0.
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ self.attend = Attend(
160
+ scale=scale,
161
+ dropout=dropout,
162
+ flash=flash
163
+ )
164
+
165
+ self.to_out = nn.Sequential(
166
+ Rearrange('b h d n -> b n (h d)'),
167
+ nn.Linear(dim_inner, dim, bias=False)
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ x
173
+ ):
174
+ x = self.norm(x)
175
+
176
+ q, k, v = self.to_qkv(x)
177
+
178
+ q, k = map(l2norm, (q, k))
179
+ q = q * self.temperature.exp()
180
+
181
+ out = self.attend(q, k, v)
182
+
183
+ return self.to_out(out)
184
+
185
+
186
+ class Transformer(Module):
187
+ def __init__(
188
+ self,
189
+ *,
190
+ dim,
191
+ depth,
192
+ dim_head=64,
193
+ heads=8,
194
+ attn_dropout=0.,
195
+ ff_dropout=0.,
196
+ ff_mult=4,
197
+ norm_output=True,
198
+ rotary_embed=None,
199
+ flash_attn=True,
200
+ linear_attn=False
201
+ ):
202
+ super().__init__()
203
+ self.layers = ModuleList([])
204
+
205
+ for _ in range(depth):
206
+ if linear_attn:
207
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
+ else:
209
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
+ rotary_embed=rotary_embed, flash=flash_attn)
211
+
212
+ self.layers.append(ModuleList([
213
+ attn,
214
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
+ ]))
216
+
217
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
+
219
+ def forward(self, x):
220
+
221
+ for attn, ff in self.layers:
222
+ x = attn(x) + x
223
+ x = ff(x) + x
224
+
225
+ return self.norm(x)
226
+
227
+
228
+ # bandsplit module
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ dim_inputs: Tuple[int, ...]
236
+ ):
237
+ super().__init__()
238
+ self.dim_inputs = dim_inputs
239
+ self.to_features = ModuleList([])
240
+
241
+ for dim_in in dim_inputs:
242
+ net = nn.Sequential(
243
+ RMSNorm(dim_in),
244
+ nn.Linear(dim_in, dim)
245
+ )
246
+
247
+ self.to_features.append(net)
248
+
249
+ def forward(self, x):
250
+ x = x.split(self.dim_inputs, dim=-1)
251
+
252
+ outs = []
253
+ for split_input, to_feature in zip(x, self.to_features):
254
+ split_output = to_feature(split_input)
255
+ outs.append(split_output)
256
+
257
+ return torch.stack(outs, dim=-2)
258
+
259
+
260
+ def MLP(
261
+ dim_in,
262
+ dim_out,
263
+ dim_hidden=None,
264
+ depth=1,
265
+ activation=nn.Tanh
266
+ ):
267
+ dim_hidden = default(dim_hidden, dim_in)
268
+
269
+ net = []
270
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
+
272
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
+ is_last = ind == (len(dims) - 2)
274
+
275
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
+
277
+ if is_last:
278
+ continue
279
+
280
+ net.append(activation())
281
+
282
+ return nn.Sequential(*net)
283
+
284
+
285
+ class MaskEstimator(Module):
286
+ @beartype
287
+ def __init__(
288
+ self,
289
+ dim,
290
+ dim_inputs: Tuple[int, ...],
291
+ depth,
292
+ mlp_expansion_factor=4
293
+ ):
294
+ super().__init__()
295
+ self.dim_inputs = dim_inputs
296
+ self.to_freqs = ModuleList([])
297
+ dim_hidden = dim * mlp_expansion_factor
298
+
299
+ for dim_in in dim_inputs:
300
+ net = []
301
+
302
+ mlp = nn.Sequential(
303
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
+ nn.GLU(dim=-1)
305
+ )
306
+
307
+ self.to_freqs.append(mlp)
308
+
309
+ def forward(self, x):
310
+ x = x.unbind(dim=-2)
311
+
312
+ outs = []
313
+
314
+ for band_features, mlp in zip(x, self.to_freqs):
315
+ freq_out = mlp(band_features)
316
+ outs.append(freq_out)
317
+
318
+ return torch.cat(outs, dim=-1)
319
+
320
+
321
+ # main class
322
+
323
+ class MelBandRoformer(Module):
324
+
325
+ @beartype
326
+ def __init__(
327
+ self,
328
+ dim,
329
+ *,
330
+ depth,
331
+ stereo=False,
332
+ num_stems=1,
333
+ time_transformer_depth=2,
334
+ freq_transformer_depth=2,
335
+ linear_transformer_depth=0,
336
+ num_bands=60,
337
+ dim_head=64,
338
+ heads=8,
339
+ attn_dropout=0.1,
340
+ ff_dropout=0.1,
341
+ flash_attn=True,
342
+ dim_freqs_in=1025,
343
+ sample_rate=44100, # needed for mel filter bank from librosa
344
+ stft_n_fft=2048,
345
+ stft_hop_length=512,
346
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
+ stft_win_length=2048,
348
+ stft_normalized=False,
349
+ stft_window_fn: Optional[Callable] = None,
350
+ mask_estimator_depth=1,
351
+ multi_stft_resolution_loss_weight=1.,
352
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
+ multi_stft_hop_size=147,
354
+ multi_stft_normalized=False,
355
+ multi_stft_window_fn: Callable = torch.hann_window,
356
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
+ ):
358
+ super().__init__()
359
+
360
+ self.stereo = stereo
361
+ self.audio_channels = 2 if stereo else 1
362
+ self.num_stems = num_stems
363
+
364
+ self.layers = ModuleList([])
365
+
366
+ transformer_kwargs = dict(
367
+ dim=dim,
368
+ heads=heads,
369
+ dim_head=dim_head,
370
+ attn_dropout=attn_dropout,
371
+ ff_dropout=ff_dropout,
372
+ flash_attn=flash_attn
373
+ )
374
+
375
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
+
378
+ for _ in range(depth):
379
+ tran_modules = []
380
+ if linear_transformer_depth > 0:
381
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
+ tran_modules.append(
383
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
+ )
385
+ tran_modules.append(
386
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
+ )
388
+ self.layers.append(nn.ModuleList(tran_modules))
389
+
390
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
+
392
+ self.stft_kwargs = dict(
393
+ n_fft=stft_n_fft,
394
+ hop_length=stft_hop_length,
395
+ win_length=stft_win_length,
396
+ normalized=stft_normalized
397
+ )
398
+
399
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
400
+
401
+ # create mel filter bank
402
+ # with librosa.filters.mel as in section 2 of paper
403
+
404
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
+
406
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
+
408
+ # for some reason, it doesn't include the first freq? just force a value for now
409
+
410
+ mel_filter_bank[0][0] = 1.
411
+
412
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
+ # so let's force a positive value
414
+
415
+ mel_filter_bank[-1, -1] = 1.
416
+
417
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
418
+
419
+ freqs_per_band = mel_filter_bank > 0
420
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
+
422
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
+ freq_indices = repeated_freq_indices[freqs_per_band]
424
+
425
+ if stereo:
426
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
+ freq_indices = freq_indices * 2 + torch.arange(2)
428
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
+
430
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
431
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
+
433
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
+
436
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
+
439
+ # band split and mask estimator
440
+
441
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
+
443
+ self.band_split = BandSplit(
444
+ dim=dim,
445
+ dim_inputs=freqs_per_bands_with_complex
446
+ )
447
+
448
+ self.mask_estimators = nn.ModuleList([])
449
+
450
+ for _ in range(num_stems):
451
+ mask_estimator = MaskEstimator(
452
+ dim=dim,
453
+ dim_inputs=freqs_per_bands_with_complex,
454
+ depth=mask_estimator_depth
455
+ )
456
+
457
+ self.mask_estimators.append(mask_estimator)
458
+
459
+ # for the multi-resolution stft loss
460
+
461
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
+ self.multi_stft_n_fft = stft_n_fft
464
+ self.multi_stft_window_fn = multi_stft_window_fn
465
+
466
+ self.multi_stft_kwargs = dict(
467
+ hop_length=multi_stft_hop_size,
468
+ normalized=multi_stft_normalized
469
+ )
470
+
471
+ self.match_input_audio_length = match_input_audio_length
472
+
473
+ def forward(
474
+ self,
475
+ raw_audio,
476
+ target=None,
477
+ return_loss_breakdown=False
478
+ ):
479
+ """
480
+ einops
481
+
482
+ b - batch
483
+ f - freq
484
+ t - time
485
+ s - audio channel (1 for mono, 2 for stereo)
486
+ n - number of 'stems'
487
+ c - complex (2)
488
+ d - feature dimension
489
+ """
490
+
491
+ device = raw_audio.device
492
+
493
+ if raw_audio.ndim == 2:
494
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
+
496
+ batch, channels, raw_audio_length = raw_audio.shape
497
+
498
+ istft_length = raw_audio_length if self.match_input_audio_length else None
499
+
500
+ assert (not self.stereo and channels == 1) or (
501
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
+
503
+ # to stft
504
+
505
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
+
507
+ stft_window = self.stft_window_fn(device=device)
508
+
509
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
+ stft_repr = torch.view_as_real(stft_repr)
511
+
512
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
+ stft_repr = rearrange(stft_repr,
514
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
+
516
+ # index out all frequencies for all frequency ranges across bands ascending in one go
517
+
518
+ batch_arange = torch.arange(batch, device=device)[..., None]
519
+
520
+ # account for stereo
521
+
522
+ x = stft_repr[batch_arange, self.freq_indices]
523
+
524
+ # fold the complex (real and imag) into the frequencies dimension
525
+
526
+ x = rearrange(x, 'b f t c -> b t (f c)')
527
+
528
+ x = self.band_split(x)
529
+
530
+ # axial / hierarchical attention
531
+
532
+ for transformer_block in self.layers:
533
+
534
+ if len(transformer_block) == 3:
535
+ linear_transformer, time_transformer, freq_transformer = transformer_block
536
+
537
+ x, ft_ps = pack([x], 'b * d')
538
+ x = linear_transformer(x)
539
+ x, = unpack(x, ft_ps, 'b * d')
540
+ else:
541
+ time_transformer, freq_transformer = transformer_block
542
+
543
+ x = rearrange(x, 'b t f d -> b f t d')
544
+ x, ps = pack([x], '* t d')
545
+
546
+ x = time_transformer(x)
547
+
548
+ x, = unpack(x, ps, '* t d')
549
+ x = rearrange(x, 'b f t d -> b t f d')
550
+ x, ps = pack([x], '* f d')
551
+
552
+ x = freq_transformer(x)
553
+
554
+ x, = unpack(x, ps, '* f d')
555
+
556
+ num_stems = len(self.mask_estimators)
557
+
558
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
+
561
+ # modulate frequency representation
562
+
563
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
+
565
+ # complex number multiplication
566
+
567
+ stft_repr = torch.view_as_complex(stft_repr)
568
+ masks = torch.view_as_complex(masks)
569
+
570
+ masks = masks.type(stft_repr.dtype)
571
+
572
+ # need to average the estimated mask for the overlapped frequencies
573
+
574
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
+
576
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
+
579
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
+
581
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
+
583
+ # modulate stft repr with estimated mask
584
+
585
+ stft_repr = stft_repr * masks_averaged
586
+
587
+ # istft
588
+
589
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
+
591
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
+ length=istft_length)
593
+
594
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
+
596
+ if num_stems == 1:
597
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
+
599
+ # if a target is passed in, calculate loss for learning
600
+
601
+ if not exists(target):
602
+ return recon_audio
603
+
604
+ if self.num_stems > 1:
605
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
606
+
607
+ if target.ndim == 2:
608
+ target = rearrange(target, '... t -> ... 1 t')
609
+
610
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
+
612
+ loss = F.l1_loss(recon_audio, target)
613
+
614
+ multi_stft_resolution_loss = 0.
615
+
616
+ for window_size in self.multi_stft_resolutions_window_sizes:
617
+ res_stft_kwargs = dict(
618
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
+ win_length=window_size,
620
+ return_complex=True,
621
+ window=self.multi_stft_window_fn(window_size, device=device),
622
+ **self.multi_stft_kwargs,
623
+ )
624
+
625
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
+
628
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
+
630
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
+
632
+ total_loss = loss + weighted_multi_resolution_loss
633
+
634
+ if not return_loss_breakdown:
635
+ return total_loss
636
+
637
+ return total_loss, (loss, multi_stft_resolution_loss)
bsroformer/configs/model_bs_roformer_ep_317_sdr_12.9755.yaml ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 352800
3
+ dim_f: 1024
4
+ dim_t: 801 # don't work (use in model)
5
+ hop_length: 441 # don't work (use in model)
6
+ n_fft: 2048
7
+ num_channels: 2
8
+ sample_rate: 44100
9
+ min_mean_abs: 0.000
10
+
11
+ model:
12
+ dim: 512
13
+ depth: 12
14
+ stereo: true
15
+ num_stems: 1
16
+ time_transformer_depth: 1
17
+ freq_transformer_depth: 1
18
+ linear_transformer_depth: 0
19
+ freqs_per_bands: !!python/tuple
20
+ - 2
21
+ - 2
22
+ - 2
23
+ - 2
24
+ - 2
25
+ - 2
26
+ - 2
27
+ - 2
28
+ - 2
29
+ - 2
30
+ - 2
31
+ - 2
32
+ - 2
33
+ - 2
34
+ - 2
35
+ - 2
36
+ - 2
37
+ - 2
38
+ - 2
39
+ - 2
40
+ - 2
41
+ - 2
42
+ - 2
43
+ - 2
44
+ - 4
45
+ - 4
46
+ - 4
47
+ - 4
48
+ - 4
49
+ - 4
50
+ - 4
51
+ - 4
52
+ - 4
53
+ - 4
54
+ - 4
55
+ - 4
56
+ - 12
57
+ - 12
58
+ - 12
59
+ - 12
60
+ - 12
61
+ - 12
62
+ - 12
63
+ - 12
64
+ - 24
65
+ - 24
66
+ - 24
67
+ - 24
68
+ - 24
69
+ - 24
70
+ - 24
71
+ - 24
72
+ - 48
73
+ - 48
74
+ - 48
75
+ - 48
76
+ - 48
77
+ - 48
78
+ - 48
79
+ - 48
80
+ - 128
81
+ - 129
82
+ dim_head: 64
83
+ heads: 8
84
+ attn_dropout: 0.1
85
+ ff_dropout: 0.1
86
+ flash_attn: true
87
+ dim_freqs_in: 1025
88
+ stft_n_fft: 2048
89
+ stft_hop_length: 441
90
+ stft_win_length: 2048
91
+ stft_normalized: false
92
+ mask_estimator_depth: 2
93
+ multi_stft_resolution_loss_weight: 1.0
94
+ multi_stft_resolutions_window_sizes: !!python/tuple
95
+ - 4096
96
+ - 2048
97
+ - 1024
98
+ - 512
99
+ - 256
100
+ multi_stft_hop_size: 147
101
+ multi_stft_normalized: False
102
+
103
+ training:
104
+ batch_size: 2
105
+ gradient_accumulation_steps: 1
106
+ grad_clip: 0
107
+ instruments:
108
+ - vocals
109
+ - other
110
+ lr: 1.0e-05
111
+ patience: 2
112
+ reduce_factor: 0.95
113
+ target_instrument: vocals
114
+ num_epochs: 1000
115
+ num_steps: 1000
116
+ q: 0.95
117
+ coarse_loss_clip: true
118
+ ema_momentum: 0.999
119
+ optimizer: adam
120
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
121
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
122
+
123
+ inference:
124
+ batch_size: 4
125
+ dim_t: 801
126
+ num_overlap: 2
bsroformer/configs/model_bs_roformer_ep_937_sdr_10.5309.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 131584
3
+ dim_f: 1024
4
+ dim_t: 256
5
+ hop_length: 512
6
+ n_fft: 2048
7
+ num_channels: 2
8
+ sample_rate: 44100
9
+ min_mean_abs: 0.001
10
+
11
+ model:
12
+ dim: 384
13
+ depth: 12
14
+ stereo: true
15
+ num_stems: 1
16
+ time_transformer_depth: 1
17
+ freq_transformer_depth: 1
18
+ linear_transformer_depth: 0
19
+ freqs_per_bands: !!python/tuple
20
+ - 2
21
+ - 2
22
+ - 2
23
+ - 2
24
+ - 2
25
+ - 2
26
+ - 2
27
+ - 2
28
+ - 2
29
+ - 2
30
+ - 2
31
+ - 2
32
+ - 2
33
+ - 2
34
+ - 2
35
+ - 2
36
+ - 2
37
+ - 2
38
+ - 2
39
+ - 2
40
+ - 2
41
+ - 2
42
+ - 2
43
+ - 2
44
+ - 4
45
+ - 4
46
+ - 4
47
+ - 4
48
+ - 4
49
+ - 4
50
+ - 4
51
+ - 4
52
+ - 4
53
+ - 4
54
+ - 4
55
+ - 4
56
+ - 12
57
+ - 12
58
+ - 12
59
+ - 12
60
+ - 12
61
+ - 12
62
+ - 12
63
+ - 12
64
+ - 24
65
+ - 24
66
+ - 24
67
+ - 24
68
+ - 24
69
+ - 24
70
+ - 24
71
+ - 24
72
+ - 48
73
+ - 48
74
+ - 48
75
+ - 48
76
+ - 48
77
+ - 48
78
+ - 48
79
+ - 48
80
+ - 128
81
+ - 129
82
+ dim_head: 64
83
+ heads: 8
84
+ attn_dropout: 0.1
85
+ ff_dropout: 0.1
86
+ flash_attn: true
87
+ dim_freqs_in: 1025
88
+ stft_n_fft: 2048
89
+ stft_hop_length: 512
90
+ stft_win_length: 2048
91
+ stft_normalized: false
92
+ mask_estimator_depth: 2
93
+ multi_stft_resolution_loss_weight: 1.0
94
+ multi_stft_resolutions_window_sizes: !!python/tuple
95
+ - 4096
96
+ - 2048
97
+ - 1024
98
+ - 512
99
+ - 256
100
+ multi_stft_hop_size: 147
101
+ multi_stft_normalized: False
102
+
103
+ training:
104
+ batch_size: 4
105
+ gradient_accumulation_steps: 1
106
+ grad_clip: 0
107
+ instruments:
108
+ - vocals
109
+ - other
110
+ lr: 5.0e-05
111
+ patience: 2
112
+ reduce_factor: 0.95
113
+ target_instrument: other
114
+ num_epochs: 1000
115
+ num_steps: 1000
116
+ q: 0.95
117
+ coarse_loss_clip: true
118
+ ema_momentum: 0.999
119
+ optimizer: adam
120
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
121
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
122
+
123
+ augmentations:
124
+ enable: true # enable or disable all augmentations (to fast disable if needed)
125
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
126
+ loudness_min: 0.5
127
+ loudness_max: 1.5
128
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
129
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
130
+ - 0.2
131
+ - 0.02
132
+ mixup_loudness_min: 0.5
133
+ mixup_loudness_max: 1.5
134
+
135
+ inference:
136
+ batch_size: 8
137
+ dim_t: 512
138
+ num_overlap: 2
bsroformer/configs/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 352800
3
+ dim_f: 1024
4
+ dim_t: 801 # don't work (use in model)
5
+ hop_length: 441 # don't work (use in model)
6
+ n_fft: 2048
7
+ num_channels: 2
8
+ sample_rate: 44100
9
+ min_mean_abs: 0.000
10
+
11
+ model:
12
+ dim: 384
13
+ depth: 12
14
+ stereo: true
15
+ num_stems: 1
16
+ time_transformer_depth: 1
17
+ freq_transformer_depth: 1
18
+ linear_transformer_depth: 0
19
+ num_bands: 60
20
+ dim_head: 64
21
+ heads: 8
22
+ attn_dropout: 0.1
23
+ ff_dropout: 0.1
24
+ flash_attn: True
25
+ dim_freqs_in: 1025
26
+ sample_rate: 44100 # needed for mel filter bank from librosa
27
+ stft_n_fft: 2048
28
+ stft_hop_length: 441
29
+ stft_win_length: 2048
30
+ stft_normalized: False
31
+ mask_estimator_depth: 2
32
+ multi_stft_resolution_loss_weight: 1.0
33
+ multi_stft_resolutions_window_sizes: !!python/tuple
34
+ - 4096
35
+ - 2048
36
+ - 1024
37
+ - 512
38
+ - 256
39
+ multi_stft_hop_size: 147
40
+ multi_stft_normalized: False
41
+
42
+ training:
43
+ batch_size: 1
44
+ gradient_accumulation_steps: 8
45
+ grad_clip: 0
46
+ instruments:
47
+ - vocals
48
+ - other
49
+ lr: 4.0e-05
50
+ patience: 2
51
+ reduce_factor: 0.95
52
+ target_instrument: vocals
53
+ num_epochs: 1000
54
+ num_steps: 1000
55
+ q: 0.95
56
+ coarse_loss_clip: true
57
+ ema_momentum: 0.999
58
+ optimizer: adam
59
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
60
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
61
+
62
+ inference:
63
+ batch_size: 4
64
+ dim_t: 801
65
+ num_overlap: 2