Spaces:
Runtime error
Runtime error
Christian J. Steinmetz
commited on
Commit
•
7ac8557
1
Parent(s):
a3e84f7
adding multi-label classification task with CNN
Browse files- cfg/model/classifier.yaml +14 -0
- cfg/model/umx.yaml +0 -2
- remfx/models.py +245 -5
cfg/model/classifier.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 1e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
network:
|
8 |
+
_target_: remfx.models.Cnn14
|
9 |
+
num_classes: ${num_classes}
|
10 |
+
n_fft: 4096
|
11 |
+
hop_length: 512
|
12 |
+
n_mels: 128
|
13 |
+
sample_rate: ${sample_rate}
|
14 |
+
|
cfg/model/umx.yaml
CHANGED
@@ -11,7 +11,5 @@ model:
|
|
11 |
_target_: remfx.models.OpenUnmixModel
|
12 |
n_fft: 2048
|
13 |
hop_length: 512
|
14 |
-
n_channels: 1
|
15 |
-
alpha: 0.3
|
16 |
sample_rate: ${sample_rate}
|
17 |
|
|
|
11 |
_target_: remfx.models.OpenUnmixModel
|
12 |
n_fft: 2048
|
13 |
hop_length: 512
|
|
|
|
|
14 |
sample_rate: ${sample_rate}
|
15 |
|
remfx/models.py
CHANGED
@@ -1,15 +1,19 @@
|
|
|
|
1 |
import torch
|
2 |
-
|
|
|
3 |
import pytorch_lightning as pl
|
|
|
|
|
|
|
4 |
from einops import rearrange
|
5 |
-
import
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
-
from remfx.utils import FADLoss
|
10 |
-
|
11 |
from umx.openunmix.model import OpenUnmix, Separator
|
12 |
-
|
|
|
13 |
|
14 |
|
15 |
class RemFXModel(pl.LightningModule):
|
@@ -326,3 +330,239 @@ def spectrogram(
|
|
326 |
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
|
327 |
|
328 |
return torch.pow(X.abs() + 1e-8, alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
import torch
|
3 |
+
import torchaudio
|
4 |
+
import torchmetrics
|
5 |
import pytorch_lightning as pl
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from torch import Tensor, nn
|
9 |
from einops import rearrange
|
10 |
+
from torchaudio.models import HDemucs
|
11 |
from audio_diffusion_pytorch import DiffusionModel
|
12 |
from auraloss.time import SISDRLoss
|
13 |
from auraloss.freq import MultiResolutionSTFTLoss
|
|
|
|
|
14 |
from umx.openunmix.model import OpenUnmix, Separator
|
15 |
+
|
16 |
+
from remfx.utils import FADLoss
|
17 |
|
18 |
|
19 |
class RemFXModel(pl.LightningModule):
|
|
|
330 |
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
|
331 |
|
332 |
return torch.pow(X.abs() + 1e-8, alpha)
|
333 |
+
|
334 |
+
|
335 |
+
# adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
|
336 |
+
|
337 |
+
|
338 |
+
def init_layer(layer):
|
339 |
+
"""Initialize a Linear or Convolutional layer."""
|
340 |
+
nn.init.xavier_uniform_(layer.weight)
|
341 |
+
|
342 |
+
if hasattr(layer, "bias"):
|
343 |
+
if layer.bias is not None:
|
344 |
+
layer.bias.data.fill_(0.0)
|
345 |
+
|
346 |
+
|
347 |
+
def init_bn(bn):
|
348 |
+
"""Initialize a Batchnorm layer."""
|
349 |
+
bn.bias.data.fill_(0.0)
|
350 |
+
bn.weight.data.fill_(1.0)
|
351 |
+
|
352 |
+
|
353 |
+
class ConvBlock(nn.Module):
|
354 |
+
def __init__(self, in_channels, out_channels):
|
355 |
+
super(ConvBlock, self).__init__()
|
356 |
+
|
357 |
+
self.conv1 = nn.Conv2d(
|
358 |
+
in_channels=in_channels,
|
359 |
+
out_channels=out_channels,
|
360 |
+
kernel_size=(3, 3),
|
361 |
+
stride=(1, 1),
|
362 |
+
padding=(1, 1),
|
363 |
+
bias=False,
|
364 |
+
)
|
365 |
+
|
366 |
+
self.conv2 = nn.Conv2d(
|
367 |
+
in_channels=out_channels,
|
368 |
+
out_channels=out_channels,
|
369 |
+
kernel_size=(3, 3),
|
370 |
+
stride=(1, 1),
|
371 |
+
padding=(1, 1),
|
372 |
+
bias=False,
|
373 |
+
)
|
374 |
+
|
375 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
376 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
377 |
+
|
378 |
+
self.init_weight()
|
379 |
+
|
380 |
+
def init_weight(self):
|
381 |
+
init_layer(self.conv1)
|
382 |
+
init_layer(self.conv2)
|
383 |
+
init_bn(self.bn1)
|
384 |
+
init_bn(self.bn2)
|
385 |
+
|
386 |
+
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
|
387 |
+
x = input
|
388 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
389 |
+
x = F.relu_(self.bn2(self.conv2(x)))
|
390 |
+
if pool_type == "max":
|
391 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
392 |
+
elif pool_type == "avg":
|
393 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
394 |
+
elif pool_type == "avg+max":
|
395 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
396 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
397 |
+
x = x1 + x2
|
398 |
+
else:
|
399 |
+
raise Exception("Incorrect argument!")
|
400 |
+
|
401 |
+
return x
|
402 |
+
|
403 |
+
|
404 |
+
class Cnn14(nn.Module):
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
num_classes: int,
|
408 |
+
sample_rate: float,
|
409 |
+
n_fft: int = 2048,
|
410 |
+
hop_length: int = 512,
|
411 |
+
n_mels: int = 128,
|
412 |
+
):
|
413 |
+
super().__init__()
|
414 |
+
self.num_classes = num_classes
|
415 |
+
self.n_fft = n_fft
|
416 |
+
self.hop_length = hop_length
|
417 |
+
|
418 |
+
window = torch.hann_window(n_fft)
|
419 |
+
self.register_buffer("window", window)
|
420 |
+
|
421 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
422 |
+
sample_rate,
|
423 |
+
n_fft,
|
424 |
+
hop_length=hop_length,
|
425 |
+
n_mels=n_mels,
|
426 |
+
)
|
427 |
+
|
428 |
+
self.bn0 = nn.BatchNorm2d(n_mels)
|
429 |
+
|
430 |
+
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
431 |
+
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
432 |
+
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
433 |
+
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
434 |
+
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
435 |
+
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
436 |
+
|
437 |
+
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
438 |
+
self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
439 |
+
|
440 |
+
self.init_weight()
|
441 |
+
|
442 |
+
def init_weight(self):
|
443 |
+
init_bn(self.bn0)
|
444 |
+
init_layer(self.fc1)
|
445 |
+
init_layer(self.fc_audioset)
|
446 |
+
|
447 |
+
def forward(self, x: torch.Tensor):
|
448 |
+
"""
|
449 |
+
Input: (batch_size, data_length)"""
|
450 |
+
|
451 |
+
x = self.melspec(x)
|
452 |
+
x = x.permute(0, 2, 1, 3)
|
453 |
+
x = self.bn0(x)
|
454 |
+
x = x.permute(0, 2, 1, 3)
|
455 |
+
|
456 |
+
if self.training:
|
457 |
+
pass
|
458 |
+
# x = self.spec_augmenter(x)
|
459 |
+
|
460 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
461 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
462 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
|
463 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
464 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
|
465 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
466 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
|
467 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
468 |
+
x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
|
469 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
470 |
+
x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
|
471 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
472 |
+
x = torch.mean(x, dim=3)
|
473 |
+
|
474 |
+
(x1, _) = torch.max(x, dim=2)
|
475 |
+
x2 = torch.mean(x, dim=2)
|
476 |
+
x = x1 + x2
|
477 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
478 |
+
x = F.relu_(self.fc1(x))
|
479 |
+
clipwise_output = self.fc_audioset(x)
|
480 |
+
|
481 |
+
return clipwise_output
|
482 |
+
|
483 |
+
|
484 |
+
def spectrogram(
|
485 |
+
x: torch.Tensor,
|
486 |
+
window: torch.Tensor,
|
487 |
+
n_fft: int,
|
488 |
+
hop_length: int,
|
489 |
+
alpha: float,
|
490 |
+
) -> torch.Tensor:
|
491 |
+
bs, chs, samp = x.size()
|
492 |
+
x = x.view(bs * chs, -1) # move channels onto batch dim
|
493 |
+
|
494 |
+
X = torch.stft(
|
495 |
+
x,
|
496 |
+
n_fft=n_fft,
|
497 |
+
hop_length=hop_length,
|
498 |
+
window=window,
|
499 |
+
return_complex=True,
|
500 |
+
)
|
501 |
+
|
502 |
+
# move channels back
|
503 |
+
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
|
504 |
+
|
505 |
+
return torch.pow(X.abs() + 1e-8, alpha)
|
506 |
+
|
507 |
+
|
508 |
+
class FXClassifier(pl.LightningModule):
|
509 |
+
def __init__(
|
510 |
+
self,
|
511 |
+
lr: float,
|
512 |
+
lr_weight_decay: float,
|
513 |
+
sample_rate: float,
|
514 |
+
network: nn.Module,
|
515 |
+
):
|
516 |
+
super().__init__()
|
517 |
+
self.lr = lr
|
518 |
+
self.lr_weight_decay = lr_weight_decay
|
519 |
+
self.sample_rate = sample_rate
|
520 |
+
self.network = network
|
521 |
+
|
522 |
+
def forward(self, x: torch.Tensor):
|
523 |
+
return self.network(x)
|
524 |
+
|
525 |
+
def common_step(self, batch, batch_idx, mode: str = "train"):
|
526 |
+
x, y, dry_label, wet_label = batch
|
527 |
+
pred_label = self.network(x)
|
528 |
+
loss = torch.nn.functional.cross_entropy(pred_label, dry_label)
|
529 |
+
self.log(
|
530 |
+
f"{mode}_loss",
|
531 |
+
loss,
|
532 |
+
on_step=True,
|
533 |
+
on_epoch=True,
|
534 |
+
prog_bar=True,
|
535 |
+
logger=True,
|
536 |
+
sync_dist=True,
|
537 |
+
)
|
538 |
+
|
539 |
+
self.log(
|
540 |
+
f"{mode}_mAP",
|
541 |
+
torchmetrics.functional.retrieval_average_precision(
|
542 |
+
pred_label, dry_label.long()
|
543 |
+
),
|
544 |
+
on_step=True,
|
545 |
+
on_epoch=True,
|
546 |
+
prog_bar=True,
|
547 |
+
logger=True,
|
548 |
+
sync_dist=True,
|
549 |
+
)
|
550 |
+
|
551 |
+
return loss
|
552 |
+
|
553 |
+
def training_step(self, batch, batch_idx):
|
554 |
+
return self.common_step(batch, batch_idx, mode="train")
|
555 |
+
|
556 |
+
def validation_step(self, batch, batch_idx):
|
557 |
+
return self.common_step(batch, batch_idx, mode="valid")
|
558 |
+
|
559 |
+
def test_step(self, batch, batch_idx):
|
560 |
+
return self.common_step(batch, batch_idx, mode="test")
|
561 |
+
|
562 |
+
def configure_optimizers(self):
|
563 |
+
optimizer = torch.optim.AdamW(
|
564 |
+
self.network.parameters(),
|
565 |
+
lr=self.lr,
|
566 |
+
weight_decay=self.lr_weight_decay,
|
567 |
+
)
|
568 |
+
return optimizer
|