Spaces:
Runtime error
Runtime error
File size: 1,101 Bytes
6c0d444 aaa317a 6c0d444 dcb34da 6c0d444 aaa317a 6c0d444 aaa317a 6c0d444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
import gradio as gr
import torchvision.transforms as transforms
from neural_network import MNISTNetwork
transform = transforms.Compose([
transforms.ToTensor(), # Convert image to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize the image
])
# Load the trained model
net = MNISTNetwork()
net.load_state_dict(torch.load('MNISTModel.pth'))
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def predict(drawing):
if drawing is None:
return "Draw a number hoe"
input_tensor = transform(drawing)
x = input_tensor
# x = input_tensor.view(input_tensor.shape[0], -1)
with torch.no_grad():
output = net(x)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 10)
results = {LABELS[i]: v.item() for i, v in zip(indices, values)}
return results
sketchpad_input = gr.Sketchpad(shape=(28, 28))
interface = gr.Interface(
fn=predict,
inputs=sketchpad_input,
outputs="label",
live=True
)
interface.launch()
|