Spaces:
Sleeping
Sleeping
Commit
·
efc35c0
0
Parent(s):
initialize repo
Browse files- .gitattributes +35 -0
- CNN.py +33 -0
- README.md +13 -0
- app.py +68 -0
- labels.json +1 -0
- model_weights.pth +3 -0
- requirements.txt +2 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
CNN.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class CNN(nn.Module):
|
| 5 |
+
def __init__(self, n_filters, hidden_dim, n_layers, n_classes):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.conv1 = nn.Conv2d(1, n_filters, 5)
|
| 8 |
+
self.relu1 = nn.ReLU()
|
| 9 |
+
self.maxpool1 = nn.MaxPool2d(2)
|
| 10 |
+
self.conv2 = nn.Conv2d(n_filters, 2*n_filters, 5)
|
| 11 |
+
self.relu2 = nn.ReLU()
|
| 12 |
+
self.maxpool2 = nn.MaxPool2d(2)
|
| 13 |
+
self.input_dim = 960
|
| 14 |
+
self.flatten = nn.Flatten()
|
| 15 |
+
self.inp_layer = nn.Linear(self.input_dim, hidden_dim)
|
| 16 |
+
self.classifier = nn.ModuleList([
|
| 17 |
+
nn.Sequential(
|
| 18 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 19 |
+
nn.BatchNorm1d(hidden_dim),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Dropout(p=0.3)
|
| 22 |
+
) for i in range(n_layers)
|
| 23 |
+
])
|
| 24 |
+
self.out_layer = nn.Linear(hidden_dim, n_classes)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self.maxpool1(self.relu1(self.conv1(x)))
|
| 28 |
+
x = self.maxpool2(self.relu2(self.conv2(x)))
|
| 29 |
+
x = self.inp_layer(torch.flatten(x, start_dim=1))
|
| 30 |
+
for layer in self.classifier:
|
| 31 |
+
x = layer(x)
|
| 32 |
+
x = self.out_layer(x)
|
| 33 |
+
return x
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sketch Recognition
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: This space uses a CNN to classify drawings.
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from CNN import CNN
|
| 9 |
+
|
| 10 |
+
# def greet(name):
|
| 11 |
+
# return "Hello " + name + "!!"
|
| 12 |
+
|
| 13 |
+
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 14 |
+
# demo.launch()
|
| 15 |
+
|
| 16 |
+
# Load the model
|
| 17 |
+
n_classes = 345
|
| 18 |
+
params = {
|
| 19 |
+
'n_filters': 30,
|
| 20 |
+
'hidden_dim': 100,
|
| 21 |
+
'n_layers': 2,
|
| 22 |
+
'n_classes': n_classes
|
| 23 |
+
}
|
| 24 |
+
print('testesesesf')
|
| 25 |
+
model = CNN(**params)
|
| 26 |
+
model.load_state_dict(torch.load('model_weights.ptn', map_location='cpu'))
|
| 27 |
+
model.eval()
|
| 28 |
+
|
| 29 |
+
# utils
|
| 30 |
+
labels_path = 'labels.json'
|
| 31 |
+
with open(labels_path, 'r') as f:
|
| 32 |
+
names = json.load(f)
|
| 33 |
+
|
| 34 |
+
transform = T.Compose([
|
| 35 |
+
T.ToTensor(), # (1, H, W), values in [0, 1], white=1 black=0
|
| 36 |
+
T.Lambda(lambda x: 1.0 - x), # invert -> white=0, black=1
|
| 37 |
+
T.Resize((28, 28), interpolation=T.InterpolationMode.BILINEAR),
|
| 38 |
+
# T.Normalize((0.5,), (0.5,)) # optional if your model expects [-1, 1]
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
def predict(input_image):
|
| 42 |
+
img = input_image['composite']
|
| 43 |
+
if img is None:
|
| 44 |
+
return {"No drawing detected": 1.0}
|
| 45 |
+
img = transform(img)
|
| 46 |
+
img = img.unsqueeze(0).to(torch.float32) # add batch dimension
|
| 47 |
+
# torch.save(img, )
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
out = model(img)
|
| 50 |
+
# idx = torch.argmax(out).item()
|
| 51 |
+
probs = F.softmax(out, dim=1).squeeze(0)
|
| 52 |
+
res = {names[i]:proba.item() for i, proba in enumerate(probs)}
|
| 53 |
+
return res
|
| 54 |
+
|
| 55 |
+
demo = gr.Interface(
|
| 56 |
+
fn=predict,
|
| 57 |
+
inputs=gr.Sketchpad(
|
| 58 |
+
label="Draw a sketch",
|
| 59 |
+
image_mode='L',
|
| 60 |
+
brush=gr.Brush(default_size=15, default_color='black', colors=['black'], color_mode='fixed')
|
| 61 |
+
),
|
| 62 |
+
outputs=gr.Label(num_top_classes=5),
|
| 63 |
+
title="Sketch Recognition model",
|
| 64 |
+
clear_btn=gr.ClearButton(),
|
| 65 |
+
live=True
|
| 66 |
+
)
|
| 67 |
+
print('test')
|
| 68 |
+
demo.launch()
|
labels.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
["aircraft carrier", "airplane", "alarm clock", "ambulance", "angel", "animal migration", "ant", "anvil", "apple", "arm", "asparagus", "axe", "backpack", "banana", "bandage", "barn", "baseball bat", "baseball", "basket", "basketball", "bat", "bathtub", "beach", "bear", "beard", "bed", "bee", "belt", "bench", "bicycle", "binoculars", "bird", "birthday cake", "blackberry", "blueberry", "book", "boomerang", "bottlecap", "bowtie", "bracelet", "brain", "bread", "bridge", "broccoli", "broom", "bucket", "bulldozer", "bus", "bush", "butterfly", "cactus", "cake", "calculator", "calendar", "camel", "camera", "camouflage", "campfire", "candle", "cannon", "canoe", "car", "carrot", "castle", "cat", "ceiling fan", "cell phone", "cello", "chair", "chandelier", "church", "circle", "clarinet", "clock", "cloud", "coffee cup", "compass", "computer", "cookie", "cooler", "couch", "cow", "crab", "crayon", "crocodile", "crown", "cruise ship", "cup", "diamond", "dishwasher", "diving board", "dog", "dolphin", "donut", "door", "dragon", "dresser", "drill", "drums", "duck", "dumbbell", "ear", "elbow", "elephant", "envelope", "eraser", "eye", "eyeglasses", "face", "fan", "feather", "fence", "finger", "fire hydrant", "fireplace", "firetruck", "fish", "flamingo", "flashlight", "flip flops", "floor lamp", "flower", "flying saucer", "foot", "fork", "frog", "frying pan", "garden hose", "garden", "giraffe", "goatee", "golf club", "grapes", "grass", "guitar", "hamburger", "hammer", "hand", "harp", "hat", "headphones", "hedgehog", "helicopter", "helmet", "hexagon", "hockey puck", "hockey stick", "horse", "hospital", "hot air balloon", "hot dog", "hot tub", "hourglass", "house plant", "house", "hurricane", "ice cream", "jacket", "jail", "kangaroo", "key", "keyboard", "knee", "knife", "ladder", "lantern", "laptop", "leaf", "leg", "light bulb", "lighter", "lighthouse", "lightning", "line", "lion", "lipstick", "lobster", "lollipop", "mailbox", "map", "marker", "matches", "megaphone", "mermaid", "microphone", "microwave", "monkey", "moon", "mosquito", "motorbike", "mountain", "mouse", "moustache", "mouth", "mug", "mushroom", "nail", "necklace", "nose", "ocean", "octagon", "octopus", "onion", "oven", "owl", "paint can", "paintbrush", "palm tree", "panda", "pants", "paper clip", "parachute", "parrot", "passport", "peanut", "pear", "peas", "pencil", "penguin", "piano", "pickup truck", "picture frame", "pig", "pillow", "pineapple", "pizza", "pliers", "police car", "pond", "pool", "popsicle", "postcard", "potato", "power outlet", "purse", "rabbit", "raccoon", "radio", "rain", "rainbow", "rake", "remote control", "rhinoceros", "rifle", "river", "roller coaster", "rollerskates", "sailboat", "sandwich", "saw", "saxophone", "school bus", "scissors", "scorpion", "screwdriver", "sea turtle", "see saw", "shark", "sheep", "shoe", "shorts", "shovel", "sink", "skateboard", "skull", "skyscraper", "sleeping bag", "smiley face", "snail", "snake", "snorkel", "snowflake", "snowman", "soccer ball", "sock", "speedboat", "spider", "spoon", "spreadsheet", "square", "squiggle", "squirrel", "stairs", "star", "steak", "stereo", "stethoscope", "stitches", "stop sign", "stove", "strawberry", "streetlight", "string bean", "submarine", "suitcase", "sun", "swan", "sweater", "swing set", "sword", "syringe", "t-shirt", "table", "teapot", "teddy-bear", "telephone", "television", "tennis racquet", "tent", "The Eiffel Tower", "The Great Wall of China", "The Mona Lisa", "tiger", "toaster", "toe", "toilet", "tooth", "toothbrush", "toothpaste", "tornado", "tractor", "traffic light", "train", "tree", "triangle", "trombone", "truck", "trumpet", "umbrella", "underwear", "van", "vase", "violin", "washing machine", "watermelon", "waterslide", "whale", "wheel", "windmill", "wine bottle", "wine glass", "wristwatch", "yoga", "zebra", "zigzag"]
|
model_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9451738bf2b8d46c28ce059e9020e65acccb1130123eceaf00e6083f3fb94c4
|
| 3 |
+
size 798193
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
json
|