File size: 2,651 Bytes
0d94b00
 
 
 
 
6afcd6e
0d94b00
 
 
 
a8cceae
 
 
0d94b00
 
 
 
 
 
 
 
a8cceae
 
 
 
 
 
 
 
 
 
 
 
 
4417b5c
 
 
a8cceae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4417b5c
a8cceae
 
 
 
 
4417b5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from models import Net,NetConv

net = torch.load('mnist.pth')
net.eval()

net_conv = torch.load('mnist_conv.pth')
net_conv.eval()

def predict(img):
    arr = np.array(img) / 255  # Assuming img is in the range [0, 255]
    arr = np.expand_dims(arr, axis=0)  # Add batch dimension
    arr = torch.from_numpy(arr).float()  # Convert to PyTorch tensor
    output = net(arr)
    topk_values, topk_indices = torch.topk(output, 2)  # Get the top 2 classes
    return [str(k) for k in topk_indices[0].tolist()]

def predict_conv(img):
    arr = np.array(img) / 255  # Assuming img is in the range [0, 255]
    arr = np.expand_dims(arr, axis=0)  # Conv needs one more dimension
    arr = np.expand_dims(arr, axis=0)  # Add batch dimension
    arr = torch.from_numpy(arr).float()  # Convert to PyTorch tensor
    output = net_conv(arr)
    topk_values, topk_indices = torch.topk(output, 2)  # Get the top 2 classes
    return [str(k) for k in topk_indices[0].tolist()]





with gr.Blocks() as iface:
    gr.Markdown("# MNIST + Gradio End to End")
    gr.HTML("Shows end to end MNIST training with Gradio interface")
    with gr.Tab("Linear Model"):
        with gr.Row():
            with gr.Column():
                sp = gr.Sketchpad(shape=(28, 28))
                with gr.Row():
                    with gr.Column():
                        pred_button = gr.Button("Predict")
                    with gr.Column():
                        clear_button = gr.Button("Clear")
            with gr.Column():
                label1 = gr.Label(label='1st Pred')
                label2 = gr.Label(label='2nd Pred')
    
    with gr.Tab("Convolution Model"):
        with gr.Row():
            with gr.Column():
                sp_conv = gr.Sketchpad(shape=(28, 28))
                with gr.Row():
                    with gr.Column():
                        pred_conv_button = gr.Button("Predict")
                    with gr.Column():
                        clear_button_conv = gr.Button("Clear")
            with gr.Column():
                label1_conv = gr.Label(label='1st Pred')
                label2_conv = gr.Label(label='2nd Pred')
    def clear():
        return ['','',None,'','',None]
    pred_button.click(predict, inputs=sp, outputs=[label1,label2])
    pred_conv_button.click(predict_conv, inputs=sp_conv, outputs=[label1_conv,label2_conv])
    clear_button.click( lambda: ['','',None], None, [label1,label2,sp,], queue=False)
    clear_button_conv.click( lambda: ['','',None], None, [label1_conv,label2_conv, sp_conv], queue=False)


iface.launch()