Spaces:
Runtime error
Runtime error
import importlib | |
from models.utils import calculate_metrics, plot_to_image, get_dance_mapping | |
import numpy as np | |
from abc import ABC, abstractmethod | |
import pytorch_lightning as pl | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
from sklearn.metrics import ( | |
roc_auc_score, | |
confusion_matrix, | |
ConfusionMatrixDisplay, | |
) | |
class TrainingEnvironment(pl.LightningModule): | |
def __init__( | |
self, | |
model: nn.Module, | |
criterion: nn.Module, | |
config: dict, | |
learning_rate=1e-4, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.model = model | |
self.criterion = criterion | |
self.learning_rate = config["training_environment"].get( | |
"learning_rate", learning_rate | |
) | |
self.experiment_loggers = load_loggers( | |
config["training_environment"].get("loggers", {}) | |
) | |
self.config = config | |
self.has_multi_label_predictions = not ( | |
type(criterion).__name__ == "CrossEntropyLoss" | |
) | |
self.save_hyperparameters( | |
{ | |
"model": type(model).__name__, | |
"loss": type(criterion).__name__, | |
"config": config, | |
**kwargs, | |
} | |
) | |
def training_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
) -> torch.Tensor: | |
features, labels = batch | |
outputs = self.model(features) | |
if self.has_multi_label_predictions: | |
outputs = nn.functional.sigmoid(outputs) | |
loss = self.criterion(outputs, labels) | |
metrics = calculate_metrics( | |
outputs, | |
labels, | |
prefix="train/", | |
multi_label=self.has_multi_label_predictions, | |
) | |
self.log_dict(metrics, prog_bar=True) | |
return loss | |
def validation_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
): | |
x, y = batch | |
preds = self.model(x) | |
if self.has_multi_label_predictions: | |
preds = nn.functional.sigmoid(preds) | |
metrics = calculate_metrics( | |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions | |
) | |
metrics["val/loss"] = self.criterion(preds, y) | |
self.log_dict(metrics, prog_bar=True, sync_dist=True) | |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int): | |
x, y = batch | |
preds = self.model(x) | |
if self.has_multi_label_predictions: | |
preds = nn.functional.sigmoid(preds) | |
metrics = calculate_metrics( | |
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions | |
) | |
if not self.has_multi_label_predictions: | |
preds = nn.functional.softmax(preds, dim=1) | |
y = y.detach().cpu().numpy() | |
preds = preds.detach().cpu().numpy() | |
# ROC-auc score | |
try: | |
metrics["test/roc_auc_score"] = torch.tensor( | |
roc_auc_score(y, preds), dtype=torch.float32 | |
) | |
except ValueError: | |
# If there is only one class, roc_auc_score will throw an error | |
pass | |
pass | |
self.log_dict(metrics, prog_bar=True) | |
# Create confusion matrix | |
preds = preds.argmax(axis=1) | |
y = y.argmax(axis=1) | |
cm = confusion_matrix( | |
preds, y, normalize="all", labels=np.arange(len(self.config["dance_ids"])) | |
) | |
if hasattr(self, "test_cm"): | |
self.test_cm += cm | |
else: | |
self.test_cm = cm | |
def on_test_end(self): | |
dance_ids = sorted(self.config["dance_ids"]) | |
np.fill_diagonal(self.test_cm, 0) | |
cm = self.test_cm / self.test_cm.max() | |
cm_plot = ConfusionMatrixDisplay(cm, display_labels=dance_ids) | |
fig, ax = plt.subplots(figsize=(12, 12)) | |
cm_plot.plot(ax=ax) | |
image = plot_to_image(fig) | |
image = torch.tensor(image, dtype=torch.uint8) | |
image = image.permute(2, 0, 1) | |
self.logger.experiment.add_image("test/confusion_matrix", image, 0) | |
delattr(self, "test_cm") | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": scheduler, | |
"monitor": "val/loss", | |
} | |
class ExperimentLogger(ABC): | |
def step(self, experiment, data): | |
pass | |
class SpectrogramLogger(ExperimentLogger): | |
def __init__(self, frequency=100) -> None: | |
self.frequency = frequency | |
self.counter = 0 | |
def step(self, experiment, batch_index, x, label): | |
if self.counter == self.frequency: | |
self.counter = 0 | |
img_index = torch.randint(0, len(x), (1,)).item() | |
img = x[img_index][0] | |
img = (img - img.min()) / (img.max() - img.min()) | |
experiment.add_image( | |
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW" | |
) | |
self.counter += 1 | |
def load_loggers(logger_config: dict) -> list[ExperimentLogger]: | |
loggers = [] | |
for logger_path, kwargs in logger_config.items(): | |
module_name, class_name = logger_path.rsplit(".", 1) | |
module = importlib.import_module(module_name) | |
Logger = getattr(module, class_name) | |
loggers.append(Logger(**kwargs)) | |
return loggers | |