mnist_demo / app.py
hiwei's picture
Update app.py
78d1a78 verified
import tensorflow as tf
import gradio as gr
import numpy as np
from PIL import Image
TITLE = "Handwritten Digit Recognition Demo"
DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
"It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
PIL_INTERPOLATION_METHODS = {
"nearest": Image.Resampling.NEAREST,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"hamming": Image.Resampling.HAMMING,
"box": Image.Resampling.BOX,
"lanczos": Image.Resampling.LANCZOS,
}
model = tf.keras.saving.load_model("tf_model_mnist")
def preprocess(image, resample_method):
""" Normalize Gradio image to MNIST format """
image = image.resize((28, 28), PIL_INTERPOLATION_METHODS[resample_method])
img_array = np.asarray(image, dtype=np.float32)
for i in range(img_array.shape[0]):
for j in range(img_array.shape[1]):
alpha = img_array[i, j, 3]
if alpha == 0.:
img_array[i, j] = [0., 0., 0., 255.]
else:
img_array[i, j] = [255., 255., 255., 255.]
new_image = Image.fromarray(img_array.astype(np.uint8), "RGBA")
new_image = new_image.convert("L")
image_array = tf.keras.utils.img_to_array(new_image)
image_array = (np.expand_dims(image_array, axis=0)/255.).astype(np.float32)
return image_array, new_image
def predict(img, resample_method):
img = img["composite"]
input_arr, new_image = preprocess(img, resample_method)
print("input:", input_arr.shape)
predictions = model.predict(input_arr)
return {str(i): predictions[0][i] for i in range(10)}, new_image
resample_method = gr.Dropdown(
choices=list(PIL_INTERPOLATION_METHODS.keys()),
value='bilinear',
)
input_image = gr.Sketchpad(
layers=False,
type="pil",
canvas_size=(500, 500),
)
demo = gr.Interface(
predict,
title=TITLE,
description=DESCRIPTION,
inputs=[input_image, resample_method],
outputs=['label', 'image']
)
demo.launch()