equ1 commited on
Commit
d14bc60
1 Parent(s): d767e62

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ from urllib.request import urlretrieve
7
+
8
+ # Loads latest model state from Github
9
+ urlretrieve("https://github.com/equ1/mnist-interface/tree/main/demo_model.pt", "demo_model.pt")
10
+
11
+ if torch.cuda.is_available():
12
+ dev = "cuda:0"
13
+ else:
14
+ dev = "cpu"
15
+
16
+ device = torch.device(dev)
17
+
18
+ model = torch.load(f"./demo_model.pt", map_location=device)
19
+ model.eval()
20
+
21
+ # inference function
22
+ def inference(img):
23
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28, 28))])
24
+ img = transform(img).unsqueeze(0) # transforms ndarray and adds batch dimension
25
+
26
+ with torch.no_grad():
27
+ output_probabilities = F.softmax(model(img), dim=1)[0] # probability prediction for each label
28
+
29
+ return {labels[i]: float(output_probabilities[i]) for i in range(len(labels))}
30
+
31
+ # Creates and launches gradio interface
32
+ labels = range(10) # 1-9 labels
33
+ outputs = gr.outputs.Label(num_top_classes=5)
34
+ gr.Interface(fn=inference, inputs='sketchpad', outputs=outputs, title="MNIST Interface",
35
+ description="Draw a number from 0-9 in the box and click submit to see the model's predictions.").launch()