nathan ayers
Create app.py
fe70043 verified
raw
history blame
733 Bytes
import pickle
import numpy as np
from PIL import Image
import gradio as gr
# 1) Load your pretrained model
model = pickle.load(open("mnist_model.pkl", "rb"))
# 2) Define a prediction function
def classify_digit(img):
# convert to grayscale 28×28
img = img.convert("L").resize((28, 28))
arr = np.array(img).reshape(1, -1)
pred = model.predict(arr)[0]
return f"Predicted digit: {pred}"
# 3) Wire up Gradio
iface = gr.Interface(
fn=classify_digit,
inputs=gr.Image(type="pil", label="Upload a 28×28 digit"),
outputs=gr.Textbox(label="Prediction"),
title="MNIST Digit Classifier",
description="Upload a handwritten digit and get a prediction!"
)
if __name__ == "__main__":
iface.launch()