CIFAR10_ResNet18 / model.py
buffaX's picture
Upload model.py with huggingface_hub
a5ba72b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
class BasicBlock(nn.Module):
expansion = 1 # ResNet18/34 使用 expansion=1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
# Downsample for shape mismatch
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1,
stride=stride, bias=False
),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet18_CIFAR10(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# 第一层换成 CIFAR10 友好的 3x3 conv,去掉 maxpool
self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# ResNet stages
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) # 32x32 -> 16x16
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) # 16x16 -> 8x8
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) # 8x8 -> 4x4
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(512 * BasicBlock.expansion, num_classes)
)
def _make_layer(self, in_c, out_c, num_blocks, stride):
layers = []
layers.append(BasicBlock(in_c, out_c, stride))
for _ in range(1, num_blocks):
layers.append(BasicBlock(out_c, out_c, stride=1)) # 后续 block stride=1
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x))) # 注意这里有relu
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avg_pool(out) # [B, 512, 1, 1]
out = torch.flatten(out, 1) # [B, 512]
out = self.fc(out) # [B, num_classes]
return out
class CIFARCNN(L.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.save_hyperparameters()
self.example_input_array = torch.Tensor(64, 3, 32, 32)
self.net = ResNet18_CIFAR10(num_classes=10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_idx): # _代表batch_idx,这里不需要用到
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log("train_loss", loss, on_step=True, prog_bar=True) # 在每个step记录
self.log("train_acc", acc, on_step=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# log 专门给 validation 用:
self.log("val_loss", loss, prog_bar=True, sync_dist=True) # 把val_loss显示在lightning的progress bar上; sync_dist=True表示在分布式训练时同步各个设备上的指标
self.log("val_acc", acc, prog_bar=True, sync_dist=True)
return {"val_loss": loss, "val_acc": acc}
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", acc, prog_bar=True)
return {"test_loss": loss, "test_acc": acc}
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, _ = batch
return self(x)
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.hparams.lr,
momentum=0.9,
weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.trainer.max_epochs
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
if __name__ == "__main__":
# 简单测试前向传播
model = CIFARCNN()
x = torch.randn(4, 3, 32, 32).to(model.device)
logits = model(x)
print(logits.shape) # [4, 10]