a10 commited on
Commit
f0841e3
1 Parent(s): 1e8608f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -9,9 +9,15 @@
9
 
10
  # Setting up the Sketch Recognition Model
11
 
 
 
12
  import torch
 
13
  from torch import nn
14
 
 
 
 
15
  model = nn.Sequential(
16
  nn.Conv2d(1, 32, 3, padding='same'),
17
  nn.ReLU(),
@@ -27,30 +33,22 @@ model = nn.Sequential(
27
  nn.ReLU(),
28
  nn.Linear(256, len(LABELS)),
29
  )
30
- state_dict = torch.load('pytorch_model.bin', map_location='cpu')
31
  model.load_state_dict(state_dict, strict=False)
32
  model.eval()
33
 
34
- # Defining a predict function
 
35
 
36
- from pathlib import Path
37
-
38
- LABELS = Path('class_names.txt').read_text().splitlines()
39
-
40
- def predict(img):
41
- x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
42
  with torch.no_grad():
43
  out = model(x)
 
44
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
 
45
  values, indices = torch.topk(probabilities, 5)
46
- confidences = {LABELS[i]: v.item() for i, v in zip(indices, values)}
47
- return confidences
48
 
49
- # Creating a Gradio Interface
50
 
51
- import gradio as gr
52
 
53
- gr.Interface(fn=predict,
54
- inputs="sketchpad",
55
- outputs="label",
56
- live=True).launch()
 
9
 
10
  # Setting up the Sketch Recognition Model
11
 
12
+ from pathlib import Path
13
+
14
  import torch
15
+ import gradio as gr
16
  from torch import nn
17
 
18
+
19
+ LABELS = Path('class_names.txt').read_text().splitlines()
20
+
21
  model = nn.Sequential(
22
  nn.Conv2d(1, 32, 3, padding='same'),
23
  nn.ReLU(),
 
33
  nn.ReLU(),
34
  nn.Linear(256, len(LABELS)),
35
  )
36
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
37
  model.load_state_dict(state_dict, strict=False)
38
  model.eval()
39
 
40
+ def predict(im):
41
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
42
 
 
 
 
 
 
 
43
  with torch.no_grad():
44
  out = model(x)
45
+
46
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
47
+
48
  values, indices = torch.topk(probabilities, 5)
 
 
49
 
50
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
51
 
 
52
 
53
+ interface = gr.Interface(predict, inputs='sketchpad', outputs='label', live=True)
54
+ interface.launch(debug=True)