Assignment12 / models /CUSTOMRESNET.py
SahithiR's picture
Upload 21 files
0253766
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy
from utils.visualize import find_lr
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential (
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
) # Number of Parameters = 3*3*3*64=1728
# Layer 1
self.conv11 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,bias=False),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
) # Number of Parameters = 3*3*64*128 = 73728
self.conv12 = nn.Sequential(
nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
# Layer 2
self.conv2 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1,bias=False),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
# Layer 3
self.conv31 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1,bias=False),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.conv32 = nn.Sequential(
nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.maxpool = nn.MaxPool2d(kernel_size=4,stride=2)
# Fully connected
self.fc = nn.Linear(512, 10, bias=True)
def forward(self, x):
#x = x.unsqueeze(0)
x = self.conv1(x)
x = self.conv11(x)
R1=x
x = self.conv12(x)
x=x+R1
x = self.conv2(x)
x = self.conv31(x)
R2=x
x = self.conv32(x)
x=x+R2
x = self.maxpool(x)
x = x.squeeze(dim=2)
x = x.squeeze(dim=2)
x = self.fc(x)
x = x.view(-1, 10)
return x
class Model(LightningModule):
def __init__(self, dataset,max_epochs=24):
super(Model, self).__init__()
self.dataset = dataset
self.network= MyModel()
self.criterion = nn.CrossEntropyLoss()
self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
self.max_epochs = max_epochs
def forward(self, x):
return self.network(x)
def common_step(self, batch, mode):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
acc_metric = getattr(self, f'{mode}_accuracy')
acc_metric(logits, y)
return loss
def training_step(self, batch, batch_idx):
loss = self.common_step(batch, 'train')
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
self.log("train_acc", self.train_accuracy, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self.common_step(batch, 'val')
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
self.log("val_acc", self.val_accuracy, on_epoch=True, prog_bar=True, logger=True)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=0):
if isinstance(batch, list):
x, _ = batch
else:
x = batch
return self.forward(x)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
best_lr = find_lr(self, self.train_dataloader(), optimizer, self.criterion)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=best_lr,
steps_per_epoch=len(self.dataset.train_loader),
epochs=self.max_epochs,
pct_start=5/self.max_epochs,
div_factor=100,
three_phase=False,
final_div_factor=100,
anneal_strategy='linear'
)
return {
'optimizer': optimizer,
'lr_scheduler': {
"scheduler": scheduler,
"interval": "step",
}
}
def prepare_data(self):
self.dataset.download()
def train_dataloader(self):
return self.dataset.train_loader
def val_dataloader(self):
return self.dataset.test_loader
def predict_dataloader(self):
return self.val_dataloader()