LinCIR / utils.py
Geonmo's picture
initial commit
cacafc1
raw history blame
No virus
6.23 kB
from typing import Optional, Tuple, List
import torch
import torch.nn.functional as F
from clip.model import CLIP
from transformers import CLIPVisionModelWithProjection
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm
from data_utils import collate_fn
from models import Phi
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
@torch.no_grad()
def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32,
num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]:
"""
Extracts image features from a dataset using a CLIP model.
"""
# Create data loader
loader = DataLoader(dataset=dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
index_features = []
index_names = []
try:
print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}")
except Exception as e:
pass
# Extract features
for batch in tqdm(loader):
images = batch.get('image')
names = batch.get('image_name')
if images is None:
images = batch.get('reference_image')
if names is None:
names = batch.get('reference_name')
images = images.to(clip_model.device)
with torch.no_grad():
batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images)
index_features.append(batch_features.cpu())
index_names.extend(names)
index_features = torch.vstack(index_features)
return index_features, index_names
def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor:
# Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
v1 = F.normalize(v1, dim=1)
v2 = F.normalize(v2, dim=1)
numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature)
numerator = torch.cat((numerator, numerator), 0)
joint_vector = torch.cat((v1, v2), 0)
pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0)
loss = -torch.mean(torch.log(numerator / denominator))
return loss
@torch.no_grad()
def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]:
"""
Extracts pseudo tokens from a dataset using a CLIP model and a phi model
"""
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
collate_fn=collate_fn)
predicted_tokens = []
names_list = []
print(f"Extracting tokens using phi model")
for batch in tqdm(data_loader):
images = batch.get('image')
names = batch.get('image_name')
if images is None:
images = batch.get('reference_image')
if names is None:
names = batch.get('reference_name')
images = images.to(device)
image_features = clip_model(pixel_values=images.half()).image_embeds
if args.l2_normalize:
image_features = F.normalize(image_features, dim=-1)
batch_predicted_tokens = phi(image_features)
predicted_tokens.append(batch_predicted_tokens.cpu())
names_list.extend(names)
predicted_tokens = torch.vstack(predicted_tokens)
return predicted_tokens, names_list
@torch.no_grad()
def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]:
"""
Extracts image features from a dataset using a CLIP model
"""
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
collate_fn=collate_fn)
predicted_tokens = []
names_list = []
print(f"Extracting tokens using phi model")
for batch in tqdm(data_loader):
images = batch.get('image')
names = batch.get('image_name')
if images is None:
images = batch.get('reference_image')
if names is None:
names = batch.get('reference_name')
images = images.to(device)
image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds
#batch_predicted_tokens = phi(image_features)
batch_predicted_tokens = image_features
predicted_tokens.append(batch_predicted_tokens.cpu())
names_list.extend(names)
predicted_tokens = torch.vstack(predicted_tokens)
return predicted_tokens, names_list
class CustomTensorDataset(Dataset):
"""
Custom Tensor Dataset which yields image_features and image_names
"""
def __init__(self, images: torch.Tensor, names: torch.Tensor):
self.images = images
self.names = names
def __getitem__(self, index) -> dict:
return {'image': self.images[index],
'image_name': self.names[index]
}
def __len__(self):
return len(self.images)
def get_templates():
"""
Return a list of templates
Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694
"""
return [
"This is a photo of a {}",
"This photo contains a {}",
"A photo of a {}",
"This is an illustration of a {}",
"This illustration contains a {}",
"An illustrations of a {}",
"This is a sketch of a {}",
"This sketch contains a {}",
"A sketch of a {}",
"This is a diagram of a {}",
"This diagram contains a {}",
"A diagram of a {}",
"A {}",
"We see a {}",
"{}",
"We see a {} in this photo",
"We see a {} in this image",
"We see a {} in this illustration",
"We see a {} photo",
"We see a {} image",
"We see a {} illustration",
"{} photo",
"{} image",
"{} illustration",
]