Christian J. Steinmetz commited on
Commit
7ac8557
1 Parent(s): a3e84f7

adding multi-label classification task with CNN

Browse files
Files changed (3) hide show
  1. cfg/model/classifier.yaml +14 -0
  2. cfg/model/umx.yaml +0 -2
  3. remfx/models.py +245 -5
cfg/model/classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 1e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.models.Cnn14
9
+ num_classes: ${num_classes}
10
+ n_fft: 4096
11
+ hop_length: 512
12
+ n_mels: 128
13
+ sample_rate: ${sample_rate}
14
+
cfg/model/umx.yaml CHANGED
@@ -11,7 +11,5 @@ model:
11
  _target_: remfx.models.OpenUnmixModel
12
  n_fft: 2048
13
  hop_length: 512
14
- n_channels: 1
15
- alpha: 0.3
16
  sample_rate: ${sample_rate}
17
 
 
11
  _target_: remfx.models.OpenUnmixModel
12
  n_fft: 2048
13
  hop_length: 512
 
 
14
  sample_rate: ${sample_rate}
15
 
remfx/models.py CHANGED
@@ -1,15 +1,19 @@
 
1
  import torch
2
- from torch import Tensor, nn
 
3
  import pytorch_lightning as pl
 
 
 
4
  from einops import rearrange
5
- import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
9
- from remfx.utils import FADLoss
10
-
11
  from umx.openunmix.model import OpenUnmix, Separator
12
- from torchaudio.models import HDemucs
 
13
 
14
 
15
  class RemFXModel(pl.LightningModule):
