altndrr commited on
Commit
0dc766b
1 Parent(s): aa4749c

Sync model code with repo code

Browse files
README.md CHANGED
@@ -38,7 +38,7 @@ outputs = model(images, alpha=0.5)
38
  labels, scores = outputs["vocabularies"][0], outputs["scores"][0]
39
 
40
  # print the top 5 most likely labels for the image
41
- values, indices = scores.topk(5)
42
  print("\nTop predictions:\n")
43
  for value, index in zip(values, indices):
44
  print(f"{labels[index]:>16s}: {100 * value.item():.2f}%")
@@ -54,7 +54,7 @@ pip install torch faiss-cpu flair inflect nltk transformers
54
 
55
  ```latex
56
  @misc{conti2023vocabularyfree,
57
- title={Vocabulary-free Image Classification},
58
  author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
59
  year={2023},
60
  eprint={2306.00917},
 
38
  labels, scores = outputs["vocabularies"][0], outputs["scores"][0]
39
 
40
  # print the top 5 most likely labels for the image
41
+ values, indices = scores.sort(dim=-1, descending=True)
42
  print("\nTop predictions:\n")
43
  for value, index in zip(values, indices):
44
  print(f"{labels[index]:>16s}: {100 * value.item():.2f}%")
 
54
 
55
  ```latex
56
  @misc{conti2023vocabularyfree,
57
+ title={Vocabulary-free Image Classification},
58
  author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
59
  year={2023},
60
  eprint={2306.00917},
config.json CHANGED
@@ -7,6 +7,7 @@
7
  "AutoConfig": "configuration_cased.CaSEDConfig",
8
  "AutoModel": "modeling_cased.CaSEDModel"
9
  },
 
10
  "index_name": "cc12m",
11
  "model_type": "cased",
12
  "retrieval_num_results": 10,
 
7
  "AutoConfig": "configuration_cased.CaSEDConfig",
8
  "AutoModel": "modeling_cased.CaSEDModel"
9
  },
10
+ "cache_dir": "~/.cache/cased",
11
  "index_name": "cc12m",
12
  "model_type": "cased",
13
  "retrieval_num_results": 10,
configuration_cased.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from transformers.modeling_utils import PretrainedConfig
2
 
3
 
@@ -5,9 +7,10 @@ class CaSEDConfig(PretrainedConfig):
5
  """Configuration class for CaSED.
6
 
7
  Args:
8
- index_name (str, optional): Name of the index. Defaults to "cc12m".
9
- alpha (float, optional): Weight of the vision loss. Defaults to 0.5.
10
- retrieval_num_results (int, optional): Number of results to return. Defaults to 10.
 
11
  """
12
 
13
  model_type = "cased"
@@ -18,9 +21,11 @@ class CaSEDConfig(PretrainedConfig):
18
  index_name: str = "cc12m",
19
  alpha: float = 0.5,
20
  retrieval_num_results: int = 10,
 
21
  **kwargs,
22
  ):
23
  super().__init__(**kwargs)
24
  self.index_name = index_name
25
  self.alpha = alpha
26
  self.retrieval_num_results = retrieval_num_results
 
 
1
+ import os
2
+
3
  from transformers.modeling_utils import PretrainedConfig
4
 
5
 
 
7
  """Configuration class for CaSED.
8
 
9
  Args:
10
+ index_name (str): Name of the index. Defaults to "cc12m".
11
+ alpha (float): Weight of the vision loss. Defaults to 0.5.
12
+ retrieval_num_results (int): Number of results to return. Defaults to 10.
13
+ cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased".
14
  """
15
 
16
  model_type = "cased"
 
21
  index_name: str = "cc12m",
22
  alpha: float = 0.5,
23
  retrieval_num_results: int = 10,
24
+ cache_dir: str = os.path.expanduser("~/.cache/cased"),
25
  **kwargs,
26
  ):
27
  super().__init__(**kwargs)
28
  self.index_name = index_name
29
  self.alpha = alpha
30
  self.retrieval_num_results = retrieval_num_results
31
+ self.cache_dir = cache_dir
modeling_cased.py CHANGED
@@ -1,66 +1,21 @@
1
  import os
