import torch import torch.nn as nn from torch.nn import functional as F from torch import optim from pytorch_lightning import LightningModule from torchmetrics import Accuracy from utils.visualize import find_lr class MyModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Sequential ( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) # Number of Parameters = 3*3*3*64=1728 # Layer 1 self.conv11 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,bias=False), nn.MaxPool2d(kernel_size=2,stride=2), nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) # Number of Parameters = 3*3*64*128 = 73728 self.conv12 = nn.Sequential( nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728 nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728 nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) # Layer 2 self.conv2 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1,bias=False), nn.MaxPool2d(kernel_size=2,stride=2), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) # Layer 3 self.conv31 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1,bias=False), nn.MaxPool2d(kernel_size=2,stride=2), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.conv32 = nn.Sequential( nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.maxpool = nn.MaxPool2d(kernel_size=4,stride=2) # Fully connected self.fc = nn.Linear(512, 10, bias=True) def forward(self, x): #x = x.unsqueeze(0) x = self.conv1(x) x = self.conv11(x) R1=x x = self.conv12(x) x=x+R1 x = self.conv2(x) x = self.conv31(x) R2=x x = self.conv32(x) x=x+R2 x = self.maxpool(x) x = x.squeeze(dim=2) x = x.squeeze(dim=2) x = self.fc(x) x = x.view(-1, 10) return x class Model(LightningModule): def __init__(self, dataset,max_epochs=24): super(Model, self).__init__() self.dataset = dataset self.network= MyModel() self.criterion = nn.CrossEntropyLoss() self.train_accuracy = Accuracy(task='multiclass', num_classes=10) self.val_accuracy = Accuracy(task='multiclass', num_classes=10) self.max_epochs = max_epochs def forward(self, x): return self.network(x) def common_step(self, batch, mode): x, y = batch logits = self.forward(x) loss = self.criterion(logits, y) acc_metric = getattr(self, f'{mode}_accuracy') acc_metric(logits, y) return loss def training_step(self, batch, batch_idx): loss = self.common_step(batch, 'train') self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True) self.log("train_acc", self.train_accuracy, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): loss = self.common_step(batch, 'val') self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) self.log("val_acc", self.val_accuracy, on_epoch=True, prog_bar=True, logger=True) return loss def predict_step(self, batch, batch_idx, dataloader_idx=0): if isinstance(batch, list): x, _ = batch else: x = batch return self.forward(x) def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2) best_lr = find_lr(self, self.train_dataloader(), optimizer, self.criterion) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=best_lr, steps_per_epoch=len(self.dataset.train_loader), epochs=self.max_epochs, pct_start=5/self.max_epochs, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear' ) return { 'optimizer': optimizer, 'lr_scheduler': { "scheduler": scheduler, "interval": "step", } } def prepare_data(self): self.dataset.download() def train_dataloader(self): return self.dataset.train_loader def val_dataloader(self): return self.dataset.test_loader def predict_dataloader(self): return self.val_dataloader()