@@ -326,3 +330,239 @@ def spectrogram(
326
  X = X.view(bs, chs, X.shape[-2], X.shape[-1])
327
 
328
  return torch.pow(X.abs() + 1e-8, alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
  import torch
3
+ import torchaudio
4
+ import torchmetrics
5
  import pytorch_lightning as pl
6
+ import torch.nn.functional as F
7
+
8
+ from torch import Tensor, nn
9
  from einops import rearrange
10
+ from torchaudio.models import HDemucs
11
  from audio_diffusion_pytorch import DiffusionModel
12
  from auraloss.time import SISDRLoss
13
  from auraloss.freq import MultiResolutionSTFTLoss
 
 
14
  from umx.openunmix.model import OpenUnmix, Separator
15
+
16
+ from remfx.utils import FADLoss
17
 
18
 
19
  class RemFXModel(pl.LightningModule):
 
330
  X = X.view(bs, chs, X.shape[-2], X.shape[-1])
331
 
332
  return torch.pow(X.abs() + 1e-8, alpha)
333
+
334
+
335
+ # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
336
+
337
+
338
+ def init_layer(layer):
339
+ """Initialize a Linear or Convolutional layer."""
340
+ nn.init.xavier_uniform_(layer.weight)
341
+
342
+ if hasattr(layer, "bias"):
343
+ if layer.bias is not None:
344
+ layer.bias.data.fill_(0.0)
345
+
346
+
347
+ def init_bn(bn):
348
+ """Initialize a Batchnorm layer."""
349
+ bn.bias.data.fill_(0.0)
350
+ bn.weight.data.fill_(1.0)
351
+
352
+
353
+ class ConvBlock(nn.Module):
354
+ def __init__(self, in_channels, out_channels):
355
+ super(ConvBlock, self).__init__()
356
+
357
+ self.conv1 = nn.Conv2d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ kernel_size=(3, 3),
361
+ stride=(1, 1),
362
+ padding=(1, 1),
363
+ bias=False,
364
+ )
365
+
366
+ self.conv2 = nn.Conv2d(
367
+ in_channels=out_channels,
368
+ out_channels=out_channels,
369
+ kernel_size=(3, 3),
370
+ stride=(1, 1),
371
+ padding=(1, 1),
372
+ bias=False,
373
+ )
374
+
375
+ self.bn1 = nn.BatchNorm2d(out_channels)
376
+ self.bn2 = nn.BatchNorm2d(out_channels)
377
+
378
+ self.init_weight()
379
+
380
+ def init_weight(self):
381
+ init_layer(self.conv1)
382
+ init_layer(self.conv2)
383
+ init_bn(self.bn1)
384
+ init_bn(self.bn2)
385
+
386
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
387
+ x = input
388
+ x = F.relu_(self.bn1(self.conv1(x)))
389
+ x = F.relu_(self.bn2(self.conv2(x)))
390
+ if pool_type == "max":
391
+ x = F.max_pool2d(x, kernel_size=pool_size)
392
+ elif pool_type == "avg":
393
+ x = F.avg_pool2d(x, kernel_size=pool_size)
394
+ elif pool_type == "avg+max":
395
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
396
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
397
+ x = x1 + x2
398
+ else:
399
+ raise Exception("Incorrect argument!")
400
+
401
+ return x
402
+
403
+
404
+ class Cnn14(nn.Module):
405
+ def __init__(
406
+ self,
407
+ num_classes: int,
408
+ sample_rate: float,
409
+ n_fft: int = 2048,
410
+ hop_length: int = 512,
411
+ n_mels: int = 128,
412
+ ):
413
+ super().__init__()
414
+ self.num_classes = num_classes
415
+ self.n_fft = n_fft
416
+ self.hop_length = hop_length
417
+
418
+ window = torch.hann_window(n_fft)
419
+ self.register_buffer("window", window)
420
+
421
+ self.melspec = torchaudio.transforms.MelSpectrogram(
422
+ sample_rate,
423
+ n_fft,
424
+ hop_length=hop_length,
425
+ n_mels=n_mels,
426
+ )
427
+
428
+ self.bn0 = nn.BatchNorm2d(n_mels)
429
+
430
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
431
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
432
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
433
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
434
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
435
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
436
+
437
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
438
+ self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
439
+
440
+ self.init_weight()
441
+
442
+ def init_weight(self):
443
+ init_bn(self.bn0)
444
+ init_layer(self.fc1)
445
+ init_layer(self.fc_audioset)
446
+
447
+ def forward(self, x: torch.Tensor):
448
+ """
449
+ Input: (batch_size, data_length)"""
450
+
451
+ x = self.melspec(x)
452
+ x = x.permute(0, 2, 1, 3)
453
+ x = self.bn0(x)
454
+ x = x.permute(0, 2, 1, 3)
455
+
456
+ if self.training:
457
+ pass
458
+ # x = self.spec_augmenter(x)
459
+
460
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
461
+ x = F.dropout(x, p=0.2, training=self.training)
462
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
463
+ x = F.dropout(x, p=0.2, training=self.training)
464
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
465
+ x = F.dropout(x, p=0.2, training=self.training)
466
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
467
+ x = F.dropout(x, p=0.2, training=self.training)
468
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
469
+ x = F.dropout(x, p=0.2, training=self.training)
470
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
471
+ x = F.dropout(x, p=0.2, training=self.training)
472
+ x = torch.mean(x, dim=3)
473
+
474
+ (x1, _) = torch.max(x, dim=2)
475
+ x2 = torch.mean(x, dim=2)
476
+ x = x1 + x2
477
+ x = F.dropout(x, p=0.5, training=self.training)
478
+ x = F.relu_(self.fc1(x))
479
+ clipwise_output = self.fc_audioset(x)
480
+
481
+ return clipwise_output
482
+
483
+
484
+ def spectrogram(
485
+ x: torch.Tensor,
486
+ window: torch.Tensor,
487
+ n_fft: int,
488
+ hop_length: int,
489
+ alpha: float,
490
+ ) -> torch.Tensor:
491
+ bs, chs, samp = x.size()
492
+ x = x.view(bs * chs, -1) # move channels onto batch dim
493
+
494
+ X = torch.stft(
495
+ x,
496
+ n_fft=n_fft,
497
+ hop_length=hop_length,
498
+ window=window,
499
+ return_complex=True,
500
+ )
501
+
502
+ # move channels back
503
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
504
+
505
+ return torch.pow(X.abs() + 1e-8, alpha)
506
+
507
+
508
+ class FXClassifier(pl.LightningModule):
509
+ def __init__(
510
+ self,
511
+ lr: float,
512
+ lr_weight_decay: float,
513
+ sample_rate: float,
514
+ network: nn.Module,
515
+ ):
516
+ super().__init__()
517
+ self.lr = lr
518
+ self.lr_weight_decay = lr_weight_decay
519
+ self.sample_rate = sample_rate
520
+ self.network = network
521
+
522
+ def forward(self, x: torch.Tensor):
523
+ return self.network(x)
524
+
525
+ def common_step(self, batch, batch_idx, mode: str = "train"):
526
+ x, y, dry_label, wet_label = batch
527
+ pred_label = self.network(x)
528
+ loss = torch.nn.functional.cross_entropy(pred_label, dry_label)
529
+ self.log(
530
+ f"{mode}_loss",
531
+ loss,
532
+ on_step=True,
533
+ on_epoch=True,
534
+ prog_bar=True,
535
+ logger=True,
536
+ sync_dist=True,
537
+ )
538
+
539
+ self.log(
540
+ f"{mode}_mAP",
541
+ torchmetrics.functional.retrieval_average_precision(
542
+ pred_label, dry_label.long()
543
+ ),
544
+ on_step=True,
545
+ on_epoch=True,
546
+ prog_bar=True,
547
+ logger=True,
548
+ sync_dist=True,
549
+ )
550
+
551
+ return loss
552
+
553
+ def training_step(self, batch, batch_idx):
554
+ return self.common_step(batch, batch_idx, mode="train")
555
+
556
+ def validation_step(self, batch, batch_idx):
557
+ return self.common_step(batch, batch_idx, mode="valid")
558
+
559
+ def test_step(self, batch, batch_idx):
560
+ return self.common_step(batch, batch_idx, mode="test")
561
+
562
+ def configure_optimizers(self):
563
+ optimizer = torch.optim.AdamW(
564
+ self.network.parameters(),
565
+ lr=self.lr,
566
+ weight_decay=self.lr_weight_decay,
567
+ )
568
+ return optimizer