poiqazwsx commited on
Commit
2133f28
1 Parent(s): a45061e

Delete bsroformer/bs_roformer

Browse files
bsroformer/bs_roformer/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from models.bs_roformer.bs_roformer import BSRoformer
2
- from models.bs_roformer.mel_band_roformer import MelBandRoformer
 
 
 
bsroformer/bs_roformer/attend.py DELETED
@@ -1,120 +0,0 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import torch
6
- from torch import nn, einsum
7
- import torch.nn.functional as F
8
-
9
- from einops import rearrange, reduce
10
-
11
- # constants
12
-
13
- FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14
-
15
- # helpers
16
-
17
- def exists(val):
18
- return val is not None
19
-
20
- def default(v, d):
21
- return v if exists(v) else d
22
-
23
- def once(fn):
24
- called = False
25
- @wraps(fn)
26
- def inner(x):
27
- nonlocal called
28
- if called:
29
- return
30
- called = True
31
- return fn(x)
32
- return inner
33
-
34
- print_once = once(print)
35
-
36
- # main class
37
-
38
- class Attend(nn.Module):
39
- def __init__(
40
- self,
41
- dropout = 0.,
42
- flash = False,
43
- scale = None
44
- ):
45
- super().__init__()
46
- self.scale = scale
47
- self.dropout = dropout
48
- self.attn_dropout = nn.Dropout(dropout)
49
-
50
- self.flash = flash
51
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52
-
53
- # determine efficient attention configs for cuda and cpu
54
-
55
- self.cpu_config = FlashAttentionConfig(True, True, True)
56
- self.cuda_config = None
57
-
58
- if not torch.cuda.is_available() or not flash:
59
- return
60
-
61
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62
-
63
- if device_properties.major == 8 and device_properties.minor == 0:
64
- print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65
- self.cuda_config = FlashAttentionConfig(True, False, False)
66
- else:
67
- print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68
- self.cuda_config = FlashAttentionConfig(False, True, True)
69
-
70
- def flash_attn(self, q, k, v):
71
- _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72
-
73
- if exists(self.scale):
74
- default_scale = q.shape[-1] ** -0.5
75
- q = q * (self.scale / default_scale)
76
-
77
- # Check if there is a compatible device for flash attention
78
-
79
- config = self.cuda_config if is_cuda else self.cpu_config
80
-
81
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
82
-
83
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
84
- out = F.scaled_dot_product_attention(
85
- q, k, v,
86
- dropout_p = self.dropout if self.training else 0.
87
- )
88
-
89
- return out
90
-
91
- def forward(self, q, k, v):
92
- """
93
- einstein notation
94
- b - batch
95
- h - heads
96
- n, i, j - sequence length (base sequence length, source, target)
97
- d - feature dimension
98
- """
99
-
100
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
101
-
102
- scale = default(self.scale, q.shape[-1] ** -0.5)
103
-
104
- if self.flash:
105
- return self.flash_attn(q, k, v)
106
-
107
- # similarity
108
-
109
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
110
-
111
- # attention
112
-
113
- attn = sim.softmax(dim=-1)
114
- attn = self.attn_dropout(attn)
115
-
116
- # aggregate values
117
-
118
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119
-
120
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bsroformer/bs_roformer/bs_roformer.py DELETED
@@ -1,577 +0,0 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from models.bs_roformer.attend import Attend
9
-
10
- from beartype.typing import Tuple, Optional, List, Callable
11
- from beartype import beartype
12
-
13
- from rotary_embedding_torch import RotaryEmbedding
14
-
15
- from einops import rearrange, pack, unpack
16
- from einops.layers.torch import Rearrange
17
-
18
- # helper functions
19
-
20
- def exists(val):
21
- return val is not None
22
-
23
-
24
- def default(v, d):
25
- return v if exists(v) else d
26
-
27
-
28
- def pack_one(t, pattern):
29
- return pack([t], pattern)
30
-
31
-
32
- def unpack_one(t, ps, pattern):
33
- return unpack(t, ps, pattern)[0]
34
-
35
-
36
- # norm
37
-
38
- def l2norm(t):
39
- return F.normalize(t, dim = -1, p = 2)
40
-
41
-
42
- class RMSNorm(Module):
43
- def __init__(self, dim):
44
- super().__init__()
45
- self.scale = dim ** 0.5
46
- self.gamma = nn.Parameter(torch.ones(dim))
47
-
48
- def forward(self, x):
49
- return F.normalize(x, dim=-1) * self.scale * self.gamma
50
-
51
-
52
- # attention
53
-
54
- class FeedForward(Module):
55
- def __init__(
56
- self,
57
- dim,
58
- mult=4,
59
- dropout=0.
60
- ):
61
- super().__init__()
62
- dim_inner = int(dim * mult)
63
- self.net = nn.Sequential(
64
- RMSNorm(dim),
65
- nn.Linear(dim, dim_inner),
66
- nn.GELU(),
67
- nn.Dropout(dropout),
68
- nn.Linear(dim_inner, dim),
69
- nn.Dropout(dropout)
70
- )
71
-
72
- def forward(self, x):
73
- return self.net(x)
74
-
75
-
76
- class Attention(Module):
77
- def __init__(
78
- self,
79
- dim,
80
- heads=8,
81
- dim_head=64,
82
- dropout=0.,
83
- rotary_embed=None,
84
- flash=True
85
- ):
86
- super().__init__()
87
- self.heads = heads
88
- self.scale = dim_head ** -0.5
89
- dim_inner = heads * dim_head
90
-
91
- self.rotary_embed = rotary_embed
92
-
93
- self.attend = Attend(flash=flash, dropout=dropout)
94
-
95
- self.norm = RMSNorm(dim)
96
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
-
98
- self.to_gates = nn.Linear(dim, heads)
99
-
100
- self.to_out = nn.Sequential(
101
- nn.Linear(dim_inner, dim, bias=False),
102
- nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
109
-
110
- if exists(self.rotary_embed):
111
- q = self.rotary_embed.rotate_queries_or_keys(q)
112
- k = self.rotary_embed.rotate_queries_or_keys(k)
113
-
114
- out = self.attend(q, k, v)
115
-
116
- gates = self.to_gates(x)
117
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
118
-
119
- out = rearrange(out, 'b h n d -> b n (h d)')
120
- return self.to_out(out)
121
-
122
-
123
- class LinearAttention(Module):
124
- """
125
- this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
126
- """
127
-
128
- @beartype
129
- def __init__(
130
- self,
131
- *,
132
- dim,
133
- dim_head=32,
134
- heads=8,
135
- scale=8,
136
- flash=False,
137
- dropout=0.
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- self.attend = Attend(
151
- scale=scale,
152
- dropout=dropout,
153
- flash=flash
154
- )
155
-
156
- self.to_out = nn.Sequential(
157
- Rearrange('b h d n -> b n (h d)'),
158
- nn.Linear(dim_inner, dim, bias=False)
159
- )
160
-
161
- def forward(
162
- self,
163
- x
164
- ):
165
- x = self.norm(x)
166
-
167
- q, k, v = self.to_qkv(x)
168
-
169
- q, k = map(l2norm, (q, k))
170
- q = q * self.temperature.exp()
171
-
172
- out = self.attend(q, k, v)
173
-
174
- return self.to_out(out)
175
-
176
-
177
- class Transformer(Module):
178
- def __init__(
179
- self,
180
- *,
181
- dim,
182
- depth,
183
- dim_head=64,
184
- heads=8,
185
- attn_dropout=0.,
186
- ff_dropout=0.,
187
- ff_mult=4,
188
- norm_output=True,
189
- rotary_embed=None,
190
- flash_attn=True,
191
- linear_attn=False
192
- ):
193
- super().__init__()
194
- self.layers = ModuleList([])
195
-
196
- for _ in range(depth):
197
- if linear_attn:
198
- attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
199
- else:
200
- attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
201
- rotary_embed=rotary_embed, flash=flash_attn)
202
-
203
- self.layers.append(ModuleList([
204
- attn,
205
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
206
- ]))
207
-
208
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
209
-
210
- def forward(self, x):
211
-
212
- for attn, ff in self.layers:
213
- x = attn(x) + x
214
- x = ff(x) + x
215
-
216
- return self.norm(x)
217
-
218
-
219
- # bandsplit module
220
-
221
- class BandSplit(Module):
222
- @beartype
223
- def __init__(
224
- self,
225
- dim,
226
- dim_inputs: Tuple[int, ...]
227
- ):
228
- super().__init__()
229
- self.dim_inputs = dim_inputs
230
- self.to_features = ModuleList([])
231
-
232
- for dim_in in dim_inputs:
233
- net = nn.Sequential(
234
- RMSNorm(dim_in),
235
- nn.Linear(dim_in, dim)
236
- )
237
-
238
- self.to_features.append(net)
239
-
240
- def forward(self, x):
241
- x = x.split(self.dim_inputs, dim=-1)
242
-
243
- outs = []
244
- for split_input, to_feature in zip(x, self.to_features):
245
- split_output = to_feature(split_input)
246
- outs.append(split_output)
247
-
248
- return torch.stack(outs, dim=-2)
249
-
250
-
251
- def MLP(
252
- dim_in,
253
- dim_out,
254
- dim_hidden=None,
255
- depth=1,
256
- activation=nn.Tanh
257
- ):
258
- dim_hidden = default(dim_hidden, dim_in)
259
-
260
- net = []
261
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
262
-
263
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
264
- is_last = ind == (len(dims) - 2)
265
-
266
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
267
-
268
- if is_last:
269
- continue
270
-
271
- net.append(activation())
272
-
273
- return nn.Sequential(*net)
274
-
275
-
276
- class MaskEstimator(Module):
277
- @beartype
278
- def __init__(
279
- self,
280
- dim,
281
- dim_inputs: Tuple[int, ...],
282
- depth,
283
- mlp_expansion_factor=4
284
- ):
285
- super().__init__()
286
- self.dim_inputs = dim_inputs
287
- self.to_freqs = ModuleList([])
288
- dim_hidden = dim * mlp_expansion_factor
289
-
290
- for dim_in in dim_inputs:
291
- net = []
292
-
293
- mlp = nn.Sequential(
294
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
295
- nn.GLU(dim=-1)
296
- )
297
-
298
- self.to_freqs.append(mlp)
299
-
300
- def forward(self, x):
301
- x = x.unbind(dim=-2)
302
-
303
- outs = []
304
-
305
- for band_features, mlp in zip(x, self.to_freqs):
306
- freq_out = mlp(band_features)
307
- outs.append(freq_out)
308
-
309
- return torch.cat(outs, dim=-1)
310
-
311
-
312
- # main class
313
-
314
- DEFAULT_FREQS_PER_BANDS = (
315
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
316
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
317
- 2, 2, 2, 2,
318
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
319
- 12, 12, 12, 12, 12, 12, 12, 12,
320
- 24, 24, 24, 24, 24, 24, 24, 24,
321
- 48, 48, 48, 48, 48, 48, 48, 48,
322
- 128, 129,
323
- )
324
-
325
-
326
- class BSRoformer(Module):
327
-
328
- @beartype
329
- def __init__(
330
- self,
331
- dim,
332
- *,
333
- depth,
334
- stereo=False,
335
- num_stems=1,
336
- time_transformer_depth=2,
337
- freq_transformer_depth=2,
338
- linear_transformer_depth=0,
339
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
340
- # in the paper, they divide into ~60 bands, test with 1 for starters
341
- dim_head=64,
342
- heads=8,
343
- attn_dropout=0.,
344
- ff_dropout=0.,
345
- flash_attn=True,
346
- dim_freqs_in=1025,
347
- stft_n_fft=2048,
348
- stft_hop_length=512,
349
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
350
- stft_win_length=2048,
351
- stft_normalized=False,
352
- stft_window_fn: Optional[Callable] = None,
353
- mask_estimator_depth=2,
354
- multi_stft_resolution_loss_weight=1.,
355
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
356
- multi_stft_hop_size=147,
357
- multi_stft_normalized=False,
358
- multi_stft_window_fn: Callable = torch.hann_window
359
- ):
360
- super().__init__()
361
-
362
- self.stereo = stereo
363
- self.audio_channels = 2 if stereo else 1
364
- self.num_stems = num_stems
365
-
366
- self.layers = ModuleList([])
367
-
368
- transformer_kwargs = dict(
369
- dim=dim,
370
- heads=heads,
371
- dim_head=dim_head,
372
- attn_dropout=attn_dropout,
373
- ff_dropout=ff_dropout,
374
- flash_attn=flash_attn,
375
- norm_output=False
376
- )
377
-
378
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
379
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
380
-
381
- for _ in range(depth):
382
- tran_modules = []
383
- if linear_transformer_depth > 0:
384
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
385
- tran_modules.append(
386
- Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
387
- )
388
- tran_modules.append(
389
- Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
390
- )
391
- self.layers.append(nn.ModuleList(tran_modules))
392
-
393
- self.final_norm = RMSNorm(dim)
394
-
395
- self.stft_kwargs = dict(
396
- n_fft=stft_n_fft,
397
- hop_length=stft_hop_length,
398
- win_length=stft_win_length,
399
- normalized=stft_normalized
400
- )
401
-
402
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
403
-
404
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
405
-
406
- assert len(freqs_per_bands) > 1
407
- assert sum(
408
- freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
409
-
410
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
411
-
412
- self.band_split = BandSplit(
413
- dim=dim,
414
- dim_inputs=freqs_per_bands_with_complex
415
- )
416
-
417
- self.mask_estimators = nn.ModuleList([])
418
-
419
- for _ in range(num_stems):
420
- mask_estimator = MaskEstimator(
421
- dim=dim,
422
- dim_inputs=freqs_per_bands_with_complex,
423
- depth=mask_estimator_depth
424
- )
425
-
426
- self.mask_estimators.append(mask_estimator)
427
-
428
- # for the multi-resolution stft loss
429
-
430
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
431
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
432
- self.multi_stft_n_fft = stft_n_fft
433
- self.multi_stft_window_fn = multi_stft_window_fn
434
-
435
- self.multi_stft_kwargs = dict(
436
- hop_length=multi_stft_hop_size,
437
- normalized=multi_stft_normalized
438
- )
439
-
440
- def forward(
441
- self,
442
- raw_audio,
443
- target=None,
444
- return_loss_breakdown=False
445
- ):
446
- """
447
- einops
448
-
449
- b - batch
450
- f - freq
451
- t - time
452
- s - audio channel (1 for mono, 2 for stereo)
453
- n - number of 'stems'
454
- c - complex (2)
455
- d - feature dimension
456
- """
457
-
458
- device = raw_audio.device
459
-
460
- if raw_audio.ndim == 2:
461
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
462
-
463
- channels = raw_audio.shape[1]
464
- assert (not self.stereo and channels == 1) or (
465
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
466
-
467
- # to stft
468
-
469
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
470
-
471
- stft_window = self.stft_window_fn(device=device)
472
-
473
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
474
- stft_repr = torch.view_as_real(stft_repr)
475
-
476
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
477
- stft_repr = rearrange(stft_repr,
478
- 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
479
-
480
- x = rearrange(stft_repr, 'b f t c -> b t (f c)')
481
-
482
- x = self.band_split(x)
483
-
484
- # axial / hierarchical attention
485
-
486
- for transformer_block in self.layers:
487
-
488
- if len(transformer_block) == 3:
489
- linear_transformer, time_transformer, freq_transformer = transformer_block
490
-
491
- x, ft_ps = pack([x], 'b * d')
492
- x = linear_transformer(x)
493
- x, = unpack(x, ft_ps, 'b * d')
494
- else:
495
- time_transformer, freq_transformer = transformer_block
496
-
497
- x = rearrange(x, 'b t f d -> b f t d')
498
- x, ps = pack([x], '* t d')
499
-
500
- x = time_transformer(x)
501
-
502
- x, = unpack(x, ps, '* t d')
503
- x = rearrange(x, 'b f t d -> b t f d')
504
- x, ps = pack([x], '* f d')
505
-
506
- x = freq_transformer(x)
507
-
508
- x, = unpack(x, ps, '* f d')
509
-
510
- x = self.final_norm(x)
511
-
512
- num_stems = len(self.mask_estimators)
513
-
514
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
515
- mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
516
-
517
- # modulate frequency representation
518
-
519
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
520
-
521
- # complex number multiplication
522
-
523
- stft_repr = torch.view_as_complex(stft_repr)
524
- mask = torch.view_as_complex(mask)
525
-
526
- stft_repr = stft_repr * mask
527
-
528
- # istft
529
-
530
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
531
-
532
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
533
-
534
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
535
-
536
- if num_stems == 1:
537
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
538
-
539
- # if a target is passed in, calculate loss for learning
540
-
541
- if not exists(target):
542
- return recon_audio
543
-
544
- if self.num_stems > 1:
545
- assert target.ndim == 4 and target.shape[1] == self.num_stems
546
-
547
- if target.ndim == 2:
548
- target = rearrange(target, '... t -> ... 1 t')
549
-
550
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
551
-
552
- loss = F.l1_loss(recon_audio, target)
553
-
554
- multi_stft_resolution_loss = 0.
555
-
556
- for window_size in self.multi_stft_resolutions_window_sizes:
557
- res_stft_kwargs = dict(
558
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
559
- win_length=window_size,
560
- return_complex=True,
561
- window=self.multi_stft_window_fn(window_size, device=device),
562
- **self.multi_stft_kwargs,
563
- )
564
-
565
- recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
566
- target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
567
-
568
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
569
-
570
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
571
-
572
- total_loss = loss + weighted_multi_resolution_loss
573
-
574
- if not return_loss_breakdown:
575
- return total_loss
576
-
577
- return total_loss, (loss, multi_stft_resolution_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bsroformer/bs_roformer/mel_band_roformer.py DELETED
@@ -1,637 +0,0 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from models.bs_roformer.attend import Attend
9
-
10
- from beartype.typing import Tuple, Optional, List, Callable
11
- from beartype import beartype
12
-
13
- from rotary_embedding_torch import RotaryEmbedding
14
-
15
- from einops import rearrange, pack, unpack, reduce, repeat
16
- from einops.layers.torch import Rearrange
17
-
18
- from librosa import filters
19
-
20
-
21
- # helper functions
22
-
23
- def exists(val):
24
- return val is not None
25
-
26
-
27
- def default(v, d):
28
- return v if exists(v) else d
29
-
30
-
31
- def pack_one(t, pattern):
32
- return pack([t], pattern)
33
-
34
-
35
- def unpack_one(t, ps, pattern):
36
- return unpack(t, ps, pattern)[0]
37
-
38
-
39
- def pad_at_dim(t, pad, dim=-1, value=0.):
40
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
- zeros = ((0, 0) * dims_from_right)
42
- return F.pad(t, (*zeros, *pad), value=value)
43
-
44
-
45
- def l2norm(t):
46
- return F.normalize(t, dim=-1, p=2)
47
-
48
-
49
- # norm
50
-
51
- class RMSNorm(Module):
52
- def __init__(self, dim):
53
- super().__init__()
54
- self.scale = dim ** 0.5
55
- self.gamma = nn.Parameter(torch.ones(dim))
56
-
57
- def forward(self, x):
58
- return F.normalize(x, dim=-1) * self.scale * self.gamma
59
-
60
-
61
- # attention
62
-
63
- class FeedForward(Module):
64
- def __init__(
65
- self,
66
- dim,
67
- mult=4,
68
- dropout=0.
69
- ):
70
- super().__init__()
71
- dim_inner = int(dim * mult)
72
- self.net = nn.Sequential(
73
- RMSNorm(dim),
74
- nn.Linear(dim, dim_inner),
75
- nn.GELU(),
76
- nn.Dropout(dropout),
77
- nn.Linear(dim_inner, dim),
78
- nn.Dropout(dropout)
79
- )
80
-
81
- def forward(self, x):
82
- return self.net(x)
83
-
84
-
85
- class Attention(Module):
86
- def __init__(
87
- self,
88
- dim,
89
- heads=8,
90
- dim_head=64,
91
- dropout=0.,
92
- rotary_embed=None,
93
- flash=True
94
- ):
95
- super().__init__()
96
- self.heads = heads
97
- self.scale = dim_head ** -0.5
98
- dim_inner = heads * dim_head
99
-
100
- self.rotary_embed = rotary_embed
101
-
102
- self.attend = Attend(flash=flash, dropout=dropout)
103
-
104
- self.norm = RMSNorm(dim)
105
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
-
107
- self.to_gates = nn.Linear(dim, heads)
108
-
109
- self.to_out = nn.Sequential(
110
- nn.Linear(dim_inner, dim, bias=False),
111
- nn.Dropout(dropout)
112
- )
113
-
114
- def forward(self, x):
115
- x = self.norm(x)
116
-
117
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
-
119
- if exists(self.rotary_embed):
120
- q = self.rotary_embed.rotate_queries_or_keys(q)
121
- k = self.rotary_embed.rotate_queries_or_keys(k)
122
-
123
- out = self.attend(q, k, v)
124
-
125
- gates = self.to_gates(x)
126
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
-
128
- out = rearrange(out, 'b h n d -> b n (h d)')
129
- return self.to_out(out)
130
-
131
-
132
- class LinearAttention(Module):
133
- """
134
- this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
- """
136
-
137
- @beartype
138
- def __init__(
139
- self,
140
- *,
141
- dim,
142
- dim_head=32,
143
- heads=8,
144
- scale=8,
145
- flash=False,
146
- dropout=0.
147
- ):
148
- super().__init__()
149
- dim_inner = dim_head * heads
150
- self.norm = RMSNorm(dim)
151
-
152
- self.to_qkv = nn.Sequential(
153
- nn.Linear(dim, dim_inner * 3, bias=False),
154
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
- )
156
-
157
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
-
159
- self.attend = Attend(
160
- scale=scale,
161
- dropout=dropout,
162
- flash=flash
163
- )
164
-
165
- self.to_out = nn.Sequential(
166
- Rearrange('b h d n -> b n (h d)'),
167
- nn.Linear(dim_inner, dim, bias=False)
168
- )
169
-
170
- def forward(
171
- self,
172
- x
173
- ):
174
- x = self.norm(x)
175
-
176
- q, k, v = self.to_qkv(x)
177
-
178
- q, k = map(l2norm, (q, k))
179
- q = q * self.temperature.exp()
180
-
181
- out = self.attend(q, k, v)
182
-
183
- return self.to_out(out)
184
-
185
-
186
- class Transformer(Module):
187
- def __init__(
188
- self,
189
- *,
190
- dim,
191
- depth,
192
- dim_head=64,
193
- heads=8,
194
- attn_dropout=0.,
195
- ff_dropout=0.,
196
- ff_mult=4,
197
- norm_output=True,
198
- rotary_embed=None,
199
- flash_attn=True,
200
- linear_attn=False
201
- ):
202
- super().__init__()
203
- self.layers = ModuleList([])
204
-
205
- for _ in range(depth):
206
- if linear_attn:
207
- attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
- else:
209
- attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
- rotary_embed=rotary_embed, flash=flash_attn)
211
-
212
- self.layers.append(ModuleList([
213
- attn,
214
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
- ]))
216
-
217
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
-
219
- def forward(self, x):
220
-
221
- for attn, ff in self.layers:
222
- x = attn(x) + x
223
- x = ff(x) + x
224
-
225
- return self.norm(x)
226
-
227
-
228
- # bandsplit module
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(
233
- self,
234
- dim,
235
- dim_inputs: Tuple[int, ...]
236
- ):
237
- super().__init__()
238
- self.dim_inputs = dim_inputs
239
- self.to_features = ModuleList([])
240
-
241
- for dim_in in dim_inputs:
242
- net = nn.Sequential(
243
- RMSNorm(dim_in),
244
- nn.Linear(dim_in, dim)
245
- )
246
-
247
- self.to_features.append(net)
248
-
249
- def forward(self, x):
250
- x = x.split(self.dim_inputs, dim=-1)
251
-
252
- outs = []
253
- for split_input, to_feature in zip(x, self.to_features):
254
- split_output = to_feature(split_input)
255
- outs.append(split_output)
256
-
257
- return torch.stack(outs, dim=-2)
258
-
259
-
260
- def MLP(
261
- dim_in,
262
- dim_out,
263
- dim_hidden=None,
264
- depth=1,
265
- activation=nn.Tanh
266
- ):
267
- dim_hidden = default(dim_hidden, dim_in)
268
-
269
- net = []
270
- dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
-
272
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
- is_last = ind == (len(dims) - 2)
274
-
275
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
-
277
- if is_last:
278
- continue
279
-
280
- net.append(activation())
281
-
282
- return nn.Sequential(*net)
283
-
284
-
285
- class MaskEstimator(Module):
286
- @beartype
287
- def __init__(
288
- self,
289
- dim,
290
- dim_inputs: Tuple[int, ...],
291
- depth,
292
- mlp_expansion_factor=4
293
- ):
294
- super().__init__()
295
- self.dim_inputs = dim_inputs
296
- self.to_freqs = ModuleList([])
297
- dim_hidden = dim * mlp_expansion_factor
298
-
299
- for dim_in in dim_inputs:
300
- net = []
301
-
302
- mlp = nn.Sequential(
303
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
- nn.GLU(dim=-1)
305
- )
306
-
307
- self.to_freqs.append(mlp)
308
-
309
- def forward(self, x):
310
- x = x.unbind(dim=-2)
311
-
312
- outs = []
313
-
314
- for band_features, mlp in zip(x, self.to_freqs):
315
- freq_out = mlp(band_features)
316
- outs.append(freq_out)
317
-
318
- return torch.cat(outs, dim=-1)
319
-
320
-
321
- # main class
322
-
323
- class MelBandRoformer(Module):
324
-
325
- @beartype
326
- def __init__(
327
- self,
328
- dim,
329
- *,
330
- depth,
331
- stereo=False,
332
- num_stems=1,
333
- time_transformer_depth=2,
334
- freq_transformer_depth=2,
335
- linear_transformer_depth=0,
336
- num_bands=60,
337
- dim_head=64,
338
- heads=8,
339
- attn_dropout=0.1,
340
- ff_dropout=0.1,
341
- flash_attn=True,
342
- dim_freqs_in=1025,
343
- sample_rate=44100, # needed for mel filter bank from librosa
344
- stft_n_fft=2048,
345
- stft_hop_length=512,
346
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
- stft_win_length=2048,
348
- stft_normalized=False,
349
- stft_window_fn: Optional[Callable] = None,
350
- mask_estimator_depth=1,
351
- multi_stft_resolution_loss_weight=1.,
352
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
- multi_stft_hop_size=147,
354
- multi_stft_normalized=False,
355
- multi_stft_window_fn: Callable = torch.hann_window,
356
- match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
- ):
358
- super().__init__()
359
-
360
- self.stereo = stereo
361
- self.audio_channels = 2 if stereo else 1
362
- self.num_stems = num_stems
363
-
364
- self.layers = ModuleList([])
365
-
366
- transformer_kwargs = dict(
367
- dim=dim,
368
- heads=heads,
369
- dim_head=dim_head,
370
- attn_dropout=attn_dropout,
371
- ff_dropout=ff_dropout,
372
- flash_attn=flash_attn
373
- )
374
-
375
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
-
378
- for _ in range(depth):
379
- tran_modules = []
380
- if linear_transformer_depth > 0:
381
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
- tran_modules.append(
383
- Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
- )
385
- tran_modules.append(
386
- Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
- )
388
- self.layers.append(nn.ModuleList(tran_modules))
389
-
390
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
-
392
- self.stft_kwargs = dict(
393
- n_fft=stft_n_fft,
394
- hop_length=stft_hop_length,
395
- win_length=stft_win_length,
396
- normalized=stft_normalized
397
- )
398
-
399
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
400
-
401
- # create mel filter bank
402
- # with librosa.filters.mel as in section 2 of paper
403
-
404
- mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
-
406
- mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
-
408
- # for some reason, it doesn't include the first freq? just force a value for now
409
-
410
- mel_filter_bank[0][0] = 1.
411
-
412
- # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
- # so let's force a positive value
414
-
415
- mel_filter_bank[-1, -1] = 1.
416
-
417
- # binary as in paper (then estimated masks are averaged for overlapping regions)
418
-
419
- freqs_per_band = mel_filter_bank > 0
420
- assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
-
422
- repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
- freq_indices = repeated_freq_indices[freqs_per_band]
424
-
425
- if stereo:
426
- freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
- freq_indices = freq_indices * 2 + torch.arange(2)
428
- freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
-
430
- self.register_buffer('freq_indices', freq_indices, persistent=False)
431
- self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
-
433
- num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
- num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
-
436
- self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
- self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
-
439
- # band split and mask estimator
440
-
441
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
-
443
- self.band_split = BandSplit(
444
- dim=dim,
445
- dim_inputs=freqs_per_bands_with_complex
446
- )
447
-
448
- self.mask_estimators = nn.ModuleList([])
449
-
450
- for _ in range(num_stems):
451
- mask_estimator = MaskEstimator(
452
- dim=dim,
453
- dim_inputs=freqs_per_bands_with_complex,
454
- depth=mask_estimator_depth
455
- )
456
-
457
- self.mask_estimators.append(mask_estimator)
458
-
459
- # for the multi-resolution stft loss
460
-
461
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
- self.multi_stft_n_fft = stft_n_fft
464
- self.multi_stft_window_fn = multi_stft_window_fn
465
-
466
- self.multi_stft_kwargs = dict(
467
- hop_length=multi_stft_hop_size,
468
- normalized=multi_stft_normalized
469
- )
470
-
471
- self.match_input_audio_length = match_input_audio_length
472
-
473
- def forward(
474
- self,
475
- raw_audio,
476
- target=None,
477
- return_loss_breakdown=False
478
- ):
479
- """
480
- einops
481
-
482
- b - batch
483
- f - freq
484
- t - time
485
- s - audio channel (1 for mono, 2 for stereo)
486
- n - number of 'stems'
487
- c - complex (2)
488
- d - feature dimension
489
- """
490
-
491
- device = raw_audio.device
492
-
493
- if raw_audio.ndim == 2:
494
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
-
496
- batch, channels, raw_audio_length = raw_audio.shape
497
-
498
- istft_length = raw_audio_length if self.match_input_audio_length else None
499
-
500
- assert (not self.stereo and channels == 1) or (
501
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
-
503
- # to stft
504
-
505
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
-
507
- stft_window = self.stft_window_fn(device=device)
508
-
509
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
- stft_repr = torch.view_as_real(stft_repr)
511
-
512
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
- stft_repr = rearrange(stft_repr,
514
- 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
-
516
- # index out all frequencies for all frequency ranges across bands ascending in one go
517
-
518
- batch_arange = torch.arange(batch, device=device)[..., None]
519
-
520
- # account for stereo
521
-
522
- x = stft_repr[batch_arange, self.freq_indices]
523
-
524
- # fold the complex (real and imag) into the frequencies dimension
525
-
526
- x = rearrange(x, 'b f t c -> b t (f c)')
527
-
528
- x = self.band_split(x)
529
-
530
- # axial / hierarchical attention
531
-
532
- for transformer_block in self.layers:
533
-
534
- if len(transformer_block) == 3:
535
- linear_transformer, time_transformer, freq_transformer = transformer_block
536
-
537
- x, ft_ps = pack([x], 'b * d')
538
- x = linear_transformer(x)
539
- x, = unpack(x, ft_ps, 'b * d')
540
- else:
541
- time_transformer, freq_transformer = transformer_block
542
-
543
- x = rearrange(x, 'b t f d -> b f t d')
544
- x, ps = pack([x], '* t d')
545
-
546
- x = time_transformer(x)
547
-
548
- x, = unpack(x, ps, '* t d')
549
- x = rearrange(x, 'b f t d -> b t f d')
550
- x, ps = pack([x], '* f d')
551
-
552
- x = freq_transformer(x)
553
-
554
- x, = unpack(x, ps, '* f d')
555
-
556
- num_stems = len(self.mask_estimators)
557
-
558
- masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
- masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
-
561
- # modulate frequency representation
562
-
563
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
-
565
- # complex number multiplication
566
-
567
- stft_repr = torch.view_as_complex(stft_repr)
568
- masks = torch.view_as_complex(masks)
569
-
570
- masks = masks.type(stft_repr.dtype)
571
-
572
- # need to average the estimated mask for the overlapped frequencies
573
-
574
- scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
-
576
- stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
- masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
-
579
- denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
-
581
- masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
-
583
- # modulate stft repr with estimated mask
584
-
585
- stft_repr = stft_repr * masks_averaged
586
-
587
- # istft
588
-
589
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
-
591
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
- length=istft_length)
593
-
594
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
-
596
- if num_stems == 1:
597
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
-
599
- # if a target is passed in, calculate loss for learning
600
-
601
- if not exists(target):
602
- return recon_audio
603
-
604
- if self.num_stems > 1:
605
- assert target.ndim == 4 and target.shape[1] == self.num_stems
606
-
607
- if target.ndim == 2:
608
- target = rearrange(target, '... t -> ... 1 t')
609
-
610
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
-
612
- loss = F.l1_loss(recon_audio, target)
613
-
614
- multi_stft_resolution_loss = 0.
615
-
616
- for window_size in self.multi_stft_resolutions_window_sizes:
617
- res_stft_kwargs = dict(
618
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
- win_length=window_size,
620
- return_complex=True,
621
- window=self.multi_stft_window_fn(window_size, device=device),
622
- **self.multi_stft_kwargs,
623
- )
624
-
625
- recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
- target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
-
628
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
-
630
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
-
632
- total_loss = loss + weighted_multi_resolution_loss
633
-
634
- if not return_loss_breakdown:
635
- return total_loss
636
-
637
- return total_loss, (loss, multi_stft_resolution_loss)