Fabrice-TIERCELIN commited on
Commit
48c8053
1 Parent(s): a13ea9a

Delete clipseg/models/clipseg.py

Browse files
Files changed (1) hide show
  1. clipseg/models/clipseg.py +0 -552
clipseg/models/clipseg.py DELETED
@@ -1,552 +0,0 @@
1
- import math
2
- from os.path import basename, dirname, join, isfile
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as nnf
6
- from torch.nn.modules.activation import ReLU
7
-
8
-
9
- def precompute_clip_vectors():
10
-
11
- from trails.initialization import init_dataset
12
- lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
13
- reduce_factor=None, add_bar=False, negative_prob=0.5)
14
-
15
- all_names = list(lvis.category_names.values())
16
-
17
- import clip
18
- from models.clip_prompts import imagenet_templates
19
- clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
20
- prompt_vectors = {}
21
- for name in all_names[:100]:
22
- with torch.no_grad():
23
- conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
24
- text_tokens = clip.tokenize(conditionals).cuda()
25
- cond = clip_model.encode_text(text_tokens).cpu()
26
-
27
- for cond, vec in zip(conditionals, cond):
28
- prompt_vectors[cond] = vec.cpu()
29
-
30
- import pickle
31
-
32
- pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
33
-
34
-
35
- def get_prompt_list(prompt):
36
- if prompt == 'plain':
37
- return ['{}']
38
- elif prompt == 'fixed':
39
- return ['a photo of a {}.']
40
- elif prompt == 'shuffle':
41
- return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
42
- elif prompt == 'shuffle+':
43
- return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
44
- 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
45
- 'a bad photo of a {}.', 'a photo of the {}.']
46
- elif prompt == 'shuffle_clip':
47
- from models.clip_prompts import imagenet_templates
48
- return imagenet_templates
49
- else:
50
- raise ValueError('Invalid value for prompt')
51
-
52
-
53
- def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
54
- """
55
- Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
56
- The mlp and layer norm come from CLIP.
57
- x: input.
58
- b: multihead attention module.
59
- """
60
-
61
- x_ = b.ln_1(x)
62
- q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
63
- tgt_len, bsz, embed_dim = q.size()
64
-
65
- head_dim = embed_dim // b.attn.num_heads
66
- scaling = float(head_dim) ** -0.5
67
-
68
- q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
69
- k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
70
- v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
71
-
72
- q = q * scaling
73
-
74
- attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
75
- if attn_mask is not None:
76
-
77
-
78
- attn_mask_type, attn_mask = attn_mask
79
- n_heads = attn_output_weights.size(0) // attn_mask.size(0)
80
- attn_mask = attn_mask.repeat(n_heads, 1)
81
-
82
- if attn_mask_type == 'cls_token':
83
- # the mask only affects similarities compared to the readout-token.
84
- attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
85
- # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
86
-
87
- if attn_mask_type == 'all':
88
- # print(attn_output_weights.shape, attn_mask[:, None].shape)
89
- attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
90
-
91
-
92
- attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
93
-
94
- attn_output = torch.bmm(attn_output_weights, v)
95
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
96
- attn_output = b.attn.out_proj(attn_output)
97
-
98
- x = x + attn_output
99
- x = x + b.mlp(b.ln_2(x))
100
-
101
- if with_aff:
102
- return x, attn_output_weights
103
- else:
104
- return x
105
-
106
-
107
- class CLIPDenseBase(nn.Module):
108
-
109
- def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
110
- super().__init__()
111
-
112
- import clip
113
-
114
- # prec = torch.FloatTensor
115
- self.clip_model, _ = clip.load(version, device='cpu', jit=False)
116
- self.model = self.clip_model.visual
117
-
118
- # if not None, scale conv weights such that we obtain n_tokens.
119
- self.n_tokens = n_tokens
120
-
121
- for p in self.clip_model.parameters():
122
- p.requires_grad_(False)
123
-
124
- # conditional
125
- if reduce_cond is not None:
126
- self.reduce_cond = nn.Linear(512, reduce_cond)
127
- for p in self.reduce_cond.parameters():
128
- p.requires_grad_(False)
129
- else:
130
- self.reduce_cond = None
131
-
132
- self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
133
- self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
134
-
135
- self.reduce = nn.Linear(768, reduce_dim)
136
-
137
- self.prompt_list = get_prompt_list(prompt)
138
-
139
- # precomputed prompts
140
- import pickle
141
- if isfile('precomputed_prompt_vectors.pickle'):
142
- precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
143
- self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
144
- else:
145
- self.precomputed_prompts = dict()
146
-
147
- def rescaled_pos_emb(self, new_size):
148
- assert len(new_size) == 2
149
-
150
- a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
151
- b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
152
- return torch.cat([self.model.positional_embedding[:1], b])
153
-
154
- def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
155
-
156
-
157
- with torch.no_grad():
158
-
159
- inp_size = x_inp.shape[2:]
160
-
161
- if self.n_tokens is not None:
162
- stride2 = x_inp.shape[2] // self.n_tokens
163
- conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
164
- x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
165
- else:
166
- x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
167
-
168
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
169
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
170
-
171
- x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
172
-
173
- standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
174
-
175
- if x.shape[1] != standard_n_tokens:
176
- new_shape = int(math.sqrt(x.shape[1]-1))
177
- x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
178
- else:
179
- x = x + self.model.positional_embedding.to(x.dtype)
180
-
181
- x = self.model.ln_pre(x)
182
-
183
- x = x.permute(1, 0, 2) # NLD -> LND
184
-
185
- activations, affinities = [], []
186
- for i, res_block in enumerate(self.model.transformer.resblocks):
187
-
188
- if mask is not None:
189
- mask_layer, mask_type, mask_tensor = mask
190
- if mask_layer == i or mask_layer == 'all':
191
- # import ipdb; ipdb.set_trace()
192
- size = int(math.sqrt(x.shape[0] - 1))
193
-
194
- attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
195
-
196
- else:
197
- attn_mask = None
198
- else:
199
- attn_mask = None
200
-
201
- x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
202
-
203
- if i in extract_layers:
204
- affinities += [aff_per_head]
205
-
206
- #if self.n_tokens is not None:
207
- # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
208
- #else:
209
- activations += [x]
210
-
211
- if len(extract_layers) > 0 and i == max(extract_layers) and skip:
212
- print('early skip')
213
- break
214
-
215
- x = x.permute(1, 0, 2) # LND -> NLD
216
- x = self.model.ln_post(x[:, 0, :])
217
-
218
- if self.model.proj is not None:
219
- x = x @ self.model.proj
220
-
221
- return x, activations, affinities
222
-
223
- def sample_prompts(self, words, prompt_list=None):
224
-
225
- prompt_list = prompt_list if prompt_list is not None else self.prompt_list
226
-
227
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
228
- prompts = [prompt_list[i] for i in prompt_indices]
229
- return [promt.format(w) for promt, w in zip(prompts, words)]
230
-
231
- def get_cond_vec(self, conditional, batch_size):
232
- # compute conditional from a single string
233
- if conditional is not None and type(conditional) == str:
234
- cond = self.compute_conditional(conditional)
235
- cond = cond.repeat(batch_size, 1)
236
-
237
- # compute conditional from string list/tuple
238
- elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
239
- assert len(conditional) == batch_size
240
- cond = self.compute_conditional(conditional)
241
-
242
- # use conditional directly
243
- elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
244
- cond = conditional
245
-
246
- # compute conditional from image
247
- elif conditional is not None and type(conditional) == torch.Tensor:
248
- with torch.no_grad():
249
- cond, _, _ = self.visual_forward(conditional)
250
- else:
251
- raise ValueError('invalid conditional')
252
- return cond
253
-
254
- def compute_conditional(self, conditional):
255
- import clip
256
-
257
- dev = next(self.parameters()).device
258
-
259
- if type(conditional) in {list, tuple}:
260
- text_tokens = clip.tokenize(conditional).to(dev)
261
- cond = self.clip_model.encode_text(text_tokens)
262
- else:
263
- if conditional in self.precomputed_prompts:
264
- cond = self.precomputed_prompts[conditional].float().to(dev)
265
- else:
266
- text_tokens = clip.tokenize([conditional]).to(dev)
267
- cond = self.clip_model.encode_text(text_tokens)[0]
268
-
269
- if self.shift_vector is not None:
270
- return cond + self.shift_vector
271
- else:
272
- return cond
273
-
274
-
275
- def clip_load_untrained(version):
276
- assert version == 'ViT-B/16'
277
- from clip.model import CLIP
278
- from clip.clip import _MODELS, _download
279
- model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
280
- state_dict = model.state_dict()
281
-
282
- vision_width = state_dict["visual.conv1.weight"].shape[0]
283
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
284
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
285
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
286
- image_resolution = vision_patch_size * grid_size
287
- embed_dim = state_dict["text_projection"].shape[1]
288
- context_length = state_dict["positional_embedding"].shape[0]
289
- vocab_size = state_dict["token_embedding.weight"].shape[0]
290
- transformer_width = state_dict["ln_final.weight"].shape[0]
291
- transformer_heads = transformer_width // 64
292
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
293
-
294
- return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
295
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
296
-
297
-
298
- class CLIPDensePredT(CLIPDenseBase):
299
-
300
- def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
301
- extra_blocks=0, reduce_cond=None, fix_shift=False,
302
- learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
303
- add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
304
-
305
- super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
306
- # device = 'cpu'
307
-
308
- self.extract_layers = extract_layers
309
- self.cond_layer = cond_layer
310
- self.limit_to_clip_only = limit_to_clip_only
311
- self.process_cond = None
312
- self.rev_activations = rev_activations
313
-
314
- depth = len(extract_layers)
315
-
316
- if add_calibration:
317
- self.calibration_conds = 1
318
-
319
- self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
320
-
321
- self.add_activation1 = True
322
-
323
- self.version = version
324
-
325
- self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
326
-
327
- if fix_shift:
328
- # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
329
- self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
330
- # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
331
- else:
332
- self.shift_vector = None
333
-
334
- if trans_conv is None:
335
- trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
336
- else:
337
- # explicitly define transposed conv kernel size
338
- trans_conv_ks = (trans_conv, trans_conv)
339
-
340
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
341
-
342
- assert len(self.extract_layers) == depth
343
-
344
- self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
345
- self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
346
- self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
347
-
348
- # refinement and trans conv
349
-
350
- if learn_trans_conv_only:
351
- for p in self.parameters():
352
- p.requires_grad_(False)
353
-
354
- for p in self.trans_conv.parameters():
355
- p.requires_grad_(True)
356
-
357
- self.prompt_list = get_prompt_list(prompt)
358
-
359
-
360
- def forward(self, inp_image, conditional=None, return_features=False, mask=None):
361
-
362
- assert type(return_features) == bool
363
-
364
- inp_image = inp_image.to(self.model.positional_embedding.device)
365
-
366
- if mask is not None:
367
- raise ValueError('mask not supported')
368
-
369
- # x_inp = normalize(inp_image)
370
- x_inp = inp_image
371
-
372
- bs, dev = inp_image.shape[0], x_inp.device
373
-
374
- cond = self.get_cond_vec(conditional, bs)
375
-
376
- visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
377
-
378
- activation1 = activations[0]
379
- activations = activations[1:]
380
-
381
- _activations = activations[::-1] if not self.rev_activations else activations
382
-
383
- a = None
384
- for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
385
-
386
- if a is not None:
387
- a = reduce(activation) + a
388
- else:
389
- a = reduce(activation)
390
-
391
- if i == self.cond_layer:
392
- if self.reduce_cond is not None:
393
- cond = self.reduce_cond(cond)
394
-
395
- a = self.film_mul(cond) * a + self.film_add(cond)
396
-
397
- a = block(a)
398
-
399
- for block in self.extra_blocks:
400
- a = a + block(a)
401
-
402
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
403
-
404
- size = int(math.sqrt(a.shape[2]))
405
-
406
- a = a.view(bs, a.shape[1], size, size)
407
-
408
- a = self.trans_conv(a)
409
-
410
- if self.n_tokens is not None:
411
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
412
-
413
- if self.upsample_proj is not None:
414
- a = self.upsample_proj(a)
415
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
416
-
417
- if return_features:
418
- return a, visual_q, cond, [activation1] + activations
419
- else:
420
- return a,
421
-
422
-
423
-
424
- class CLIPDensePredTMasked(CLIPDensePredT):
425
-
426
- def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
427
- prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
428
- refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
429
-
430
- super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
431
- n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
432
- fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
433
- limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
434
- n_tokens=n_tokens)
435
-
436
- def visual_forward_masked(self, img_s, seg_s):
437
- return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
438
-
439
- def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
440
-
441
- if seg_s is None:
442
- cond = cond_or_img_s
443
- else:
444
- img_s = cond_or_img_s
445
-
446
- with torch.no_grad():
447
- cond, _, _ = self.visual_forward_masked(img_s, seg_s)
448
-
449
- return super().forward(img_q, cond, return_features=return_features)
450
-
451
-
452
-
453
- class CLIPDenseBaseline(CLIPDenseBase):
454
-
455
- def __init__(self, version='ViT-B/32', cond_layer=0,
456
- extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
457
- reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
458
-
459
- super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
460
- device = 'cpu'
461
-
462
- # self.cond_layer = cond_layer
463
- self.extract_layer = extract_layer
464
- self.limit_to_clip_only = limit_to_clip_only
465
- self.shift_vector = None
466
-
467
- self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
468
-
469
- assert reduce2_dim is not None
470
-
471
- self.reduce2 = nn.Sequential(
472
- nn.Linear(reduce_dim, reduce2_dim),
473
- nn.ReLU(),
474
- nn.Linear(reduce2_dim, reduce_dim)
475
- )
476
-
477
- trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
478
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
479
-
480
-
481
- def forward(self, inp_image, conditional=None, return_features=False):
482
-
483
- inp_image = inp_image.to(self.model.positional_embedding.device)
484
-
485
- # x_inp = normalize(inp_image)
486
- x_inp = inp_image
487
-
488
- bs, dev = inp_image.shape[0], x_inp.device
489
-
490
- cond = self.get_cond_vec(conditional, bs)
491
-
492
- visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
493
-
494
- a = activations[0]
495
- a = self.reduce(a)
496
- a = self.film_mul(cond) * a + self.film_add(cond)
497
-
498
- if self.reduce2 is not None:
499
- a = self.reduce2(a)
500
-
501
- # the original model would execute a transformer block here
502
-
503
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
504
-
505
- size = int(math.sqrt(a.shape[2]))
506
-
507
- a = a.view(bs, a.shape[1], size, size)
508
- a = self.trans_conv(a)
509
-
510
- if return_features:
511
- return a, visual_q, cond, activations
512
- else:
513
- return a,
514
-
515
-
516
- class CLIPSegMultiLabel(nn.Module):
517
-
518
- def __init__(self, model) -> None:
519
- super().__init__()
520
-
521
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
522
-
523
- self.pascal_classes = VOC
524
-
525
- from models.clipseg import CLIPDensePredT
526
- from general_utils import load_model
527
- # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
528
- self.clipseg = load_model(model, strict=False)
529
-
530
- self.clipseg.eval()
531
-
532
- def forward(self, x):
533
-
534
- bs = x.shape[0]
535
- out = torch.ones(21, bs, 352, 352).to(x.device) * -10
536
-
537
- for class_id, class_name in enumerate(self.pascal_classes):
538
-
539
- fac = 3 if class_name == 'background' else 1
540
-
541
- with torch.no_grad():
542
- pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
543
-
544
- out[class_id] += pred
545
-
546
-
547
- out = out.permute(1, 0, 2, 3)
548
-
549
- return out
550
-
551
- # construct output tensor
552
-