Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple, List | |
import torch | |
import torch.nn.functional as F | |
from clip.model import CLIP | |
from transformers import CLIPVisionModelWithProjection | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from data_utils import collate_fn | |
from models import Phi | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
else: | |
device = torch.device("cpu") | |
dtype = torch.float32 | |
def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32, | |
num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]: | |
""" | |
Extracts image features from a dataset using a CLIP model. | |
""" | |
# Create data loader | |
loader = DataLoader(dataset=dataset, batch_size=batch_size, | |
num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) | |
index_features = [] | |
index_names = [] | |
try: | |
print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}") | |
except Exception as e: | |
pass | |
# Extract features | |
for batch in tqdm(loader): | |
images = batch.get('image') | |
names = batch.get('image_name') | |
if images is None: | |
images = batch.get('reference_image') | |
if names is None: | |
names = batch.get('reference_name') | |
images = images.to(clip_model.device) | |
with torch.no_grad(): | |
batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images) | |
index_features.append(batch_features.cpu()) | |
index_names.extend(names) | |
index_features = torch.vstack(index_features) | |
return index_features, index_names | |
def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor: | |
# Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py | |
v1 = F.normalize(v1, dim=1) | |
v2 = F.normalize(v2, dim=1) | |
numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature) | |
numerator = torch.cat((numerator, numerator), 0) | |
joint_vector = torch.cat((v1, v2), 0) | |
pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature) | |
denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0) | |
loss = -torch.mean(torch.log(numerator / denominator)) | |
return loss | |
def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]: | |
""" | |
Extracts pseudo tokens from a dataset using a CLIP model and a phi model | |
""" | |
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False, | |
collate_fn=collate_fn) | |
predicted_tokens = [] | |
names_list = [] | |
print(f"Extracting tokens using phi model") | |
for batch in tqdm(data_loader): | |
images = batch.get('image') | |
names = batch.get('image_name') | |
if images is None: | |
images = batch.get('reference_image') | |
if names is None: | |
names = batch.get('reference_name') | |
images = images.to(device) | |
image_features = clip_model(pixel_values=images.half()).image_embeds | |
if args.l2_normalize: | |
image_features = F.normalize(image_features, dim=-1) | |
batch_predicted_tokens = phi(image_features) | |
predicted_tokens.append(batch_predicted_tokens.cpu()) | |
names_list.extend(names) | |
predicted_tokens = torch.vstack(predicted_tokens) | |
return predicted_tokens, names_list | |
def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]: | |
""" | |
Extracts image features from a dataset using a CLIP model | |
""" | |
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False, | |
collate_fn=collate_fn) | |
predicted_tokens = [] | |
names_list = [] | |
print(f"Extracting tokens using phi model") | |
for batch in tqdm(data_loader): | |
images = batch.get('image') | |
names = batch.get('image_name') | |
if images is None: | |
images = batch.get('reference_image') | |
if names is None: | |
names = batch.get('reference_name') | |
images = images.to(device) | |
image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds | |
#batch_predicted_tokens = phi(image_features) | |
batch_predicted_tokens = image_features | |
predicted_tokens.append(batch_predicted_tokens.cpu()) | |
names_list.extend(names) | |
predicted_tokens = torch.vstack(predicted_tokens) | |
return predicted_tokens, names_list | |
class CustomTensorDataset(Dataset): | |
""" | |
Custom Tensor Dataset which yields image_features and image_names | |
""" | |
def __init__(self, images: torch.Tensor, names: torch.Tensor): | |
self.images = images | |
self.names = names | |
def __getitem__(self, index) -> dict: | |
return {'image': self.images[index], | |
'image_name': self.names[index] | |
} | |
def __len__(self): | |
return len(self.images) | |
def get_templates(): | |
""" | |
Return a list of templates | |
Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694 | |
""" | |
return [ | |
"This is a photo of a {}", | |
"This photo contains a {}", | |
"A photo of a {}", | |
"This is an illustration of a {}", | |
"This illustration contains a {}", | |
"An illustrations of a {}", | |
"This is a sketch of a {}", | |
"This sketch contains a {}", | |
"A sketch of a {}", | |
"This is a diagram of a {}", | |
"This diagram contains a {}", | |
"A diagram of a {}", | |
"A {}", | |
"We see a {}", | |
"{}", | |
"We see a {} in this photo", | |
"We see a {} in this image", | |
"We see a {} in this illustration", | |
"We see a {} photo", | |
"We see a {} image", | |
"We see a {} illustration", | |
"{} photo", | |
"{} image", | |
"{} illustration", | |
] | |