Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
1c22425
1
Parent(s):
c95adc4
fixed metrics and weights
Browse files- models/utils.py +3 -1
- preprocessing/dataset.py +3 -1
models/utils.py
CHANGED
@@ -37,7 +37,9 @@ def calculate_metrics(
|
|
37 |
pred, target, threshold=0.5, prefix="", multi_label=True
|
38 |
) -> dict[str, torch.Tensor]:
|
39 |
target = target.detach().cpu().numpy()
|
40 |
-
pred = pred.detach().cpu()
|
|
|
|
|
41 |
params = {
|
42 |
"y_true": target if multi_label else target.argmax(1),
|
43 |
"y_pred": np.array(pred > threshold, dtype=float)
|
|
|
37 |
pred, target, threshold=0.5, prefix="", multi_label=True
|
38 |
) -> dict[str, torch.Tensor]:
|
39 |
target = target.detach().cpu().numpy()
|
40 |
+
pred = pred.detach().cpu()
|
41 |
+
pred = nn.functional.softmax(pred, dim=1)
|
42 |
+
pred = pred.numpy()
|
43 |
params = {
|
44 |
"y_true": target if multi_label else target.argmax(1),
|
45 |
"y_pred": np.array(pred > threshold, dtype=float)
|
preprocessing/dataset.py
CHANGED
@@ -80,7 +80,9 @@ class SongDataset(Dataset):
|
|
80 |
|
81 |
def get_label_weights(self):
|
82 |
n_examples, n_classes = self.dance_labels.shape
|
83 |
-
|
|
|
|
|
84 |
|
85 |
def _backtrace_audio_path(self, index: int) -> str:
|
86 |
return self.audio_paths[self._idx2audio_idx(index)]
|
|
|
80 |
|
81 |
def get_label_weights(self):
|
82 |
n_examples, n_classes = self.dance_labels.shape
|
83 |
+
weights = n_examples / (n_classes * sum(self.dance_labels))
|
84 |
+
weights[np.isinf(weights)] = 0.0
|
85 |
+
return torch.from_numpy(weights)
|
86 |
|
87 |
def _backtrace_audio_path(self, index: int) -> str:
|
88 |
return self.audio_paths[self._idx2audio_idx(index)]
|