import json import tarfile from pathlib import Path from typing import Optional import faiss import gdown import numpy as np import open_clip import torch from open_clip.transformer import Transformer from PIL import Image from src.retrieval import ArrowMetadataProvider, meta_to_dict from src.transforms import TextCompose, default_vocabulary_transforms DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") RETRIEVAL_DATABASES = { "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t", } class CaSED(torch.nn.Module): """Torch module for Category Search from External Databases (CaSED). Args: index_name (str): Name of the faiss index to use. vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary. model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14". pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai". Extra hparams: alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5. artifact_dir (str): Path to the directory where the databases are stored. Defaults to "artifacts/". retrieval_num_results (int): Number of results to return. Defaults to 10. vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}". tau (float): Temperature to use for the classifier. Defaults to 1.0. """ def __init__( self, index_name: str = "ViT-L-14_CC12M", vocabulary_transforms: TextCompose = default_vocabulary_transforms(), model_name: str = "ViT-L-14", pretrained: str = "openai", vocabulary_prompt: str = "{}", **kwargs, ): super().__init__() self._prev_vocab_words = None self._prev_used_prompts = None self._prev_vocab_words_z = None model, _, preprocess = open_clip.create_model_and_transforms( model_name, pretrained=pretrained, device="cpu" ) tokenizer = open_clip.get_tokenizer(model_name) self.tokenizer = tokenizer self.preprocess = preprocess kwargs["alpha"] = kwargs.get("alpha", 0.5) kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/") kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10) vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}") kwargs["vocabulary_prompts"] = [vocabulary_prompt] kwargs["tau"] = kwargs.get("tau", 1.0) self.hparams = kwargs language_encoder = LanguageTransformer( model.transformer, model.token_embedding, model.positional_embedding, model.ln_final, model.text_projection, model.attn_mask, ) scale = model.logit_scale.exp().item() classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"]) self.index_name = index_name self.vocabulary_transforms = vocabulary_transforms self.vision_encoder = model.visual self.language_encoder = language_encoder self.classifier = classifier # download databases self.prepare_data() # load faiss indices indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval" indices_fp = indices_list_dir / "indices.json" self.indices = json.load(open(indices_fp, "r")) # load faiss indices and metadata providers self.resources = {} for name, index_fp in self.indices.items(): text_index_fp = Path(index_fp) / "text.index" metadata_fp = Path(index_fp) / "metadata/" text_index = faiss.read_index( str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY ) metadata_provider = ArrowMetadataProvider(metadata_fp) self.resources[name] = { "device": DEVICE, "model": model_name, "text_index": text_index, "metadata_provider": metadata_provider, } def prepare_data(self): """Download data if needed.""" databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases" for name, url in RETRIEVAL_DATABASES.items(): database_path = Path(databases_path, name) if database_path.exists(): continue # download data target_path = Path(databases_path, name + ".tar.gz") try: gdown.download(url, str(target_path), quiet=False) tar = tarfile.open(target_path, "r:gz") tar.extractall(target_path.parent) tar.close() target_path.unlink() except FileNotFoundError: print(f"Could not download {url}.") print(f"Please download it manually and place it in {target_path.parent}.") @torch.no_grad() def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: # get the index resources = self.resources[self.index_name] text_index = resources["text_index"] metadata_provider = resources["metadata_provider"] # query the index sample_z = sample_z.squeeze(0) sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) query_input = sample_z.cpu().detach().numpy().tolist() query = np.expand_dims(np.array(query_input).astype("float32"), 0) distances, idxs, _ = text_index.search_and_reconstruct( query, self.hparams["retrieval_num_results"] ) results = idxs[0] nb_results = np.where(results == -1)[0] nb_results = nb_results[0] if len(nb_results) > 0 else len(results) indices = results[:nb_results] distances = distances[0][:nb_results] if len(distances) == 0: return [] # get the metadata results = [] metadata = metadata_provider.get(indices[:20], ["caption"]) for key, (d, i) in enumerate(zip(distances, indices)): output = {} meta = None if key + 1 > len(metadata) else metadata[key] if meta is not None: output.update(meta_to_dict(meta)) output["id"] = i.item() output["similarity"] = d.item() results.append(output) # get the captions only vocabularies = [result["caption"] for result in results] return vocabularies @torch.no_grad() def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor: """Encode a vocabulary. Args: vocabulary (list): List of words. """ # check if vocabulary has changed if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts: return self._prev_vocab_words_z # tokenize vocabulary classes = [c.replace("_", " ") for c in vocabulary] prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"] texts_views = [[p.format(c) for c in classes] for p in prompts] tokenized_texts_views = [ torch.cat([self.tokenizer(prompt) for prompt in class_prompts]) for class_prompts in texts_views ] tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE) # encode vocabulary T, C, _ = tokenized_texts_views.shape texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1)) texts_z_views = texts_z_views.view(T, C, -1) texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True) # cache vocabulary self._prev_vocab_words = vocabulary self._prev_used_prompts = use_prompts self._prev_vocab_words_z = texts_z_views return texts_z_views @torch.no_grad() def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor(): image = self.preprocess(Image.open(image_fp)).unsqueeze(0) image_z = self.vision_encoder(image.to(DEVICE)) # get the vocabulary vocabulary = self.query_index(image_z) # generate a single text embedding from the unfiltered vocabulary unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0) text_z = unfiltered_vocabulary_z.mean(dim=0) text_z = text_z / text_z.norm(dim=-1, keepdim=True) text_z = text_z.unsqueeze(0) # filter the vocabulary, embed it, and get its mean embedding vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True) mean_vocabulary_z = vocabulary_z.mean(dim=0) mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True) # get the image and text predictions image_p = self.classifier(image_z, vocabulary_z) text_p = self.classifier(text_z, vocabulary_z) # average the image and text predictions alpha = alpha or self.hparams["alpha"] sample_p = alpha * image_p + (1 - alpha) * text_p # get the scores sample_p = sample_p.cpu() scores = sample_p[0].tolist() del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z del image_p, text_p, sample_p return vocabulary, scores class NearestNeighboursClassifier(torch.nn.Module): """Nearest neighbours classifier. It computes the similarity between the query and the supports using the cosine similarity and then applies a softmax to obtain the logits. Args: scale (float): Scale for the logits of the query. Defaults to 1.0. tau (float): Temperature for the softmax. Defaults to 1.0. """ def __init__(self, scale: float = 1.0, tau: float = 1.0): super().__init__() self.scale = scale self.tau = tau def forward(self, query: torch.Tensor, supports: torch.Tensor): query = query / query.norm(dim=-1, keepdim=True) supports = supports / supports.norm(dim=-1, keepdim=True) if supports.dim() == 2: supports = supports.unsqueeze(0) Q, _ = query.shape N, C, _ = supports.shape supports = supports.mean(dim=0) supports = supports / supports.norm(dim=-1, keepdim=True) similarity = self.scale * query @ supports.T similarity = similarity / self.tau if self.tau != 1.0 else similarity logits = similarity.softmax(dim=-1) return logits class LanguageTransformer(torch.nn.Module): """Language Transformer for CLIP. Args: transformer (Transformer): Transformer model. token_embedding (torch.nn.Embedding): Token embedding. positional_embedding (torch.nn.Parameter): Positional embedding. ln_final (torch.nn.LayerNorm): Layer norm. text_projection (torch.nn.Parameter): Text projection. """ def __init__( self, model: Transformer, token_embedding: torch.nn.Embedding, positional_embedding: torch.nn.Parameter, ln_final: torch.nn.LayerNorm, text_projection: torch.nn.Parameter, attn_mask: torch.Tensor, ): super().__init__() self.transformer = model self.token_embedding = token_embedding self.positional_embedding = positional_embedding self.ln_final = ln_final self.text_projection = text_projection self.register_buffer("attn_mask", attn_mask, persistent=False) def forward(self, text: torch.Tensor) -> torch.Tensor: cast_dtype = self.transformer.get_cast_dtype() """Forward pass for the text encoder.""" x = self.token_embedding(text).to(cast_dtype) x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x