Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastai.vision.all import * | |
from PIL import Image, ImageOps | |
import numpy as np | |
# Load your trained model | |
learn = load_learner('mnist_resnet18.pkl') | |
def predict_from_sketch(img): | |
img = img["composite"] | |
# Convert to grayscale, invert, resize, RGB | |
pil_img = Image.fromarray(img).convert("RGBA") | |
# Composite over a white background | |
background = Image.new("RGBA", pil_img.size, (255, 255, 255, 255)) | |
pil_img = Image.alpha_composite(background, pil_img) | |
# Now convert to grayscale properly | |
pil_img = pil_img.convert("L") | |
# Invert and resize | |
pil_img = ImageOps.invert(pil_img) | |
pil_img = pil_img.resize((224, 224)).convert("RGB") | |
#pil_img.show(title="Final Preprocessed Image") | |
dl = learn.dls.test_dl([pil_img]) | |
xb = dl.one_batch()[0] | |
with torch.no_grad(): | |
preds = learn.model.eval()(xb) | |
pred_idx = preds.argmax(dim=1).item() | |
probs = preds.softmax(dim=1).squeeze() | |
return {str(learn.dls.vocab[i]): float(probs[i]) for i in range(len(probs))} | |
demo = gr.Interface( | |
predict_from_sketch, | |
inputs="sketchpad", | |
outputs=gr.Label(num_top_classes=3), | |
title="MNIST Digit Classifier", | |
description="Draw a digit (0–9) and get predictions in real time!" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |