a10 commited on
Commit
1e8608f
1 Parent(s): 2e75cf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -0
app.py CHANGED
@@ -7,6 +7,8 @@
7
  #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
  #iface.launch()
9
 
 
 
10
  import torch
11
  from torch import nn
12
 
@@ -28,3 +30,27 @@ model = nn.Sequential(
28
  state_dict = torch.load('pytorch_model.bin', map_location='cpu')
29
  model.load_state_dict(state_dict, strict=False)
30
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
  #iface.launch()
9
 
10
+ # Setting up the Sketch Recognition Model
11
+
12
  import torch
13
  from torch import nn
14
 
 
30
  state_dict = torch.load('pytorch_model.bin', map_location='cpu')
31
  model.load_state_dict(state_dict, strict=False)
32
  model.eval()
33
+
34
+ # Defining a predict function
35
+
36
+ from pathlib import Path
37
+
38
+ LABELS = Path('class_names.txt').read_text().splitlines()
39
+
40
+ def predict(img):
41
+ x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
42
+ with torch.no_grad():
43
+ out = model(x)
44
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
45
+ values, indices = torch.topk(probabilities, 5)
46
+ confidences = {LABELS[i]: v.item() for i, v in zip(indices, values)}
47
+ return confidences
48
+
49
+ # Creating a Gradio Interface
50
+
51
+ import gradio as gr
52
+
53
+ gr.Interface(fn=predict,
54
+ inputs="sketchpad",
55
+ outputs="label",
56
+ live=True).launch()