segformer-sidewalk / model.py
chainyo's picture
create pl model
8e87ed1
import pytorch_lightning as pl
import torch
from datasets import load_metric
from torch import nn
from transformers import SegformerForSemanticSegmentation
from typing import Dict
class SidewalkSegmentationModel(pl.LightningModule):
def __init__(
self,
num_labels: int,
id2label: Dict[int, str],
model_flavor: int = 0,
learning_rate: float = 6e-5,
):
super().__init__()
self.id2label = id2label
self.label2id = {v: k for k, v in id2label.items()}
self.learning_rate = learning_rate
self.metrics = {
"train": load_metric("mean_iou"),
"val": load_metric("mean_iou"),
}
self.model = SegformerForSemanticSegmentation.from_pretrained(
f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id,
)
self.save_hyperparameters()
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def training_step(self, batch, batch_idx):
pixel_values = batch["pixel_values"]
labels = batch["labels"]
outputs = self(pixel_values=pixel_values, labels=labels)
loss, logits = outputs.loss, outputs.logits
self.add_batch_to_metric("train", logits, labels)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
pixel_values = batch["pixel_values"]
labels = batch["labels"]
outputs = self(pixel_values=pixel_values, labels=labels)
loss, logits = outputs.loss, outputs.logits
self.add_batch_to_metric("val", logits, labels)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return {"val_loss": loss}
def training_epoch_end(self, training_step_outputs):
"""
Log the training metrics.
"""
metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)
def validation_epoch_end(self, validation_step_outputs):
"""
Log the validation metrics.
"""
metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)
def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor):
"""
Add the current batch to the metric.
Parameters
----------
stage : str
Stage of the training. Either "train" or "val".
logits : torch.Tensor
Predicted logits.
labels : torch.Tensor
Ground truth labels.
"""
with torch.no_grad():
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.metrics[stage].add_batch(
predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy()
)
def configure_optimizers(self) -> torch.optim.AdamW:
"""
Configure the optimizer.
Returns
-------
torch.optim.AdamW
Optimizer for the model
"""
return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)