resnet18 / model.py
anilbhatt1's picture
initial commit
b447fcc
raw
history blame contribute delete
No virus
2.46 kB
import torch
import lightning as L
import torchmetrics
class LightningModel(L.LightningModule):
def __init__(self, model, learning_rate, cosine_t_max, mode):
super().__init__()
self.learning_rate = learning_rate
self.cosine_t_max = cosine_t_max
self.model = model
self.example_input_array = torch.Tensor(1, 3, 32, 32)
self.mode = mode
self.save_hyperparameters(ignore=["model"])
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
return self.model(x)
def _shared_step(self, batch):
features, true_labels = batch
logits = self(features)
loss = F.cross_entropy(logits, true_labels)
predicted_labels = torch.argmax(logits, dim=1)
return loss, true_labels, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss)
self.train_acc(predicted_labels, true_labels)
self.log(
"train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False
)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, prog_bar=True)
self.val_acc(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, prog_bar=True)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc(predicted_labels, true_labels)
self.log("test_acc", self.test_acc)
def configure_optimizers(self):
opt = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
if self.mode == 'lrfind':
return opt
else:
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cosine_t_max) # New!
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": sch,
"monitor": "train_loss",
"interval": "step", # step means "batch" here, default: epoch
"frequency": 1, # default
},
}