ffcm's picture
updates desc
908930d
import gradio as gr
import torch
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('cnn_model.bin', 'rb') as f:
nn = torch.load(f, map_location=torch.device('cpu'))
nn.to(device)
def predict(input):
if input is None:
return 'None'
x = np.array([[input]])
x = torch.tensor(x).to(device)
p = nn(x)
p = p[0].cpu().detach().numpy()
return dict(enumerate(p.tolist()))
desc = """\
This project uses a Convolutional Neural Network to classify handwritten digits.
Trained on the MNIST dataset.
Use most of the drawing area for better results.
"""
demo = gr.Interface(
fn=predict,
title='ConvNet for handwritten digits classification',
description=desc,
inputs=[
gr.Sketchpad(
shape=(28, 28),
brush_radius=1.2,
)
],
outputs=[
gr.Label(
num_top_classes=3,
scale=3,
)
],
live=True,
allow_flagging='never',
).launch()