import pytorch_lightning as pl import torch from datasets import load_metric from torch import nn from transformers import SegformerForSemanticSegmentation from typing import Dict class SidewalkSegmentationModel(pl.LightningModule): def __init__( self, num_labels: int, id2label: Dict[int, str], model_flavor: int = 0, learning_rate: float = 6e-5, ): super().__init__() self.id2label = id2label self.label2id = {v: k for k, v in id2label.items()} self.learning_rate = learning_rate self.metrics = { "train": load_metric("mean_iou"), "val": load_metric("mean_iou"), } self.model = SegformerForSemanticSegmentation.from_pretrained( f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id, ) self.save_hyperparameters() def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step(self, batch, batch_idx): pixel_values = batch["pixel_values"] labels = batch["labels"] outputs = self(pixel_values=pixel_values, labels=labels) loss, logits = outputs.loss, outputs.logits self.add_batch_to_metric("train", logits, labels) self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) return {"loss": loss} def validation_step(self, batch, batch_idx): pixel_values = batch["pixel_values"] labels = batch["labels"] outputs = self(pixel_values=pixel_values, labels=labels) loss, logits = outputs.loss, outputs.logits self.add_batch_to_metric("val", logits, labels) self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) return {"val_loss": loss} def training_epoch_end(self, training_step_outputs): """ Log the training metrics. """ metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) def validation_epoch_end(self, validation_step_outputs): """ Log the validation metrics. """ metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor): """ Add the current batch to the metric. Parameters ---------- stage : str Stage of the training. Either "train" or "val". logits : torch.Tensor Predicted logits. labels : torch.Tensor Ground truth labels. """ with torch.no_grad(): upsampled_logits = nn.functional.interpolate( logits, size=labels.shape[-2:], mode="bilinear", align_corners=False ) predicted = upsampled_logits.argmax(dim=1) self.metrics[stage].add_batch( predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy() ) def configure_optimizers(self) -> torch.optim.AdamW: """ Configure the optimizer. Returns ------- torch.optim.AdamW Optimizer for the model """ return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)