altndrr commited on
Commit
fcf6714
1 Parent(s): c275a12

Load CLIP model from transformers

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -2
  2. src/nn.py +35 -179
  3. src/retrieval.py +2 -2
requirements.txt CHANGED
@@ -6,5 +6,4 @@ gradio==3.33.1
6
  gdown==4.4.0
7
  inflect==6.0.4
8
  nltk==3.8.1
9
- open_clip_torch==2.20.0
10
- transformers==4.26.1
 
6
  gdown==4.4.0
7
  inflect==6.0.4
8
  nltk==3.8.1
9
+ transformers==4.29.2
 
src/nn.py CHANGED
@@ -6,10 +6,9 @@ from typing import Optional
6
  import faiss
7
  import gdown
8
  import numpy as np
9
- import open_clip
10
  import torch
11
- from open_clip.transformer import Transformer
12
  from PIL import Image
 
13
 
14
  from src.retrieval import ArrowMetadataProvider
15
  from src.transforms import TextCompose, default_vocabulary_transforms
@@ -28,73 +27,46 @@ class CaSED(torch.nn.Module):
28
  Args:
29
  index_name (str): Name of the faiss index to use.
30
  vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
31
- model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
32
- pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".
33
 
34
  Extra hparams:
35
  alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
36
  artifact_dir (str): Path to the directory where the databases are stored. Defaults to
37
  "artifacts/".
38
  retrieval_num_results (int): Number of results to return. Defaults to 10.
