chainyo commited on
Commit
8e87ed1
1 Parent(s): 0719c14

create pl model

Browse files
Files changed (1) hide show
  1. model.py +111 -0
model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+ from datasets import load_metric
5
+ from torch import nn
6
+ from transformers import SegformerForSemanticSegmentation
7
+ from typing import Dict
8
+
9
+
10
+ class SidewalkSegmentationModel(pl.LightningModule):
11
+ def __init__(
12
+ self,
13
+ num_labels: int,
14
+ id2label: Dict[int, str],
15
+ model_flavor: int = 0,
16
+ learning_rate: float = 6e-5,
17
+ ):
18
+ super().__init__()
19
+ self.id2label = id2label
20
+ self.label2id = {v: k for k, v in id2label.items()}
21
+ self.learning_rate = learning_rate
22
+ self.metrics = {
23
+ "train": load_metric("mean_iou"),
24
+ "val": load_metric("mean_iou"),
25
+ }
26
+
27
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
28
+ f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id,
29
+ )
30
+ self.save_hyperparameters()
31
+
32
+
33
+ def forward(self, *args, **kwargs):
34
+ return self.model(*args, **kwargs)
35
+
36
+
37
+ def training_step(self, batch, batch_idx):
38
+ pixel_values = batch["pixel_values"]
39
+ labels = batch["labels"]
40
+
41
+ outputs = self(pixel_values=pixel_values, labels=labels)
42
+ loss, logits = outputs.loss, outputs.logits
43
+
44
+ self.add_batch_to_metric("train", logits, labels)
45
+ self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
46
+ return {"loss": loss}
47
+
48
+
49
+ def validation_step(self, batch, batch_idx):
50
+ pixel_values = batch["pixel_values"]
51
+ labels = batch["labels"]
52
+
53
+ outputs = self(pixel_values=pixel_values, labels=labels)
54
+ loss, logits = outputs.loss, outputs.logits
55
+
56
+ self.add_batch_to_metric("val", logits, labels)
57
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
58
+ return {"val_loss": loss}
59
+
60
+
61
+ def training_epoch_end(self, training_step_outputs):
62
+ """
63
+ Log the training metrics.
64
+ """
65
+ metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
66
+ self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
67
+ self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)
68
+
69
+
70
+ def validation_epoch_end(self, validation_step_outputs):
71
+ """
72
+ Log the validation metrics.
73
+ """
74
+ metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False)
75
+ self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True)
76
+ self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True)
77
+
78
+
79
+ def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor):
80
+ """
81
+ Add the current batch to the metric.
82
+
83
+ Parameters
84
+ ----------
85
+ stage : str
86
+ Stage of the training. Either "train" or "val".
87
+ logits : torch.Tensor
88
+ Predicted logits.
89
+ labels : torch.Tensor
90
+ Ground truth labels.
91
+ """
92
+ with torch.no_grad():
93
+ upsampled_logits = nn.functional.interpolate(
94
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
95
+ )
96
+ predicted = upsampled_logits.argmax(dim=1)
97
+ self.metrics[stage].add_batch(
98
+ predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy()
99
+ )
100
+
101
+
102
+ def configure_optimizers(self) -> torch.optim.AdamW:
103
+ """
104
+ Configure the optimizer.
105
+
106
+ Returns
107
+ -------
108
+ torch.optim.AdamW
109
+ Optimizer for the model
110
+ """
111
+ return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)