Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| TITLE = "Handwritten Digit Recognition Demo" | |
| DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\ | |
| "It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing." | |
| PIL_INTERPOLATION_METHODS = { | |
| "nearest": Image.Resampling.NEAREST, | |
| "bilinear": Image.Resampling.BILINEAR, | |
| "bicubic": Image.Resampling.BICUBIC, | |
| "hamming": Image.Resampling.HAMMING, | |
| "box": Image.Resampling.BOX, | |
| "lanczos": Image.Resampling.LANCZOS, | |
| } | |
| model = tf.keras.saving.load_model("tf_model_mnist") | |
| def preprocess(image, resample_method): | |
| """ Normalize Gradio image to MNIST format """ | |
| image = image.resize((28, 28), PIL_INTERPOLATION_METHODS[resample_method]) | |
| img_array = np.asarray(image, dtype=np.float32) | |
| for i in range(img_array.shape[0]): | |
| for j in range(img_array.shape[1]): | |
| alpha = img_array[i, j, 3] | |
| if alpha == 0.: | |
| img_array[i, j] = [0., 0., 0., 255.] | |
| else: | |
| img_array[i, j] = [255., 255., 255., 255.] | |
| new_image = Image.fromarray(img_array.astype(np.uint8), "RGBA") | |
| new_image = new_image.convert("L") | |
| image_array = tf.keras.utils.img_to_array(new_image) | |
| image_array = (np.expand_dims(image_array, axis=0)/255.).astype(np.float32) | |
| return image_array, new_image | |
| def predict(img, resample_method): | |
| img = img["composite"] | |
| input_arr, new_image = preprocess(img, resample_method) | |
| print("input:", input_arr.shape) | |
| predictions = model.predict(input_arr) | |
| return {str(i): predictions[0][i] for i in range(10)}, new_image | |
| resample_method = gr.Dropdown( | |
| choices=list(PIL_INTERPOLATION_METHODS.keys()), | |
| value='bilinear', | |
| ) | |
| input_image = gr.Sketchpad( | |
| layers=False, | |
| type="pil", | |
| canvas_size=(500, 500), | |
| ) | |
| demo = gr.Interface( | |
| predict, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| inputs=[input_image, resample_method], | |
| outputs=['label', 'image'] | |
| ) | |
| demo.launch() | |