Christian J. Steinmetz commited on
Commit
4d2eb76
·
1 Parent(s): 52db8b0

updates to score logging fx classifier

Browse files
Files changed (1) hide show
  1. remfx/models.py +36 -13
remfx/models.py CHANGED
@@ -442,6 +442,7 @@ class FXClassifier(pl.LightningModule):
442
  sample_rate: float,
443
  network: nn.Module,
444
  mixup: bool = False,
 
445
  ):
446
  super().__init__()
447
  self.lr = lr
@@ -450,6 +451,9 @@ class FXClassifier(pl.LightningModule):
450
  self.network = network
451
  self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
452
  self.mixup = mixup
 
 
 
453
 
454
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
455
  5, average="none", multidim_average="global"
@@ -461,12 +465,28 @@ class FXClassifier(pl.LightningModule):
461
  5, average="none", multidim_average="global"
462
  )
463
 
 
 
 
 
 
 
 
 
 
 
464
  self.metrics = {
465
  "train": self.train_f1,
466
  "valid": self.val_f1,
467
  "test": self.test_f1,
468
  }
469
 
 
 
 
 
 
 
470
  def forward(self, x: torch.Tensor, train: bool = False):
471
  return self.network(x, train=train)
472
 
@@ -477,12 +497,12 @@ class FXClassifier(pl.LightningModule):
477
  if mode == "train" and self.mixup:
478
  x_mixed, label_mixed, lam = mixup(x, wet_label)
479
  pred_label = self(x_mixed, train)
480
- loss = nn.functional.cross_entropy(pred_label, label_mixed)
481
  print(torch.sigmoid(pred_label[0, ...]))
482
  print(label_mixed[0, ...])
483
  else:
484
  pred_label = self(x, train)
485
- loss = nn.functional.cross_entropy(pred_label, wet_label)
486
  print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
487
  print(wet_label.long()[0, ...])
488
 
@@ -497,17 +517,6 @@ class FXClassifier(pl.LightningModule):
497
  )
498
 
499
  metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
500
- avg_metrics = torch.mean(metrics)
501
-
502
- self.log(
503
- f"{mode}_f1_avg",
504
- avg_metrics,
505
- on_step=True,
506
- on_epoch=True,
507
- prog_bar=True,
508
- logger=True,
509
- sync_dist=True,
510
- )
511
 
512
  for idx, effect_name in enumerate(self.effects):
513
  self.log(
@@ -520,6 +529,20 @@ class FXClassifier(pl.LightningModule):
520
  sync_dist=True,
521
  )
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  return loss
524
 
525
  def training_step(self, batch, batch_idx):
 
442
  sample_rate: float,
443
  network: nn.Module,
444
  mixup: bool = False,
445
+ label_smoothing: float = 0.0,
446
  ):
447
  super().__init__()
448
  self.lr = lr
 
451
  self.network = network
452
  self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
453
  self.mixup = mixup
454
+ self.label_smoothing = label_smoothing
455
+
456
+ self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
457
 
458
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
459
  5, average="none", multidim_average="global"
 
465
  5, average="none", multidim_average="global"
466
  )
467
 
468
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
469
+ 5, threshold=0.5, average="macro", multidim_average="global"
470
+ )
471
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
472
+ 5, threshold=0.5, average="macro", multidim_average="global"
473
+ )
474
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
475
+ 5, threshold=0.5, average="macro", multidim_average="global"
476
+ )
477
+
478
  self.metrics = {
479
  "train": self.train_f1,
480
  "valid": self.val_f1,
481
  "test": self.test_f1,
482
  }
483
 
484
+ self.avg_metrics = {
485
+ "train": self.train_f1_avg,
486
+ "valid": self.val_f1_avg,
487
+ "test": self.test_f1_avg,
488
+ }
489
+
490
  def forward(self, x: torch.Tensor, train: bool = False):
491
  return self.network(x, train=train)
492
 
 
497
  if mode == "train" and self.mixup:
498
  x_mixed, label_mixed, lam = mixup(x, wet_label)
499
  pred_label = self(x_mixed, train)
500
+ loss = self.loss_fn(pred_label, label_mixed)
501
  print(torch.sigmoid(pred_label[0, ...]))
502
  print(label_mixed[0, ...])
503
  else:
504
  pred_label = self(x, train)
505
+ loss = self.loss_fn(pred_label, wet_label)
506
  print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
507
  print(wet_label.long()[0, ...])
508
 
 
517
  )
518
 
519
  metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  for idx, effect_name in enumerate(self.effects):
522
  self.log(
 
529
  sync_dist=True,
530
  )
531
 
532
+ avg_metrics = self.avg_metrics[mode](
533
+ torch.sigmoid(pred_label), wet_label.long()
534
+ )
535
+
536
+ self.log(
537
+ f"{mode}_f1_avg",
538
+ avg_metrics,
539
+ on_step=True,
540
+ on_epoch=True,
541
+ prog_bar=True,
542
+ logger=True,
543
+ sync_dist=True,
544
+ )
545
+
546
  return loss
547
 
548
  def training_step(self, batch, batch_idx):