Spaces:
Sleeping
Sleeping
""" | |
ResNet in PyTorch. | |
For Pre-activation ResNet, see 'preact_resnet.py'. | |
Reference: | |
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |
Deep Residual Learning for Image Recognition. arXiv:1512.03385 | |
""" | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# imports | |
import os | |
import torch | |
from pytorch_lightning import LightningModule, Trainer | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader, random_split | |
from torchmetrics import Accuracy | |
from torchvision import transforms | |
from torchvision.datasets import CIFAR10 | |
# from pytorch_lightning.callbacks import ModelSummary | |
# from lightning.pytorch.callbacks import ModelCheckpoint | |
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary | |
import torchvision.transforms as transforms | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, in_planes, planes, stride=1): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_planes != self.expansion*planes: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(self.expansion*planes) | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.shortcut(x) | |
out = F.relu(out) | |
return out | |
class CIFAR10Model(LightningModule): | |
def __init__(self, block, num_blocks, num_classes=10, data_dir=PATH_DATASETS, learning_rate=0.01): | |
super(CIFAR10Model, self).__init__() | |
self.in_planes = 64 | |
# Define transformations using Albumentations | |
normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2434, 0.2615)) | |
random_crop = transforms.RandomCrop((32, 32)) | |
horizontal_flip = transforms.RandomHorizontalFlip() | |
to_tensor = transforms.ToTensor() | |
self.transform = transforms.Compose([ | |
random_crop, | |
horizontal_flip, | |
to_tensor, | |
normalize | |
]) | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | |
self.linear = nn.Linear(512 * block.expansion, num_classes) | |
self.accuracy = Accuracy(task="MULTICLASS", num_classes=10) | |
self.data_dir = data_dir | |
self.learning_rate = learning_rate | |
def _make_layer(self, block, planes, num_blocks, stride): | |
strides = [stride] + [1] * (num_blocks - 1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_planes, planes, stride)) | |
self.in_planes = planes * block.expansion | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
out = F.avg_pool2d(out, 4) | |
out = out.view(out.size(0), -1) | |
out = self.linear(out) | |
return out | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.cross_entropy(logits, y) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.cross_entropy(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
self.accuracy(preds, y) | |
self.log("val_loss", loss, prog_bar=True) | |
self.log("val_acc", self.accuracy, prog_bar=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
return self.validation_step(batch, batch_idx) | |
def configure_optimizers(self): | |
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) | |
return optimizer | |
def prepare_data(self): | |
CIFAR10(self.data_dir, train=True, download=True) | |
CIFAR10(self.data_dir, train=False, download=True) | |
def setup(self, stage=None): | |
if stage == "fit" or stage is None: | |
cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform) | |
train_size = int(len(cifar10_full) * 0.9) | |
val_size = len(cifar10_full) - train_size | |
self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [train_size, val_size]) | |
if stage == "test" or stage is None: | |
self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform) | |
def train_dataloader(self): | |
return DataLoader(self.cifar10_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) | |
def val_dataloader(self): | |
return DataLoader(self.cifar10_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count(), persistent_workers=True) | |
def test_dataloader(self): | |
return DataLoader(self.cifar10_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) | |
def ResNet18(): | |
return CIFAR10Model(BasicBlock, [2, 2, 2, 2]) | |
def ResNet34(): | |
return ResNet(BasicBlock, [3, 4, 6, 3]) | |