jingwora commited on
Commit
7f933e9
·
1 Parent(s): 0c99b1a

Add application file

Browse files
Files changed (4) hide show
  1. app.py +49 -0
  2. class_names.txt +100 -0
  3. pytorch_model.bin +3 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import torch
4
+ import gradio as gr
5
+ from torch import nn
6
+
7
+ LABELS = Path("class_names.txt").read_text().splitlines()
8
+
9
+ model = nn.Sequential(
10
+ nn.Conv2d(1, 32, 3, padding="same"),
11
+ nn.ReLU(),
12
+ nn.MaxPool2d(2),
13
+ nn.Conv2d(32, 64, 3, padding="same"),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2),
16
+ nn.Conv2d(64, 128, 3, padding="same"),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(2),
19
+ nn.Flatten(),
20
+ nn.Linear(1152, 256),
21
+ nn.ReLU(),
22
+ nn.Linear(256, len(LABELS)),
23
+ )
24
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
25
+ model.load_state_dict(state_dict, strict=False)
26
+ model.eval()
27
+
28
+
29
+ def predict(im):
30
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
31
+ with torch.no_grad():
32
+ out = model(x)
33
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
34
+ values, indices = torch.topk(probabilities, 5)
35
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
36
+
37
+
38
+ demo = gr.Interface(
39
+ predict,
40
+ inputs="sketchpad",
41
+ outputs="label",
42
+ theme="freddyaboulton/dracula_revamped",
43
+ title="Sketch Recognition",
44
+ description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
45
+ article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
46
+ live=True,
47
+ )
48
+
49
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ gradio==3.36.1
3
+ torch==2.0.1