--- datasets: - cifar10 metrics: - accuracy pipeline_tag: image-classification library_name: pytorch --- # mlp-cifar2 Multi-layer perceptron (MLP) trained on CIFAR-2 (a subset of CIFAR-10 for classifying 'airplane' vs. 'bird'). `RandomCrop(24)` was applied on the training set images, and `Resize(24)` was applied on the validation set images. This model pertains to Exercise 1 of Chapter 7 of the book "Deep Learning with PyTorch" by Eli Stevens, Luca Antiga, and Thomas Viehmann. Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exercise_1.ipynb Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2 ## Usage ``` !pip install -q datasets from datasets import load_dataset cifar10 = load_dataset("cifar10") label_map = {0: 0, 2: 1} class_names = ['airplane', 'bird'] cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]] cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]] example = cifar2_val[0] img, label = example import torch from torchvision.transforms import v2 val_tfms = v2.Compose([ v2.Resize(24), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616]) ]) img = val_tfms(img) batch = img.reshape(-1).unsqueeze(0) # Flatten. import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin): """Multi-layer perceptron (MLP) for classifying 'airplane' vs. 'bird' in the CIFAR-2 dataset (a subset of CIFAR-10).""" def __init__(self): super().__init__() self.mlp = nn.Sequential( nn.Linear(1728, 1024), # Hidden layer. nn.Tanh(), nn.Linear(1024, 512), # Hidden layer. nn.Tanh(), nn.Linear(512, 128), # Hidden layer. nn.Tanh(), nn.Linear(128, 2) # Output layer. ) def forward(self, x): return self.mlp(x) model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2") model.eval() import torch.nn.functional as F with torch.no_grad(): logits = model(batch) pred = logits[0].argmax().item() proba = F.softmax(logits, dim=1) print(f"Predicted class: {class_names[pred]}") print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}") ``` ## Metric Accuracy on `cifar2_val`: 0.8090