Spaces:
Runtime error
Runtime error
File size: 1,003 Bytes
35cfa5c 9229393 d2b6714 35cfa5c 908930d 35cfa5c d2b6714 908930d 35cfa5c d2b6714 35cfa5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import torch
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('cnn_model.bin', 'rb') as f:
nn = torch.load(f, map_location=torch.device('cpu'))
nn.to(device)
def predict(input):
if input is None:
return 'None'
x = np.array([[input]])
x = torch.tensor(x).to(device)
p = nn(x)
p = p[0].cpu().detach().numpy()
return dict(enumerate(p.tolist()))
desc = """\
This project uses a Convolutional Neural Network to classify handwritten digits.
Trained on the MNIST dataset.
Use most of the drawing area for better results.
"""
demo = gr.Interface(
fn=predict,
title='ConvNet for handwritten digits classification',
description=desc,
inputs=[
gr.Sketchpad(
shape=(28, 28),
brush_radius=1.2,
)
],
outputs=[
gr.Label(
num_top_classes=3,
scale=3,
)
],
live=True,
allow_flagging='never',
).launch()
|