Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch import optim | |
from pytorch_lightning import LightningModule | |
from torchmetrics import Accuracy | |
from utils.visualize import find_lr | |
class MyModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Sequential ( | |
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True) | |
) # Number of Parameters = 3*3*3*64=1728 | |
# Layer 1 | |
self.conv11 = nn.Sequential( | |
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.MaxPool2d(kernel_size=2,stride=2), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True) | |
) # Number of Parameters = 3*3*64*128 = 73728 | |
self.conv12 = nn.Sequential( | |
nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728 | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128,128, kernel_size=3, stride=1, padding=1,bias=False),# Number of Parameters = 3*3*64*128 = 73728 | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True) | |
) | |
# Layer 2 | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.MaxPool2d(kernel_size=2,stride=2), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True) | |
) | |
# Layer 3 | |
self.conv31 = nn.Sequential( | |
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.MaxPool2d(kernel_size=2,stride=2), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True) | |
) | |
self.conv32 = nn.Sequential( | |
nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1,bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True) | |
) | |
self.maxpool = nn.MaxPool2d(kernel_size=4,stride=2) | |
# Fully connected | |
self.fc = nn.Linear(512, 10, bias=True) | |
def forward(self, x): | |
#x = x.unsqueeze(0) | |
x = self.conv1(x) | |
x = self.conv11(x) | |
R1=x | |
x = self.conv12(x) | |
x=x+R1 | |
x = self.conv2(x) | |
x = self.conv31(x) | |
R2=x | |
x = self.conv32(x) | |
x=x+R2 | |
x = self.maxpool(x) | |
x = x.squeeze(dim=2) | |
x = x.squeeze(dim=2) | |
x = self.fc(x) | |
x = x.view(-1, 10) | |
return x | |
class Model(LightningModule): | |
def __init__(self, dataset,max_epochs=24): | |
super(Model, self).__init__() | |
self.dataset = dataset | |
self.network= MyModel() | |
self.criterion = nn.CrossEntropyLoss() | |
self.train_accuracy = Accuracy(task='multiclass', num_classes=10) | |
self.val_accuracy = Accuracy(task='multiclass', num_classes=10) | |
self.max_epochs = max_epochs | |
def forward(self, x): | |
return self.network(x) | |
def common_step(self, batch, mode): | |
x, y = batch | |
logits = self.forward(x) | |
loss = self.criterion(logits, y) | |
acc_metric = getattr(self, f'{mode}_accuracy') | |
acc_metric(logits, y) | |
return loss | |
def training_step(self, batch, batch_idx): | |
loss = self.common_step(batch, 'train') | |
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True) | |
self.log("train_acc", self.train_accuracy, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss = self.common_step(batch, 'val') | |
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) | |
self.log("val_acc", self.val_accuracy, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
if isinstance(batch, list): | |
x, _ = batch | |
else: | |
x = batch | |
return self.forward(x) | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2) | |
best_lr = find_lr(self, self.train_dataloader(), optimizer, self.criterion) | |
scheduler = optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
max_lr=best_lr, | |
steps_per_epoch=len(self.dataset.train_loader), | |
epochs=self.max_epochs, | |
pct_start=5/self.max_epochs, | |
div_factor=100, | |
three_phase=False, | |
final_div_factor=100, | |
anneal_strategy='linear' | |
) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': { | |
"scheduler": scheduler, | |
"interval": "step", | |
} | |
} | |
def prepare_data(self): | |
self.dataset.download() | |
def train_dataloader(self): | |
return self.dataset.train_loader | |
def val_dataloader(self): | |
return self.dataset.test_loader | |
def predict_dataloader(self): | |
return self.val_dataloader() | |