Edit model card

mlp-cifar2-v2

Multi-layer perceptron (MLP) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird').

nn.BCEWithLogitsLoss was used to train the model.

This model pertains to Exercise 2 of Chapter 7 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exercise_2.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2-v2

Usage

!pip install -q datasets

from datasets import load_dataset

cifar10 = load_dataset("cifar10")
label_map = {0: 0.0, 2: 1.0}
class_names = ['airplane', 'bird']
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]

example = cifar2_val[0]
img, label = example

import torch
from torchvision.transforms import v2

val_tfms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
])
img = val_tfms(img)
batch = img.reshape(-1).unsqueeze(0) # Flatten.

import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(3072, 64), # Hidden layer.
            nn.Tanh(),
            nn.Linear(64, 1) # Output layer.
        )

    def forward(self, x):
        return self.mlp(x)

model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2-v2")
model.eval()

import torch.nn.functional as F

with torch.no_grad():
    logits = model(batch)
    proba = F.sigmoid(logits.squeeze())
    pred = int(proba.item() > 0.5)

print(f"Predicted class: {class_names[pred]}")
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {[proba.item(), 1 - proba.item()]}")

Metric

Accuracy on cifar2_val: 0.829

Downloads last month
0
Inference API
Drag image file here or click to browse from your device
Inference API (serverless) does not yet support pytorch models for this pipeline type.

Dataset used to train sadhaklal/mlp-cifar2-v2