waidhoferj commited on
Commit
1c22425
1 Parent(s): c95adc4

fixed metrics and weights

Browse files
Files changed (2) hide show
  1. models/utils.py +3 -1
  2. preprocessing/dataset.py +3 -1
models/utils.py CHANGED
@@ -37,7 +37,9 @@ def calculate_metrics(
37
  pred, target, threshold=0.5, prefix="", multi_label=True
38
  ) -> dict[str, torch.Tensor]:
39
  target = target.detach().cpu().numpy()
40
- pred = pred.detach().cpu().numpy()
 
 
41
  params = {
42
  "y_true": target if multi_label else target.argmax(1),
43
  "y_pred": np.array(pred > threshold, dtype=float)
 
37
  pred, target, threshold=0.5, prefix="", multi_label=True
38
  ) -> dict[str, torch.Tensor]:
39
  target = target.detach().cpu().numpy()
40
+ pred = pred.detach().cpu()
41
+ pred = nn.functional.softmax(pred, dim=1)
42
+ pred = pred.numpy()
43
  params = {
44
  "y_true": target if multi_label else target.argmax(1),
45
  "y_pred": np.array(pred > threshold, dtype=float)
preprocessing/dataset.py CHANGED
@@ -80,7 +80,9 @@ class SongDataset(Dataset):
80
 
81
  def get_label_weights(self):
82
  n_examples, n_classes = self.dance_labels.shape
83
- return torch.from_numpy(n_examples / (n_classes * sum(self.dance_labels)))
 
 
84
 
85
  def _backtrace_audio_path(self, index: int) -> str:
86
  return self.audio_paths[self._idx2audio_idx(index)]
 
80
 
81
  def get_label_weights(self):
82
  n_examples, n_classes = self.dance_labels.shape
83
+ weights = n_examples / (n_classes * sum(self.dance_labels))
84
+ weights[np.isinf(weights)] = 0.0
85
+ return torch.from_numpy(weights)
86
 
87
  def _backtrace_audio_path(self, index: int) -> str:
88
  return self.audio_paths[self._idx2audio_idx(index)]