mnist / app.py
gaviego's picture
clear working
a8cceae
raw
history blame
2.65 kB
import gradio as gr
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from models import Net,NetConv
net = torch.load('mnist.pth')
net.eval()
net_conv = torch.load('mnist_conv.pth')
net_conv.eval()
def predict(img):
arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
arr = np.expand_dims(arr, axis=0) # Add batch dimension
arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor
output = net(arr)
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
return [str(k) for k in topk_indices[0].tolist()]
def predict_conv(img):
arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
arr = np.expand_dims(arr, axis=0) # Conv needs one more dimension
arr = np.expand_dims(arr, axis=0) # Add batch dimension
arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor
output = net_conv(arr)
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
return [str(k) for k in topk_indices[0].tolist()]
with gr.Blocks() as iface:
gr.Markdown("# MNIST + Gradio End to End")
gr.HTML("Shows end to end MNIST training with Gradio interface")
with gr.Tab("Linear Model"):
with gr.Row():
with gr.Column():
sp = gr.Sketchpad(shape=(28, 28))
with gr.Row():
with gr.Column():
pred_button = gr.Button("Predict")
with gr.Column():
clear_button = gr.Button("Clear")
with gr.Column():
label1 = gr.Label(label='1st Pred')
label2 = gr.Label(label='2nd Pred')
with gr.Tab("Convolution Model"):
with gr.Row():
with gr.Column():
sp_conv = gr.Sketchpad(shape=(28, 28))
with gr.Row():
with gr.Column():
pred_conv_button = gr.Button("Predict")
with gr.Column():
clear_button_conv = gr.Button("Clear")
with gr.Column():
label1_conv = gr.Label(label='1st Pred')
label2_conv = gr.Label(label='2nd Pred')
def clear():
return ['','',None,'','',None]
pred_button.click(predict, inputs=sp, outputs=[label1,label2])
pred_conv_button.click(predict_conv, inputs=sp_conv, outputs=[label1_conv,label2_conv])
clear_button.click( lambda: ['','',None], None, [label1,label2,sp,], queue=False)
clear_button_conv.click( lambda: ['','',None], None, [label1_conv,label2_conv, sp_conv], queue=False)
iface.launch()