import os from typing import Callable, Optional import numpy as np import torch from transformers import CLIPModel, CLIPProcessor from transformers.modeling_utils import PreTrainedModel from .configuration_cased import CaSEDConfig from .retrieval_cased import RetrievalDatabase, download_retrieval_databases from .transforms_cased import default_vocabulary_transforms class CaSEDModel(PreTrainedModel): """Transformers module for Category Search from External Databases (CaSED). Reference: - Conti et al. Vocabulary-free Image Classification. NeurIPS 2023. Args: config (CaSEDConfig): Configuration class for CaSED. """ config_class = CaSEDConfig def __init__(self, config: CaSEDConfig): super().__init__(config) # load CLIP model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.vision_encoder = model.vision_model self.vision_proj = model.visual_projection self.language_encoder = model.text_model self.language_proj = model.text_projection self.logit_scale = model.logit_scale.exp() self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") # set hparams self.hparams = {} self.hparams["alpha"] = config.alpha self.hparams["index_name"] = config.index_name self.hparams["retrieval_num_results"] = config.retrieval_num_results self.hparams["cache_dir"] = config.cache_dir # create cache dir os.makedirs(self.hparams["cache_dir"], exist_ok=True) # download data download_retrieval_databases(cache_dir=self.hparams["cache_dir"]) # setup vocabulary self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"]) self._vocab_transform = default_vocabulary_transforms() @property def vocab_transform(self) -> Callable: """Get image preprocess transform. The getter wraps the transform in a map_reduce function and applies it to a list of images. If interested in the transform itself, use `self._vocab_transform`. """ vocab_transform = self._vocab_transform def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]: return [vocab_transform(text) for text in texts] return vocabs_transforms def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]: """Get the vocabulary for a batch of images. Args: images_z (torch.Tensor): Batch of image embeddings. """ num_samples = self.hparams["retrieval_num_results"] assert images_z is not None images_z = images_z / images_z.norm(dim=-1, keepdim=True) images_z = images_z.cpu().detach().numpy().tolist() if isinstance(images_z[0], float): images_z = [images_z] query = np.matrix(images_z).astype("float32") results = self.vocabulary.query(query, modality="text", num_samples=num_samples) vocabularies = [[r["caption"] for r in result] for result in results] return vocabularies def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor: """Forward pass. Args: images (dict): Dictionary with the images. The expected keys are: - pixel_values (torch.Tensor): Pixel values of the images. alpha (Optional[float]): Alpha value for the interpolation. """ alpha = alpha or self.hparams["alpha"] # forward the images images["pixel_values"] = images["pixel_values"].to(self.device) images_z = self.vision_proj(self.vision_encoder(**images)[1]) images_z = images_z / images_z.norm(dim=-1, keepdim=True) vocabularies = self.get_vocabulary(images_z=images_z) # encode unfiltered words unfiltered_words = sum(vocabularies, []) texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True) texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device) texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device) texts_z = self.language_encoder(**texts_z)[1] texts_z = self.language_proj(texts_z) texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) # generate a text embedding for each image from their unfiltered words unfiltered_words_per_image = [len(vocab) for vocab in vocabularies] texts_z = torch.split(texts_z, unfiltered_words_per_image) texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z]) texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) # filter the words and embed them vocabularies = self.vocab_transform(vocabularies) vocabularies = [vocab or ["object"] for vocab in vocabularies] words = sum(vocabularies, []) words_z = self.processor(words, return_tensors="pt", padding=True) words_z = {k: v.to(self.device) for k, v in words_z.items()} words_z = self.language_encoder(**words_z)[1] words_z = self.language_proj(words_z) words_z = words_z / words_z.norm(dim=-1, keepdim=True) # create a one-hot relation mask between images and words words_per_image = [len(vocab) for vocab in vocabularies] col_indices = torch.arange(sum(words_per_image)) row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image)) mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device) mask[row_indices, col_indices] = 1 # get the image and text similarities images_z = images_z / images_z.norm(dim=-1, keepdim=True) texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) words_z = words_z / words_z.norm(dim=-1, keepdim=True) images_sim = self.logit_scale * images_z @ words_z.T texts_sim = self.logit_scale * texts_z @ words_z.T # mask unrelated words images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf")) texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf")) # get the image and text predictions images_p = images_sim.softmax(dim=-1) texts_p = texts_sim.softmax(dim=-1) # average the image and text predictions samples_p = alpha * images_p + (1 - alpha) * texts_p return {"scores": samples_p, "words": words, "vocabularies": vocabularies}