ffcm's picture
Initial commit
35cfa5c
raw
history blame
734 Bytes
import gradio as gr
import pickle
import torch
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('cnn_model_9961_sm.bin', 'rb') as f:
nn = pickle.load(f)
nn.to(device)
def predict(input):
if input is None:
return 'None'
x = np.array([[input]])
x = torch.tensor(x).to(device)
p = nn(x)
p = p[0].cpu().detach().numpy()
return dict(enumerate(p.tolist()))
demo = gr.Interface(
fn=predict,
inputs=[
gr.Sketchpad(
shape=(28, 28),
brush_radius=1.2,
)
],
outputs=[
gr.Label(
num_top_classes=3,
scale=2,
)
],
live=True,
allow_flagging='never',
).launch()