Edit model card

Model Card for resnet_mnist_digits

This model is is a Residual Neural Network (ResNet) for classifying handwritten digits in the MNIST dataset. This model has 1.35 M parameters and achieves 99.04% accuracy on the MNIST test dataset (i.e., on digits not seen during training).

Model Details

Model Description

This model takes as an input a 224x224 array of MNIST digits with values normalized to [0, 1]. Intended to compare to 224x224 vision transformers. The model was trained using Keras on an Nvidia Ampere A100.

  • Developed by: Phillip Allen Lane
  • Model type: ResNet
  • License: afl-3.0

How to Get Started with the Model

Use the code below to get started with the model.

from tensorflow.keras import models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from keras.utils.data_utils import get_file

# load the MNIST dataset test images and labels
(_, _), (test_images, test_labels) = mnist.load_data()

# normalize the images
test_images = test_images.astype('float32') / 255
test_images = np.expand_dims(test_images, axis=-1)
test_images = np.repeat(vis_test_images, 3, axis=-1)
test_images = tf.image.resize(vis_test_images, [224,224]).numpy()
# create one-hot labels
test_labels_onehot = to_categorical(test_labels)

# download the model
model_path = get_file('/path/to/large_resnet_mnist.hdf5', 'https://huggingface.co/lane99/resnet_mnist_digits_highres/resolve/main/large-resnet-mnist.hdf5')
# import the model
resnet = models.load_model(model_path)

# evaluate the model
evaluation_conv = resnet.evaluate(test_images[...,0], test_labels_onehot)
print("Accuracy: ", str(evaluation_conv[1]))

Training Details

Training Data

This model was trained on the 60,000 entries in the MNIST training dataset.

Training Procedure

This model was trained with a 0.1 validation split for 10 epochs using a batch size of 128.

Downloads last month
0
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train lane99/resnet_mnist_digits_highres

Evaluation results