|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import lightning as L |
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module): |
|
|
expansion = 1 |
|
|
|
|
|
def __init__(self, in_channels, out_channels, stride=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=3, |
|
|
stride=stride, padding=1, bias=False |
|
|
) |
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
|
|
|
|
self.conv2 = nn.Conv2d( |
|
|
out_channels, out_channels, kernel_size=3, |
|
|
stride=1, padding=1, bias=False |
|
|
) |
|
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=1, |
|
|
stride=stride, bias=False |
|
|
), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += self.shortcut(x) |
|
|
out = F.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class ResNet18_CIFAR10(nn.Module): |
|
|
def __init__(self, num_classes=10): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
|
|
|
|
|
|
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1) |
|
|
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) |
|
|
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) |
|
|
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) |
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.fc = nn.Sequential( |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(512 * BasicBlock.expansion, num_classes) |
|
|
) |
|
|
|
|
|
def _make_layer(self, in_c, out_c, num_blocks, stride): |
|
|
layers = [] |
|
|
layers.append(BasicBlock(in_c, out_c, stride)) |
|
|
for _ in range(1, num_blocks): |
|
|
layers.append(BasicBlock(out_c, out_c, stride=1)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
|
|
|
out = self.layer1(out) |
|
|
out = self.layer2(out) |
|
|
out = self.layer3(out) |
|
|
out = self.layer4(out) |
|
|
|
|
|
out = self.avg_pool(out) |
|
|
out = torch.flatten(out, 1) |
|
|
out = self.fc(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class CIFARCNN(L.LightningModule): |
|
|
def __init__(self, lr=1e-3): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.example_input_array = torch.Tensor(64, 3, 32, 32) |
|
|
|
|
|
self.net = ResNet18_CIFAR10(num_classes=10) |
|
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
x, y = batch |
|
|
logits = self(x) |
|
|
loss = self.loss_fn(logits, y) |
|
|
|
|
|
preds = torch.argmax(logits, dim=1) |
|
|
acc = (preds == y).float().mean() |
|
|
|
|
|
self.log("train_loss", loss, on_step=True, prog_bar=True) |
|
|
self.log("train_acc", acc, on_step=True, prog_bar=True) |
|
|
return loss |
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
x, y = batch |
|
|
logits = self(x) |
|
|
loss = self.loss_fn(logits, y) |
|
|
|
|
|
preds = torch.argmax(logits, dim=1) |
|
|
acc = (preds == y).float().mean() |
|
|
|
|
|
|
|
|
self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
|
|
self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
|
|
|
|
|
return {"val_loss": loss, "val_acc": acc} |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
x, y = batch |
|
|
logits = self(x) |
|
|
loss = self.loss_fn(logits, y) |
|
|
|
|
|
preds = torch.argmax(logits, dim=1) |
|
|
acc = (preds == y).float().mean() |
|
|
|
|
|
self.log("test_loss", loss, prog_bar=True) |
|
|
self.log("test_acc", acc, prog_bar=True) |
|
|
|
|
|
return {"test_loss": loss, "test_acc": acc} |
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
x, _ = batch |
|
|
return self(x) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.SGD( |
|
|
self.parameters(), |
|
|
lr=self.hparams.lr, |
|
|
momentum=0.9, |
|
|
weight_decay=5e-4 |
|
|
) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer, T_max=self.trainer.max_epochs |
|
|
) |
|
|
return {"optimizer": optimizer, "lr_scheduler": scheduler} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model = CIFARCNN() |
|
|
x = torch.randn(4, 3, 32, 32).to(model.device) |
|
|
logits = model(x) |
|
|
print(logits.shape) |