File size: 3,847 Bytes
8e87ed1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pytorch_lightning as pl
import torch

from datasets import load_metric
from torch import nn
from transformers import SegformerForSemanticSegmentation
from typing import Dict


class SidewalkSegmentationModel(pl.LightningModule):
    def __init__(
        self, 
        num_labels: int, 
        id2label: Dict[int, str],
        model_flavor: int = 0,
        learning_rate: float = 6e-5,
    ):
        super().__init__()
        self.id2label = id2label
        self.label2id = {v: k for k, v in id2label.items()}
        self.learning_rate = learning_rate
        self.metrics = {
            "train": load_metric("mean_iou"),
            "val": load_metric("mean_iou"),
        }

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id,
        )
        self.save_hyperparameters()

    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    
    def training_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        self.add_batch_to_metric("train", logits, labels)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return {"loss": loss}

    
    def validation_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        self.add_batch_to_metric("val", logits, labels)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return {"val_loss": loss}

    
    def training_epoch_end(self, training_step_outputs):
        """
        Log the training metrics.
        """
        metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
        self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)

    
    def validation_epoch_end(self, validation_step_outputs):
        """
        Log the validation metrics.
        """
        metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
        self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)

    
    def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor):
        """
        Add the current batch to the metric.

        Parameters
        ----------
        stage : str
            Stage of the training. Either "train" or "val".
        logits : torch.Tensor
            Predicted logits.
        labels : torch.Tensor
            Ground truth labels.
        """
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            predicted = upsampled_logits.argmax(dim=1)
            self.metrics[stage].add_batch(
                predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy()
            )

    
    def configure_optimizers(self) -> torch.optim.AdamW:
        """
        Configure the optimizer.

        Returns
        -------
        torch.optim.AdamW
            Optimizer for the model
        """
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)