equ1 commited on
Commit
d7c377f
1 Parent(s): 17984df

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+
7
+ from model import Net
8
+
9
+ # loads demo model
10
+ if torch.cuda.is_available():
11
+ dev = "cuda:0"
12
+ else:
13
+ dev = "cpu"
14
+
15
+ device = torch.device(dev)
16
+
17
+ model = torch.load(f"./demo_model.pt", map_location=device)
18
+
19
+ model.eval()
20
+
21
+ # inference function
22
+ def inference(img):
23
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28, 28))])
24
+ img = transform(img).unsqueeze(0) # transforms ndarray and adds batch dimension
25
+
26
+ with torch.no_grad():
27
+ output_probabilities = F.softmax(model(img), dim=1)[0] # probability prediction for each label
28
+
29
+ return {labels[i]: float(output_probabilities[i]) for i in range(len(labels))}
30
+
31
+ # Creates and launches gradio interface
32
+ labels = range(10) # 1-9 labels
33
+ outputs = gr.outputs.Label(num_top_classes=5)
34
+ gr.Interface(fn=inference, inputs='sketchpad', outputs=outputs, title="MNIST Interface",
35
+ description="Draw a number from 0-9 in the box and click submit to see the model's predictions.").launch()