equ1 commited on
Commit
7ef1a1b
1 Parent(s): 016194a

Upload interface.py

Browse files
Files changed (1) hide show
  1. interface.py +41 -0
interface.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ from urllib.request import urlretrieve
7
+ from model import Net
8
+
9
+ # Loads latest model state from Github
10
+ urlretrieve("https://github.com/equ1/mnist-interface/tree/main/saved_models")
11
+
12
+ model_timestamps = [filename[10:-3]
13
+ for filename in os.listdir("./saved_models")]
14
+ latest_timestamp = max(model_timestamps)
15
+
16
+ if torch.cuda.is_available():
17
+ dev = "cuda:0"
18
+ else:
19
+ dev = "cpu"
20
+
21
+ device = torch.device(dev)
22
+
23
+ model = Net()
24
+ model.load_state_dict(torch.load(f"./saved_models/mnist-cnn-{latest_timestamp}.pt", map_location=device))
25
+ model.eval()
26
+
27
+ # inference function
28
+ def inference(img):
29
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28, 28))])
30
+ img = transform(img).unsqueeze(0) # transforms ndarray and adds batch dimension
31
+
32
+ with torch.no_grad():
33
+ output_probabilities = F.softmax(model(img), dim=1)[0] # probability prediction for each label
34
+
35
+ return {labels[i]: float(output_probabilities[i]) for i in range(len(labels))}
36
+
37
+ # Creates and launches gradio interface
38
+ labels = range(10) # 1-9 labels
39
+ outputs = gr.outputs.Label(num_top_classes=5)
40
+ gr.Interface(fn=inference, inputs='sketchpad', outputs=outputs, title="MNIST interface",
41
+ description="Draw a number from 0-9 in the box and click submit to see the model's predictions.").launch()