import gradio as gr import torch import numpy as np device = 'cuda' if torch.cuda.is_available() else 'cpu' with open('cnn_model.bin', 'rb') as f: nn = torch.load(f, map_location=torch.device('cpu')) nn.to(device) def predict(input): if input is None: return 'None' x = np.array([[input]]) x = torch.tensor(x).to(device) p = nn(x) p = p[0].cpu().detach().numpy() return dict(enumerate(p.tolist())) desc = """\ This project uses a Convolutional Neural Network to classify handwritten digits. Trained on the MNIST dataset. Use most of the drawing area for better results. """ demo = gr.Interface( fn=predict, title='ConvNet for handwritten digits classification', description=desc, inputs=[ gr.Sketchpad( shape=(28, 28), brush_radius=1.2, ) ], outputs=[ gr.Label( num_top_classes=3, scale=3, ) ], live=True, allow_flagging='never', ).launch()