ffcm's picture
updated description
00e0338
raw
history blame
No virus
885 Bytes
import gradio as gr
import pickle
import torch
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('cnn_model.bin', 'rb') as f:
# nn = pickle.load(f)
nn = torch.load(f, map_location=torch.device('cpu'))
nn.to(device)
def predict(input):
if input is None:
return 'None'
x = np.array([[input]])
x = torch.tensor(x).to(device)
p = nn(x)
p = p[0].cpu().detach().numpy()
return dict(enumerate(p.tolist()))
demo = gr.Interface(
fn=predict,
title='ConvNet for handwritten digits classification',
description='Created with PyTorch.',
inputs=[
gr.Sketchpad(
shape=(28, 28),
brush_radius=1.2,
)
],
outputs=[
gr.Label(
num_top_classes=3,
scale=3,
)
],
live=True,
allow_flagging='never',
).launch()