Aaron Collins
Removed 'Live' functionality so the prediction doesn't start with an error
f0c61f1
import gradio as gr
from fastai.vision.all import *
from PIL import Image, ImageOps
import numpy as np
# Load your trained model
learn = load_learner('mnist_resnet18.pkl')
def predict_from_sketch(img):
img = img["composite"]
# Convert to grayscale, invert, resize, RGB
pil_img = Image.fromarray(img).convert("RGBA")
# Composite over a white background
background = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
pil_img = Image.alpha_composite(background, pil_img)
# Now convert to grayscale properly
pil_img = pil_img.convert("L")
# Invert and resize
pil_img = ImageOps.invert(pil_img)
pil_img = pil_img.resize((224, 224)).convert("RGB")
#pil_img.show(title="Final Preprocessed Image")
dl = learn.dls.test_dl([pil_img])
xb = dl.one_batch()[0]
with torch.no_grad():
preds = learn.model.eval()(xb)
pred_idx = preds.argmax(dim=1).item()
probs = preds.softmax(dim=1).squeeze()
return {str(learn.dls.vocab[i]): float(probs[i]) for i in range(len(probs))}
demo = gr.Interface(
predict_from_sketch,
inputs="sketchpad",
outputs=gr.Label(num_top_classes=3),
title="MNIST Digit Classifier",
description="Draw a digit (0–9) and get predictions in real time!"
)
if __name__ == "__main__":
demo.launch()