cased / modeling_cased.py
altndrr's picture
Sync model code with repo code
0dc766b
raw history blame
No virus
6.51 kB
import os
from typing import Callable, Optional
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
from transformers.modeling_utils import PreTrainedModel
from .configuration_cased import CaSEDConfig
from .retrieval_cased import RetrievalDatabase, download_retrieval_databases
from .transforms_cased import default_vocabulary_transforms
class CaSEDModel(PreTrainedModel):
"""Transformers module for Category Search from External Databases (CaSED).
Reference:
- Conti et al. Vocabulary-free Image Classification. NeurIPS 2023.
Args:
config (CaSEDConfig): Configuration class for CaSED.
"""
config_class = CaSEDConfig
def __init__(self, config: CaSEDConfig):
super().__init__(config)
# load CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.vision_encoder = model.vision_model
self.vision_proj = model.visual_projection
self.language_encoder = model.text_model
self.language_proj = model.text_projection
self.logit_scale = model.logit_scale.exp()
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# set hparams
self.hparams = {}
self.hparams["alpha"] = config.alpha
self.hparams["index_name"] = config.index_name
self.hparams["retrieval_num_results"] = config.retrieval_num_results
self.hparams["cache_dir"] = config.cache_dir
# create cache dir
os.makedirs(self.hparams["cache_dir"], exist_ok=True)
# download data
download_retrieval_databases(cache_dir=self.hparams["cache_dir"])
# setup vocabulary
self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"])
self._vocab_transform = default_vocabulary_transforms()
@property
def vocab_transform(self) -> Callable:
"""Get image preprocess transform.
The getter wraps the transform in a map_reduce function and applies it to a list of images.
If interested in the transform itself, use `self._vocab_transform`.
"""
vocab_transform = self._vocab_transform
def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]:
return [vocab_transform(text) for text in texts]
return vocabs_transforms
def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]:
"""Get the vocabulary for a batch of images.
Args:
images_z (torch.Tensor): Batch of image embeddings.
"""
num_samples = self.hparams["retrieval_num_results"]
assert images_z is not None
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
images_z = images_z.cpu().detach().numpy().tolist()
if isinstance(images_z[0], float):
images_z = [images_z]
query = np.matrix(images_z).astype("float32")
results = self.vocabulary.query(query, modality="text", num_samples=num_samples)
vocabularies = [[r["caption"] for r in result] for result in results]
return vocabularies
def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
"""Forward pass.
Args:
images (dict): Dictionary with the images. The expected keys are:
- pixel_values (torch.Tensor): Pixel values of the images.
alpha (Optional[float]): Alpha value for the interpolation.
"""
alpha = alpha or self.hparams["alpha"]
# forward the images
images["pixel_values"] = images["pixel_values"].to(self.device)
images_z = self.vision_proj(self.vision_encoder(**images)[1])
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
vocabularies = self.get_vocabulary(images_z=images_z)
# encode unfiltered words
unfiltered_words = sum(vocabularies, [])
texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True)
texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device)
texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device)
texts_z = self.language_encoder(**texts_z)[1]
texts_z = self.language_proj(texts_z)
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
# generate a text embedding for each image from their unfiltered words
unfiltered_words_per_image = [len(vocab) for vocab in vocabularies]
texts_z = torch.split(texts_z, unfiltered_words_per_image)
texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z])
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
# filter the words and embed them
vocabularies = self.vocab_transform(vocabularies)
vocabularies = [vocab or ["object"] for vocab in vocabularies]
words = sum(vocabularies, [])
words_z = self.processor(words, return_tensors="pt", padding=True)
words_z = {k: v.to(self.device) for k, v in words_z.items()}
words_z = self.language_encoder(**words_z)[1]
words_z = self.language_proj(words_z)
words_z = words_z / words_z.norm(dim=-1, keepdim=True)
# create a one-hot relation mask between images and words
words_per_image = [len(vocab) for vocab in vocabularies]
col_indices = torch.arange(sum(words_per_image))
row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image))
mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device)
mask[row_indices, col_indices] = 1
# get the image and text similarities
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
words_z = words_z / words_z.norm(dim=-1, keepdim=True)
images_sim = self.logit_scale * images_z @ words_z.T
texts_sim = self.logit_scale * texts_z @ words_z.T
# mask unrelated words
images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf"))
texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf"))
# get the image and text predictions
images_p = images_sim.softmax(dim=-1)
texts_p = texts_sim.softmax(dim=-1)
# average the image and text predictions
samples_p = alpha * images_p + (1 - alpha) * texts_p
return {"scores": samples_p, "words": words, "vocabularies": vocabularies}