Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
4d2eb76
1
Parent(s):
52db8b0
updates to score logging fx classifier
Browse files- remfx/models.py +36 -13
remfx/models.py
CHANGED
@@ -442,6 +442,7 @@ class FXClassifier(pl.LightningModule):
|
|
442 |
sample_rate: float,
|
443 |
network: nn.Module,
|
444 |
mixup: bool = False,
|
|
|
445 |
):
|
446 |
super().__init__()
|
447 |
self.lr = lr
|
@@ -450,6 +451,9 @@ class FXClassifier(pl.LightningModule):
|
|
450 |
self.network = network
|
451 |
self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
|
452 |
self.mixup = mixup
|
|
|
|
|
|
|
453 |
|
454 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
455 |
5, average="none", multidim_average="global"
|
@@ -461,12 +465,28 @@ class FXClassifier(pl.LightningModule):
|
|
461 |
5, average="none", multidim_average="global"
|
462 |
)
|
463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
self.metrics = {
|
465 |
"train": self.train_f1,
|
466 |
"valid": self.val_f1,
|
467 |
"test": self.test_f1,
|
468 |
}
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
def forward(self, x: torch.Tensor, train: bool = False):
|
471 |
return self.network(x, train=train)
|
472 |
|
@@ -477,12 +497,12 @@ class FXClassifier(pl.LightningModule):
|
|
477 |
if mode == "train" and self.mixup:
|
478 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
479 |
pred_label = self(x_mixed, train)
|
480 |
-
loss =
|
481 |
print(torch.sigmoid(pred_label[0, ...]))
|
482 |
print(label_mixed[0, ...])
|
483 |
else:
|
484 |
pred_label = self(x, train)
|
485 |
-
loss =
|
486 |
print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
|
487 |
print(wet_label.long()[0, ...])
|
488 |
|
@@ -497,17 +517,6 @@ class FXClassifier(pl.LightningModule):
|
|
497 |
)
|
498 |
|
499 |
metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
|
500 |
-
avg_metrics = torch.mean(metrics)
|
501 |
-
|
502 |
-
self.log(
|
503 |
-
f"{mode}_f1_avg",
|
504 |
-
avg_metrics,
|
505 |
-
on_step=True,
|
506 |
-
on_epoch=True,
|
507 |
-
prog_bar=True,
|
508 |
-
logger=True,
|
509 |
-
sync_dist=True,
|
510 |
-
)
|
511 |
|
512 |
for idx, effect_name in enumerate(self.effects):
|
513 |
self.log(
|
@@ -520,6 +529,20 @@ class FXClassifier(pl.LightningModule):
|
|
520 |
sync_dist=True,
|
521 |
)
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
return loss
|
524 |
|
525 |
def training_step(self, batch, batch_idx):
|
|
|
442 |
sample_rate: float,
|
443 |
network: nn.Module,
|
444 |
mixup: bool = False,
|
445 |
+
label_smoothing: float = 0.0,
|
446 |
):
|
447 |
super().__init__()
|
448 |
self.lr = lr
|
|
|
451 |
self.network = network
|
452 |
self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
|
453 |
self.mixup = mixup
|
454 |
+
self.label_smoothing = label_smoothing
|
455 |
+
|
456 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
457 |
|
458 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
459 |
5, average="none", multidim_average="global"
|
|
|
465 |
5, average="none", multidim_average="global"
|
466 |
)
|
467 |
|
468 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
469 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
470 |
+
)
|
471 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
472 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
473 |
+
)
|
474 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
475 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
476 |
+
)
|
477 |
+
|
478 |
self.metrics = {
|
479 |
"train": self.train_f1,
|
480 |
"valid": self.val_f1,
|
481 |
"test": self.test_f1,
|
482 |
}
|
483 |
|
484 |
+
self.avg_metrics = {
|
485 |
+
"train": self.train_f1_avg,
|
486 |
+
"valid": self.val_f1_avg,
|
487 |
+
"test": self.test_f1_avg,
|
488 |
+
}
|
489 |
+
|
490 |
def forward(self, x: torch.Tensor, train: bool = False):
|
491 |
return self.network(x, train=train)
|
492 |
|
|
|
497 |
if mode == "train" and self.mixup:
|
498 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
499 |
pred_label = self(x_mixed, train)
|
500 |
+
loss = self.loss_fn(pred_label, label_mixed)
|
501 |
print(torch.sigmoid(pred_label[0, ...]))
|
502 |
print(label_mixed[0, ...])
|
503 |
else:
|
504 |
pred_label = self(x, train)
|
505 |
+
loss = self.loss_fn(pred_label, wet_label)
|
506 |
print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
|
507 |
print(wet_label.long()[0, ...])
|
508 |
|
|
|
517 |
)
|
518 |
|
519 |
metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
for idx, effect_name in enumerate(self.effects):
|
522 |
self.log(
|
|
|
529 |
sync_dist=True,
|
530 |
)
|
531 |
|
532 |
+
avg_metrics = self.avg_metrics[mode](
|
533 |
+
torch.sigmoid(pred_label), wet_label.long()
|
534 |
+
)
|
535 |
+
|
536 |
+
self.log(
|
537 |
+
f"{mode}_f1_avg",
|
538 |
+
avg_metrics,
|
539 |
+
on_step=True,
|
540 |
+
on_epoch=True,
|
541 |
+
prog_bar=True,
|
542 |
+
logger=True,
|
543 |
+
sync_dist=True,
|
544 |
+
)
|
545 |
+
|
546 |
return loss
|
547 |
|
548 |
def training_step(self, batch, batch_idx):
|