custom-cnn-cifar2 / README.md
sadhaklal's picture
added "Usage" section to README.md
8b8b88d verified
metadata
datasets:
  - cifar10
metrics:
  - accuracy
library_name: pytorch
pipeline_tag: image-classification

custom-cnn-cifar2

Custom convolutional neural network (CNN) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird').

This model pertains to Exercise 1 of Chapter 8 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Note: In the exercise, we tried out (5, 5) and (1, 3) convolution kernel sizes. However, these didn't outperform the baseline network with (3, 3) kernel size. Hence, this checkpoint sticks to the (3, 3) kernel size.

Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_8/exercise_1.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/custom-cnn-cifar2

Usage

!pip install -q datasets

from datasets import load_dataset

cifar10 = load_dataset("cifar10")
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]

example = cifar2_val[0]
img, label = example

import torch
from torchvision.transforms import v2

tfms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
])
img = tfms(img)
batch = img.unsqueeze(0)

import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin

class Net(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1, stride=1)
        self.fc1 = nn.Linear(8 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), kernel_size=2, stride=2) # Output shape: (batch_size, 16, 16, 16)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), kernel_size=2, stride=2) # Output shape: (batch_size, 8, 8, 8)
        out = out.view(-1, 8 * 8 * 8) # Output shape: (batch_size, 512)
        out = torch.tanh(self.fc1(out)) # Output shape: (batch_size, 32)
        out = self.fc2(out) # Output shape: (batch_size, 2)
        return out

model = Net.from_pretrained("sadhaklal/custom-cnn-cifar2")
model.eval()

with torch.no_grad():
    logits = model(batch)
    pred = logits[0].argmax().item()
    proba = torch.softmax(logits, dim=1)

print(f"Predicted class: {class_names[pred]}")
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")

Metric

Accuracy on cifar2_val: 0.8995