import torch import torch.nn as nn from pytorch_lightning import LightningModule from torch.optim.lr_scheduler import OneCycleLR from torchmetrics import Accuracy import torch.nn.functional as F BATCH_SIZE = 256 class ResBlocks(LightningModule): def __init__(self, inchannels, outchannels, stride): super(ResBlocks, self).__init__() self.conv1 = self.make_conv(inchannels, outchannels, stride=stride) self.conv2 = self.make_conv(outchannels, outchannels) if stride != 1 or inchannels != outchannels: self.shortcut = nn.Sequential( nn.Conv2d(inchannels, outchannels, kernel_size=1, stride=stride) ) def make_conv(self, inchannels, outchannels, kernel=3, padding=1, stride=1): layers = [] layers.append(nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=kernel, padding=padding, stride=stride)) layers.append(nn.BatchNorm2d(outchannels)) layers.append(nn.ReLU()) return nn.Sequential(*layers) def forward(self, x): shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x out = self.conv1(x) out = self.conv2(out) return out + shortcut class ResNet18(LightningModule): def __init__(self, lr=0.05): super(ResNet18, self).__init__() self.save_hyperparameters() self.avgpool = nn.AvgPool2d(kernel_size=4) self.fc = self.make_FC() self.accuracy = Accuracy(task="multiclass", num_classes=10) self.in_layers = [64, 64, 128, 256] self.out_layers = [64, 128, 256, 512] self.strides = [1, 2, 2, 2] self.num = [2, 2, 2, 2] self.convin = nn.Sequential( nn.Conv2d(3, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU() ) self.res_layers = nn.ModuleList([self.make_res(self.in_layers[i], self.out_layers[i], self.num[i], self.strides[i]) for i in range(len(self.in_layers))]) def make_res(self, inchannels, outchannels, num, stride): strides = [stride] + [1] * (num-1) layers = [] for stride in strides: layers.append(ResBlocks(inchannels=inchannels, outchannels=outchannels, stride=stride)) inchannels = outchannels return nn.Sequential(*layers) def make_FC(self): layers = [] layers.append(nn.Linear(512, 256)) layers.append(nn.GELU()) layers.append(nn.Linear(256, 10)) layers.append(nn.LogSoftmax(dim=1)) return nn.Sequential(*layers) def forward(self, x): x = self.convin(x) for layer in self.res_layers: x = layer(x) x = self.avgpool(x) x = x.view(-1, 512) x = self.fc(x) return x def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) self.log("train_loss", loss) return loss def evaluate(self, batch, stage=None): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = self.accuracy(preds, y) if stage: self.log(f"{stage}_loss", loss, prog_bar=True) self.log(f"{stage}_acc", acc, prog_bar=True) def validation_step(self, batch, batch_idx): self.evaluate(batch, "val") def test_step(self, batch, batch_idx): self.evaluate(batch, "test") def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.hparams.lr, weight_decay=5e-4, ) steps_per_epoch = 45000 // BATCH_SIZE scheduler_dict = { "scheduler": OneCycleLR( optimizer, max_lr=1.26*1e-2, steps_per_epoch=steps_per_epoch, epochs=20, pct_start=0.2, div_factor=10, three_phase=False, final_div_factor=10, anneal_strategy='linear' ), "interval": "step", } return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}