Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import model, classes | |
from torchvision import transforms | |
checkpoint = torch.load('mnist_model.pth', map_location=torch.device('cpu')) | |
# Load the state dictionary into model | |
model.load_state_dict(checkpoint['model_state_dict']) | |
# Set your model to evaluation mode | |
model.eval() | |
def preprocess_image(image): | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
threshold = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2) | |
resized = cv2.resize(threshold, (28, 28), interpolation=cv2.INTER_AREA) | |
tensor = transforms.ToTensor()(resized).unsqueeze(0) | |
tensor = transforms.Normalize((0.5,), (0.5,))(tensor) | |
return tensor | |
def classify(image): | |
tensor = preprocess_image(image) | |
with torch.no_grad(): | |
output = model(tensor) | |
prediction = output.argmax(dim=1, keepdim=True).item() | |
return str(prediction) # Convert prediction to string | |
iface = gr.Interface( | |
fn=classify, | |
inputs="sketchpad", | |
outputs='label', | |
theme="huggingface", | |
title="Digit Recognition", | |
description="Draw a Digit 0-9 and the algorithm will detect it in real time!", | |
article="<p style='text-align: center'>Digit Recognition | Demo Model by Jugal</p>", | |
live=True) | |
iface.launch(debug=True) | |