Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pickle | |
import torch | |
import numpy as np | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
with open('cnn_model_9961_sm.bin', 'rb') as f: | |
# nn = pickle.load(f) | |
nn = torch.load(f, map_location=torch.device('cpu')) | |
nn.to(device) | |
def predict(input): | |
if input is None: | |
return 'None' | |
x = np.array([[input]]) | |
x = torch.tensor(x).to(device) | |
p = nn(x) | |
p = p[0].cpu().detach().numpy() | |
return dict(enumerate(p.tolist())) | |
demo = gr.Interface( | |
fn=predict, | |
title='ConvNet for handwritten digits classification', | |
description='Check this out <a href="google.com">Google</a>', | |
inputs=[ | |
gr.Sketchpad( | |
shape=(28, 28), | |
brush_radius=1.2, | |
) | |
], | |
outputs=[ | |
gr.Label( | |
num_top_classes=3, | |
scale=3, | |
) | |
], | |
live=True, | |
allow_flagging='never', | |
).launch() | |