|
|
import gradio as gr |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
from PIL import Image, ImageOps |
|
|
|
|
|
|
|
|
|
|
|
model = tf.keras.models.load_model("mnist_model.keras") |
|
|
|
|
|
def predict_digit(image): |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
image = image.convert('L') |
|
|
|
|
|
|
|
|
image = image.resize((28, 28)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if np.mean(image) > 127: |
|
|
image = ImageOps.invert(image) |
|
|
|
|
|
|
|
|
image_array = np.array(image) / 255.0 |
|
|
|
|
|
|
|
|
image_array = image_array.reshape(1, 784) |
|
|
|
|
|
|
|
|
prediction = model.predict(image_array) |
|
|
|
|
|
|
|
|
return {str(i): float(prediction[0][i]) for i in range(10)} |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_digit, |
|
|
inputs=gr.Image(type="pil", label="Upload an Image"), |
|
|
outputs=gr.Label(num_top_classes=3, label="Predictions"), |
|
|
title="MNIST Digit Classifier", |
|
|
description="Upload an image of a handwritten digit (0-9) to see the prediction. Works best with a single digit centered in the image." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|