thomasinovic commited on
Commit
efc35c0
·
0 Parent(s):

initialize repo

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. CNN.py +33 -0
  3. README.md +13 -0
  4. app.py +68 -0
  5. labels.json +1 -0
  6. model_weights.pth +3 -0
  7. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
CNN.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CNN(nn.Module):
5
+ def __init__(self, n_filters, hidden_dim, n_layers, n_classes):
6
+ super().__init__()
7
+ self.conv1 = nn.Conv2d(1, n_filters, 5)
8
+ self.relu1 = nn.ReLU()
9
+ self.maxpool1 = nn.MaxPool2d(2)
10
+ self.conv2 = nn.Conv2d(n_filters, 2*n_filters, 5)
11
+ self.relu2 = nn.ReLU()
12
+ self.maxpool2 = nn.MaxPool2d(2)
13
+ self.input_dim = 960
14
+ self.flatten = nn.Flatten()
15
+ self.inp_layer = nn.Linear(self.input_dim, hidden_dim)
16
+ self.classifier = nn.ModuleList([
17
+ nn.Sequential(
18
+ nn.Linear(hidden_dim, hidden_dim),
19
+ nn.BatchNorm1d(hidden_dim),
20
+ nn.ReLU(),
21
+ nn.Dropout(p=0.3)
22
+ ) for i in range(n_layers)
23
+ ])
24
+ self.out_layer = nn.Linear(hidden_dim, n_classes)
25
+
26
+ def forward(self, x):
27
+ x = self.maxpool1(self.relu1(self.conv1(x)))
28
+ x = self.maxpool2(self.relu2(self.conv2(x)))
29
+ x = self.inp_layer(torch.flatten(x, start_dim=1))
30
+ for layer in self.classifier:
31
+ x = layer(x)
32
+ x = self.out_layer(x)
33
+ return x
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sketch Recognition
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: This space uses a CNN to classify drawings.
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as T
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from CNN import CNN
9
+
10
+ # def greet(name):
11
+ # return "Hello " + name + "!!"
12
+
13
+ # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
14
+ # demo.launch()
15
+
16
+ # Load the model
17
+ n_classes = 345
18
+ params = {
19
+ 'n_filters': 30,
20
+ 'hidden_dim': 100,
21
+ 'n_layers': 2,
22
+ 'n_classes': n_classes
23
+ }
24
+ print('testesesesf')
25
+ model = CNN(**params)
26
+ model.load_state_dict(torch.load('model_weights.ptn', map_location='cpu'))
27
+ model.eval()
28
+
29
+ # utils
30
+ labels_path = 'labels.json'
31
+ with open(labels_path, 'r') as f:
32
+ names = json.load(f)
33
+
34
+ transform = T.Compose([
35
+ T.ToTensor(), # (1, H, W), values in [0, 1], white=1 black=0
36
+ T.Lambda(lambda x: 1.0 - x), # invert -> white=0, black=1
37
+ T.Resize((28, 28), interpolation=T.InterpolationMode.BILINEAR),
38
+ # T.Normalize((0.5,), (0.5,)) # optional if your model expects [-1, 1]
39
+ ])
40
+
41
+ def predict(input_image):
42
+ img = input_image['composite']
43
+ if img is None:
44
+ return {"No drawing detected": 1.0}
45
+ img = transform(img)
46
+ img = img.unsqueeze(0).to(torch.float32) # add batch dimension
47
+ # torch.save(img, )
48
+ with torch.no_grad():
49
+ out = model(img)
50
+ # idx = torch.argmax(out).item()
51
+ probs = F.softmax(out, dim=1).squeeze(0)
52
+ res = {names[i]:proba.item() for i, proba in enumerate(probs)}
53
+ return res
54
+
55
+ demo = gr.Interface(
56
+ fn=predict,
57
+ inputs=gr.Sketchpad(
58
+ label="Draw a sketch",
59
+ image_mode='L',
60
+ brush=gr.Brush(default_size=15, default_color='black', colors=['black'], color_mode='fixed')
61
+ ),
62
+ outputs=gr.Label(num_top_classes=5),
63
+ title="Sketch Recognition model",
64
+ clear_btn=gr.ClearButton(),
65
+ live=True
66
+ )
67
+ print('test')
68
+ demo.launch()
labels.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["aircraft carrier", "airplane", "alarm clock", "ambulance", "angel", "animal migration", "ant", "anvil", "apple", "arm", "asparagus", "axe", "backpack", "banana", "bandage", "barn", "baseball bat", "baseball", "basket", "basketball", "bat", "bathtub", "beach", "bear", "beard", "bed", "bee", "belt", "bench", "bicycle", "binoculars", "bird", "birthday cake", "blackberry", "blueberry", "book", "boomerang", "bottlecap", "bowtie", "bracelet", "brain", "bread", "bridge", "broccoli", "broom", "bucket", "bulldozer", "bus", "bush", "butterfly", "cactus", "cake", "calculator", "calendar", "camel", "camera", "camouflage", "campfire", "candle", "cannon", "canoe", "car", "carrot", "castle", "cat", "ceiling fan", "cell phone", "cello", "chair", "chandelier", "church", "circle", "clarinet", "clock", "cloud", "coffee cup", "compass", "computer", "cookie", "cooler", "couch", "cow", "crab", "crayon", "crocodile", "crown", "cruise ship", "cup", "diamond", "dishwasher", "diving board", "dog", "dolphin", "donut", "door", "dragon", "dresser", "drill", "drums", "duck", "dumbbell", "ear", "elbow", "elephant", "envelope", "eraser", "eye", "eyeglasses", "face", "fan", "feather", "fence", "finger", "fire hydrant", "fireplace", "firetruck", "fish", "flamingo", "flashlight", "flip flops", "floor lamp", "flower", "flying saucer", "foot", "fork", "frog", "frying pan", "garden hose", "garden", "giraffe", "goatee", "golf club", "grapes", "grass", "guitar", "hamburger", "hammer", "hand", "harp", "hat", "headphones", "hedgehog", "helicopter", "helmet", "hexagon", "hockey puck", "hockey stick", "horse", "hospital", "hot air balloon", "hot dog", "hot tub", "hourglass", "house plant", "house", "hurricane", "ice cream", "jacket", "jail", "kangaroo", "key", "keyboard", "knee", "knife", "ladder", "lantern", "laptop", "leaf", "leg", "light bulb", "lighter", "lighthouse", "lightning", "line", "lion", "lipstick", "lobster", "lollipop", "mailbox", "map", "marker", "matches", "megaphone", "mermaid", "microphone", "microwave", "monkey", "moon", "mosquito", "motorbike", "mountain", "mouse", "moustache", "mouth", "mug", "mushroom", "nail", "necklace", "nose", "ocean", "octagon", "octopus", "onion", "oven", "owl", "paint can", "paintbrush", "palm tree", "panda", "pants", "paper clip", "parachute", "parrot", "passport", "peanut", "pear", "peas", "pencil", "penguin", "piano", "pickup truck", "picture frame", "pig", "pillow", "pineapple", "pizza", "pliers", "police car", "pond", "pool", "popsicle", "postcard", "potato", "power outlet", "purse", "rabbit", "raccoon", "radio", "rain", "rainbow", "rake", "remote control", "rhinoceros", "rifle", "river", "roller coaster", "rollerskates", "sailboat", "sandwich", "saw", "saxophone", "school bus", "scissors", "scorpion", "screwdriver", "sea turtle", "see saw", "shark", "sheep", "shoe", "shorts", "shovel", "sink", "skateboard", "skull", "skyscraper", "sleeping bag", "smiley face", "snail", "snake", "snorkel", "snowflake", "snowman", "soccer ball", "sock", "speedboat", "spider", "spoon", "spreadsheet", "square", "squiggle", "squirrel", "stairs", "star", "steak", "stereo", "stethoscope", "stitches", "stop sign", "stove", "strawberry", "streetlight", "string bean", "submarine", "suitcase", "sun", "swan", "sweater", "swing set", "sword", "syringe", "t-shirt", "table", "teapot", "teddy-bear", "telephone", "television", "tennis racquet", "tent", "The Eiffel Tower", "The Great Wall of China", "The Mona Lisa", "tiger", "toaster", "toe", "toilet", "tooth", "toothbrush", "toothpaste", "tornado", "tractor", "traffic light", "train", "tree", "triangle", "trombone", "truck", "trumpet", "umbrella", "underwear", "van", "vase", "violin", "washing machine", "watermelon", "waterslide", "whale", "wheel", "windmill", "wine bottle", "wine glass", "wristwatch", "yoga", "zebra", "zigzag"]
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9451738bf2b8d46c28ce059e9020e65acccb1130123eceaf00e6083f3fb94c4
3
+ size 798193
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ json