File size: 12,382 Bytes
a3ee979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import json
import tarfile
from pathlib import Path
from typing import Optional

import faiss
import gdown
import numpy as np
import open_clip
import torch
from open_clip.transformer import Transformer
from PIL import Image

from src.retrieval import ArrowMetadataProvider, meta_to_dict
from src.transforms import TextCompose, default_vocabulary_transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


RETRIEVAL_DATABASES = {
    "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
}


class CaSED(torch.nn.Module):
    """Torch module for Category Search from External Databases (CaSED).

    Args:
        index_name (str): Name of the faiss index to use.
        vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
        model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
        pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".

    Extra hparams:
        alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
        artifact_dir (str): Path to the directory where the databases are stored. Defaults to
            "artifacts/".
        retrieval_num_results (int): Number of results to return. Defaults to 10.
        vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
        tau (float): Temperature to use for the classifier. Defaults to 1.0.
    """

    def __init__(
        self,
        index_name: str = "ViT-L-14_CC12M",
        vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
        model_name: str = "ViT-L-14",
        pretrained: str = "openai",
        vocabulary_prompt: str = "{}",
        **kwargs,
    ):
        super().__init__()
        self._prev_vocab_words = None
        self._prev_used_prompts = None
        self._prev_vocab_words_z = None

        model, _, preprocess = open_clip.create_model_and_transforms(
            model_name, pretrained=pretrained, device="cpu"
        )
        tokenizer = open_clip.get_tokenizer(model_name)
        self.tokenizer = tokenizer
        self.preprocess = preprocess

        kwargs["alpha"] = kwargs.get("alpha", 0.5)
        kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
        kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
        vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
        kwargs["vocabulary_prompts"] = [vocabulary_prompt]
        kwargs["tau"] = kwargs.get("tau", 1.0)
        self.hparams = kwargs

        language_encoder = LanguageTransformer(
            model.transformer,
            model.token_embedding,
            model.positional_embedding,
            model.ln_final,
            model.text_projection,
            model.attn_mask,
        )
        scale = model.logit_scale.exp().item()
        classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])

        self.index_name = index_name
        self.vocabulary_transforms = vocabulary_transforms
        self.vision_encoder = model.visual
        self.language_encoder = language_encoder
        self.classifier = classifier

        # download databases
        self.prepare_data()

        # load faiss indices
        indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
        indices_fp = indices_list_dir / "indices.json"
        self.indices = json.load(open(indices_fp, "r"))

        # load faiss indices and metadata providers
        self.resources = {}
        for name, index_fp in self.indices.items():
            text_index_fp = Path(index_fp) / "text.index"
            metadata_fp = Path(index_fp) / "metadata/"

            text_index = faiss.read_index(
                str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
            )
            metadata_provider = ArrowMetadataProvider(metadata_fp)

            self.resources[name] = {
                "device": DEVICE,
                "model": model_name,
                "text_index": text_index,
                "metadata_provider": metadata_provider,
            }

    def prepare_data(self):
        """Download data if needed."""
        databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"

        for name, url in RETRIEVAL_DATABASES.items():
            database_path = Path(databases_path, name)
            if database_path.exists():
                continue

            # download data
            target_path = Path(databases_path, name + ".tar.gz")
            try:
                gdown.download(url, str(target_path), quiet=False)
                tar = tarfile.open(target_path, "r:gz")
                tar.extractall(target_path.parent)
                tar.close()
                target_path.unlink()
            except FileNotFoundError:
                print(f"Could not download {url}.")
                print(f"Please download it manually and place it in {target_path.parent}.")

    @torch.no_grad()
    def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
        # get the index
        resources = self.resources[self.index_name]
        text_index = resources["text_index"]
        metadata_provider = resources["metadata_provider"]

        # query the index
        sample_z = sample_z.squeeze(0)
        sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
        query_input = sample_z.cpu().detach().numpy().tolist()
        query = np.expand_dims(np.array(query_input).astype("float32"), 0)

        distances, idxs, _ = text_index.search_and_reconstruct(
            query, self.hparams["retrieval_num_results"]
        )
        results = idxs[0]
        nb_results = np.where(results == -1)[0]
        nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
        indices = results[:nb_results]
        distances = distances[0][:nb_results]

        if len(distances) == 0:
            return []

        # get the metadata
        results = []
        metadata = metadata_provider.get(indices[:20], ["caption"])
        for key, (d, i) in enumerate(zip(distances, indices)):
            output = {}
            meta = None if key + 1 > len(metadata) else metadata[key]
            if meta is not None:
                output.update(meta_to_dict(meta))
            output["id"] = i.item()
            output["similarity"] = d.item()
            results.append(output)

        # get the captions only
        vocabularies = [result["caption"] for result in results]

        return vocabularies

    @torch.no_grad()
    def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
        """Encode a vocabulary.

        Args:
            vocabulary (list): List of words.
        """
        # check if vocabulary has changed
        if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
            return self._prev_vocab_words_z

        # tokenize vocabulary
        classes = [c.replace("_", " ") for c in vocabulary]
        prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
        texts_views = [[p.format(c) for c in classes] for p in prompts]
        tokenized_texts_views = [
            torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
            for class_prompts in texts_views
        ]
        tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)

        # encode vocabulary
        T, C, _ = tokenized_texts_views.shape
        texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
        texts_z_views = texts_z_views.view(T, C, -1)
        texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)

        # cache vocabulary
        self._prev_vocab_words = vocabulary
        self._prev_used_prompts = use_prompts
        self._prev_vocab_words_z = texts_z_views

        return texts_z_views

    @torch.no_grad()
    def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
        image = self.preprocess(Image.open(image_fp)).unsqueeze(0)
        image_z = self.vision_encoder(image.to(DEVICE))

        # get the vocabulary
        vocabulary = self.query_index(image_z)

        # generate a single text embedding from the unfiltered vocabulary
        unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0)
        text_z = unfiltered_vocabulary_z.mean(dim=0)
        text_z = text_z / text_z.norm(dim=-1, keepdim=True)
        text_z = text_z.unsqueeze(0)

        # filter the vocabulary, embed it, and get its mean embedding
        vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
        vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True)
        mean_vocabulary_z = vocabulary_z.mean(dim=0)
        mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True)

        # get the image and text predictions
        image_p = self.classifier(image_z, vocabulary_z)
        text_p = self.classifier(text_z, vocabulary_z)

        # average the image and text predictions
        alpha = alpha or self.hparams["alpha"]
        sample_p = alpha * image_p + (1 - alpha) * text_p

        # get the scores
        sample_p = sample_p.cpu()
        scores = sample_p[0].tolist()

        del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
        del image_p, text_p, sample_p

        return vocabulary, scores


