altndrr commited on
Commit
a3ee979
1 Parent(s): 3070a83

Add first version

Browse files
.gitignore ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # pytype type checking
128
+ .pytype/
129
+
130
+ # Pyre type checker
131
+ .pyre/
132
+
133
+ ### VisualStudioCode
134
+ .vscode/*
135
+ !.vscode/settings.json
136
+ !.vscode/tasks.json
137
+ !.vscode/launch.json
138
+ !.vscode/extensions.json
139
+ *.code-workspace
140
+ **/.vscode
141
+
142
+ # JetBrains
143
+ .idea/
144
+
145
+ # Data & Models
146
+ *.h5
147
+ *.tar
148
+ *.tar.gz
149
+
150
+ # Template
151
+ /artifacts/models/databases/*/
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Vic
3
  emoji: 🌍
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.33.1
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Vocabulary-free Image Classification
3
  emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.33.1
8
+ python_version: 3.9
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import gradio as gr
4
+ import torch
5
+
6
+ from src.nn import CaSED
7
+
8
+ PAPER_TITLE = "Vocabulary-free Image Classification"
9
+ PAPER_DESCRIPTION = """
10
+
11
+
12
+ <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
13
+ <a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
14
+ <img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
15
+ </a>
16
+ <a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;">
17
+ <img src="https://img.shields.io/badge/paper-arXiv%3A2306.00917-B31B1B.svg"/>
18
+ </a>
19
+ <a href="https://altndrr.github.io/vic/" style="margin-right: 0.5rem;">
20
+ <img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
21
+ </a>
22
+ </div>
23
+
24
+
25
+ Vocabulary-free Image Classification aims to assign a class to an image *without* prior knowledge
26
+ on the list of class names, thus operating on the semantic class space that contains all the
27
+ possible concepts. Our proposed method CaSED finds the best matching category within the
28
+ unconstrained semantic space by multimodal data from large vision-language databases. We first
29
+ retrieve the semantically most similar captions from a database, from which we extract a set of
30
+ candidate categories by applying text parsing and filtering techniques. We further score the
31
+ candidates using the multimodal aligned representation of the large pre-trained VLM, *i.e.* CLIP,
32
+ to obtain the best-matching category, using *alpha* as a hyperparameter to control the trade-off
33
+ between the visual and textual similarity.
34
+ """
35
+ PAPER_URL = "https://arxiv.org/abs/2306.00917"
36
+
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ model = CaSED().to(DEVICE).eval()
40
+
41
+
42
+ def vic(filename: str, alpha: Optional[float] = None):
43
+ # get the outputs of the model
44
+ vocabulary, scores = model(filename, alpha=alpha)
45
+ confidences = dict(zip(vocabulary, scores))
46
+
47
+ return confidences
48
+
49
+ def resize_image(image, max_size: int = 256):
50
+ """Resize image to max_size keeping the aspect ratio."""
51
+ width, height = image.size
52
+ if width > height:
53
+ ratio = width / height
54
+ new_width = max_size * ratio
55
+ new_height = max_size
56
+ else:
57
+ ratio = height / width
58
+ new_width = max_size
59
+ new_height = max_size * ratio
60
+ return image.resize((int(new_width), int(new_height)))
61
+
62
+
63
+ demo = gr.Interface(
64
+ fn=vic,
65
+ inputs=[
66
+ gr.Image(type="filepath", label="input"),
67
+ gr.Slider(0.0, 1.0, value=0.5, label="alpha"),
68
+ ],
69
+ outputs=[gr.Label(num_top_classes=5, label="output")],
70
+ title=PAPER_TITLE,
71
+ description=PAPER_DESCRIPTION,
72
+ article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
73
+ examples="./artifacts/examples/",
74
+ allow_flagging='never',
75
+ theme=gr.themes.Soft()
76
+ )
77
+
78
+ demo.launch(share=False)
artifacts/examples/basketball.jpg ADDED
artifacts/examples/cassowary.jpg ADDED
artifacts/examples/colosseum.jpg ADDED
artifacts/examples/desk.jpg ADDED
artifacts/examples/kitchen.jpg ADDED
artifacts/examples/log.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_fp
2
+ basketball.jpg
3
+ cassowary.jpg
4
+ colosseum.jpg
5
+ desk.jpg
6
+ kitchen.jpg
7
+ monkey.jpg
8
+ park.jpg
9
+ ramen.jpg
10
+ sagrada.jpg
11
+ venice.jpg
artifacts/examples/monkey.jpg ADDED
artifacts/examples/park.jpg ADDED
artifacts/examples/ramen.jpg ADDED
artifacts/examples/sagrada.jpg ADDED
artifacts/examples/venice.jpg ADDED
artifacts/models/databases/.gitkeep ADDED
File without changes
artifacts/models/retrieval/indices.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
3
+ }
flagged/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ faiss-cpu==1.7.4
4
+ flair==0.12.2
5
+ gradio==3.33.1
6
+ gdown==4.4.0
7
+ inflect==6.0.4
8
+ nltk==3.8.1
9
+ open_clip_torch==2.20.0
10
+ transformers==4.26.1
src/nn.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import tarfile
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import faiss
7
+ import gdown
8
+ import numpy as np
9
+ import open_clip
10
+ import torch
11
+ from open_clip.transformer import Transformer
12
+ from PIL import Image
13
+
14
+ from src.retrieval import ArrowMetadataProvider, meta_to_dict
15
+ from src.transforms import TextCompose, default_vocabulary_transforms
16
+
17
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+
20
+ RETRIEVAL_DATABASES = {
21
+ "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
22
+ }
23
+
24
+
25
+ class CaSED(torch.nn.Module):
26
+ """Torch module for Category Search from External Databases (CaSED).
27
+
28
+ Args:
29
+ index_name (str): Name of the faiss index to use.
30
+ vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
31
+ model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
32
+ pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".
33
+
34
+ Extra hparams:
35
+ alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
36
+ artifact_dir (str): Path to the directory where the databases are stored. Defaults to
37
+ "artifacts/".
38
+ retrieval_num_results (int): Number of results to return. Defaults to 10.
39
+ vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
40
+ tau (float): Temperature to use for the classifier. Defaults to 1.0.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ index_name: str = "ViT-L-14_CC12M",
46
+ vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
47
+ model_name: str = "ViT-L-14",
48
+ pretrained: str = "openai",
49
+ vocabulary_prompt: str = "{}",
50
+ **kwargs,
51
+ ):
52
+ super().__init__()
53
+ self._prev_vocab_words = None
54
+ self._prev_used_prompts = None
55
+ self._prev_vocab_words_z = None
56
+
57
+ model, _, preprocess = open_clip.create_model_and_transforms(
58
+ model_name, pretrained=pretrained, device="cpu"
59
+ )
60
+ tokenizer = open_clip.get_tokenizer(model_name)
61
+ self.tokenizer = tokenizer
62
+ self.preprocess = preprocess
63
+
64
+ kwargs["alpha"] = kwargs.get("alpha", 0.5)
65
+ kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
66
+ kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
67
+ vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
68
+ kwargs["vocabulary_prompts"] = [vocabulary_prompt]
69
+ kwargs["tau"] = kwargs.get("tau", 1.0)
70
+ self.hparams = kwargs
71
+
72
+ language_encoder = LanguageTransformer(
73
+ model.transformer,
74
+ model.token_embedding,
75
+ model.positional_embedding,
76
+ model.ln_final,
77
+ model.text_projection,
78
+ model.attn_mask,
79
+ )
80
+ scale = model.logit_scale.exp().item()
81
+ classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])
82
+
83
+ self.index_name = index_name
84
+ self.vocabulary_transforms = vocabulary_transforms
85
+ self.vision_encoder = model.visual
86
+ self.language_encoder = language_encoder
87
+ self.classifier = classifier
88
+
89
+ # download databases
90
+ self.prepare_data()
91
+
92
+ # load faiss indices
93
+ indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
94
+ indices_fp = indices_list_dir / "indices.json"
95
+ self.indices = json.load(open(indices_fp, "r"))
96
+
97
+ # load faiss indices and metadata providers
98
+ self.resources = {}
99
+ for name, index_fp in self.indices.items():
100
+ text_index_fp = Path(index_fp) / "text.index"
101
+ metadata_fp = Path(index_fp) / "metadata/"
102
+
103
+ text_index = faiss.read_index(
104
+ str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
105
+ )
106
+ metadata_provider = ArrowMetadataProvider(metadata_fp)
107
+
108
+ self.resources[name] = {
109
+ "device": DEVICE,
110
+ "model": model_name,
111
+ "text_index": text_index,
112
+ "metadata_provider": metadata_provider,
113
+ }
114
+
115
+ def prepare_data(self):
116
+ """Download data if needed."""
117
+ databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"
118
+
119
+ for name, url in RETRIEVAL_DATABASES.items():
120
+ database_path = Path(databases_path, name)
121
+ if database_path.exists():
122
+ continue
123
+
124
+ # download data
125
+ target_path = Path(databases_path, name + ".tar.gz")
126
+ try:
127
+ gdown.download(url, str(target_path), quiet=False)
128
+ tar = tarfile.open(target_path, "r:gz")
129
+ tar.extractall(target_path.parent)
130
+ tar.close()
131
+ target_path.unlink()
132
+ except FileNotFoundError:
133
+ print(f"Could not download {url}.")
134
+ print(f"Please download it manually and place it in {target_path.parent}.")
135
+
136
+ @torch.no_grad()
137
+ def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
138
+ # get the index
139
+ resources = self.resources[self.index_name]
140
+ text_index = resources["text_index"]
141
+ metadata_provider = resources["metadata_provider"]
142
+
143
+ # query the index
144
+ sample_z = sample_z.squeeze(0)
145
+ sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
146
+ query_input = sample_z.cpu().detach().numpy().tolist()
147
+ query = np.expand_dims(np.array(query_input).astype("float32"), 0)
148
+
149
+ distances, idxs, _ = text_index.search_and_reconstruct(
150
+ query, self.hparams["retrieval_num_results"]
151
+ )
152
+ results = idxs[0]
153
+ nb_results = np.where(results == -1)[0]
154
+ nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
155
+ indices = results[:nb_results]
156
+ distances = distances[0][:nb_results]
157
+
158
+ if len(distances) == 0:
159
+ return []
160
+
161
+ # get the metadata
162
+ results = []
163
+ metadata = metadata_provider.get(indices[:20], ["caption"])
164
+ for key, (d, i) in enumerate(zip(distances, indices)):
165
+ output = {}
166
+ meta = None if key + 1 > len(metadata) else metadata[key]
167
+ if meta is not None:
168
+ output.update(meta_to_dict(meta))
169
+ output["id"] = i.item()
170
+ output["similarity"] = d.item()
171
+ results.append(output)
172
+
173
+ # get the captions only
174
+ vocabularies = [result["caption"] for result in results]
175
+
176
+ return vocabularies
177
+
178
+ @torch.no_grad()
179
+ def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
180
+ """Encode a vocabulary.
181
+
182
+ Args:
183
+ vocabulary (list): List of words.
184
+ """
185
+ # check if vocabulary has changed
186
+ if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
187
+ return self._prev_vocab_words_z
188
+
189
+ # tokenize vocabulary
190
+ classes = [c.replace("_", " ") for c in vocabulary]
191
+ prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
192
+ texts_views = [[p.format(c) for c in classes] for p in prompts]
193
+ tokenized_texts_views = [
194
+ torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
195
+ for class_prompts in texts_views
196
+ ]
197
+ tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)
198
+
199
+ # encode vocabulary
200
+ T, C, _ = tokenized_texts_views.shape
201
+ texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
202
+ texts_z_views = texts_z_views.view(T, C, -1)
203
+ texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)
204
+
205
+ # cache vocabulary
206
+ self._prev_vocab_words = vocabulary
207
+ self._prev_used_prompts = use_prompts
208
+ self._prev_vocab_words_z = texts_z_views
209
+
210
+ return texts_z_views
211
+
212
+ @torch.no_grad()
213
+ def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
214
+ image = self.preprocess(Image.open(image_fp)).unsqueeze(0)
215
+ image_z = self.vision_encoder(image.to(DEVICE))
216
+
217
+ # get the vocabulary
218
+ vocabulary = self.query_index(image_z)
219
+
220
+ # generate a single text embedding from the unfiltered vocabulary
221
+ unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0)
222
+ text_z = unfiltered_vocabulary_z.mean(dim=0)
223
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
224
+ text_z = text_z.unsqueeze(0)
225
+
226
+ # filter the vocabulary, embed it, and get its mean embedding
227
+ vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
228
+ vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True)
229
+ mean_vocabulary_z = vocabulary_z.mean(dim=0)
230
+ mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True)
231
+
232
+ # get the image and text predictions
233
+ image_p = self.classifier(image_z, vocabulary_z)
234
+ text_p = self.classifier(text_z, vocabulary_z)
235
+
236
+ # average the image and text predictions
237
+ alpha = alpha or self.hparams["alpha"]
238
+ sample_p = alpha * image_p + (1 - alpha) * text_p
239
+
240
+ # get the scores
241
+ sample_p = sample_p.cpu()
242
+ scores = sample_p[0].tolist()
243
+
244
+ del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
245
+ del image_p, text_p, sample_p
246
+
247
+ return vocabulary, scores
248
+
249
+
250
+ class NearestNeighboursClassifier(torch.nn.Module):
251
+ """Nearest neighbours classifier.
252
+
253
+ It computes the similarity between the query and the supports using the
254
+ cosine similarity and then applies a softmax to obtain the logits.
255
+
256
+ Args:
257
+ scale (float): Scale for the logits of the query. Defaults to 1.0.
258
+ tau (float): Temperature for the softmax. Defaults to 1.0.
259
+ """
260
+
261
+ def __init__(self, scale: float = 1.0, tau: float = 1.0):
262
+ super().__init__()
263
+ self.scale = scale
264
+ self.tau = tau
265
+
266
+ def forward(self, query: torch.Tensor, supports: torch.Tensor):
267
+ query = query / query.norm(dim=-1, keepdim=True)
268
+ supports = supports / supports.norm(dim=-1, keepdim=True)
269
+
270
+ if supports.dim() == 2:
271
+ supports = supports.unsqueeze(0)
272
+
273
+ Q, _ = query.shape
274
+ N, C, _ = supports.shape
275
+
276
+ supports = supports.mean(dim=0)
277
+ supports = supports / supports.norm(dim=-1, keepdim=True)
278
+ similarity = self.scale * query @ supports.T
279
+ similarity = similarity / self.tau if self.tau != 1.0 else similarity
280
+ logits = similarity.softmax(dim=-1)
281
+
282
+ return logits
283
+
284
+
285
+ class LanguageTransformer(torch.nn.Module):
286
+ """Language Transformer for CLIP.
287
+
288
+ Args:
289
+ transformer (Transformer): Transformer model.
290
+ token_embedding (torch.nn.Embedding): Token embedding.
291
+ positional_embedding (torch.nn.Parameter): Positional embedding.
292
+ ln_final (torch.nn.LayerNorm): Layer norm.
293
+ text_projection (torch.nn.Parameter): Text projection.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ model: Transformer,
299
+ token_embedding: torch.nn.Embedding,
300
+ positional_embedding: torch.nn.Parameter,
301
+ ln_final: torch.nn.LayerNorm,
302
+ text_projection: torch.nn.Parameter,
303
+ attn_mask: torch.Tensor,
304
+ ):
305
+ super().__init__()
306
+ self.transformer = model
307
+ self.token_embedding = token_embedding
308
+ self.positional_embedding = positional_embedding
309
+ self.ln_final = ln_final
310
+ self.text_projection = text_projection
311
+
312
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
313
+
314
+ def forward(self, text: torch.Tensor) -> torch.Tensor:
315
+ cast_dtype = self.transformer.get_cast_dtype()
316
+
317
+ """Forward pass for the text encoder."""
318
+ x = self.token_embedding(text).to(cast_dtype)
319
+
320
+ x = x + self.positional_embedding.to(cast_dtype)
321
+ x = x.permute(1, 0, 2)
322
+ x = self.transformer(x, attn_mask=self.attn_mask)
323
+ x = x.permute(1, 0, 2)
324
+ x = self.ln_final(x)
325
+
326
+ # x.shape = [batch_size, n_ctx, transformer.width]
327
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
328
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
329
+
330
+ return x
src/retrieval.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pyarrow as pa
4
+ import numpy as np
5
+
6
+
7
+ class ArrowMetadataProvider:
8
+ """The arrow metadata provider provides metadata from contiguous ids using arrow.
9
+
10
+ Code taken from:
11
+ https://github.dev/rom1504/clip-retrieval
12
+ """
13
+
14
+ def __init__(self, arrow_folder):
15
+ arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
16
+ self.table = pa.concat_tables(
17
+ [
18
+ pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
19
+ for arrow_file in arrow_files
20
+ ]
21
+ )
22
+
23
+ def get(self, ids, cols=None):
24
+ """implement the get method from the arrow metadata provide, get metadata from ids"""
25
+ if cols is None:
26
+ cols = self.table.schema.names
27
+ else:
28
+ cols = list(set(self.table.schema.names) & set(cols))
29
+ t = pa.concat_tables([self.table[i:(i + 1)] for i in ids])
30
+ return t.select(cols).to_pandas().to_dict("records")
31
+
32
+
33
+ def meta_to_dict(meta):
34
+ """Convert a metadata list to a dictionary."""
35
+ output = {}
36
+ for k, v in meta.items():
37
+ if isinstance(v, bytes):
38
+ v = v.decode()
39
+ elif type(v).__module__ == np.__name__:
40
+ v = v.item()
41
+ output[k] = v
42
+ return output
src/transforms.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Optional, Union, cast
4
+
5
+ import inflect
6
+ import nltk
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+ from flair.data import Sentence
13
+ from flair.models import SequenceTagger
14
+
15
+ __all__ = [
16
+ "DynamicResize",
17
+ "DropFileExtensions",
18
+ "DropNonAlpha",
19
+ "DropShortWords",
20
+ "DropSpecialCharacters",
21
+ "DropTokens",
22
+ "DropURLs",
23
+ "DropWords",
24
+ "FilterPOS",
25
+ "FrequencyMinWordCount",
26
+ "FrequencyTopK",
27
+ "ReplaceSeparators",
28
+ "ToRGBTensor",
29
+ "ToLowercase",
30
+ "ToSingular",
31
+ ]
32
+
33
+
34
+ class BaseTextTransform(ABC):
35
+ """Base class for string transforms."""
36
+
37
+ @abstractmethod
38
+ def __call__(self, text: str):
39
+ raise NotImplementedError
40
+
41
+ def __repr__(self) -> str:
42
+ return f"{self.__class__.__name__}()"
43
+
44
+
45
+ class DynamicResize(T.Resize):
46
+ """Resize the input PIL Image to the given size.
47
+
48
+ Extends the torchvision Resize transform to dynamically evaluate the second dimension of the
49
+ output size based on the aspect ratio of the first input image.
50
+ """
51
+
52
+ def forward(self, img):
53
+ if isinstance(self.size, int):
54
+ _, h, w = F.get_dimensions(img)
55
+ aspect_ratio = w / h
56
+ side = self.size
57
+
58
+ if aspect_ratio < 1.0:
59
+ self.size = int(side / aspect_ratio), side
60
+ else:
61
+ self.size = side, int(side * aspect_ratio)
62
+
63
+ return super().forward(img)
64
+
65
+
66
+ class DropFileExtensions(BaseTextTransform):
67
+ """Remove file extensions from the input text."""
68
+
69
+ def __call__(self, text: str):
70
+ """
71
+ Args:
72
+ text (str): Text to remove file extensions from.
73
+ """
74
+ text = re.sub(r"\.\w+", "", text)
75
+
76
+ return text
77
+
78
+
79
+ class DropNonAlpha(BaseTextTransform):
80
+ """Remove non-alpha words from the input text."""
81
+
82
+ def __call__(self, text: str):
83
+ """
84
+ Args:
85
+ text (str): Text to remove non-alpha words from.
86
+ """
87
+ text = re.sub(r"[^a-zA-Z\s]", "", text)
88
+
89
+ return text
90
+
91
+
92
+ class DropShortWords(BaseTextTransform):
93
+ """Remove short words from the input text.
94
+
95
+ Args:
96
+ min_length (int): Minimum length of words to keep.
97
+ """
98
+
99
+ def __init__(self, min_length) -> None:
100
+ super().__init__()
101
+ self.min_length = min_length
102
+
103
+ def __call__(self, text: str):
104
+ """
105
+ Args:
106
+ text (str): Text to remove short words from.
107
+ """
108
+ text = " ".join([word for word in text.split() if len(word) >= self.min_length])
109
+
110
+ return text
111
+
112
+ def __repr__(self) -> str:
113
+ return f"{self.__class__.__name__}(min_length={self.min_length})"
114
+
115
+
116
+ class DropSpecialCharacters(BaseTextTransform):
117
+ """Remove special characters from the input text.
118
+
119
+ Special characters are defined as any character that is not a word character, whitespace,
120
+ hyphen, period, apostrophe, or ampersand.
121
+ """
122
+
123
+ def __call__(self, text: str):
124
+ """
125
+ Args:
126
+ text (str): Text to remove special characters from.
127
+ """
128
+ text = re.sub(r"[^\w\s\-\.\'\&]", "", text)
129
+
130
+ return text
131
+
132
+
133
+ class DropTokens(BaseTextTransform):
134
+ """Remove tokens from the input text.
135
+
136
+ Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
137
+ """
138
+
139
+ def __call__(self, text: str):
140
+ """
141
+ Args:
142
+ text (str): Text to remove tokens from.
143
+ """
144
+ text = re.sub(r"<[^>]+>", "", text)
145
+
146
+ return text
147
+
148
+
149
+ class DropURLs(BaseTextTransform):
150
+ """Remove URLs from the input text."""
151
+
152
+ def __call__(self, text: str):
153
+ """
154
+ Args:
155
+ text (str): Text to remove URLs from.
156
+ """
157
+ text = re.sub(r"http\S+", "", text)
158
+
159
+ return text
160
+
161
+
162
+ class DropWords(BaseTextTransform):
163
+ """Remove words from the input text.
164
+
165
+ It is case-insensitive and supports singular and plural forms of the words.
166
+ """
167
+
168
+ def __init__(self, words: list[str]) -> None:
169
+ super().__init__()
170
+ self.words = words
171
+ self.pattern = r"\b(?:{})\b".format("|".join(words))
172
+
173
+ def __call__(self, text: str):
174
+ """
175
+ Args:
176
+ text (str): Text to remove words from.
177
+ """
178
+ text = re.sub(self.pattern, "", text, flags=re.IGNORECASE)
179
+
180
+ return text
181
+
182
+ def __repr__(self) -> str:
183
+ return f"{self.__class__.__name__}(pattern={self.pattern})"
184
+
185
+
186
+ class FilterPOS(BaseTextTransform):
187
+ """Filter words by POS tags.
188
+
189
+ Args:
190
+ tags (list): List of POS tags to remove.
191
+ engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk".
192
+ keep_compound_nouns (bool): Whether to keep composed words. Defaults to True.
193
+ """
194
+
195
+ def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None:
196
+ super().__init__()
197
+ self.tags = tags
198
+ self.engine = engine
199
+ self.keep_compound_nouns = keep_compound_nouns
200
+
201
+ if engine == "nltk":
202
+ nltk.download("averaged_perceptron_tagger", quiet=True)
203
+ nltk.download("punkt", quiet=True)
204
+ self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
205
+ elif engine == "flair":
206
+ self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
207
+
208
+ def __call__(self, text: str):
209
+ """
210
+ Args:
211
+ text (str): Text to remove words with specific POS tags from.
212
+ """
213
+ if self.engine == "nltk":
214
+ word_tags = self.tagger(text)
215
+ text = " ".join([word for word, tag in word_tags if tag not in self.tags])
216
+ elif self.engine == "flair":
217
+ sentence = Sentence(text)
218
+ self.tagger(sentence)
219
+ text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
220
+
221
+ if self.keep_compound_nouns:
222
+ compound_nouns = []
223
+
224
+ if self.engine == "nltk":
225
+ for i in range(len(word_tags) - 1):
226
+ if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN":
227
+ # if they are the same word, skip
228
+ if word_tags[i][0] == word_tags[i + 1][0]:
229
+ continue
230
+
231
+ compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0]
232
+ compound_nouns.append(compound_noun)
233
+ elif self.engine == "flair":
234
+ for i in range(len(sentence.tokens) - 1):
235
+ if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN":
236
+ # if they are the same word, skip
237
+ if sentence.tokens[i].text == sentence.tokens[i + 1].text:
238
+ continue
239
+
240
+ compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text
241
+ compound_nouns.append(compound_noun)
242
+
243
+ text = " ".join([text, " ".join(compound_nouns)])
244
+
245
+ return text
246
+
247
+ def __repr__(self) -> str:
248
+ return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})"
249
+
250
+
251
+ class FrequencyMinWordCount(BaseTextTransform):
252
+ """Keep only words that occur more than a minimum number of times in the input text.
253
+
254
+ If the threshold is too strong and no words pass the threshold, the threshold is reduced to
255
+ the most frequent word.
256
+
257
+ Args:
258
+ min_count (int): Minimum number of occurrences of a word to keep.
259
+ """
260
+
261
+ def __init__(self, min_count) -> None:
262
+ super().__init__()
263
+ self.min_count = min_count
264
+
265
+ def __call__(self, text: str):
266
+ """
267
+ Args:
268
+ text (str): Text to remove infrequent words from.
269
+ """
270
+ if self.min_count <= 1:
271
+ return text
272
+
273
+ words = text.split()
274
+ word_counts = {word: words.count(word) for word in words}
275
+
276
+ # if nothing passes the threshold, reduce the threshold to the most frequent word
277
+ max_word_count = max(word_counts.values() or [0])
278
+ min_count = max_word_count if self.min_count > max_word_count else self.min_count
279
+
280
+ text = " ".join([word for word in words if word_counts[word] >= min_count])
281
+
282
+ return text
283
+
284
+ def __repr__(self) -> str:
285
+ return f"{self.__class__.__name__}(min_count={self.min_count})"
286
+
287
+
288
+ class FrequencyTopK(BaseTextTransform):
289
+ """Keep only the top k most frequent words in the input text.
290
+
291
+ In case of a tie, all words with the same count as the last word are kept.
292
+
293
+ Args:
294
+ top_k (int): Number of top words to keep.
295
+ """
296
+
297
+ def __init__(self, top_k: int) -> None:
298
+ super().__init__()
299
+ self.top_k = top_k
300
+
301
+ def __call__(self, text: str):
302
+ """
303
+ Args:
304
+ text (str): Text to remove infrequent words from.
305
+ """
306
+ if self.top_k < 1:
307
+ return text
308
+
309
+ words = text.split()
310
+ word_counts = {word: words.count(word) for word in words}
311
+ top_words = sorted(word_counts, key=word_counts.get, reverse=True)
312
+
313
+ # in case of a tie, keep all words with the same count
314
+ top_words = top_words[: self.top_k]
315
+ top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]]
316
+
317
+ text = " ".join([word for word in words if word in top_words])
318
+
319
+ return text
320
+
321
+ def __repr__(self) -> str:
322
+ return f"{self.__class__.__name__}(top_k={self.top_k})"
323
+
324
+
325
+ class ReplaceSeparators(BaseTextTransform):
326
+ """Replace underscores and dashes with spaces."""
327
+
328
+ def __call__(self, text: str):
329
+ """
330
+ Args:
331
+ text (str): Text to replace separators in.
332
+ """
333
+ text = re.sub(r"[_\-]", " ", text)
334
+
335
+ return text
336
+
337
+ def __repr__(self) -> str:
338
+ return f"{self.__class__.__name__}()"
339
+
340
+
341
+ class RemoveDuplicates(BaseTextTransform):
342
+ """Remove duplicate words from the input text."""
343
+
344
+ def __call__(self, text: str):
345
+ """
346
+ Args:
347
+ text (str): Text to remove duplicate words from.
348
+ """
349
+ text = " ".join(list(set(text.split())))
350
+
351
+ return text
352
+
353
+
354
+ class TextCompose:
355
+ """Compose several transforms together.
356
+
357
+ It differs from the torchvision.transforms.Compose class in that it applies the transforms to
358
+ a string instead of a PIL Image or Tensor. In addition, it automatically join the list of
359
+ input strings into a single string and splits the output string into a list of words.
360
+
361
+ Args:
362
+ transforms (list): List of transforms to compose.
363
+ """
364
+
365
+ def __init__(self, transforms: list[BaseTextTransform]) -> None:
366
+ self.transforms = transforms
367
+
368
+ def __call__(self, text: Union[str, list[str]]) -> Any:
369
+ if isinstance(text, list):
370
+ text = " ".join(text)
371
+
372
+ for t in self.transforms:
373
+ text = t(text)
374
+ return text.split()
375
+
376
+ def __repr__(self) -> str:
377
+ format_string = self.__class__.__name__ + "("
378
+ for t in self.transforms:
379
+ format_string += "\n"
380
+ format_string += f" {t}"
381
+ format_string += "\n)"
382
+ return format_string
383
+
384
+
385
+ class ToRGBTensor(T.ToTensor):
386
+ """Convert a `PIL Image` or `numpy.ndarray` to tensor.
387
+
388
+ Compared with the torchvision `ToTensor` transform, it converts images with a single channel to
389
+ RGB images. In addition, the conversion to tensor is done only if the input is not already a
390
+ tensor.
391
+ """
392
+
393
+ def __call__(self, pic: Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
394
+ """
395
+ Args:
396
+ pic (PIL Image | numpy.ndarray | torch.Tensor): Image to be converted to tensor.
397
+ """
398
+ img = pic if isinstance(pic, torch.Tensor) else F.to_tensor(pic)
399
+ img = cast(torch.Tensor, img)
400
+
401
+ if img.shape[0] == 1:
402
+ img = img.repeat(3, 1, 1)
403
+
404
+ return img
405
+
406
+ def __repr__(self) -> str:
407
+ return f"{self.__class__.__name__}()"
408
+
409
+
410
+ class ToLowercase(BaseTextTransform):
411
+ """Convert text to lowercase."""
412
+
413
+ def __call__(self, text: str):
414
+ """
415
+ Args:
416
+ text (str): Text to convert to lowercase.
417
+ """
418
+ text = text.lower()
419
+
420
+ return text
421
+
422
+
423
+ class ToSingular(BaseTextTransform):
424
+ """Convert plural words to singular form."""
425
+
426
+ def __init__(self) -> None:
427
+ super().__init__()
428
+ self.transform = inflect.engine().singular_noun
429
+
430
+ def __call__(self, text: str):
431
+ """
432
+ Args:
433
+ text (str): Text to convert to singular form.
434
+ """
435
+ words = text.split()
436
+ for i, word in enumerate(words):
437
+ if not word.endswith("s"):
438
+ continue
439
+
440
+ if word[-2:] in ["ss", "us", "is"]:
441
+ continue
442
+
443
+ if word[-3:] in ["ies", "oes"]:
444
+ continue
445
+
446
+ words[i] = self.transform(word) or word
447
+
448
+ text = " ".join(words)
449
+
450
+ return text
451
+
452
+ def __repr__(self) -> str:
453
+ return f"{self.__class__.__name__}()"
454
+
455
+
456
+ def default_preprocess(size: Optional[int] = None) -> T.Compose:
457
+ """Preprocess input images with preprocessing transforms.
458
+
459
+ Args:
460
+ size (int): Size to resize image to.
461
+ """
462
+ transforms = []
463
+ if size is not None:
464
+ transforms.append(DynamicResize(size, interpolation=T.InterpolationMode.BICUBIC))
465
+ transforms.append(ToRGBTensor())
466
+ transforms = T.Compose(transforms)
467
+
468
+ return transforms
469
+
470
+
471
+ def default_vocabulary_transforms() -> TextCompose:
472
+ """Preprocess input text with preprocessing transforms."""
473
+ words_to_drop = [
474
+ "image",
475
+ "photo",
476
+ "picture",
477
+ "thumbnail",
478
+ "logo",
479
+ "symbol",
480
+ "clipart",
481
+ "portrait",
482
+ "painting",
483
+ "illustration",
484
+ "icon",
485
+ "profile",
486
+ ]
487
+ pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"]
488
+
489
+ transforms = []
490
+ transforms.append(DropTokens())
491
+ transforms.append(DropURLs())
492
+ transforms.append(DropSpecialCharacters())
493
+ transforms.append(DropFileExtensions())
494
+ transforms.append(ReplaceSeparators())
495
+ transforms.append(DropShortWords(min_length=3))
496
+ transforms.append(DropNonAlpha())
497
+ transforms.append(ToLowercase())
498
+ transforms.append(ToSingular())
499
+ transforms.append(DropWords(words=words_to_drop))
500
+ transforms.append(FrequencyMinWordCount(min_count=2))
501
+ transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False))
502
+ transforms.append(RemoveDuplicates())
503
+
504
+ transforms = TextCompose(transforms)
505
+
506
+ return transforms