Jugal-sheth's picture
Add application file
52a915d
raw
history blame
1.32 kB
import gradio as gr
import torch
from model import model, classes
from torchvision import transforms
checkpoint = torch.load('mnist_model.pth', map_location=torch.device('cpu'))
# Load the state dictionary into model
model.load_state_dict(checkpoint['model_state_dict'])
# Set your model to evaluation mode
model.eval()
def preprocess_image(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
threshold = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
resized = cv2.resize(threshold, (28, 28), interpolation=cv2.INTER_AREA)
tensor = transforms.ToTensor()(resized).unsqueeze(0)
tensor = transforms.Normalize((0.5,), (0.5,))(tensor)
return tensor
def classify(image):
tensor = preprocess_image(image)
with torch.no_grad():
output = model(tensor)
prediction = output.argmax(dim=1, keepdim=True).item()
return str(prediction) # Convert prediction to string
iface = gr.Interface(
fn=classify,
inputs="sketchpad",
outputs='label',
theme="huggingface",
title="Digit Recognition",
description="Draw a Digit 0-9 and the algorithm will detect it in real time!",
article="<p style='text-align: center'>Digit Recognition | Demo Model by Jugal</p>",
live=True)
iface.launch(debug=True)