mlp-cifar2
Multi-layer perceptron (MLP) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird').
RandomCrop(24)
was applied on the training set images, and Resize(24)
was applied on the validation set images.
This model pertains to Exercise 1 of Chapter 7 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.
Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exercise_1.ipynb
Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2
Usage
!pip install -q datasets
from datasets import load_dataset
cifar10 = load_dataset("cifar10")
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]
example = cifar2_val[0]
img, label = example
import torch
from torchvision.transforms import v2
val_tfms = v2.Compose([
v2.Resize(24),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
])
img = val_tfms(img)
batch = img.reshape(-1).unsqueeze(0) # Flatten.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin):
"""Multi-layer perceptron (MLP) for classifying 'airplane' vs. 'bird' in the CIFAR-2 dataset (a subset of CIFAR-10)."""
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1728, 1024), # Hidden layer.
nn.Tanh(),
nn.Linear(1024, 512), # Hidden layer.
nn.Tanh(),
nn.Linear(512, 128), # Hidden layer.
nn.Tanh(),
nn.Linear(128, 2) # Output layer.
)
def forward(self, x):
return self.mlp(x)
model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2")
model.eval()
import torch.nn.functional as F
with torch.no_grad():
logits = model(batch)
pred = logits[0].argmax().item()
proba = F.softmax(logits, dim=1)
print(f"Predicted class: {class_names[pred]}")
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")
Metric
Accuracy on cifar2_val
: 0.8090
Inference API (serverless) does not yet support pytorch models for this pipeline type.