39
- vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
40
- tau (float): Temperature to use for the classifier. Defaults to 1.0.
41
  """
42
 
43
  def __init__(
44
  self,
45
  index_name: str = "ViT-L-14_CC12M",
46
  vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
47
- model_name: str = "ViT-L-14",
48
- pretrained: str = "openai",
49
- vocabulary_prompt: str = "{}",
50
  **kwargs,
51
  ):
52
  super().__init__()
53
- self._prev_vocab_words = None
54
- self._prev_used_prompts = None
55
- self._prev_vocab_words_z = None
56
-
57
- model, _, preprocess = open_clip.create_model_and_transforms(
58
- model_name, pretrained=pretrained, device="cpu"
59
- )
60
- tokenizer = open_clip.get_tokenizer(model_name)
61
- self.tokenizer = tokenizer
62
- self.preprocess = preprocess
63
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  kwargs["alpha"] = kwargs.get("alpha", 0.5)
65
  kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
66
  kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
67
- vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
68
- kwargs["vocabulary_prompts"] = [vocabulary_prompt]
69
- kwargs["tau"] = kwargs.get("tau", 1.0)
70
  self.hparams = kwargs
71
 
72
- language_encoder = LanguageTransformer(
73
- model.transformer,
74
- model.token_embedding,
75
- model.positional_embedding,
76
- model.ln_final,
77
- model.text_projection,
78
- model.attn_mask,
79
- )
80
- scale = model.logit_scale.exp().item()
81
- classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])
82
-
83
- self.index_name = index_name
84
- self.vocabulary_transforms = vocabulary_transforms
85
- self.vision_encoder = model.visual
86
- self.language_encoder = language_encoder
87
- self.classifier = classifier
88
-
89
  # download databases
90
  self.prepare_data()
91
 
92
- # load faiss indices
93
  indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
94
  indices_fp = indices_list_dir / "indices.json"
95
  self.indices = json.load(open(indices_fp))
96
-
97
- # load faiss indices and metadata providers
98
  self.resources = {}
99
  for name, index_fp in self.indices.items():
100
  text_index_fp = Path(index_fp) / "text.index"
@@ -107,7 +79,7 @@ class CaSED(torch.nn.Module):
107
 
108
  self.resources[name] = {
109
  "device": DEVICE,
110
- "model": model_name,
111
  "text_index": text_index,
112
  "metadata_provider": metadata_provider,
113
  }
@@ -175,156 +147,40 @@ class CaSED(torch.nn.Module):
175
 
176
  return vocabularies
177
 
178
- @torch.no_grad()
179
- def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
180
- """Encode a vocabulary.
181
-
182
- Args:
183
- vocabulary (list): List of words.
184
- """
185
- # check if vocabulary has changed
186
- if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
187
- return self._prev_vocab_words_z
188
-
189
- # tokenize vocabulary
190
- classes = [c.replace("_", " ") for c in vocabulary]
191
- prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
192
- texts_views = [[p.format(c) for c in classes] for p in prompts]
193
- tokenized_texts_views = [
194
- torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
195
- for class_prompts in texts_views
196
- ]
197
- tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)
198
-
199
- # encode vocabulary
200
- T, C, _ = tokenized_texts_views.shape
201
- texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
202
- texts_z_views = texts_z_views.view(T, C, -1)
203
- texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)
204
-
205
- # cache vocabulary
206
- self._prev_vocab_words = vocabulary
207
- self._prev_used_prompts = use_prompts
208
- self._prev_vocab_words_z = texts_z_views
209
-
210
- return texts_z_views
211
-
212
  @torch.no_grad()
213
  def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
214
- image = self.preprocess(Image.open(image_fp)).unsqueeze(0)
215
- image_z = self.vision_encoder(image.to(DEVICE))
216
-
217
- # get the vocabulary
218
- vocabulary = self.query_index(image_z)
219
 
220
  # generate a single text embedding from the unfiltered vocabulary
221
- unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0)
222
- text_z = unfiltered_vocabulary_z.mean(dim=0)
223
- text_z = text_z / text_z.norm(dim=-1, keepdim=True)
224
- text_z = text_z.unsqueeze(0)
 
 
225
 
226
  # filter the vocabulary, embed it, and get its mean embedding
227
  vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
228
- vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True)
229
- mean_vocabulary_z = vocabulary_z.mean(dim=0)
230
- mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True)
 
 
231
 
232
  # get the image and text predictions
233
- image_p = self.classifier(image_z, vocabulary_z)
234
- text_p = self.classifier(text_z, vocabulary_z)
 
 
235
 
236
  # average the image and text predictions
237
  alpha = alpha or self.hparams["alpha"]
238
  sample_p = alpha * image_p + (1 - alpha) * text_p
239
 
240
  # get the scores
241
- sample_p = sample_p.cpu()
242
- scores = sample_p[0].tolist()
243
-
244
- del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
245
- del image_p, text_p, sample_p
246
 
247
  return vocabulary, scores
248
-
249
-
250
- class NearestNeighboursClassifier(torch.nn.Module):
251
- """Nearest neighbours classifier.
252
-
253
- It computes the similarity between the query and the supports using the
254
- cosine similarity and then applies a softmax to obtain the logits.
255
-
256
- Args:
257
- scale (float): Scale for the logits of the query. Defaults to 1.0.
258
- tau (float): Temperature for the softmax. Defaults to 1.0.
259
- """
260
-
261
- def __init__(self, scale: float = 1.0, tau: float = 1.0):
262
- super().__init__()
263
- self.scale = scale
264
- self.tau = tau
265
-
266
- def forward(self, query: torch.Tensor, supports: torch.Tensor):
267
- query = query / query.norm(dim=-1, keepdim=True)
268
- supports = supports / supports.norm(dim=-1, keepdim=True)
269
-
270
- if supports.dim() == 2:
271
- supports = supports.unsqueeze(0)
272
-
273
- Q, _ = query.shape
274
- N, C, _ = supports.shape
275
-
276
- supports = supports.mean(dim=0)
277
- supports = supports / supports.norm(dim=-1, keepdim=True)
278
- similarity = self.scale * query @ supports.T
279
- similarity = similarity / self.tau if self.tau != 1.0 else similarity
280
- logits = similarity.softmax(dim=-1)
281
-
282
- return logits
283
-
284
-
285
- class LanguageTransformer(torch.nn.Module):
286
- """Language Transformer for CLIP.
287
-
288
- Args:
289
- transformer (Transformer): Transformer model.
290
- token_embedding (torch.nn.Embedding): Token embedding.
291
- positional_embedding (torch.nn.Parameter): Positional embedding.
292
- ln_final (torch.nn.LayerNorm): Layer norm.
293
- text_projection (torch.nn.Parameter): Text projection.
294
- """
295
-
296
- def __init__(
297
- self,
298
- model: Transformer,
299
- token_embedding: torch.nn.Embedding,
300
- positional_embedding: torch.nn.Parameter,
301
- ln_final: torch.nn.LayerNorm,
302
- text_projection: torch.nn.Parameter,
303
- attn_mask: torch.Tensor,
304
- ):
305
- super().__init__()
306
- self.transformer = model
307
- self.token_embedding = token_embedding
308
- self.positional_embedding = positional_embedding
309
- self.ln_final = ln_final
310
- self.text_projection = text_projection
311
-
312
- self.register_buffer("attn_mask", attn_mask, persistent=False)
313
-
314
- def forward(self, text: torch.Tensor) -> torch.Tensor:
315
- cast_dtype = self.transformer.get_cast_dtype()
316
-
317
- """Forward pass for the text encoder."""
318
- x = self.token_embedding(text).to(cast_dtype)
319
-
320
- x = x + self.positional_embedding.to(cast_dtype)
321
- x = x.permute(1, 0, 2)
322
- x = self.transformer(x, attn_mask=self.attn_mask)
323
- x = x.permute(1, 0, 2)
324
- x = self.ln_final(x)
325
-
326
- # x.shape = [batch_size, n_ctx, transformer.width]
327
- # take features from the eot embedding (eot_token is the highest number in each sequence)
328
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
329
-
330
- return x
 
6
  import faiss
7
  import gdown
8
  import numpy as np
 
9
  import torch
 
10
  from PIL import Image
11
+ from transformers import CLIPModel, CLIPProcessor
12
 
13
  from src.retrieval import ArrowMetadataProvider
14
  from src.transforms import TextCompose, default_vocabulary_transforms
 
27
  Args:
28
  index_name (str): Name of the faiss index to use.
29
  vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
 
 
30
 
31
  Extra hparams:
32
  alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
33
  artifact_dir (str): Path to the directory where the databases are stored. Defaults to
34
  "artifacts/".
35
  retrieval_num_results (int): Number of results to return. Defaults to 10.
 
 
36
  """
