DigitClassifier / app.py
hkanumilli's picture
updating with final model
dcb34da
import torch
import gradio as gr
import torchvision.transforms as transforms
from neural_network import MNISTNetwork
transform = transforms.Compose([
transforms.ToTensor(), # Convert image to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize the image
])
# Load the trained model
net = MNISTNetwork()
net.load_state_dict(torch.load('MNISTModel.pth'))
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def predict(drawing):
if drawing is None:
return "Draw a number hoe"
input_tensor = transform(drawing)
x = input_tensor
# x = input_tensor.view(input_tensor.shape[0], -1)
with torch.no_grad():
output = net(x)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 10)
results = {LABELS[i]: v.item() for i, v in zip(indices, values)}
return results
sketchpad_input = gr.Sketchpad(shape=(28, 28))
interface = gr.Interface(
fn=predict,
inputs=sketchpad_input,
outputs="label",
live=True
)
interface.launch()