Tinsae commited on
Commit
fbae273
1 Parent(s): d2713de

created first space

Browse files
Files changed (4) hide show
  1. app.py +46 -0
  2. class_names.txt +100 -0
  3. pytorch_model.bin +3 -0
  4. requirement.txt +2 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import gradio as gr
4
+ from torch import nn
5
+
6
+ LABELS = Path("class_names.txt").read_text().splitlines()
7
+
8
+ model = nn.Sequential(
9
+ nn.Conv2d(1, 32, 3, padding="same"),
10
+ nn.ReLU(),
11
+ nn.MaxPool2d(2),
12
+ nn.Conv2d(32, 64, 3, padding="same"),
13
+ nn.ReLU(),
14
+ nn.MaxPool2d(2),
15
+ nn.Conv2d(64, 128, 3, padding="same"),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2),
18
+ nn.Flatten(),
19
+ nn.Linear(1152, 256),
20
+ nn.ReLU(),
21
+ nn.Linear(256, len(LABELS)),
22
+ )
23
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
24
+ model.load_state_dict(state_dict, strict=False)
25
+ model.eval()
26
+
27
+
28
+ def predict(im):
29
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
30
+ with torch.no_grad():
31
+ out = model(x)
32
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
33
+ values, indices = torch.topk(probabilities, 5)
34
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
35
+
36
+ interface = gr.Interface(
37
+ predict,
38
+ inputs="sketchpad",
39
+ outputs="label",
40
+ theme="huggingface",
41
+ title="Sketch Recognition",
42
+ description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
43
+ article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
44
+ live=True,
45
+ )
46
+ interface.launch(share=True)
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903
requirement.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ torch
2
+ gradio