Christian J. Steinmetz commited on
Commit
29f23c3
·
1 Parent(s): 0b430bb

moving to multi-way binary classification task

Browse files
Files changed (2) hide show
  1. remfx/classifier.py +19 -7
  2. remfx/models.py +70 -50
remfx/classifier.py CHANGED
@@ -170,7 +170,11 @@ class Cnn14(nn.Module):
170
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
 
 
 
 
174
 
175
  self.init_weight()
176
 
@@ -186,7 +190,7 @@ class Cnn14(nn.Module):
186
  def init_weight(self):
187
  init_bn(self.bn0)
188
  init_layer(self.fc1)
189
- init_layer(self.fc_audioset)
190
 
191
  def forward(self, x: torch.Tensor, train: bool = False):
192
  """
@@ -206,9 +210,12 @@ class Cnn14(nn.Module):
206
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
207
  # plt.savefig("spec_augment.png", dpi=300)
208
 
209
- x = x.permute(0, 2, 1, 3)
210
- x = self.bn0(x)
211
- x = x.permute(0, 2, 1, 3)
 
 
 
212
 
213
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
214
  x = F.dropout(x, p=0.2, training=train)
@@ -229,9 +236,14 @@ class Cnn14(nn.Module):
229
  x = x1 + x2
230
  x = F.dropout(x, p=0.5, training=train)
231
  x = F.relu_(self.fc1(x))
232
- clipwise_output = self.fc_audioset(x)
233
 
234
- return clipwise_output
 
 
 
 
 
 
235
 
236
 
237
  class ConvBlock(nn.Module):
 
170
  self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
+
174
+ # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
175
+ self.heads = torch.nn.ModuleList()
176
+ for _ in range(num_classes):
177
+ self.heads.append(nn.Linear(2048, 1, bias=True))
178
 
179
  self.init_weight()
180
 
 
190
  def init_weight(self):
191
  init_bn(self.bn0)
192
  init_layer(self.fc1)
193
+ # init_layer(self.fc_audioset)
194
 
195
  def forward(self, x: torch.Tensor, train: bool = False):
196
  """
 
210
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
211
  # plt.savefig("spec_augment.png", dpi=300)
212
 
213
+ # x = x.permute(0, 2, 1, 3)
214
+ # x = self.bn0(x)
215
+ # x = x.permute(0, 2, 1, 3)
216
+
217
+ # apply standardization
218
+ x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
219
 
220
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
221
  x = F.dropout(x, p=0.2, training=train)
 
236
  x = x1 + x2
237
  x = F.dropout(x, p=0.5, training=train)
238
  x = F.relu_(self.fc1(x))
 
239
 
240
+ outputs = []
241
+ for head in self.heads:
242
+ outputs.append(torch.sigmoid(head(x)))
243
+
244
+ # clipwise_output = self.fc_audioset(x)
245
+
246
+ return outputs
247
 
248
 
249
  class ConvBlock(nn.Module):
remfx/models.py CHANGED
@@ -423,13 +423,20 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
423
  """
424
  batch_size = x.size(0)
425
  if alpha > 0:
426
- lam = np.random.beta(alpha, alpha)
 
 
427
  else:
428
  lam = 1
429
 
430
- index = torch.randperm(batch_size).to(x.device)
431
- mixed_x = lam * x + (1 - lam) * x[index, :]
432
- mixed_y = lam * y + (1 - lam) * y[index, :]
 
 
 
 
 
433
 
434
  return mixed_x, mixed_y, lam
435
 
@@ -454,38 +461,52 @@ class FXClassifier(pl.LightningModule):
454
  self.label_smoothing = label_smoothing
455
 
456
  self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
 
457
 
458
- self.train_f1 = torchmetrics.classification.MultilabelF1Score(
459
- 5, average="none", multidim_average="global"
460
- )
461
- self.val_f1 = torchmetrics.classification.MultilabelF1Score(
462
- 5, average="none", multidim_average="global"
463
- )
464
- self.test_f1 = torchmetrics.classification.MultilabelF1Score(
465
- 5, average="none", multidim_average="global"
466
- )
 
467
 
468
- self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
469
- 5, threshold=0.5, average="macro", multidim_average="global"
470
- )
471
- self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
472
- 5, threshold=0.5, average="macro", multidim_average="global"
473
- )
474
- self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
475
- 5, threshold=0.5, average="macro", multidim_average="global"
476
- )
477
 
478
- self.metrics = {
479
- "train": self.train_f1,
480
- "valid": self.val_f1,
481
- "test": self.test_f1,
482
- }
483
 
484
- self.avg_metrics = {
485
- "train": self.train_f1_avg,
486
- "valid": self.val_f1_avg,
487
- "test": self.test_f1_avg,
488
- }
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  def forward(self, x: torch.Tensor, train: bool = False):
491
  return self.network(x, train=train)
@@ -496,15 +517,15 @@ class FXClassifier(pl.LightningModule):
496
 
497
  if mode == "train" and self.mixup:
498
  x_mixed, label_mixed, lam = mixup(x, wet_label)
499
- pred_label = self(x_mixed, train)
500
- loss = self.loss_fn(pred_label, label_mixed)
501
- print(torch.sigmoid(pred_label[0, ...]))
502
- print(label_mixed[0, ...])
503
  else:
504
- pred_label = self(x, train)
505
- loss = self.loss_fn(pred_label, wet_label)
506
- print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
507
- print(wet_label.long()[0, ...])
508
 
509
  self.log(
510
  f"{mode}_loss",
@@ -516,26 +537,25 @@ class FXClassifier(pl.LightningModule):
516
  sync_dist=True,
517
  )
518
 
519
- metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
520
-
521
  for idx, effect_name in enumerate(self.effects):
 
 
 
522
  self.log(
523
- f"{mode}_f1_{effect_name}",
524
- metrics[idx],
525
  on_step=True,
526
  on_epoch=True,
527
  prog_bar=True,
528
  logger=True,
529
  sync_dist=True,
530
  )
531
-
532
- avg_metrics = self.avg_metrics[mode](
533
- torch.sigmoid(pred_label), wet_label.long()
534
- )
535
 
536
  self.log(
537
- f"{mode}_f1_avg",
538
- avg_metrics,
539
  on_step=True,
540
  on_epoch=True,
541
  prog_bar=True,
 
423
  """
