Edit model card


Multi-layer perceptron (MLP) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird').

RandomCrop(24) was applied on the training set images, and Resize(24) was applied on the validation set images.

This model pertains to Exercise 1 of Chapter 7 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exercise_1.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2


!pip install -q datasets

from datasets import load_dataset

cifar10 = load_dataset("cifar10")
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]

example = cifar2_val[0]
img, label = example

import torch
from torchvision.transforms import v2

val_tfms = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
img = val_tfms(img)
batch = img.reshape(-1).unsqueeze(0) # Flatten.

import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin):
    """Multi-layer perceptron (MLP) for classifying 'airplane' vs. 'bird' in the CIFAR-2 dataset (a subset of CIFAR-10)."""

    def __init__(self):
        self.mlp = nn.Sequential(
            nn.Linear(1728, 1024), # Hidden layer.
            nn.Linear(1024, 512), # Hidden layer.
            nn.Linear(512, 128), # Hidden layer.
            nn.Linear(128, 2) # Output layer.

    def forward(self, x):
        return self.mlp(x)

model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2")

import torch.nn.functional as F

with torch.no_grad():
    logits = model(batch)
    pred = logits[0].argmax().item()
    proba = F.softmax(logits, dim=1)

print(f"Predicted class: {class_names[pred]}")
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")


Accuracy on cifar2_val: 0.8090

Downloads last month


Downloads are not tracked for this model. How to track
Inference Examples
Inference API (serverless) does not yet support pytorch models for this pipeline type.

Dataset used to train sadhaklal/mlp-cifar2