2
- import tarfile
3
- from pathlib import Path
4
- from typing import Optional
5
 
6
- import faiss
7
  import numpy as np
8
- import pyarrow as pa
9
- import requests
10
  import torch
11
- from tqdm import tqdm
12
  from transformers import CLIPModel, CLIPProcessor
13
  from transformers.modeling_utils import PreTrainedModel
14
 
15
  from .configuration_cased import CaSEDConfig
 
16
  from .transforms_cased import default_vocabulary_transforms
17
 
18
- DATABASES = {
19
- "cc12m": {
20
- "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz",
21
- "cache_subdir": "./cc12m/vit-l-14/",
22
- },
23
- }
24
-
25
-
26
- class MetadataProvider:
27
- """Metadata provider.
28
-
29
- It uses arrow files to store metadata and retrieve it efficiently.
30
-
31
- Code reference:
32
- - https://github.dev/rom1504/clip-retrieval
33
- """
34
-
35
- def __init__(self, arrow_folder: Path):
36
- arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
37
- self.table = pa.concat_tables(
38
- [
39
- pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
40
- for arrow_file in arrow_files
41
- ]
42
- )
43
-
44
- def get(self, ids: np.ndarray, cols: Optional[list] = None):
45
- """Get arrow metadata from ids.
46
-
47
- Args:
48
- ids (np.ndarray): Ids to retrieve.
49
- cols (Optional[list], optional): Columns to retrieve. Defaults to None.
50
- """
51
- if cols is None:
52
- cols = self.table.schema.names
53
- else:
54
- cols = list(set(self.table.schema.names) & set(cols))
55
- t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
56
- return t.select(cols).to_pandas().to_dict("records")
57
-
58
 
59
  class CaSEDModel(PreTrainedModel):
60
  """Transformers module for Category Search from External Databases (CaSED).
61
 
62
  Reference:
63
- - Conti et al. Vocabulary-free Image Classification. arXiv 2023.
64
 
65
  Args:
66
  config (CaSEDConfig): Configuration class for CaSED.
@@ -80,121 +35,57 @@ class CaSEDModel(PreTrainedModel):
80
  self.logit_scale = model.logit_scale.exp()
81
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
82
 
83
- # load transforms
84
- self.vocabulary_transforms = default_vocabulary_transforms()
85
-
86
  # set hparams
87
  self.hparams = {}
88
  self.hparams["alpha"] = config.alpha
89
  self.hparams["index_name"] = config.index_name
90
  self.hparams["retrieval_num_results"] = config.retrieval_num_results
 
91
 
92
- # set cache dir
93
- self.hparams["cache_dir"] = Path(os.path.expanduser("~/.cache/cased"))
94
  os.makedirs(self.hparams["cache_dir"], exist_ok=True)
95
 
96
- # download databases
97
- self.prepare_data()
98
-
99
- # load faiss indices and metadata providers
100
- self.resources = {}
101
- for name, items in DATABASES.items():
102
- database_path = self.hparams["cache_dir"] / "databases" / items["cache_subdir"]
103
- text_index_fp = database_path / "text.index"
104
- metadata_fp = database_path / "metadata/"
105
-
106
- text_index = faiss.read_index(
107
- str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
108
- )
109
- metadata_provider = MetadataProvider(metadata_fp)
110
-
111
- self.resources[name] = {
112
- "device": self.device,
113
- "model": "ViT-L-14",
114
- "text_index": text_index,
115
- "metadata_provider": metadata_provider,
116
- }
117
-
118
- def prepare_data(self):
119
- """Download data if needed."""
120
- databases_path = Path(self.hparams["cache_dir"]) / "databases"
121
-
122
- for name, items in DATABASES.items():
123
- url = items["url"]
124
- database_path = Path(databases_path, name)
125
- if database_path.exists():
126
- continue
127
-
128
- # download data
129
- target_path = Path(databases_path, name + ".tar.gz")
130
- os.makedirs(target_path.parent, exist_ok=True)
131
- with requests.get(url, stream=True) as r:
132
- r.raise_for_status()
133
- total_bytes_size = int(r.headers.get('content-length', 0))
134
- chunk_size = 8192
135
- p_bar = tqdm(
136
- desc="Downloading cc12m index",
137
- total=total_bytes_size,
138
- unit='iB',
139
- unit_scale=True,
140
- )
141
- with open(target_path, 'wb') as f:
142
- for chunk in r.iter_content(chunk_size=chunk_size):
143
- f.write(chunk)
144
- p_bar.update(len(chunk))
145
- p_bar.close()
146
-
147
- # extract data
148
- tar = tarfile.open(target_path, "r:gz")
149
- tar.extractall(target_path.parent)
150
- tar.close()
151
- target_path.unlink()
152
-
153
- @torch.no_grad()
154
- def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
155
- """Query the external database index.
156
 