class NearestNeighboursClassifier(torch.nn.Module):
    """Nearest neighbours classifier.

    It computes the similarity between the query and the supports using the
    cosine similarity and then applies a softmax to obtain the logits.

    Args:
        scale (float): Scale for the logits of the query. Defaults to 1.0.
        tau (float): Temperature for the softmax. Defaults to 1.0.
    """

    def __init__(self, scale: float = 1.0, tau: float = 1.0):
        super().__init__()
        self.scale = scale
        self.tau = tau

    def forward(self, query: torch.Tensor, supports: torch.Tensor):
        query = query / query.norm(dim=-1, keepdim=True)
        supports = supports / supports.norm(dim=-1, keepdim=True)

        if supports.dim() == 2:
            supports = supports.unsqueeze(0)

        Q, _ = query.shape
        N, C, _ = supports.shape

        supports = supports.mean(dim=0)
        supports = supports / supports.norm(dim=-1, keepdim=True)
        similarity = self.scale * query @ supports.T
        similarity = similarity / self.tau if self.tau != 1.0 else similarity
        logits = similarity.softmax(dim=-1)

        return logits


class LanguageTransformer(torch.nn.Module):
    """Language Transformer for CLIP.

    Args:
        transformer (Transformer): Transformer model.
        token_embedding (torch.nn.Embedding): Token embedding.
        positional_embedding (torch.nn.Parameter): Positional embedding.
        ln_final (torch.nn.LayerNorm): Layer norm.
        text_projection (torch.nn.Parameter): Text projection.
    """

    def __init__(
        self,
        model: Transformer,
        token_embedding: torch.nn.Embedding,
        positional_embedding: torch.nn.Parameter,
        ln_final: torch.nn.LayerNorm,
        text_projection: torch.nn.Parameter,
        attn_mask: torch.Tensor,
    ):
        super().__init__()
        self.transformer = model
        self.token_embedding = token_embedding
        self.positional_embedding = positional_embedding
        self.ln_final = ln_final
        self.text_projection = text_projection

        self.register_buffer("attn_mask", attn_mask, persistent=False)

    def forward(self, text: torch.Tensor) -> torch.Tensor:
        cast_dtype = self.transformer.get_cast_dtype()

        """Forward pass for the text encoder."""
        x = self.token_embedding(text).to(cast_dtype)

        x = x + self.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x