versae commited on
Commit
0f14459
1 Parent(s): 193f456

Maken CLIP Sketch

Browse files
Files changed (6) hide show
  1. app.py +114 -0
  2. class_names.txt +100 -0
  3. clip.csv +3 -0
  4. clip.npy +3 -0
  5. pytorch_model.bin +3 -0
  6. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from pathlib import Path
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
6
+ import torch
7
+ from torch import nn
8
+ import gradio as gr
9
+ import requests
10
+ from PIL import Image, ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+
14
+ LABELS = Path('class_names.txt').read_text().splitlines()
15
+ class_model = nn.Sequential(
16
+ nn.Conv2d(1, 32, 3, padding='same'),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(2),
19
+ nn.Conv2d(32, 64, 3, padding='same'),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(2),
22
+ nn.Conv2d(64, 128, 3, padding='same'),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2),
25
+ nn.Flatten(),
26
+ nn.Linear(1152, 256),
27
+ nn.ReLU(),
28
+ nn.Linear(256, len(LABELS)),
29
+ )
30
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
31
+ class_model.load_state_dict(state_dict, strict=False)
32
+ class_model.eval()
33
+
34
+
35
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
36
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
37
+ df = pd.read_csv('clip.csv')
38
+ embeddings_npy = np.load('clip.npy')
39
+ embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
40
+
41
+
42
+ def compute_text_embeddings(list_of_strings):
43
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
44
+ return model.get_text_features(**inputs)
45
+
46
+
47
+ def compute_image_embeddings(list_of_images):
48
+ inputs = processor(images=list_of_images, return_tensors="pt", padding=True)
49
+ return model.get_image_features(**inputs)
50
+
51
+
52
+ def load_image(image, same_height=False):
53
+ # im = Image.open(path)
54
+ im = Image.fromarray(np.uint8(image))
55
+ if im.mode != 'RGB':
56
+ im = im.convert('RGB')
57
+ if same_height:
58
+ ratio = 224/im.size[1]
59
+ return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))
60
+ else:
61
+ ratio = 224/min(im.size)
62
+ return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))
63
+
64
+
65
+ def download_img(identifier, url):
66
+ local_path = f"{identifier}.jpg"
67
+ if not os.path.isfile(local_path):
68
+ img_data = requests.get(url).content
69
+ with open(local_path, 'wb') as handler:
70
+ handler.write(img_data)
71
+ return local_path
72
+
73
+
74
+ def predict(image=None, text=None, sketch=None):
75
+ if image is not None:
76
+ input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy()
77
+ topk = {"local": 100}
78
+ else:
79
+ if text:
80
+ query = text
81
+ topk = {text: 100}
82
+ else:
83
+ x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
84
+ with torch.no_grad():
85
+ out = class_model(x)
86
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
87
+ values, indices = torch.topk(probabilities, 5)
88
+ query = LABELS[indices[0]]
89
+ topk = {LABELS[i]: v.item() for i, v in zip(indices, values)}
90
+ input_embeddings = compute_text_embeddings([query]).detach().numpy()
91
+
92
+ n_results = 3
93
+ results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1]
94
+ outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results]
95
+ outputs.insert(0, topk)
96
+ print(outputs)
97
+ return outputs
98
+
99
+
100
+ def predict_sketch(sketch):
101
+ return predict(None, None, sketch)
102
+
103
+
104
+ title = "Type or draw to search in the Nasjonalbiblioteket"
105
+ description = "Find images in the Nasjonalbiblioteket image collections based on what you draw or type"
106
+ interface = gr.Interface(
107
+ fn=[predict_sketch],
108
+ inputs=["sketch"],
109
+ outputs=[gr.outputs.Label(num_top_classes=3)] + 3 * [gr.outputs.Image(type="file")],
110
+ title=title,
111
+ description=description,
112
+ live=True
113
+ )
114
+ interface.launch(debug=True)
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
clip.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ecd243d042c4b0c0f93de5df51a444b9e3076c17dee21313403e066ec15750
3
+ size 10444275
clip.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f1cee3eb68b43e8af5c5aab5955c473d76104e2568815d9c9285f4ba079d6ac
3
+ size 112642176
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ numpy
4
+ pandas
5
+ ftfy
6
+ pillow