157
  Args:
158
- sample_z (torch.Tensor): Sample to query the index.
159
  """
160
- # get the index
161
- resources = self.resources[self.hparams["index_name"]]
162
- text_index = resources["text_index"]
163
- metadata_provider = resources["metadata_provider"]
164
-
165
- # query the index
166
- sample_z = sample_z.squeeze(0)
167
- sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
168
- query_input = sample_z.cpu().detach().numpy().tolist()
169
- query = np.expand_dims(np.array(query_input).astype("float32"), 0)
170
-
171
- distances, idxs, _ = text_index.search_and_reconstruct(
172
- query, self.hparams["retrieval_num_results"]
173
- )
174
- results = idxs[0]
175
- nb_results = np.where(results == -1)[0]
176
- nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
177
- indices = results[:nb_results]
178
- distances = distances[0][:nb_results]
179
-
180
- if len(distances) == 0:
181
- return []
182
-
183
- # get the metadata
184
- results = []
185
- metadata = metadata_provider.get(indices[:20], ["caption"])
186
- for key, (d, i) in enumerate(zip(distances, indices)):
187
- output = {}
188
- meta = None if key + 1 > len(metadata) else metadata[key]
189
- if meta is not None:
190
- output.update(meta)
191
- output["id"] = i.item()
192
- output["similarity"] = d.item()
193
- results.append(output)
194
-
195
- # get the captions only
196
- vocabularies = [result["caption"] for result in results]
197
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  return vocabularies
199
 
200
  def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
@@ -205,52 +96,62 @@ class CaSEDModel(PreTrainedModel):
205
  - pixel_values (torch.Tensor): Pixel values of the images.
206
  alpha (Optional[float]): Alpha value for the interpolation.
207
  """
 
 
208
  # forward the images
209
  images["pixel_values"] = images["pixel_values"].to(self.device)
210
  images_z = self.vision_proj(self.vision_encoder(**images)[1])
211
-
212
- vocabularies, samples_p = [], []
213
- for image_z in images_z:
214
- image_z = image_z.unsqueeze(0)
215
-
216
- # generate a single text embedding from the unfiltered vocabulary
217
- vocabulary = self.query_index(image_z)
218
- text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
219
- text["input_ids"] = text["input_ids"][:, :77].to(self.device)
220
- text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
221
- text_z = self.language_encoder(**text)[1]
222
- text_z = self.language_proj(text_z)
223
- text_z = text_z / text_z.norm(dim=-1, keepdim=True)
224
- text_z = text_z.mean(dim=0).unsqueeze(0)
225
- text_z = text_z / text_z.norm(dim=-1, keepdim=True)
226
-
227
- # filter the vocabulary, embed it, and get its mean embedding
228
- vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
229
- text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
230
- text = {k: v.to(self.device) for k, v in text.items()}
231
- vocabulary_z = self.language_encoder(**text)[1]
232
- vocabulary_z = self.language_proj(vocabulary_z)
233
- vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
234
-
235
- # get the image and text predictions
236
- image_z = image_z / image_z.norm(dim=-1, keepdim=True)
237
- text_z = text_z / text_z.norm(dim=-1, keepdim=True)
238
- image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1)
239
- text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1)
240
-
241
- # average the image and text predictions
242
- alpha = alpha or self.hparams["alpha"]
243
- sample_p = alpha * image_p + (1 - alpha) * text_p
244
-
245
- # save the results
246
- samples_p.append(sample_p)
247
- vocabularies.append(vocabulary)
248
-
249
- # get the scores
250
- samples_p = torch.cat(samples_p, dim=0)
251
- scores = samples_p.cpu()
252
-
253
- # define the results
254
- results = {"vocabularies": vocabularies, "scores": scores}
255
-
256
- return results
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Callable, Optional
 
 
3
 
 
4
  import numpy as np
 
 
