i4ata commited on
Commit
e3cba03
·
verified ·
1 Parent(s): c113524

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -0
model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from lightning.pytorch.utilities.model_summary import ModelSummary
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+
8
+ import torchmetrics
9
+ from torchvision import transforms
10
+
11
+ from typing import Optional
12
+
13
+ class ClassifierModel(L.LightningModule):
14
+
15
+ def __init__(self, model: nn.Module, image_size: int = 500, learning_rate: float = 1e-3, num_classes: int = 3,
16
+ train_transform: Optional[transforms.Compose] = None, val_transform: Optional[transforms.Compose] = None) -> None:
17
+ super().__init__()
18
+ self.model = model
19
+ self.learning_rate = learning_rate
20
+ self.example_input_array = torch.Tensor(5, 3, image_size, image_size)
21
+ self.f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes)
22
+ self.train_transform = train_transform
23
+ self.val_transform = val_transform
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.model(x)
27
+
28
+ def print_summary(self) -> None:
29
+ print(ModelSummary(self, max_depth=-1))
30
+
31
+ def configure_optimizers(self) -> torch.optim.Optimizer:
32
+ return torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate)
33
+
34
+ def training_step(self, batch: tuple, batch_idx: int) -> float:
35
+ X, y = batch
36
+ y_pred = self(X)
37
+ loss = F.cross_entropy(y_pred, y)
38
+ self.log_dict({'Train loss': loss, f'Train F1 score': self.f1_score(y_pred, y)},
39
+ on_step=False, on_epoch=True)
40
+ return loss
41
+
42
+ def validation_step(self, batch: tuple, batch_idx: int) -> float:
43
+ X, y = batch
44
+ y_pred = self(X)
45
+ loss = F.cross_entropy(y_pred, y)
46
+ self.log_dict({'Validation loss': loss, f'Validation F1 score': self.f1_score(y_pred, y)},
47
+ on_step=False, on_epoch=True)
48
+ return loss
49
+