Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
17a2a7d
1
Parent(s):
ad4c4e2
added evaulations
Browse files- models/training_environment.py +55 -9
- models/utils.py +29 -2
models/training_environment.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
import importlib
|
2 |
-
from models.utils import calculate_metrics
|
3 |
-
|
4 |
from abc import ABC, abstractmethod
|
5 |
import pytorch_lightning as pl
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class TrainingEnvironment(pl.LightningModule):
|
@@ -27,8 +33,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
27 |
config["training_environment"].get("loggers", {})
|
28 |
)
|
29 |
self.config = config
|
30 |
-
self.has_multi_label_predictions = (
|
31 |
-
|
32 |
)
|
33 |
self.save_hyperparameters(
|
34 |
{
|
@@ -44,6 +50,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
44 |
) -> torch.Tensor:
|
45 |
features, labels = batch
|
46 |
outputs = self.model(features)
|
|
|
|
|
47 |
loss = self.criterion(outputs, labels)
|
48 |
metrics = calculate_metrics(
|
49 |
outputs,
|
@@ -62,6 +70,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
62 |
):
|
63 |
x, y = batch
|
64 |
preds = self.model(x)
|
|
|
|
|
65 |
metrics = calculate_metrics(
|
66 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
67 |
)
|
@@ -71,12 +81,48 @@ class TrainingEnvironment(pl.LightningModule):
|
|
71 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
72 |
x, y = batch
|
73 |
preds = self.model(x)
|
74 |
-
self.
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def configure_optimizers(self):
|
82 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
|
1 |
import importlib
|
2 |
+
from models.utils import calculate_metrics, plot_to_image, get_dance_mapping
|
3 |
+
import numpy as np
|
4 |
from abc import ABC, abstractmethod
|
5 |
import pytorch_lightning as pl
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
+
from sklearn.metrics import (
|
10 |
+
roc_auc_score,
|
11 |
+
confusion_matrix,
|
12 |
+
ConfusionMatrixDisplay,
|
13 |
+
)
|
14 |
|
15 |
|
16 |
class TrainingEnvironment(pl.LightningModule):
|
|
|
33 |
config["training_environment"].get("loggers", {})
|
34 |
)
|
35 |
self.config = config
|
36 |
+
self.has_multi_label_predictions = not (
|
37 |
+
type(criterion).__name__ == "CrossEntropyLoss"
|
38 |
)
|
39 |
self.save_hyperparameters(
|
40 |
{
|
|
|
50 |
) -> torch.Tensor:
|
51 |
features, labels = batch
|
52 |
outputs = self.model(features)
|
53 |
+
if self.has_multi_label_predictions:
|
54 |
+
outputs = nn.functional.sigmoid(outputs)
|
55 |
loss = self.criterion(outputs, labels)
|
56 |
metrics = calculate_metrics(
|
57 |
outputs,
|
|
|
70 |
):
|
71 |
x, y = batch
|
72 |
preds = self.model(x)
|
73 |
+
if self.has_multi_label_predictions:
|
74 |
+
preds = nn.functional.sigmoid(preds)
|
75 |
metrics = calculate_metrics(
|
76 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
77 |
)
|
|
|
81 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
82 |
x, y = batch
|
83 |
preds = self.model(x)
|
84 |
+
if self.has_multi_label_predictions:
|
85 |
+
preds = nn.functional.sigmoid(preds)
|
86 |
+
metrics = calculate_metrics(
|
87 |
+
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
|
88 |
+
)
|
89 |
+
if not self.has_multi_label_predictions:
|
90 |
+
preds = nn.functional.softmax(preds, dim=1)
|
91 |
+
y = y.detach().cpu().numpy()
|
92 |
+
preds = preds.detach().cpu().numpy()
|
93 |
+
# ROC-auc score
|
94 |
+
try:
|
95 |
+
metrics["test/roc_auc_score"] = torch.tensor(
|
96 |
+
roc_auc_score(y, preds), dtype=torch.float32
|
97 |
+
)
|
98 |
+
except ValueError:
|
99 |
+
# If there is only one class, roc_auc_score will throw an error
|
100 |
+
pass
|
101 |
+
|
102 |
+
pass
|
103 |
+
self.log_dict(metrics, prog_bar=True)
|
104 |
+
# Create confusion matrix
|
105 |
+
|
106 |
+
preds = preds.argmax(axis=1)
|
107 |
+
y = y.argmax(axis=1)
|
108 |
+
cm = confusion_matrix(
|
109 |
+
preds, y, normalize="all", labels=np.arange(len(self.config["dance_ids"]))
|
110 |
)
|
111 |
+
if hasattr(self, "test_cm"):
|
112 |
+
self.test_cm += cm
|
113 |
+
else:
|
114 |
+
self.test_cm = cm
|
115 |
+
|
116 |
+
def on_test_end(self):
|
117 |
+
dance_ids = sorted(self.config["dance_ids"])
|
118 |
+
np.fill_diagonal(self.test_cm, 0)
|
119 |
+
cm = self.test_cm / self.test_cm.max()
|
120 |
+
ConfusionMatrixDisplay(cm, display_labels=dance_ids).plot()
|
121 |
+
image = plot_to_image(plt.gcf())
|
122 |
+
image = torch.tensor(image, dtype=torch.uint8)
|
123 |
+
image = image.permute(2, 0, 1)
|
124 |
+
self.logger.experiment.add_image("test/confusion_matrix", image, 0)
|
125 |
+
delattr(self, "test_cm")
|
126 |
|
127 |
def configure_optimizers(self):
|
128 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
models/utils.py
CHANGED
@@ -2,6 +2,11 @@ import torch.nn as nn
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class LabelWeightedBCELoss(nn.Module):
|
@@ -38,10 +43,13 @@ def calculate_metrics(
|
|
38 |
) -> dict[str, torch.Tensor]:
|
39 |
target = target.detach().cpu().numpy()
|
40 |
pred = pred.detach().cpu()
|
41 |
-
|
|
|
42 |
pred = pred.numpy()
|
43 |
params = {
|
44 |
-
"y_true": target
|
|
|
|
|
45 |
"y_pred": np.array(pred > threshold, dtype=float)
|
46 |
if multi_label
|
47 |
else pred.argmax(1),
|
@@ -85,3 +93,22 @@ def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
|
|
85 |
def compute_hf_metrics(eval_pred):
|
86 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
87 |
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
5 |
+
from functools import cache
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import io
|
9 |
+
from PIL import Image
|
10 |
|
11 |
|
12 |
class LabelWeightedBCELoss(nn.Module):
|
|
|
43 |
) -> dict[str, torch.Tensor]:
|
44 |
target = target.detach().cpu().numpy()
|
45 |
pred = pred.detach().cpu()
|
46 |
+
if not multi_label:
|
47 |
+
pred = nn.functional.softmax(pred, dim=1)
|
48 |
pred = pred.numpy()
|
49 |
params = {
|
50 |
+
"y_true": np.array(target > 0.0, dtype=float)
|
51 |
+
if multi_label
|
52 |
+
else target.argmax(1),
|
53 |
"y_pred": np.array(pred > threshold, dtype=float)
|
54 |
if multi_label
|
55 |
else pred.argmax(1),
|
|
|
93 |
def compute_hf_metrics(eval_pred):
|
94 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
95 |
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
|
96 |
+
|
97 |
+
|
98 |
+
@cache
|
99 |
+
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
|
100 |
+
mapping_df = pd.read_csv(mapping_file)
|
101 |
+
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
|
102 |
+
|
103 |
+
|
104 |
+
def plot_to_image(figure) -> np.ndarray:
|
105 |
+
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
|
106 |
+
returns it. The supplied figure is closed and inaccessible after this call."""
|
107 |
+
# Save the plot to a PNG in memory.
|
108 |
+
buf = io.BytesIO()
|
109 |
+
plt.savefig(buf, format="png")
|
110 |
+
# Closing the figure prevents it from being displayed directly inside
|
111 |
+
# the notebook.
|
112 |
+
plt.close(figure)
|
113 |
+
buf.seek(0)
|
114 |
+
return np.array(Image.open(buf))
|