MNIST CNN Classifier

This repository contains a validation-selected MNIST CNN digit classifier trained with dlab.

Architecture

MNIST CNN architecture

Results

5-seed held-out test evaluation:

metric value
test accuracy 99.6140% ± 0.0802 pp
test loss 0.14677 ± 0.00272
best validation loss 0.14282 ± 0.00138

Per-seed held-out test results:

seed W&B run test accuracy test loss best validation loss
1 5um57rnu 99.6100% 0.14772 0.14124
2 23f1frqb 99.6800% 0.14437 0.14317
3 25yaaj1o 99.4800% 0.15107 0.14403
4 3rrlxghp 99.6300% 0.14573 0.14150
5 y51200ov 99.6700% 0.14494 0.14418

The ONNX model was exported from the seed-1 checkpoint, which had the best validation loss in the final 5-seed evaluation sweep. Test metrics were not used for checkpoint selection and were logged in W&B sweep ikfs5ox8.

Model Details

  • Dataset: MNIST
  • Architecture: CNN
  • Channels: [32, 64, 128]
  • Convolutions per stage: 2
  • Batch normalization: enabled
  • Dropout: 0.1
  • Optimizer: Adam
  • Learning rate: 0.001
  • Weight decay: 0.0001
  • Scheduler: OneCycleLR
  • Label smoothing: 0.02
  • Weight averaging: EMA
  • Batch size: 512
  • Training augmentation: random affine rotation/translation/scale
  • Early stopping: validation loss, patience 8, min delta 0.0005
  • Source W&B run: 5um57rnu
  • Source W&B sweep: ikfs5ox8

Input / Output

Use model.onnx for code-independent inference.

  • Input name: images
  • Input shape: [batch, 1, 28, 28]
  • Input dtype: float32
  • Output name: logits
  • Output shape: [batch, 10]

Preprocessing:

  • Convert image to grayscale.
  • Resize to 28 x 28.
  • Scale pixel values to [0, 1].
  • Normalize with mean 0.1307 and standard deviation 0.3081.
  • Arrange the tensor as channels-first [batch, 1, 28, 28].

Usage

Install the runtime dependencies:

pip install huggingface_hub onnxruntime pillow numpy

Run inference with the ONNX model:

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-cnn",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])

Labels

MNIST labels:

id label
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Files

  • model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
  • model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
  • config.yaml: resolved Hydra training config.
  • metrics.csv: training metrics from the uploaded checkpoint run.
  • metrics_summary.csv: compact 5-seed final evaluation summary.
  • metadata.json: compact metadata for inference and provenance.

Limitations

This compact CNN is near MNIST saturation. Remaining errors are expected to be rare and often visually ambiguous or unusually written digits.

Downloads last month
20
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tsilva/mnist-classifier-cnn