Spaces:
Runtime error
Runtime error
import json | |
import tarfile | |
from pathlib import Path | |
from typing import Optional | |
import faiss | |
import gdown | |
import numpy as np | |
import open_clip | |
import torch | |
from open_clip.transformer import Transformer | |
from PIL import Image | |
from src.retrieval import ArrowMetadataProvider, meta_to_dict | |
from src.transforms import TextCompose, default_vocabulary_transforms | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
RETRIEVAL_DATABASES = { | |
"cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t", | |
} | |
class CaSED(torch.nn.Module): | |
"""Torch module for Category Search from External Databases (CaSED). | |
Args: | |
index_name (str): Name of the faiss index to use. | |
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary. | |
model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14". | |
pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai". | |
Extra hparams: | |
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5. | |
artifact_dir (str): Path to the directory where the databases are stored. Defaults to | |
"artifacts/". | |
retrieval_num_results (int): Number of results to return. Defaults to 10. | |
vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}". | |
tau (float): Temperature to use for the classifier. Defaults to 1.0. | |
""" | |
def __init__( | |
self, | |
index_name: str = "ViT-L-14_CC12M", | |
vocabulary_transforms: TextCompose = default_vocabulary_transforms(), | |
model_name: str = "ViT-L-14", | |
pretrained: str = "openai", | |
vocabulary_prompt: str = "{}", | |
**kwargs, | |
): | |
super().__init__() | |
self._prev_vocab_words = None | |
self._prev_used_prompts = None | |
self._prev_vocab_words_z = None | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
model_name, pretrained=pretrained, device="cpu" | |
) | |
tokenizer = open_clip.get_tokenizer(model_name) | |
self.tokenizer = tokenizer | |
self.preprocess = preprocess | |
kwargs["alpha"] = kwargs.get("alpha", 0.5) | |
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/") | |
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10) | |
vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}") | |
kwargs["vocabulary_prompts"] = [vocabulary_prompt] | |
kwargs["tau"] = kwargs.get("tau", 1.0) | |
self.hparams = kwargs | |
language_encoder = LanguageTransformer( | |
model.transformer, | |
model.token_embedding, | |
model.positional_embedding, | |
model.ln_final, | |
model.text_projection, | |
model.attn_mask, | |
) | |
scale = model.logit_scale.exp().item() | |
classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"]) | |
self.index_name = index_name | |
self.vocabulary_transforms = vocabulary_transforms | |
self.vision_encoder = model.visual | |
self.language_encoder = language_encoder | |
self.classifier = classifier | |
# download databases | |
self.prepare_data() | |
# load faiss indices | |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval" | |
indices_fp = indices_list_dir / "indices.json" | |
self.indices = json.load(open(indices_fp, "r")) | |
# load faiss indices and metadata providers | |
self.resources = {} | |
for name, index_fp in self.indices.items(): | |
text_index_fp = Path(index_fp) / "text.index" | |
metadata_fp = Path(index_fp) / "metadata/" | |
text_index = faiss.read_index( | |
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY | |
) | |
metadata_provider = ArrowMetadataProvider(metadata_fp) | |
self.resources[name] = { | |
"device": DEVICE, | |
"model": model_name, | |
"text_index": text_index, | |
"metadata_provider": metadata_provider, | |
} | |
def prepare_data(self): | |
"""Download data if needed.""" | |
databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases" | |
for name, url in RETRIEVAL_DATABASES.items(): | |
database_path = Path(databases_path, name) | |
if database_path.exists(): | |
continue | |
# download data | |
target_path = Path(databases_path, name + ".tar.gz") | |
try: | |
gdown.download(url, str(target_path), quiet=False) | |
tar = tarfile.open(target_path, "r:gz") | |
tar.extractall(target_path.parent) | |
tar.close() | |
target_path.unlink() | |
except FileNotFoundError: | |
print(f"Could not download {url}.") | |
print(f"Please download it manually and place it in {target_path.parent}.") | |
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: | |
# get the index | |
resources = self.resources[self.index_name] | |
text_index = resources["text_index"] | |
metadata_provider = resources["metadata_provider"] | |
# query the index | |
sample_z = sample_z.squeeze(0) | |
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) | |
query_input = sample_z.cpu().detach().numpy().tolist() | |
query = np.expand_dims(np.array(query_input).astype("float32"), 0) | |
distances, idxs, _ = text_index.search_and_reconstruct( | |
query, self.hparams["retrieval_num_results"] | |
) | |
results = idxs[0] | |
nb_results = np.where(results == -1)[0] | |
nb_results = nb_results[0] if len(nb_results) > 0 else len(results) | |
indices = results[:nb_results] | |
distances = distances[0][:nb_results] | |
if len(distances) == 0: | |
return [] | |
# get the metadata | |
results = [] | |
metadata = metadata_provider.get(indices[:20], ["caption"]) | |
for key, (d, i) in enumerate(zip(distances, indices)): | |
output = {} | |
meta = None if key + 1 > len(metadata) else metadata[key] | |
if meta is not None: | |
output.update(meta_to_dict(meta)) | |
output["id"] = i.item() | |
output["similarity"] = d.item() | |
results.append(output) | |
# get the captions only | |
vocabularies = [result["caption"] for result in results] | |
return vocabularies | |
def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor: | |
"""Encode a vocabulary. | |
Args: | |
vocabulary (list): List of words. | |
""" | |
# check if vocabulary has changed | |
if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts: | |
return self._prev_vocab_words_z | |
# tokenize vocabulary | |
classes = [c.replace("_", " ") for c in vocabulary] | |
prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"] | |
texts_views = [[p.format(c) for c in classes] for p in prompts] | |
tokenized_texts_views = [ | |
torch.cat([self.tokenizer(prompt) for prompt in class_prompts]) | |
for class_prompts in texts_views | |
] | |
tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE) | |
# encode vocabulary | |
T, C, _ = tokenized_texts_views.shape | |
texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1)) | |
texts_z_views = texts_z_views.view(T, C, -1) | |
texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True) | |
# cache vocabulary | |
self._prev_vocab_words = vocabulary | |
self._prev_used_prompts = use_prompts | |
self._prev_vocab_words_z = texts_z_views | |
return texts_z_views | |
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor(): | |
image = self.preprocess(Image.open(image_fp)).unsqueeze(0) | |
image_z = self.vision_encoder(image.to(DEVICE)) | |
# get the vocabulary | |
vocabulary = self.query_index(image_z) | |
# generate a single text embedding from the unfiltered vocabulary | |
unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0) | |
text_z = unfiltered_vocabulary_z.mean(dim=0) | |
text_z = text_z / text_z.norm(dim=-1, keepdim=True) | |
text_z = text_z.unsqueeze(0) | |
# filter the vocabulary, embed it, and get its mean embedding | |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] | |
vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True) | |
mean_vocabulary_z = vocabulary_z.mean(dim=0) | |
mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True) | |
# get the image and text predictions | |
image_p = self.classifier(image_z, vocabulary_z) | |
text_p = self.classifier(text_z, vocabulary_z) | |
# average the image and text predictions | |
alpha = alpha or self.hparams["alpha"] | |
sample_p = alpha * image_p + (1 - alpha) * text_p | |
# get the scores | |
sample_p = sample_p.cpu() | |
scores = sample_p[0].tolist() | |
del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z | |
del image_p, text_p, sample_p | |
return vocabulary, scores | |
class NearestNeighboursClassifier(torch.nn.Module): | |
"""Nearest neighbours classifier. | |
It computes the similarity between the query and the supports using the | |
cosine similarity and then applies a softmax to obtain the logits. | |
Args: | |
scale (float): Scale for the logits of the query. Defaults to 1.0. | |
tau (float): Temperature for the softmax. Defaults to 1.0. | |
""" | |
def __init__(self, scale: float = 1.0, tau: float = 1.0): | |
super().__init__() | |
self.scale = scale | |
self.tau = tau | |
def forward(self, query: torch.Tensor, supports: torch.Tensor): | |
query = query / query.norm(dim=-1, keepdim=True) | |
supports = supports / supports.norm(dim=-1, keepdim=True) | |
if supports.dim() == 2: | |
supports = supports.unsqueeze(0) | |
Q, _ = query.shape | |
N, C, _ = supports.shape | |
supports = supports.mean(dim=0) | |
supports = supports / supports.norm(dim=-1, keepdim=True) | |
similarity = self.scale * query @ supports.T | |
similarity = similarity / self.tau if self.tau != 1.0 else similarity | |
logits = similarity.softmax(dim=-1) | |
return logits | |
class LanguageTransformer(torch.nn.Module): | |
"""Language Transformer for CLIP. | |
Args: | |
transformer (Transformer): Transformer model. | |
token_embedding (torch.nn.Embedding): Token embedding. | |
positional_embedding (torch.nn.Parameter): Positional embedding. | |
ln_final (torch.nn.LayerNorm): Layer norm. | |
text_projection (torch.nn.Parameter): Text projection. | |
""" | |
def __init__( | |
self, | |
model: Transformer, | |
token_embedding: torch.nn.Embedding, | |
positional_embedding: torch.nn.Parameter, | |
ln_final: torch.nn.LayerNorm, | |
text_projection: torch.nn.Parameter, | |
attn_mask: torch.Tensor, | |
): | |
super().__init__() | |
self.transformer = model | |
self.token_embedding = token_embedding | |
self.positional_embedding = positional_embedding | |
self.ln_final = ln_final | |
self.text_projection = text_projection | |
self.register_buffer("attn_mask", attn_mask, persistent=False) | |
def forward(self, text: torch.Tensor) -> torch.Tensor: | |
cast_dtype = self.transformer.get_cast_dtype() | |
"""Forward pass for the text encoder.""" | |
x = self.token_embedding(text).to(cast_dtype) | |
x = x + self.positional_embedding.to(cast_dtype) | |
x = x.permute(1, 0, 2) | |
x = self.transformer(x, attn_mask=self.attn_mask) | |
x = x.permute(1, 0, 2) | |
x = self.ln_final(x) | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
return x | |