424
  batch_size = x.size(0)
425
  if alpha > 0:
426
+ # lam = np.random.beta(alpha, alpha)
427
+ lam = np.random.uniform(0.25, 0.75, batch_size)
428
+ lam = torch.from_numpy(lam).float().to(x.device).view(batch_size, 1, 1)
429
  else:
430
  lam = 1
431
 
432
+ print(lam)
433
+ if np.random.rand() > 0.5:
434
+ index = torch.randperm(batch_size).to(x.device)
435
+ mixed_x = lam * x + (1 - lam) * x[index, :]
436
+ mixed_y = torch.logical_or(y, y[index, :]).float()
437
+ else:
438
+ mixed_x = x
439
+ mixed_y = y
440
 
441
  return mixed_x, mixed_y, lam
442
 
 
461
  self.label_smoothing = label_smoothing
462
 
463
  self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
464
+ self.loss_fn = torch.nn.BCELoss()
465
 
466
+ if False:
467
+ self.train_f1 = torchmetrics.classification.MultilabelF1Score(
468
+ 5, average="none", multidim_average="global"
469
+ )
470
+ self.val_f1 = torchmetrics.classification.MultilabelF1Score(
471
+ 5, average="none", multidim_average="global"
472
+ )
473
+ self.test_f1 = torchmetrics.classification.MultilabelF1Score(
474
+ 5, average="none", multidim_average="global"
475
+ )
476
 
477
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
478
+ 5, threshold=0.5, average="macro", multidim_average="global"
479
+ )
480
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
481
+ 5, threshold=0.5, average="macro", multidim_average="global"
482
+ )
483
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
484
+ 5, threshold=0.5, average="macro", multidim_average="global"
485
+ )
486
 
487
+ self.metrics = {
488
+ "train": self.train_acc,
489
+ "valid": self.val_acc,
490
+ "test": self.test_acc,
491
+ }
492
 
493
+ self.avg_metrics = {
494
+ "train": self.train_f1_avg,
495
+ "valid": self.val_f1_avg,
496
+ "test": self.test_f1_avg,
497
+ }
498
+
499
+ self.metrics = torch.nn.ModuleDict()
500
+ for effect in self.effects:
501
+ self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
502
+ task="binary"
503
+ )
504
+ self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
505
+ task="binary"
506
+ )
507
+ self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
508
+ task="binary"
509
+ )
510
 
511
  def forward(self, x: torch.Tensor, train: bool = False):
512
  return self.network(x, train=train)
 
517
 
518
  if mode == "train" and self.mixup:
519
  x_mixed, label_mixed, lam = mixup(x, wet_label)
520
+ outputs = self(x_mixed, train)
521
+ loss = 0
522
+ for idx, output in enumerate(outputs):
523
+ loss += self.loss_fn(output.squeeze(-1), label_mixed[..., idx])
524
  else:
525
+ outputs = self(x, train)
526
+ loss = 0
527
+ for idx, output in enumerate(outputs):
528
+ loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
529
 
530
  self.log(
531
  f"{mode}_loss",
 
537
  sync_dist=True,
538
  )
539
 
540
+ acc_metrics = []
 
541
  for idx, effect_name in enumerate(self.effects):
542
+ acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
543
+ outputs[idx].squeeze(-1), wet_label[..., idx]
544
+ )
545
  self.log(
546
+ f"{mode}_{effect_name}_acc",
547
+ acc_metric,
548
  on_step=True,
549
  on_epoch=True,
550
  prog_bar=True,
551
  logger=True,
552
  sync_dist=True,
553
  )
554
+ acc_metrics.append(acc_metric)
 
 
 
555
 
556
  self.log(
557
+ f"{mode}_avg_acc",
558
+ torch.mean(torch.stack(acc_metrics)),
559
  on_step=True,
560
  on_epoch=True,
561
  prog_bar=True,