File size: 661 Bytes
0d94b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import gradio as gr
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import model

net = torch.load('mnist.pth')
net.eval()

def predict(img):
    arr = np.array(img) / 255  # Assuming img is in the range [0, 255]
    arr = np.expand_dims(arr, axis=0)  # Add batch dimension
    arr = torch.from_numpy(arr).float()  # Convert to PyTorch tensor
    output = net(arr)
    topk_values, topk_indices = torch.topk(output, 2)  # Get the top 2 classes
    return [str(k) for k in topk_indices[0].tolist()]


sp = gr.Sketchpad(shape=(28, 28))

gr.Interface(fn=predict,
             inputs=sp,
             outputs=['label','label']).launch()