5
  import torch
 
6
  from transformers import CLIPModel, CLIPProcessor
7
  from transformers.modeling_utils import PreTrainedModel
8
 
9
  from .configuration_cased import CaSEDConfig
10
+ from .retrieval_cased import RetrievalDatabase, download_retrieval_databases
11
  from .transforms_cased import default_vocabulary_transforms
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class CaSEDModel(PreTrainedModel):
15
  """Transformers module for Category Search from External Databases (CaSED).
16
 
17
  Reference:
18
+ - Conti et al. Vocabulary-free Image Classification. NeurIPS 2023.
19
 
20
  Args:
21
  config (CaSEDConfig): Configuration class for CaSED.
 
35
  self.logit_scale = model.logit_scale.exp()
36
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
37
 
 
 
 
38
  # set hparams
39
  self.hparams = {}
40
  self.hparams["alpha"] = config.alpha
41
  self.hparams["index_name"] = config.index_name
42
  self.hparams["retrieval_num_results"] = config.retrieval_num_results
43
+ self.hparams["cache_dir"] = config.cache_dir
44
 
45
+ # create cache dir
 
46
  os.makedirs(self.hparams["cache_dir"], exist_ok=True)
47
 
48
+ # download data
49
+ download_retrieval_databases(cache_dir=self.hparams["cache_dir"])
50
+
51
+ # setup vocabulary
52
+ self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"])
53
+ self._vocab_transform = default_vocabulary_transforms()
54
+
55
+ @property
56
+ def vocab_transform(self) -> Callable:
57
+ """Get image preprocess transform.
58
+
59
+ The getter wraps the transform in a map_reduce function and applies it to a list of images.
60
+ If interested in the transform itself, use `self._vocab_transform`.
61
+ """
62
+ vocab_transform = self._vocab_transform
63
+
64
+ def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]:
65
+ return [vocab_transform(text) for text in texts]
66
+
67
+ return vocabs_transforms
68
+
69
+ def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]:
70
+ """Get the vocabulary for a batch of images.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  Args:
73
+ images_z (torch.Tensor): Batch of image embeddings.
74
  """
75
+ num_samples = self.hparams["retrieval_num_results"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ assert images_z is not None
78
+
79
+ images_z = images_z / images_z.norm(dim=-1, keepdim=True)
80
+ images_z = images_z.cpu().detach().numpy().tolist()
81
+
82
+ if isinstance(images_z[0], float):
83
+ images_z = [images_z]
84
+
85
+ query = np.matrix(images_z).astype("float32")
86
+ results = self.vocabulary.query(query, modality="text", num_samples=num_samples)
87
+
88
+ vocabularies = [[r["caption"] for r in result] for result in results]
89
  return vocabularies
90
 
91
  def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
 
96
  - pixel_values (torch.Tensor): Pixel values of the images.
97
  alpha (Optional[float]): Alpha value for the interpolation.
98
  """
99
+ alpha = alpha or self.hparams["alpha"]
100
+
101
  # forward the images
102
  images["pixel_values"] = images["pixel_values"].to(self.device)
103
  images_z = self.vision_proj(self.vision_encoder(**images)[1])
