Edit model card

mlp-cifar2

Multi-layer perceptron (MLP) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird').

RandomCrop(24) was applied on the training set images, and Resize(24) was applied on the validation set images.

This model pertains to Exercise 1 of Chapter 7 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exercise_1.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2

Usage

!pip install -q datasets

from datasets import load_dataset

cifar10 = load_dataset("cifar10")
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]

example = cifar2_val[0]
img, label = example

import torch
from torchvision.transforms import v2

val_tfms = v2.Compose([
    v2.Resize(24),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
])
img = val_tfms(img)
batch = img.reshape(-1).unsqueeze(0) # Flatten.

import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin):
    """Multi-layer perceptron (MLP) for classifying 'airplane' vs. 'bird' in the CIFAR-2 dataset (a subset of CIFAR-10)."""

    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1728, 1024), # Hidden layer.
            nn.Tanh(),
            nn.Linear(1024, 512), # Hidden layer.
            nn.Tanh(),
            nn.Linear(512, 128), # Hidden layer.
            nn.Tanh(),
            nn.Linear(128, 2) # Output layer.
        )

    def forward(self, x):
        return self.mlp(x)

model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2")
model.eval()

import torch.nn.functional as F

with torch.no_grad():
    logits = model(batch)
    pred = logits[0].argmax().item()
    proba = F.softmax(logits, dim=1)

print(f"Predicted class: {class_names[pred]}")
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")

Metric

Accuracy on cifar2_val: 0.8090

Downloads last month
0
Inference API
Drag image file here or click to browse from your device
Inference API (serverless) does not yet support pytorch models for this pipeline type.

Dataset used to train sadhaklal/mlp-cifar2