37
 
38
  def __init__(
39
  self,
40
  index_name: str = "ViT-L-14_CC12M",
41
  vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
 
 
 
42
  **kwargs,
43
  ):
44
  super().__init__()
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # load CLIP
47
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
48
+ self.index_name = index_name
49
+ self.vocabulary_transforms = vocabulary_transforms
50
+ self.vision_encoder = model.vision_model
51
+ self.vision_proj = model.visual_projection
52
+ self.language_encoder = model.text_model
53
+ self.language_proj = model.text_projection
54
+ self.logit_scale = model.logit_scale.exp()
55
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
56
+
57
+ # set hparams
58
  kwargs["alpha"] = kwargs.get("alpha", 0.5)
59
  kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
60
  kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
 
 
 
61
  self.hparams = kwargs
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # download databases
64
  self.prepare_data()
65
 
66
+ # load faiss indices and metadata providers
67
  indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
68
  indices_fp = indices_list_dir / "indices.json"
69
  self.indices = json.load(open(indices_fp))
 
 
70
  self.resources = {}
71
  for name, index_fp in self.indices.items():
72
  text_index_fp = Path(index_fp) / "text.index"
 
79
 
80
  self.resources[name] = {
81
  "device": DEVICE,
82
+ "model": "ViT-L-14",
83
  "text_index": text_index,
84
  "metadata_provider": metadata_provider,
85
  }
 
147
 
148
  return vocabularies
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  @torch.no_grad()
151
  def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
152
+ # forward the image
153
+ image = self.processor(images=Image.open(image_fp), return_tensors="pt")
154
+ image["pixel_values"] = image["pixel_values"].to(DEVICE)
155
+ image_z = self.vision_proj(self.vision_encoder(**image)[1])
 
156
 
157
  # generate a single text embedding from the unfiltered vocabulary
158
+ vocabulary = self.query_index(image_z)
159
+ text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
160
+ text["input_ids"] = text["input_ids"][:, :77].to(DEVICE)
161
+ text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE)
162
+ text_z = self.language_encoder(**text)[1]
163
+ text_z = self.language_proj(text_z)
164
 
165
  # filter the vocabulary, embed it, and get its mean embedding
166
  vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
167
+ text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
168
+ text = {k: v.to(DEVICE) for k, v in text.items()}
169
+ vocabulary_z = self.language_encoder(**text)[1]
170
+ vocabulary_z = self.language_proj(vocabulary_z)
171
+ vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
172
 
173
  # get the image and text predictions
174
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True)
175
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
176
+ image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
177
+ text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
178
 
179
  # average the image and text predictions
180
  alpha = alpha or self.hparams["alpha"]
181
  sample_p = alpha * image_p + (1 - alpha) * text_p
182
 
183
  # get the scores
184
+ scores = sample_p[0].cpu().tolist()
 
 
 
 
185
 
186
  return vocabulary, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/retrieval.py CHANGED
@@ -11,8 +11,8 @@ class ArrowMetadataProvider:
11
  Code taken from: https://github.dev/rom1504/clip-retrieval
12
  """
13
 
14
- def __init__(self, arrow_folder: str):
15
- arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
16
  self.table = pa.concat_tables(
17
  [
18
  pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
 
11
  Code taken from: https://github.dev/rom1504/clip-retrieval
12
  """
13
 
14
+ def __init__(self, arrow_folder: Path):
15
+ arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
16
  self.table = pa.concat_tables(
17
  [
18
  pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()