Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import torchvision.transforms as transforms | |
from neural_network import MNISTNetwork | |
transform = transforms.Compose([ | |
transforms.ToTensor(), # Convert image to tensor | |
transforms.Normalize((0.5,), (0.5,)) # Normalize the image | |
]) | |
# Load the trained model | |
net = MNISTNetwork() | |
net.load_state_dict(torch.load('MNISTModel.pth')) | |
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |
def predict(drawing): | |
if drawing is None: | |
return "Draw a number hoe" | |
input_tensor = transform(drawing) | |
x = input_tensor | |
# x = input_tensor.view(input_tensor.shape[0], -1) | |
with torch.no_grad(): | |
output = net(x) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
values, indices = torch.topk(probabilities, 10) | |
results = {LABELS[i]: v.item() for i, v in zip(indices, values)} | |
return results | |
sketchpad_input = gr.Sketchpad(shape=(28, 28)) | |
interface = gr.Interface( | |
fn=predict, | |
inputs=sketchpad_input, | |
outputs="label", | |
live=True | |
) | |
interface.launch() | |