vic / src /nn.py
altndrr's picture
Add first version
a3ee979
import json
import tarfile
from pathlib import Path
from typing import Optional
import faiss
import gdown
import numpy as np
import open_clip
import torch
from open_clip.transformer import Transformer
from PIL import Image
from src.retrieval import ArrowMetadataProvider, meta_to_dict
from src.transforms import TextCompose, default_vocabulary_transforms
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RETRIEVAL_DATABASES = {
"cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
}
class CaSED(torch.nn.Module):
"""Torch module for Category Search from External Databases (CaSED).
Args:
index_name (str): Name of the faiss index to use.
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".
Extra hparams:
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
artifact_dir (str): Path to the directory where the databases are stored. Defaults to
"artifacts/".
retrieval_num_results (int): Number of results to return. Defaults to 10.
vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
tau (float): Temperature to use for the classifier. Defaults to 1.0.
"""
def __init__(
self,
index_name: str = "ViT-L-14_CC12M",
vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
model_name: str = "ViT-L-14",
pretrained: str = "openai",
vocabulary_prompt: str = "{}",
**kwargs,
):
super().__init__()
self._prev_vocab_words = None
self._prev_used_prompts = None
self._prev_vocab_words_z = None
model, _, preprocess = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained, device="cpu"
)
tokenizer = open_clip.get_tokenizer(model_name)
self.tokenizer = tokenizer
self.preprocess = preprocess
kwargs["alpha"] = kwargs.get("alpha", 0.5)
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
kwargs["vocabulary_prompts"] = [vocabulary_prompt]
kwargs["tau"] = kwargs.get("tau", 1.0)
self.hparams = kwargs
language_encoder = LanguageTransformer(
model.transformer,
model.token_embedding,
model.positional_embedding,
model.ln_final,
model.text_projection,
model.attn_mask,
)
scale = model.logit_scale.exp().item()
classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])
self.index_name = index_name
self.vocabulary_transforms = vocabulary_transforms
self.vision_encoder = model.visual
self.language_encoder = language_encoder
self.classifier = classifier
# download databases
self.prepare_data()
# load faiss indices
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
indices_fp = indices_list_dir / "indices.json"
self.indices = json.load(open(indices_fp, "r"))
# load faiss indices and metadata providers
self.resources = {}
for name, index_fp in self.indices.items():
text_index_fp = Path(index_fp) / "text.index"
metadata_fp = Path(index_fp) / "metadata/"
text_index = faiss.read_index(
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
)
metadata_provider = ArrowMetadataProvider(metadata_fp)
self.resources[name] = {
"device": DEVICE,
"model": model_name,
"text_index": text_index,
"metadata_provider": metadata_provider,
}
def prepare_data(self):
"""Download data if needed."""
databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"
for name, url in RETRIEVAL_DATABASES.items():
database_path = Path(databases_path, name)
if database_path.exists():
continue
# download data
target_path = Path(databases_path, name + ".tar.gz")
try:
gdown.download(url, str(target_path), quiet=False)
tar = tarfile.open(target_path, "r:gz")
tar.extractall(target_path.parent)
tar.close()
target_path.unlink()
except FileNotFoundError:
print(f"Could not download {url}.")
print(f"Please download it manually and place it in {target_path.parent}.")
@torch.no_grad()
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
# get the index
resources = self.resources[self.index_name]
text_index = resources["text_index"]
metadata_provider = resources["metadata_provider"]
# query the index
sample_z = sample_z.squeeze(0)
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
query_input = sample_z.cpu().detach().numpy().tolist()
query = np.expand_dims(np.array(query_input).astype("float32"), 0)
distances, idxs, _ = text_index.search_and_reconstruct(
query, self.hparams["retrieval_num_results"]
)
results = idxs[0]
nb_results = np.where(results == -1)[0]
nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
indices = results[:nb_results]
distances = distances[0][:nb_results]
if len(distances) == 0:
return []
# get the metadata
results = []
metadata = metadata_provider.get(indices[:20], ["caption"])
for key, (d, i) in enumerate(zip(distances, indices)):
output = {}
meta = None if key + 1 > len(metadata) else metadata[key]
if meta is not None:
output.update(meta_to_dict(meta))
output["id"] = i.item()
output["similarity"] = d.item()
results.append(output)
# get the captions only
vocabularies = [result["caption"] for result in results]
return vocabularies
@torch.no_grad()
def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
"""Encode a vocabulary.
Args:
vocabulary (list): List of words.
"""
# check if vocabulary has changed
if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
return self._prev_vocab_words_z
# tokenize vocabulary
classes = [c.replace("_", " ") for c in vocabulary]
prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
texts_views = [[p.format(c) for c in classes] for p in prompts]
tokenized_texts_views = [
torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
for class_prompts in texts_views
]
tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)
# encode vocabulary
T, C, _ = tokenized_texts_views.shape
texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
texts_z_views = texts_z_views.view(T, C, -1)
texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)
# cache vocabulary
self._prev_vocab_words = vocabulary
self._prev_used_prompts = use_prompts
self._prev_vocab_words_z = texts_z_views
return texts_z_views
@torch.no_grad()
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
image = self.preprocess(Image.open(image_fp)).unsqueeze(0)
image_z = self.vision_encoder(image.to(DEVICE))
# get the vocabulary
vocabulary = self.query_index(image_z)
# generate a single text embedding from the unfiltered vocabulary
unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0)
text_z = unfiltered_vocabulary_z.mean(dim=0)
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
text_z = text_z.unsqueeze(0)
# filter the vocabulary, embed it, and get its mean embedding
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True)
mean_vocabulary_z = vocabulary_z.mean(dim=0)
mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True)
# get the image and text predictions
image_p = self.classifier(image_z, vocabulary_z)
text_p = self.classifier(text_z, vocabulary_z)
# average the image and text predictions
alpha = alpha or self.hparams["alpha"]
sample_p = alpha * image_p + (1 - alpha) * text_p
# get the scores
sample_p = sample_p.cpu()
scores = sample_p[0].tolist()
del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
del image_p, text_p, sample_p
return vocabulary, scores
class NearestNeighboursClassifier(torch.nn.Module):
"""Nearest neighbours classifier.
It computes the similarity between the query and the supports using the
cosine similarity and then applies a softmax to obtain the logits.
Args:
scale (float): Scale for the logits of the query. Defaults to 1.0.
tau (float): Temperature for the softmax. Defaults to 1.0.
"""
def __init__(self, scale: float = 1.0, tau: float = 1.0):
super().__init__()
self.scale = scale
self.tau = tau
def forward(self, query: torch.Tensor, supports: torch.Tensor):
query = query / query.norm(dim=-1, keepdim=True)
supports = supports / supports.norm(dim=-1, keepdim=True)
if supports.dim() == 2:
supports = supports.unsqueeze(0)
Q, _ = query.shape
N, C, _ = supports.shape
supports = supports.mean(dim=0)
supports = supports / supports.norm(dim=-1, keepdim=True)
similarity = self.scale * query @ supports.T
similarity = similarity / self.tau if self.tau != 1.0 else similarity
logits = similarity.softmax(dim=-1)
return logits
class LanguageTransformer(torch.nn.Module):
"""Language Transformer for CLIP.
Args:
transformer (Transformer): Transformer model.
token_embedding (torch.nn.Embedding): Token embedding.
positional_embedding (torch.nn.Parameter): Positional embedding.
ln_final (torch.nn.LayerNorm): Layer norm.
text_projection (torch.nn.Parameter): Text projection.
"""
def __init__(
self,
model: Transformer,
token_embedding: torch.nn.Embedding,
positional_embedding: torch.nn.Parameter,
ln_final: torch.nn.LayerNorm,
text_projection: torch.nn.Parameter,
attn_mask: torch.Tensor,
):
super().__init__()
self.transformer = model
self.token_embedding = token_embedding
self.positional_embedding = positional_embedding
self.ln_final = ln_final
self.text_projection = text_projection
self.register_buffer("attn_mask", attn_mask, persistent=False)
def forward(self, text: torch.Tensor) -> torch.Tensor:
cast_dtype = self.transformer.get_cast_dtype()
"""Forward pass for the text encoder."""
x = self.token_embedding(text).to(cast_dtype)
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2)
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2)
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x