ylecun/mnist
Viewer • Updated • 70k • 74.8k • 243
This repository contains a validation-selected MNIST CNN digit classifier trained with dlab.
5-seed held-out test evaluation:
| metric | value |
|---|---|
| test accuracy | 99.6140% ± 0.0802 pp |
| test loss | 0.14677 ± 0.00272 |
| best validation loss | 0.14282 ± 0.00138 |
Per-seed held-out test results:
| seed | W&B run | test accuracy | test loss | best validation loss |
|---|---|---|---|---|
| 1 | 5um57rnu |
99.6100% | 0.14772 | 0.14124 |
| 2 | 23f1frqb |
99.6800% | 0.14437 | 0.14317 |
| 3 | 25yaaj1o |
99.4800% | 0.15107 | 0.14403 |
| 4 | 3rrlxghp |
99.6300% | 0.14573 | 0.14150 |
| 5 | y51200ov |
99.6700% | 0.14494 | 0.14418 |
The ONNX model was exported from the seed-1 checkpoint, which had the best validation loss in the final 5-seed evaluation sweep. Test metrics were not used for checkpoint selection and were logged in W&B sweep ikfs5ox8.
[32, 64, 128]20.10.0010.00010.025128, min delta 0.00055um57rnuikfs5ox8Use model.onnx for code-independent inference.
images[batch, 1, 28, 28]float32logits[batch, 10]Preprocessing:
28 x 28.[0, 1].0.1307 and standard deviation 0.3081.[batch, 1, 28, 28].Install the runtime dependencies:
pip install huggingface_hub onnxruntime pillow numpy
Run inference with the ONNX model:
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
LABELS = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
}
model_path = hf_hub_download(
repo_id="tsilva/mnist-classifier-cnn",
filename="model.onnx",
)
image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])
print(prediction, LABELS[prediction])
MNIST labels:
| id | label |
|---|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |
model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.config.yaml: resolved Hydra training config.metrics.csv: training metrics from the uploaded checkpoint run.metrics_summary.csv: compact 5-seed final evaluation summary.metadata.json: compact metadata for inference and provenance.This compact CNN is near MNIST saturation. Remaining errors are expected to be rare and often visually ambiguous or unusually written digits.