equ1 commited on
Commit
b6b005d
1 Parent(s): f87147a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ from urllib.request import urlretrieve
7
+ from model import Net
8
+
9
+ # Loads latest model state from Github
10
+ urlretrieve("https://github.com/equ1/mnist-interface/tree/main/saved_models")
11
+
12
+ model_timestamps = [filename[10:-3]
13
+ for filename in os.listdir("./saved_models")]
14
+ latest_timestamp = max(model_timestamps)
15
+
16
+ if torch.cuda.is_available():
17
+ dev = "cuda:0"
18
+ else:
19
+ dev = "cpu"
20
+
21
+ device = torch.device(dev)
22
+
23
+ model = torch.load(f"./saved_models/mnist-cnn-{latest_timestamp}.pt", map_location=device)
24
+ model.eval()
25
+
26
+ # inference function
27
+ def inference(img):
28
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28, 28))])
29
+ img = transform(img).unsqueeze(0) # transforms ndarray and adds batch dimension
30
+
31
+ with torch.no_grad():
32
+ output_probabilities = F.softmax(model(img), dim=1)[0] # probability prediction for each label
33
+
34
+ return {labels[i]: float(output_probabilities[i]) for i in range(len(labels))}
35
+
36
+ # Creates and launches gradio interface
37
+ labels = range(10) # 1-9 labels
38
+ outputs = gr.outputs.Label(num_top_classes=5)
39
+ gr.Interface(fn=inference, inputs='sketchpad', outputs=outputs, title="MNIST interface",
40
+ description="Draw a number from 0-9 in the box and click submit to see the model's predictions.").launch()