Fabrice-TIERCELIN commited on
Commit
7ff865a
1 Parent(s): 67efc35

Delete clipseg/Visual_Feature_Engineering.ipynb

Browse files
clipseg/Visual_Feature_Engineering.ipynb DELETED
@@ -1,366 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Systematic"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": null,
13
- "metadata": {},
14
- "outputs": [],
15
- "source": [
16
- "%load_ext autoreload\n",
17
- "%autoreload 2\n",
18
- "\n",
19
- "import clip\n",
20
- "from evaluation_utils import norm, denorm\n",
21
- "from general_utils import *\n",
22
- "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23
- "\n",
24
- "clip_device = 'cuda'\n",
25
- "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26
- "clip_model.eval();\n",
27
- "\n",
28
- "from models.clipseg import CLIPDensePredTMasked\n",
29
- "\n",
30
- "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31
- "clip_mask_model.eval();"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41
- " text_class_labels=True, image_size=352, min_area=0.1,\n",
42
- " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": [
51
- "plot_data(lvis)"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": null,
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "from collections import defaultdict\n",
61
- "import json\n",
62
- "\n",
63
- "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64
- "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65
- "\n",
66
- "objects_per_image = defaultdict(lambda : set())\n",
67
- "for ann in lvis_raw['annotations']:\n",
68
- " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69
- " \n",
70
- "for ann in lvis_val_raw['annotations']:\n",
71
- " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72
- " \n",
73
- "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74
- "\n",
75
- "del lvis_raw, lvis_val_raw"
76
- ]
77
- },
78
- {
79
- "cell_type": "code",
80
- "execution_count": null,
81
- "metadata": {},
82
- "outputs": [],
83
- "source": [
84
- "#bs = 32\n",
85
- "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86
- ]
87
- },
88
- {
89
- "cell_type": "code",
90
- "execution_count": null,
91
- "metadata": {},
92
- "outputs": [],
93
- "source": [
94
- "from general_utils import get_batch\n",
95
- "from functools import partial\n",
96
- "from evaluation_utils import img_preprocess\n",
97
- "import torch\n",
98
- "\n",
99
- "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100
- "\n",
101
- " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102
- "\n",
103
- " all_prompts = []\n",
104
- " \n",
105
- " with torch.no_grad():\n",
106
- " valid_sims = []\n",
107
- " torch.manual_seed(571)\n",
108
- " \n",
109
- " if type(batches_or_dataset) == list:\n",
110
- " loader = batches_or_dataset # already loaded\n",
111
- " max_iter = float('inf')\n",
112
- " else:\n",
113
- " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114
- " max_iter = 50\n",
115
- " \n",
116
- " global batch\n",
117
- " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118
- " \n",
119
- " if i_batch >= max_iter: break\n",
120
- " \n",
121
- " processed_batch = process(batch)\n",
122
- " if type(processed_batch) == dict:\n",
123
- " \n",
124
- " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125
- " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126
- " else:\n",
127
- " processed_batch = process(batch).to(clip_device)\n",
128
- " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129
- " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130
- " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131
- " \n",
132
- " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133
- " bs = len(batch[0])\n",
134
- " for j in range(bs):\n",
135
- " \n",
136
- " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137
- " support_image = basename(lvis.samples[c][sid])\n",
138
- " \n",
139
- " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140
- " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141
- " \n",
142
- " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143
- " if o != batch_y[2][j]]\n",
144
- " \n",
145
- " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146
- " all_prompts += [prompts]\n",
147
- " \n",
148
- " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149
- " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150
- "\n",
151
- " global logits\n",
152
- " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153
- "\n",
154
- " global sim\n",
155
- " sim = torch.softmax(logits, dim=-1)\n",
156
- " \n",
157
- " valid_sims += [sim]\n",
158
- " \n",
159
- " #valid_sims = torch.stack(valid_sims)\n",
160
- " return valid_sims, all_prompts\n",
161
- " \n",
162
- "\n",
163
- "def new_img_preprocess(x):\n",
164
- " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165
- " \n",
166
- "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167
- "get_similarities(lvis, lambda x: x[1]);"
168
- ]
169
- },
170
- {
171
- "cell_type": "code",
172
- "execution_count": null,
173
- "metadata": {},
174
- "outputs": [],
175
- "source": [
176
- "preprocessing_functions = [\n",
177
- "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178
- "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179
- "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180
- "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181
- "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182
- " \n",
183
- "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184
- "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185
- "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186
- "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187
- "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188
- " \n",
189
- "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190
- "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191
- " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192
- " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193
- "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194
- "]\n",
195
- "\n",
196
- "preprocessing_functions = preprocessing_functions\n",
197
- "\n",
198
- "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199
- "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200
- ]
201
- },
202
- {
203
- "cell_type": "code",
204
- "execution_count": null,
205
- "metadata": {},
206
- "outputs": [],
207
- "source": [
208
- "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209
- ]
210
- },
211
- {
212
- "cell_type": "code",
213
- "execution_count": null,
214
- "metadata": {},
215
- "outputs": [],
216
- "source": [
217
- "for j in range(1):\n",
218
- " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219
- ]
220
- },
221
- {
222
- "cell_type": "code",
223
- "execution_count": null,
224
- "metadata": {},
225
- "outputs": [],
226
- "source": [
227
- "from pandas import DataFrame\n",
228
- "tab = dict()\n",
229
- "for j, (name, _) in enumerate(preprocessing_functions):\n",
230
- " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231
- " \n",
232
- " \n",
233
- "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234
- ]
235
- },
236
- {
237
- "cell_type": "markdown",
238
- "metadata": {},
239
- "source": [
240
- "# Visual"
241
- ]
242
- },
243
- {
244
- "cell_type": "code",
245
- "execution_count": null,
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "from evaluation_utils import denorm, norm"
250
- ]
251
- },
252
- {
253
- "cell_type": "code",
254
- "execution_count": null,
255
- "metadata": {},
256
- "outputs": [],
257
- "source": [
258
- "def load_sample(filename, filename2):\n",
259
- " from os.path import join\n",
260
- " bp = expanduser('~/cloud/resources/sample_images')\n",
261
- " tf = transforms.Compose([\n",
262
- " transforms.ToTensor(),\n",
263
- " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264
- " transforms.Resize(224),\n",
265
- " transforms.CenterCrop(224)\n",
266
- " ])\n",
267
- " tf2 = transforms.Compose([\n",
268
- " transforms.ToTensor(),\n",
269
- " transforms.Resize(224),\n",
270
- " transforms.CenterCrop(224)\n",
271
- " ])\n",
272
- " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273
- " inp1[1] = inp1[1].unsqueeze(0)\n",
274
- " inp1[2] = inp1[2][:1] \n",
275
- " return inp1\n",
276
- "\n",
277
- "def all_preprocessing(inp1):\n",
278
- " return [\n",
279
- " img_preprocess(inp1),\n",
280
- " img_preprocess(inp1, colorize=True),\n",
281
- " img_preprocess(inp1, outline=True), \n",
282
- " img_preprocess(inp1, blur=3),\n",
283
- " img_preprocess(inp1, bg_fac=0.1),\n",
284
- " #img_preprocess(inp1, bg_fac=0.5),\n",
285
- " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286
- " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287
- " ]\n",
288
- "\n"
289
- ]
290
- },
291
- {
292
- "cell_type": "code",
293
- "execution_count": null,
294
- "metadata": {},
295
- "outputs": [],
296
- "source": [
297
- "from torchvision import transforms\n",
298
- "from PIL import Image\n",
299
- "from matplotlib import pyplot as plt\n",
300
- "from evaluation_utils import img_preprocess\n",
301
- "import clip\n",
302
- "\n",
303
- "images_queries = [\n",
304
- " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305
- " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306
- "]\n",
307
- "\n",
308
- "\n",
309
- "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310
- "\n",
311
- "for j, (images, objects) in enumerate(images_queries):\n",
312
- " \n",
313
- " joint_image = all_preprocessing(images)\n",
314
- " \n",
315
- " joint_image = torch.stack(joint_image)[:,0]\n",
316
- " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317
- " image_features = clip_model.encode_image(joint_image)\n",
318
- " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319
- " \n",
320
- " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321
- " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322
- " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323
- " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324
- " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325
- "\n",
326
- " for i, img in enumerate(joint_image):\n",
327
- " ax[2*j, i].axis('off')\n",
328
- " \n",
329
- " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330
- " ax[2*j+ 1, i].grid(True)\n",
331
- " \n",
332
- " ax[2*j + 1, i].set_ylim(0,1)\n",
333
- " ax[2*j + 1, i].set_yticklabels([])\n",
334
- " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335
- "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336
- " for k in range(len(sim[i])):\n",
337
- " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338
- " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339
- "\n",
340
- "plt.tight_layout()\n",
341
- "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342
- ]
343
- }
344
- ],
345
- "metadata": {
346
- "kernelspec": {
347
- "display_name": "env2",
348
- "language": "python",
349
- "name": "env2"
350
- },
351
- "language_info": {
352
- "codemirror_mode": {
353
- "name": "ipython",
354
- "version": 3
355
- },
356
- "file_extension": ".py",
357
- "mimetype": "text/x-python",
358
- "name": "python",
359
- "nbconvert_exporter": "python",
360
- "pygments_lexer": "ipython3",
361
- "version": "3.8.8"
362
- }
363
- },
364
- "nbformat": 4,
365
- "nbformat_minor": 4
366
- }