104
+ images_z = images_z / images_z.norm(dim=-1, keepdim=True)
105
+ vocabularies = self.get_vocabulary(images_z=images_z)
106
+
107
+ # encode unfiltered words
108
+ unfiltered_words = sum(vocabularies, [])
109
+ texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True)
110
+ texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device)
111
+ texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device)
112
+ texts_z = self.language_encoder(**texts_z)[1]
113
+ texts_z = self.language_proj(texts_z)
114
+ texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
115
+
116
+ # generate a text embedding for each image from their unfiltered words
117
+ unfiltered_words_per_image = [len(vocab) for vocab in vocabularies]
118
+ texts_z = torch.split(texts_z, unfiltered_words_per_image)
119
+ texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z])
120
+ texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
121
+
122
+ # filter the words and embed them
123
+ vocabularies = self.vocab_transform(vocabularies)
124
+ vocabularies = [vocab or ["object"] for vocab in vocabularies]
125
+ words = sum(vocabularies, [])
126
+ words_z = self.processor(words, return_tensors="pt", padding=True)
127
+ words_z = {k: v.to(self.device) for k, v in words_z.items()}
128
+ words_z = self.language_encoder(**words_z)[1]
129
+ words_z = self.language_proj(words_z)
130
+ words_z = words_z / words_z.norm(dim=-1, keepdim=True)
131
+
132
+ # create a one-hot relation mask between images and words
133
+ words_per_image = [len(vocab) for vocab in vocabularies]
134
+ col_indices = torch.arange(sum(words_per_image))
135
+ row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image))
136
+ mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device)
137
+ mask[row_indices, col_indices] = 1
138
+
139
+ # get the image and text similarities
140
+ images_z = images_z / images_z.norm(dim=-1, keepdim=True)
141
+ texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
142
+ words_z = words_z / words_z.norm(dim=-1, keepdim=True)
143
+ images_sim = self.logit_scale * images_z @ words_z.T
144
+ texts_sim = self.logit_scale * texts_z @ words_z.T
145
+
146
+ # mask unrelated words
147
+ images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf"))
148
+ texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf"))
149
+
150
+ # get the image and text predictions
151
+ images_p = images_sim.softmax(dim=-1)
152
+ texts_p = texts_sim.softmax(dim=-1)
153
+
154
+ # average the image and text predictions
155
+ samples_p = alpha * images_p + (1 - alpha) * texts_p
156
+
157
+ return {"scores": samples_p, "words": words, "vocabularies": vocabularies}
retrieval_cased.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tarfile
2
+ from collections import defaultdict
3
+ from pathlib import Path
4
+
5
+ import faiss
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ import requests
9
+ from tqdm import tqdm
10
+
11
+ __all__ = ["RetrievalDatabase", "download_retrieval_databases"]
12
+
13
+ RETRIEVAL_DATABASES_URLS = {
14
+ "cc12m": {
15
+ "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz",
16
+ "cache_subdir": "./cc12m/vit-l-14/",
17
+ },
18
+ }
19
+
20
+
21
+ def download_retrieval_databases(cache_dir: str = "~/.cache/cased"):
22
+ """Download data if needed.
23
+
24
+ Args:
25
+ cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased".
26
+ """
27
+ databases_path = Path(cache_dir, "databases")
28
+
29
+ for name, items in RETRIEVAL_DATABASES_URLS.items():
30
+ url = items["url"]
31
+ database_path = Path(databases_path, name)
32
+ if database_path.exists():
33
+ continue
34
+
35
+ # download data
36
+ target_path = Path(databases_path, name + ".tar.gz")
37
+ target_path.parent.mkdir(parents=True, exist_ok=True)
38
+ with requests.get(url, stream=True) as r:
39
+ r.raise_for_status()
40
+ total_bytes_size = int(r.headers.get("content-length", 0))
41
+ chunk_size = 8192
42
+ p_bar = tqdm(
43
+ desc="Downloading cc12m index",
44
+ total=total_bytes_size,
45
+ unit="iB",
46
+ unit_scale=True,
47
+ )
48
+ with open(target_path, "wb") as f:
49
+ for chunk in r.iter_content(chunk_size=chunk_size):
50
+ f.write(chunk)
51
+ p_bar.update(len(chunk))
52
+ p_bar.close()
53
+
54
+ # extract data
55
+ tar = tarfile.open(target_path, "r:gz")
56
+ tar.extractall(target_path.parent)
57
+ tar.close()
58
+ target_path.unlink()
59
+
60
+
61
+ class RetrievalDatabaseMetadataProvider:
62
+ """Metadata provider for the retrieval database.
63
+
64
+ Args:
65
+ metadata_dir (str): Path to the metadata directory.
66
+ """
67
+
68
+ def __init__(self, metadata_dir: str):
69
+ metadatas = [str(a) for a in sorted(Path(metadata_dir).glob("**/*")) if a.is_file()]
70
+ self.table = pa.concat_tables(
71
+ [
72
+ pa.ipc.RecordBatchFileReader(pa.memory_map(metadata, "r")).read_all()
73
+ for metadata in metadatas
74
+ ]
75
+ )
76
+
77
+ def get(self, ids):
78
+ """Get the metadata for the given ids.
79
+
80
+ Args:
81
+ ids (list): List of ids.
82
+ """
83
+ columns = self.table.schema.names
84
+ end_ids = [i + 1 for i in ids]
85
+ t = pa.concat_tables([self.table[start:end] for start, end in zip(ids, end_ids)])
86
+ return t.select(columns).to_pandas().to_dict("records")
87
+
88
+
89
+ class RetrievalDatabase:
90
+ """Retrieval database.
91
+
92
+ Args:
93
+ database_name (str): Name of the database.
94
+ cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased".
95
+ """
96
+
97
+ def __init__(self, database_name: str, cache_dir: str = "~/.cache/cased"):
98
+ assert database_name in RETRIEVAL_DATABASES_URLS.keys(), (
99
+ f"Database name should be one of "
100
+ f"{list(RETRIEVAL_DATABASES_URLS.keys())}, got {database_name}."
101
+ )
102
+
103
+ database_dir = Path(cache_dir) / "databases"
104
+ database_dir = database_dir / RETRIEVAL_DATABASES_URLS[database_name]["cache_subdir"]
105
+ self._database_dir = database_dir
106
+
107
+ image_index_fp = Path(database_dir) / "image.index"
108
+ text_index_fp = Path(database_dir) / "text.index"
109
+
110
+ image_index = (
111
+ faiss.read_index(str(image_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
112
+ if image_index_fp.exists()
113
+ else None
114
+ )
115
+ text_index = (
116
+ faiss.read_index(str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
117
+ if text_index_fp.exists()
118
+ else None
119
+ )
120
+
121
+ metadata_dir = str(Path(database_dir) / "metadata")
122
+ metadata_provider = RetrievalDatabaseMetadataProvider(metadata_dir)
123
+
124
+ self._image_index = image_index
125
+ self._text_index = text_index
126
+ self._metadata_provider = metadata_provider
127
+
128
+ def _map_to_metadata(self, indices: list, distances: list, embs: list, num_images: int):
129
+ """Map the indices to metadata.
130
+
131
+ Args:
132
+ indices (list): List of indices.
133
+ distances (list): List of distances.
134
+ embs (list): List of results embeddings.
135
+ num_images (int): Number of images.
136
+ """
137
+ results = []
138
+ metas = self._metadata_provider.get(indices[:num_images])
139
+ for key, (d, i, emb) in enumerate(zip(distances, indices, embs)):
140
+ output = {}
141
+ meta = None if key + 1 > len(metas) else metas[key]
142
+ if meta is not None:
143
+ output.update(self._meta_to_dict(meta))
144
+ output["id"] = i.item()
145
+ output["similarity"] = d.item()
146
+ output["sample_z"] = emb.tolist()
147
+ results.append(output)
148
+
149
+ return results
150
+
151
+ def _meta_to_dict(self, metadata):
152
+ """Convert metadata to dict.
153
+
154
+ Args:
155
+ metadata (dict): Metadata.
156
+ """
157
+ output = {}
158
+ for k, v in metadata.items():
159
+ if isinstance(v, bytes):
160
+ v = v.decode()
161
+ elif type(v).__module__ == np.__name__:
162
+ v = v.item()
163
+ output[k] = v
164
+ return output
165
+
166
+ def _get_connected_components(self, neighbors):
167
+ """Find connected components in a graph.
168
+
169
+ Args:
170
+ neighbors (dict): Dictionary of neighbors.
171
+ """
172
+ seen = set()
173
+
174
+ def component(node):
175
+ r = []
176
+ nodes = {node}
177
+ while nodes:
178
+ node = nodes.pop()
179
+ seen.add(node)
180
+ nodes |= set(neighbors[node]) - seen
181
+ r.append(node)
182
+ return r
183
+
184
+ u = []
185
+ for node in neighbors:
186
+ if node not in seen:
187
+ u.append(component(node))
188
+ return u
189
+
190
+ def _deduplicate_embeddings(self, embeddings, threshold=0.94):
191
+ """Deduplicate embeddings.
192
+
193
+ Args:
194
+ embeddings (np.matrix): Embeddings to deduplicate.
195
+ threshold (float): Threshold to use for deduplication. Default is 0.94.
196
+ """
197
+ index = faiss.IndexFlatIP(embeddings.shape[1])
198
+ index.add(embeddings)
199
+ l, _, indices = index.range_search(embeddings, threshold)
200
+
201
+ same_mapping = defaultdict(list)
202
+
203
+ for i in range(embeddings.shape[0]):
204
+ start = l[i]
205
+ end = l[i + 1]
206
+ for j in indices[start:end]:
207
+ same_mapping[int(i)].append(int(j))
208
+
209
+ groups = self._get_connected_components(same_mapping)
210
+ non_uniques = set()
211
+ for g in groups:
212
+ for e in g[1:]:
213
+ non_uniques.add(e)
214
+
215
+ return set(list(non_uniques))
216
+
217
+ def query(
218
+ self, query: np.matrix, modality: str = "text", num_samples: int = 10
219
+ ) -> list[list[dict]]:
220
+ """Query the database.
221
+
222
+ Args:
223
+ query (np.matrix): Query to search.
224
+ modality (str): Modality to search. One of `image` or `text`. Default to `text`.
225
+ num_samples (int): Number of samples to return. Default is 40.
226
+ """
227
+ index = self._image_index if modality == "image" else self._text_index
228
+
229
+ distances, indices, embeddings = index.search_and_reconstruct(query, num_samples)
230
+ results = [indices[i] for i in range(len(indices))]
231
+
232
+ nb_results = [np.where(r == -1)[0] for r in results]
233
+ total_distances = []
234
+ total_indices = []
235
+ total_embeddings = []
236
+ for i in range(len(results)):
237
+ num_res = nb_results[i][0] if len(nb_results[i]) > 0 else len(results[i])
238
+
239
+ result_indices = results[i][:num_res]
240
+ result_distances = distances[i][:num_res]
241
+ result_embeddings = embeddings[i][:num_res]
242
+
243
+ # normalise embeddings
244
+ l2 = np.atleast_1d(np.linalg.norm(result_embeddings, 2, -1))
245
+ l2[l2 == 0] = 1
246
+ result_embeddings = result_embeddings / np.expand_dims(l2, -1)
247
+
248
+ # deduplicate embeddings
249
+ local_indices_to_remove = self._deduplicate_embeddings(result_embeddings)
250
+ indices_to_remove = set()
251
+ for local_index in local_indices_to_remove:
252
+ indices_to_remove.add(result_indices[local_index])
253
+
254
+ curr_indices = []
255
+ curr_distances = []
256
+ curr_embeddings = []
257
+ for ind, dis, emb in zip(result_indices, result_distances, result_embeddings):
258
+ if ind not in indices_to_remove:
259
+ indices_to_remove.add(ind)
260
+ curr_indices.append(ind)
261
+ curr_distances.append(dis)
262
+ curr_embeddings.append(emb)
263
+
264
+ total_indices.append(curr_indices)
265
+ total_distances.append(curr_distances)
266
+ total_embeddings.append(curr_embeddings)
267
+
268
+ if len(total_distances) == 0:
269
+ return []
270
+
271
+ total_results = []
272
+ for i in range(len(total_distances)):
273
+ results = self._map_to_metadata(
274
+ total_indices[i], total_distances[i], total_embeddings[i], num_samples
275
+ )
276
+ total_results.append(results)
277
+
278
+ return total_results
transforms_cased.py CHANGED
@@ -1,6 +1,6 @@
1
  import re
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Union
4
 
5
  import inflect
6
  import nltk
 
1
  import re
2
  from abc import ABC, abstractmethod
3
+ from typing import Union
4
 
5
  import inflect
6
  import nltk