Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
29f23c3
1
Parent(s):
0b430bb
moving to multi-way binary classification task
Browse files- remfx/classifier.py +19 -7
- remfx/models.py +70 -50
remfx/classifier.py
CHANGED
@@ -170,7 +170,11 @@ class Cnn14(nn.Module):
|
|
170 |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
171 |
|
172 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
|
175 |
self.init_weight()
|
176 |
|
@@ -186,7 +190,7 @@ class Cnn14(nn.Module):
|
|
186 |
def init_weight(self):
|
187 |
init_bn(self.bn0)
|
188 |
init_layer(self.fc1)
|
189 |
-
init_layer(self.fc_audioset)
|
190 |
|
191 |
def forward(self, x: torch.Tensor, train: bool = False):
|
192 |
"""
|
@@ -206,9 +210,12 @@ class Cnn14(nn.Module):
|
|
206 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
207 |
# plt.savefig("spec_augment.png", dpi=300)
|
208 |
|
209 |
-
x = x.permute(0, 2, 1, 3)
|
210 |
-
x = self.bn0(x)
|
211 |
-
x = x.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
212 |
|
213 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
214 |
x = F.dropout(x, p=0.2, training=train)
|
@@ -229,9 +236,14 @@ class Cnn14(nn.Module):
|
|
229 |
x = x1 + x2
|
230 |
x = F.dropout(x, p=0.5, training=train)
|
231 |
x = F.relu_(self.fc1(x))
|
232 |
-
clipwise_output = self.fc_audioset(x)
|
233 |
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
class ConvBlock(nn.Module):
|
|
|
170 |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
171 |
|
172 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
173 |
+
|
174 |
+
# self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
175 |
+
self.heads = torch.nn.ModuleList()
|
176 |
+
for _ in range(num_classes):
|
177 |
+
self.heads.append(nn.Linear(2048, 1, bias=True))
|
178 |
|
179 |
self.init_weight()
|
180 |
|
|
|
190 |
def init_weight(self):
|
191 |
init_bn(self.bn0)
|
192 |
init_layer(self.fc1)
|
193 |
+
# init_layer(self.fc_audioset)
|
194 |
|
195 |
def forward(self, x: torch.Tensor, train: bool = False):
|
196 |
"""
|
|
|
210 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
211 |
# plt.savefig("spec_augment.png", dpi=300)
|
212 |
|
213 |
+
# x = x.permute(0, 2, 1, 3)
|
214 |
+
# x = self.bn0(x)
|
215 |
+
# x = x.permute(0, 2, 1, 3)
|
216 |
+
|
217 |
+
# apply standardization
|
218 |
+
x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
|
219 |
|
220 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
221 |
x = F.dropout(x, p=0.2, training=train)
|
|
|
236 |
x = x1 + x2
|
237 |
x = F.dropout(x, p=0.5, training=train)
|
238 |
x = F.relu_(self.fc1(x))
|
|
|
239 |
|
240 |
+
outputs = []
|
241 |
+
for head in self.heads:
|
242 |
+
outputs.append(torch.sigmoid(head(x)))
|
243 |
+
|
244 |
+
# clipwise_output = self.fc_audioset(x)
|
245 |
+
|
246 |
+
return outputs
|
247 |
|
248 |
|
249 |
class ConvBlock(nn.Module):
|
remfx/models.py
CHANGED
@@ -423,13 +423,20 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
|
423 |
"""
|
424 |
batch_size = x.size(0)
|
425 |
if alpha > 0:
|
426 |
-
lam = np.random.beta(alpha, alpha)
|
|
|
|
|
427 |
else:
|
428 |
lam = 1
|
429 |
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
return mixed_x, mixed_y, lam
|
435 |
|
@@ -454,38 +461,52 @@ class FXClassifier(pl.LightningModule):
|
|
454 |
self.label_smoothing = label_smoothing
|
455 |
|
456 |
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
|
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
467 |
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
def forward(self, x: torch.Tensor, train: bool = False):
|
491 |
return self.network(x, train=train)
|
@@ -496,15 +517,15 @@ class FXClassifier(pl.LightningModule):
|
|
496 |
|
497 |
if mode == "train" and self.mixup:
|
498 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
499 |
-
|
500 |
-
loss =
|
501 |
-
|
502 |
-
|
503 |
else:
|
504 |
-
|
505 |
-
loss =
|
506 |
-
|
507 |
-
|
508 |
|
509 |
self.log(
|
510 |
f"{mode}_loss",
|
@@ -516,26 +537,25 @@ class FXClassifier(pl.LightningModule):
|
|
516 |
sync_dist=True,
|
517 |
)
|
518 |
|
519 |
-
|
520 |
-
|
521 |
for idx, effect_name in enumerate(self.effects):
|
|
|
|
|
|
|
522 |
self.log(
|
523 |
-
f"{mode}
|
524 |
-
|
525 |
on_step=True,
|
526 |
on_epoch=True,
|
527 |
prog_bar=True,
|
528 |
logger=True,
|
529 |
sync_dist=True,
|
530 |
)
|
531 |
-
|
532 |
-
avg_metrics = self.avg_metrics[mode](
|
533 |
-
torch.sigmoid(pred_label), wet_label.long()
|
534 |
-
)
|
535 |
|
536 |
self.log(
|
537 |
-
f"{mode}
|
538 |
-
|
539 |
on_step=True,
|
540 |
on_epoch=True,
|
541 |
prog_bar=True,
|
|
|
423 |
"""
|
424 |
batch_size = x.size(0)
|
425 |
if alpha > 0:
|
426 |
+
# lam = np.random.beta(alpha, alpha)
|
427 |
+
lam = np.random.uniform(0.25, 0.75, batch_size)
|
428 |
+
lam = torch.from_numpy(lam).float().to(x.device).view(batch_size, 1, 1)
|
429 |
else:
|
430 |
lam = 1
|
431 |
|
432 |
+
print(lam)
|
433 |
+
if np.random.rand() > 0.5:
|
434 |
+
index = torch.randperm(batch_size).to(x.device)
|
435 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
436 |
+
mixed_y = torch.logical_or(y, y[index, :]).float()
|
437 |
+
else:
|
438 |
+
mixed_x = x
|
439 |
+
mixed_y = y
|
440 |
|
441 |
return mixed_x, mixed_y, lam
|
442 |
|
|
|
461 |
self.label_smoothing = label_smoothing
|
462 |
|
463 |
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
464 |
+
self.loss_fn = torch.nn.BCELoss()
|
465 |
|
466 |
+
if False:
|
467 |
+
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
468 |
+
5, average="none", multidim_average="global"
|
469 |
+
)
|
470 |
+
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
471 |
+
5, average="none", multidim_average="global"
|
472 |
+
)
|
473 |
+
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
474 |
+
5, average="none", multidim_average="global"
|
475 |
+
)
|
476 |
|
477 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
478 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
479 |
+
)
|
480 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
481 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
482 |
+
)
|
483 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
484 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
485 |
+
)
|
486 |
|
487 |
+
self.metrics = {
|
488 |
+
"train": self.train_acc,
|
489 |
+
"valid": self.val_acc,
|
490 |
+
"test": self.test_acc,
|
491 |
+
}
|
492 |
|
493 |
+
self.avg_metrics = {
|
494 |
+
"train": self.train_f1_avg,
|
495 |
+
"valid": self.val_f1_avg,
|
496 |
+
"test": self.test_f1_avg,
|
497 |
+
}
|
498 |
+
|
499 |
+
self.metrics = torch.nn.ModuleDict()
|
500 |
+
for effect in self.effects:
|
501 |
+
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
502 |
+
task="binary"
|
503 |
+
)
|
504 |
+
self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
505 |
+
task="binary"
|
506 |
+
)
|
507 |
+
self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
508 |
+
task="binary"
|
509 |
+
)
|
510 |
|
511 |
def forward(self, x: torch.Tensor, train: bool = False):
|
512 |
return self.network(x, train=train)
|
|
|
517 |
|
518 |
if mode == "train" and self.mixup:
|
519 |
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
520 |
+
outputs = self(x_mixed, train)
|
521 |
+
loss = 0
|
522 |
+
for idx, output in enumerate(outputs):
|
523 |
+
loss += self.loss_fn(output.squeeze(-1), label_mixed[..., idx])
|
524 |
else:
|
525 |
+
outputs = self(x, train)
|
526 |
+
loss = 0
|
527 |
+
for idx, output in enumerate(outputs):
|
528 |
+
loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
|
529 |
|
530 |
self.log(
|
531 |
f"{mode}_loss",
|
|
|
537 |
sync_dist=True,
|
538 |
)
|
539 |
|
540 |
+
acc_metrics = []
|
|
|
541 |
for idx, effect_name in enumerate(self.effects):
|
542 |
+
acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
|
543 |
+
outputs[idx].squeeze(-1), wet_label[..., idx]
|
544 |
+
)
|
545 |
self.log(
|
546 |
+
f"{mode}_{effect_name}_acc",
|
547 |
+
acc_metric,
|
548 |
on_step=True,
|
549 |
on_epoch=True,
|
550 |
prog_bar=True,
|
551 |
logger=True,
|
552 |
sync_dist=True,
|
553 |
)
|
554 |
+
acc_metrics.append(acc_metric)
|
|
|
|
|
|
|
555 |
|
556 |
self.log(
|
557 |
+
f"{mode}_avg_acc",
|
558 |
+
torch.mean(torch.stack(acc_metrics)),
|
559 |
on_step=True,
|
560 |
on_epoch=True,
|
561 |
prog_bar=True,
|