Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import * | |
from PIL import Image | |
import torchvision.transforms as transforms | |
title = "Digit Classifier" | |
description = ( | |
"Multilayer-Perceptron built for the fast.ai 'Deep Learning' course " | |
"to classify handwritten digits from the MNIST dataset. " | |
) | |
inputs = gr.components.Image() | |
outputs = gr.components.Label() | |
examples = "examples" | |
model = torch.load("model/digit_classifier.pt", map_location=torch.device("cpu")) | |
labels = [str(i) for i in range(10)] | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((28, 28)), | |
transforms.Grayscale(), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x[0]), | |
transforms.Lambda(lambda x: x.unsqueeze(0)), | |
] | |
) | |
def predict_digit(img): | |
img = transform(Image.fromarray(img)) | |
output = model(img) | |
probs = torch.nn.functional.softmax(output, dim=1) | |
return dict(zip(labels, map(float, probs.flatten()[:10]))) | |
with gr.Blocks() as demo: | |
with gr.Tab("Digit Prediction"): | |
gr.Interface( | |
fn=predict_digit, | |
inputs=inputs, | |
outputs=outputs, | |
examples=examples, | |
title=title, | |
description=description, | |
).queue(default_concurrency_limit=5) | |
demo.launch() | |