Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
with open('cnn_model.bin', 'rb') as f: | |
nn = torch.load(f, map_location=torch.device('cpu')) | |
nn.to(device) | |
def predict(input): | |
if input is None: | |
return 'None' | |
x = np.array([[input]]) | |
x = torch.tensor(x).to(device) | |
p = nn(x) | |
p = p[0].cpu().detach().numpy() | |
return dict(enumerate(p.tolist())) | |
desc = """\ | |
This project uses a Convolutional Neural Network to classify handwritten digits. | |
Trained on the MNIST dataset. | |
Use most of the drawing area for better results. | |
""" | |
demo = gr.Interface( | |
fn=predict, | |
title='ConvNet for handwritten digits classification', | |
description=desc, | |
inputs=[ | |
gr.Sketchpad( | |
shape=(28, 28), | |
brush_radius=1.2, | |
) | |
], | |
outputs=[ | |
gr.Label( | |
num_top_classes=3, | |
scale=3, | |
) | |
], | |
live=True, | |
allow_flagging='never', | |
).launch() | |