import json import tarfile from pathlib import Path from typing import Optional import faiss import gdown import numpy as np import torch from PIL import Image from transformers import CLIPModel, CLIPProcessor from src.retrieval import ArrowMetadataProvider from src.transforms import TextCompose, default_vocabulary_transforms DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") RETRIEVAL_DATABASES = { "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t", } class CaSED(torch.nn.Module): """Torch module for Category Search from External Databases (CaSED). Args: index_name (str): Name of the faiss index to use. vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary. Extra hparams: alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5. artifact_dir (str): Path to the directory where the databases are stored. Defaults to "artifacts/". retrieval_num_results (int): Number of results to return. Defaults to 10. """ def __init__( self, index_name: str = "ViT-L-14_CC12M", vocabulary_transforms: TextCompose = default_vocabulary_transforms(), **kwargs, ): super().__init__() # load CLIP model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE) self.index_name = index_name self.vocabulary_transforms = vocabulary_transforms self.vision_encoder = model.vision_model self.vision_proj = model.visual_projection self.language_encoder = model.text_model self.language_proj = model.text_projection self.logit_scale = model.logit_scale.exp() self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") # set hparams kwargs["alpha"] = kwargs.get("alpha", 0.5) kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/") kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10) self.hparams = kwargs # download databases self.prepare_data() # load faiss indices and metadata providers indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval" indices_fp = indices_list_dir / "indices.json" self.indices = json.load(open(indices_fp)) self.resources = {} for name, index_fp in self.indices.items(): text_index_fp = Path(index_fp) / "text.index" metadata_fp = Path(index_fp) / "metadata/" text_index = faiss.read_index( str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY ) metadata_provider = ArrowMetadataProvider(metadata_fp) self.resources[name] = { "device": DEVICE, "model": "ViT-L-14", "text_index": text_index, "metadata_provider": metadata_provider, } def prepare_data(self): """Download data if needed.""" databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases" for name, url in RETRIEVAL_DATABASES.items(): database_path = Path(databases_path, name) if database_path.exists(): continue # download data target_path = Path(databases_path, name + ".tar.gz") try: gdown.download(url, str(target_path), quiet=False) tar = tarfile.open(target_path, "r:gz") tar.extractall(target_path.parent) tar.close() target_path.unlink() except FileNotFoundError: print(f"Could not download {url}.") print(f"Please download it manually and place it in {target_path.parent}.") @torch.no_grad() def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: # get the index resources = self.resources[self.index_name] text_index = resources["text_index"] metadata_provider = resources["metadata_provider"] # query the index sample_z = sample_z.squeeze(0) sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) query_input = sample_z.cpu().detach().numpy().tolist() query = np.expand_dims(np.array(query_input).astype("float32"), 0) distances, idxs, _ = text_index.search_and_reconstruct( query, self.hparams["retrieval_num_results"] ) results = idxs[0] nb_results = np.where(results == -1)[0] nb_results = nb_results[0] if len(nb_results) > 0 else len(results) indices = results[:nb_results] distances = distances[0][:nb_results] if len(distances) == 0: return [] # get the metadata results = [] metadata = metadata_provider.get(indices[:20], ["caption"]) for key, (d, i) in enumerate(zip(distances, indices)): output = {} meta = None if key + 1 > len(metadata) else metadata[key] if meta is not None: output.update(meta) output["id"] = i.item() output["similarity"] = d.item() results.append(output) # get the captions only vocabularies = [result["caption"] for result in results] return vocabularies @torch.no_grad() def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor(): # forward the image image = self.processor(images=Image.open(image_fp), return_tensors="pt") image["pixel_values"] = image["pixel_values"].to(DEVICE) image_z = self.vision_proj(self.vision_encoder(**image)[1]) # generate a single text embedding from the unfiltered vocabulary vocabulary = self.query_index(image_z) text = self.processor(text=vocabulary, return_tensors="pt", padding=True) text["input_ids"] = text["input_ids"][:, :77].to(DEVICE) text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE) text_z = self.language_encoder(**text)[1] text_z = self.language_proj(text_z) # filter the vocabulary, embed it, and get its mean embedding vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] text = self.processor(text=vocabulary, return_tensors="pt", padding=True) text = {k: v.to(DEVICE) for k, v in text.items()} vocabulary_z = self.language_encoder(**text)[1] vocabulary_z = self.language_proj(vocabulary_z) vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True) # get the image and text predictions image_z = image_z / image_z.norm(dim=-1, keepdim=True) text_z = text_z / text_z.norm(dim=-1, keepdim=True) image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1) text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1) # average the image and text predictions alpha = alpha or self.hparams["alpha"] sample_p = alpha * image_p + (1 - alpha) * text_p # get the scores scores = sample_p[0].cpu().tolist() return vocabulary, scores