Fabrice-TIERCELIN commited on
Commit
98c0c4d
1 Parent(s): 48c8053

Delete clipseg/models/vitseg.py

Browse files
Files changed (1) hide show
  1. clipseg/models/vitseg.py +0 -286
clipseg/models/vitseg.py DELETED
@@ -1,286 +0,0 @@
1
- import math
2
- from posixpath import basename, dirname, join
3
- # import clip
4
- from clip.model import convert_weights
5
- import torch
6
- import json
7
- from torch import nn
8
- from torch.nn import functional as nnf
9
- from torch.nn.modules import activation
10
- from torch.nn.modules.activation import ReLU
11
- from torchvision import transforms
12
-
13
- normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
14
-
15
- from torchvision.models import ResNet
16
-
17
-
18
- def process_prompts(conditional, prompt_list, conditional_map):
19
- # DEPRECATED
20
-
21
- # randomly sample a synonym
22
- words = [conditional_map[int(i)] for i in conditional]
23
- words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
24
- words = [w.replace('_', ' ') for w in words]
25
-
26
- if prompt_list is not None:
27
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
28
- prompts = [prompt_list[i] for i in prompt_indices]
29
- else:
30
- prompts = ['a photo of {}'] * (len(words))
31
-
32
- return [promt.format(w) for promt, w in zip(prompts, words)]
33
-
34
-
35
- class VITDenseBase(nn.Module):
36
-
37
- def rescaled_pos_emb(self, new_size):
38
- assert len(new_size) == 2
39
-
40
- a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
41
- b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
42
- return torch.cat([self.model.positional_embedding[:1], b])
43
-
44
- def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
45
-
46
- with torch.no_grad():
47
-
48
- x_inp = nnf.interpolate(x_inp, (384, 384))
49
-
50
- x = self.model.patch_embed(x_inp)
51
- cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
52
- if self.model.dist_token is None:
53
- x = torch.cat((cls_token, x), dim=1)
54
- else:
55
- x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
56
- x = self.model.pos_drop(x + self.model.pos_embed)
57
-
58
- activations = []
59
- for i, block in enumerate(self.model.blocks):
60
- x = block(x)
61
-
62
- if i in extract_layers:
63
- # permute to be compatible with CLIP
64
- activations += [x.permute(1,0,2)]
65
-
66
- x = self.model.norm(x)
67
- x = self.model.head(self.model.pre_logits(x[:, 0]))
68
-
69
- # again for CLIP compatibility
70
- # x = x.permute(1, 0, 2)
71
-
72
- return x, activations, None
73
-
74
- def sample_prompts(self, words, prompt_list=None):
75
-
76
- prompt_list = prompt_list if prompt_list is not None else self.prompt_list
77
-
78
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
79
- prompts = [prompt_list[i] for i in prompt_indices]
80
- return [promt.format(w) for promt, w in zip(prompts, words)]
81
-
82
- def get_cond_vec(self, conditional, batch_size):
83
- # compute conditional from a single string
84
- if conditional is not None and type(conditional) == str:
85
- cond = self.compute_conditional(conditional)
86
- cond = cond.repeat(batch_size, 1)
87
-
88
- # compute conditional from string list/tuple
89
- elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
90
- assert len(conditional) == batch_size
91
- cond = self.compute_conditional(conditional)
92
-
93
- # use conditional directly
94
- elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
95
- cond = conditional
96
-
97
- # compute conditional from image
98
- elif conditional is not None and type(conditional) == torch.Tensor:
99
- with torch.no_grad():
100
- cond, _, _ = self.visual_forward(conditional)
101
- else:
102
- raise ValueError('invalid conditional')
103
- return cond
104
-
105
- def compute_conditional(self, conditional):
106
- import clip
107
-
108
- dev = next(self.parameters()).device
109
-
110
- if type(conditional) in {list, tuple}:
111
- text_tokens = clip.tokenize(conditional).to(dev)
112
- cond = self.clip_model.encode_text(text_tokens)
113
- else:
114
- if conditional in self.precomputed_prompts:
115
- cond = self.precomputed_prompts[conditional].float().to(dev)
116
- else:
117
- text_tokens = clip.tokenize([conditional]).to(dev)
118
- cond = self.clip_model.encode_text(text_tokens)[0]
119
-
120
- return cond
121
-
122
-
123
- class VITDensePredT(VITDenseBase):
124
-
125
- def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
126
- depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
127
- learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
128
- add_calibration=False, process_cond=None, not_pretrained=False):
129
- super().__init__()
130
- # device = 'cpu'
131
-
132
- self.extract_layers = extract_layers
133
- self.cond_layer = cond_layer
134
- self.limit_to_clip_only = limit_to_clip_only
135
- self.process_cond = None
136
-
137
- if add_calibration:
138
- self.calibration_conds = 1
139
-
140
- self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
141
-
142
- self.add_activation1 = True
143
-
144
- import timm
145
- self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
146
- self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
147
-
148
- for p in self.model.parameters():
149
- p.requires_grad_(False)
150
-
151
- import clip
152
- self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
153
- # del self.clip_model.visual
154
-
155
-
156
- self.token_shape = (14, 14)
157
-
158
- # conditional
159
- if reduce_cond is not None:
160
- self.reduce_cond = nn.Linear(512, reduce_cond)
161
- for p in self.reduce_cond.parameters():
162
- p.requires_grad_(False)
163
- else:
164
- self.reduce_cond = None
165
-
166
- # self.film = AVAILABLE_BLOCKS['film'](512, 128)
167
- self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
168
- self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
169
-
170
- # DEPRECATED
171
- # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
172
-
173
- assert len(self.extract_layers) == depth
174
-
175
- self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
176
- self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
177
- self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
178
-
179
- trans_conv_ks = (16, 16)
180
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
181
-
182
- # refinement and trans conv
183
-
184
- if learn_trans_conv_only:
185
- for p in self.parameters():
186
- p.requires_grad_(False)
187
-
188
- for p in self.trans_conv.parameters():
189
- p.requires_grad_(True)
190
-
191
- if prompt == 'fixed':
192
- self.prompt_list = ['a photo of a {}.']
193
- elif prompt == 'shuffle':
194
- self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
195
- elif prompt == 'shuffle+':
196
- self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
197
- 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
198
- 'a bad photo of a {}.', 'a photo of the {}.']
199
- elif prompt == 'shuffle_clip':
200
- from models.clip_prompts import imagenet_templates
201
- self.prompt_list = imagenet_templates
202
-
203
- if process_cond is not None:
204
- if process_cond == 'clamp' or process_cond[0] == 'clamp':
205
-
206
- val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
207
-
208
- def clamp_vec(x):
209
- return torch.clamp(x, -val, val)
210
-
211
- self.process_cond = clamp_vec
212
-
213
- elif process_cond.endswith('.pth'):
214
-
215
- shift = torch.load(process_cond)
216
- def add_shift(x):
217
- return x + shift.to(x.device)
218
-
219
- self.process_cond = add_shift
220
-
221
- import pickle
222
- precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
223
- self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
224
-
225
-
226
- def forward(self, inp_image, conditional=None, return_features=False, mask=None):
227
-
228
- assert type(return_features) == bool
229
-
230
- # inp_image = inp_image.to(self.model.positional_embedding.device)
231
-
232
- if mask is not None:
233
- raise ValueError('mask not supported')
234
-
235
- # x_inp = normalize(inp_image)
236
- x_inp = inp_image
237
-
238
- bs, dev = inp_image.shape[0], x_inp.device
239
-
240
- inp_image_size = inp_image.shape[2:]
241
-
242
- cond = self.get_cond_vec(conditional, bs)
243
-
244
- visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
245
-
246
- activation1 = activations[0]
247
- activations = activations[1:]
248
-
249
- a = None
250
- for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
251
-
252
- if a is not None:
253
- a = reduce(activation) + a
254
- else:
255
- a = reduce(activation)
256
-
257
- if i == self.cond_layer:
258
- if self.reduce_cond is not None:
259
- cond = self.reduce_cond(cond)
260
-
261
- a = self.film_mul(cond) * a + self.film_add(cond)
262
-
263
- a = block(a)
264
-
265
- for block in self.extra_blocks:
266
- a = a + block(a)
267
-
268
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
269
-
270
- size = int(math.sqrt(a.shape[2]))
271
-
272
- a = a.view(bs, a.shape[1], size, size)
273
-
274
- if self.trans_conv is not None:
275
- a = self.trans_conv(a)
276
-
277
- if self.upsample_proj is not None:
278
- a = self.upsample_proj(a)
279
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
280
-
281
- a = nnf.interpolate(a, inp_image_size)
282
-
283
- if return_features:
284
- return a, visual_q, cond, [activation1] + activations
285
- else:
286
- return a,