File size: 1,101 Bytes
6c0d444
 
 
 
 
 
 
 
aaa317a
6c0d444
 
dcb34da
 
 
6c0d444
 
 
 
aaa317a
6c0d444
 
aaa317a
 
6c0d444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import gradio as gr
import torchvision.transforms as transforms
from neural_network import MNISTNetwork


transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image
])

# Load the trained model   
net = MNISTNetwork() 
net.load_state_dict(torch.load('MNISTModel.pth'))
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

def predict(drawing): 
    if drawing is None: 
        return "Draw a number hoe"
    
    input_tensor = transform(drawing)
    x = input_tensor
    # x = input_tensor.view(input_tensor.shape[0], -1)
    
    with torch.no_grad():
        output = net(x)

        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        values, indices = torch.topk(probabilities, 10)
        results = {LABELS[i]: v.item() for i, v in zip(indices, values)}
        
        return results


sketchpad_input = gr.Sketchpad(shape=(28, 28))
interface = gr.Interface(
    fn=predict, 
    inputs=sketchpad_input,
    outputs="label",
    live=True
)
interface.launch()