from typing import List import torch import torch.nn.functional as F import torchmetrics import pytorch_lightning as pl import matplotlib.pyplot as plt import wandb from tqdm import tqdm from backbones import get_backbone from utils.confusion_viz import ConfusionVisualizer from utils.plots import get_confusion_matrix_figure from utils.grad_cam import GradCAMBuilder from loss import ordinal_regression_loss, get_breadstick_probabilities, focal_loss from utils.val_loop_hook import ValidationLoopHook class VerseFxClassifier(pl.LightningModule): def __init__(self, hparams): super().__init__() self.save_hyperparameters(dict(hparams)) self.backbone = get_backbone(self.hparams) metric_args = dict(average='macro', num_classes=self.hparams.num_classes) metrics = torchmetrics.MetricCollection([ torchmetrics.Accuracy(**metric_args), torchmetrics.F1(**metric_args), torchmetrics.Precision(**metric_args), torchmetrics.Recall(**metric_args) ]) self.train_metrics = metrics.clone(prefix='train/') self.val_metrics = metrics.clone(prefix='val/') self.class_weights = None image_shape = (1 + (self.hparams.mask =='channel') + self.hparams.input_dim * self.hparams.coordinates,) + (self.hparams.input_size,) * self.hparams.input_dim grad_cam_builder = GradCAMBuilder(image_shape, target_category=0 if self.hparams.task == 'detection' else None) confusion_visualizer = ConfusionVisualizer(image_shape, 2 if self.hparams.task == 'detection' else self.hparams.num_classes) self.validation_hooks: List[ValidationLoopHook] = [grad_cam_builder, confusion_visualizer] def get_class_weights(self, dm: pl.LightningDataModule): targets = [] for batch in tqdm(dm.train_dataloader(), desc="Determining class weights"): targets.append(self.batch_to_targets(batch)) targets = torch.cat(targets) classes, counts = torch.unique(targets, return_counts=True) return (1 / counts) * torch.sum(counts) / classes.shape[0] def on_pretrain_routine_start(self): # FIXME This is slightly inefficient if multiple GPUs are used as this routine # is called once per device. There might be a better hook available. super().on_pretrain_routine_start() if self.hparams.weighted_loss: self.class_weights = self.get_class_weights(self.trainer.datamodule).to(self.device) if self.hparams.loss == 'binary_cross_entropy': # Only keep the positive class weight self.class_weights = self.class_weights[-1] def forward(self, x): return self.backbone(x) def loss(self, logits, targets): if self.hparams.loss == 'cross_entropy': return F.cross_entropy(logits, targets, weight=self.class_weights) elif self.hparams.loss == 'binary_cross_entropy': return F.binary_cross_entropy_with_logits(logits.squeeze(-1), targets.float(), pos_weight=self.class_weights) elif self.hparams.loss == 'ordinal_regression': return ordinal_regression_loss(logits, targets, class_weights=self.class_weights) elif self.hparams.loss == 'focal': return focal_loss(logits.squeeze(-1), targets.float()) else: raise ValueError def logits_to_predictions(self, logits): if self.hparams.loss == 'binary_cross_entropy' or (self.hparams.loss == 'focal' and self.hparams.task == 'detection'): probs = torch.sigmoid(logits.squeeze(-1)) preds = probs.gt(0.5).long() elif self.hparams.loss == 'cross_entropy' or self.hparams.loss == 'focal': probs = torch.softmax(logits) preds = probs.argmax(-1) elif self.hparams.loss == 'ordinal_regression': probs = get_breadstick_probabilities(logits) preds = probs.argmax(-1) else: raise ValueError return probs, preds def batch_to_targets(self, batch): if self.hparams.task == 'detection': return batch['fx'].long() elif self.hparams.task == 'grading': return batch['fx_grading'].long() elif self.hparams.task == 'simple_grading': targets = batch['fx_grading'].long() targets[torch.bitwise_or(targets==2, targets==3)] = 1 targets[targets>3] -= 2 return targets def training_step(self, batch, batch_idx): logits = self(batch['image']) targets = self.batch_to_targets(batch) loss = self.loss(logits, targets) probs, preds = self.logits_to_predictions(logits) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.hparams.batch_size) return {'loss': loss, 'probs': probs.detach(), 'preds': preds.detach(), 'targets': targets.detach()} def training_epoch_end(self, outputs): outputs = {k: torch.cat([d[k] for d in outputs]) for k in outputs[0] if k != 'loss'} metrics = self.train_metrics(outputs['probs'], outputs['targets']) self.log_dict(metrics) targets_flat = outputs['targets'].cpu().numpy() preds_flat = outputs['preds'].cpu().numpy() # sklearn confusion matrix self.logger.experiment.log({ "train/confusion_matrix": get_confusion_matrix_figure( targets_flat, preds_flat, title="Training Confusion Matrix" ) }) plt.close('all') def validation_step(self, batch, batch_idx): logits = self(batch['image']) targets = self.batch_to_targets(batch) loss = self.loss(logits, targets) probs, preds = self.logits_to_predictions(logits) self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.hparams.batch_size) for val_hook in self.validation_hooks: val_hook.process(batch, targets, logits, preds) metrics = self.val_metrics(probs, targets) self.log_dict(metrics) return {'loss': loss, 'probs': probs.detach(), 'preds': preds.detach(), 'targets': targets.detach()} def validation_epoch_end(self, outputs): outputs = {k: torch.cat([d[k] for d in outputs]) for k in outputs[0] if k != 'loss'} metrics = self.val_metrics(outputs['probs'], outputs['targets']) self.log_dict(metrics) targets_flat = outputs['targets'].cpu().numpy() preds_flat = outputs['preds'].cpu().numpy() # sklearn confusion matrix self.logger.experiment.log({ "val/confusion_matrix": get_confusion_matrix_figure( targets_flat, preds_flat, title="Validation Confusion Matrix", ) }) plt.close('all') # wandb confusion matrix # print(targets_flat, targets_flat.squeeze(-1).shape, type(preds_flat[0])) self.logger.experiment.log({ "full_fx_grading": wandb.plot.confusion_matrix( # probs=outputs['y_pred'], preds=list(preds_flat), y_true=list(targets_flat), class_names=None, ), "epoch": self.current_epoch }) def on_train_epoch_end(self): # Trigger all validation hooks and reset them afterwards for val_hook in self.validation_hooks: val_hook.trigger(self) val_hook.reset() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return optimizer