LinCIR / validate.py
Geonmo's picture
initial commit
cacafc1
raw
history blame
No virus
31 kB
import json
import pickle
from argparse import ArgumentParser
from typing import List, Dict, Tuple
import clip
import numpy as np
import torch
import torch.nn.functional as F
from clip.model import CLIP
from transformers import CLIPTextModelWithProjection
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm
from data_utils import collate_fn, PROJECT_ROOT, targetpad_transform
from loader import FashionIQDataset, CIRRDataset, CIRCODataset
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
from models import build_text_encoder, Phi, PIC2WORD
from utils import extract_image_features, device, extract_pseudo_tokens_with_phi
torch.multiprocessing.set_sharing_strategy('file_system')
@torch.no_grad()
def fiq_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
pseudo_tokens: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
"""
Generates features predictions for the validation set of Fashion IQ.
"""
# Create data loader
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
pin_memory=False, collate_fn=collate_fn, shuffle=False)
predicted_features_list = []
target_names_list = []
# Compute features
for batch in tqdm(relative_val_loader):
reference_names = batch['reference_name']
target_names = batch['target_name']
relative_captions = batch['relative_captions']
flattened_captions: list = np.array(relative_captions).T.flatten().tolist()
input_captions = [
f"{flattened_captions[i].strip('.?, ')} and {flattened_captions[i + 1].strip('.?, ')}" for
i in range(0, len(flattened_captions), 2)]
input_captions_reversed = [
f"{flattened_captions[i + 1].strip('.?, ')} and {flattened_captions[i].strip('.?, ')}" for
i in range(0, len(flattened_captions), 2)]
input_captions = [
f"a photo of $ that {in_cap}" for in_cap in input_captions]
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
input_captions_reversed = [
f"a photo of $ that {in_cap}" for in_cap in input_captions_reversed]
tokenized_input_captions_reversed = clip.tokenize(input_captions_reversed, context_length=77).to(device)
text_features_reversed = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions_reversed,
batch_tokens)
predicted_features = F.normalize((F.normalize(text_features) + F.normalize(text_features_reversed)) / 2)
# predicted_features = F.normalize((text_features + text_features_reversed) / 2)
predicted_features_list.append(predicted_features)
target_names_list.extend(target_names)
predicted_features = torch.vstack(predicted_features_list)
return predicted_features, target_names_list
@torch.no_grad()
def fiq_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
-> Dict[str, float]:
"""
Compute the retrieval metrics on the FashionIQ validation set given the dataset, pseudo tokens and the reference names
"""
# Generate the predicted features
predicted_features, target_names = fiq_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list,
pseudo_tokens)
# Move the features to the device
index_features = index_features.to(device)
predicted_features = predicted_features.to(device)
# Normalize the features
index_features = F.normalize(index_features.float())
# Compute the distances
distances = 1 - predicted_features @ index_features.T
sorted_indices = torch.argsort(distances, dim=-1).cpu()
sorted_index_names = np.array(index_names)[sorted_indices]
# Check if the target names are in the top 10 and top 50
labels = torch.tensor(
sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
# Compute the metrics
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
return {'fiq_recall_at10': recall_at10,
'fiq_recall_at50': recall_at50}
@torch.no_grad()
def fiq_val_retrieval(dataset_path: str, dress_type: str, image_encoder, text_encoder, ref_names_list: List[str],
pseudo_tokens: torch.Tensor, preprocess: callable) -> Dict[str, float]:
"""
Compute the retrieval metrics on the FashionIQ validation set given the pseudo tokens and the reference names
"""
# Load the model
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
#clip_model = clip_model.float().eval().requires_grad_(False)
# Extract the index features
classic_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'classic', preprocess)
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
# Define the relative dataset
relative_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'relative', preprocess)
return fiq_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
pseudo_tokens)
@torch.no_grad()
def cirr_generate_val_predictions(clip_model: CLIPTextModelWithProjection, relative_val_dataset: Dataset, ref_names_list: List[str],
pseudo_tokens: torch.Tensor) -> \
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
"""
Generates features predictions for the validation set of CIRR
"""
# Define the dataloader
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
pin_memory=False, collate_fn=collate_fn)
predicted_features_list = []
target_names_list = []
group_members_list = []
reference_names_list = []
for batch in tqdm(relative_val_loader):
reference_names = batch['reference_name']
target_names = batch['target_name']
relative_captions = batch['relative_caption']
group_members = batch['group_members']
group_members = np.array(group_members).T.tolist()
input_captions = [
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
predicted_features = F.normalize(text_features)
predicted_features_list.append(predicted_features)
target_names_list.extend(target_names)
group_members_list.extend(group_members)
reference_names_list.extend(reference_names)
predicted_features = torch.vstack(predicted_features_list)
return predicted_features, reference_names_list, target_names_list, group_members_list
@torch.no_grad()
def cirr_generate_val_predictions_with_phi(clip_model: CLIPTextModelWithProjection, phi, relative_val_dataset: Dataset, ref_names_list: List[str],
image_features: torch.Tensor) -> \
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
"""
Generates features predictions for the validation set of CIRR
"""
# Define the dataloader
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
pin_memory=False, collate_fn=collate_fn)
predicted_features_list = []
target_names_list = []
group_members_list = []
reference_names_list = []
for batch in tqdm(relative_val_loader):
reference_names = batch['reference_name']
target_names = batch['target_name']
relative_captions = batch['relative_caption']
group_members = batch['group_members']
group_members = np.array(group_members).T.tolist()
input_captions = [
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
# we need to make batch_tokens with selected_image_features
selected_image_features = torch.vstack([image_features[ref_names_list.index(ref)] for ref in reference_names])
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
context = clip_model.text_model.embeddings.token_embedding(tokenized_input_captions) + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
batch_tokens = phi(selected_image_features, context)
#batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
predicted_features = F.normalize(text_features)
predicted_features_list.append(predicted_features)
target_names_list.extend(target_names)
group_members_list.extend(group_members)
reference_names_list.extend(reference_names)
predicted_features = torch.vstack(predicted_features_list)
return predicted_features, reference_names_list, target_names_list, group_members_list
@torch.no_grad()
def cirr_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
-> Dict[str, float]:
"""
Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
"""
# Generate the predicted features
predicted_features, reference_names, target_names, group_members = \
cirr_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list, pseudo_tokens)
index_features = index_features.to(device)
predicted_features = predicted_features.to(device)
# Normalize the index features
index_features = F.normalize(index_features, dim=-1).float()
predicted_features = predicted_features.float()
# Compute the distances and sort the results
distances = 1 - predicted_features @ index_features.T
sorted_indices = torch.argsort(distances, dim=-1).cpu()
sorted_index_names = np.array(index_names)[sorted_indices]
# Delete the reference image from the results
reference_mask = torch.tensor(
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
sorted_index_names.shape[1] - 1)
# Compute the ground-truth labels wrt the predictions
labels = torch.tensor(
sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
# Compute the subset predictions and ground-truth labels
group_members = np.array(group_members)
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
group_labels = labels[group_mask].reshape(labels.shape[0], -1)
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
# Compute the metrics
recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
return {
'cirr_recall_at1': recall_at1,
'cirr_recall_at5': recall_at5,
'cirr_recall_at10': recall_at10,
'cirr_recall_at50': recall_at50,
'cirr_group_recall_at1': group_recall_at1,
'cirr_group_recall_at2': group_recall_at2,
'cirr_group_recall_at3': group_recall_at3,
}
@torch.no_grad()
def cirr_compute_val_metrics_with_phi(relative_val_dataset: Dataset, clip_model: CLIPTextModelWithProjection, phi, index_features: torch.Tensor,
index_names: List[str], ref_names_list: List[str], image_features: torch.Tensor) \
-> Dict[str, float]:
"""
Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
"""
# Generate the predicted features
predicted_features, reference_names, target_names, group_members = \
cirr_generate_val_predictions_with_phi(clip_model, phi, relative_val_dataset, ref_names_list, image_features)
index_features = index_features.to(device)
predicted_features = predicted_features.to(device)
# Normalize the index features
index_features = F.normalize(index_features, dim=-1).float()
predicted_features = predicted_features.float()
# Compute the distances and sort the results
distances = 1 - predicted_features @ index_features.T
sorted_indices = torch.argsort(distances, dim=-1).cpu()
sorted_index_names = np.array(index_names)[sorted_indices]
# Delete the reference image from the results
reference_mask = torch.tensor(
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
sorted_index_names.shape[1] - 1)
# Compute the ground-truth labels wrt the predictions
labels = torch.tensor(
sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
# Compute the subset predictions and ground-truth labels
group_members = np.array(group_members)
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
group_labels = labels[group_mask].reshape(labels.shape[0], -1)
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
# Compute the metrics
recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
return {
'cirr_recall_at1': recall_at1,
'cirr_recall_at5': recall_at5,
'cirr_recall_at10': recall_at10,
'cirr_recall_at50': recall_at50,
'cirr_group_recall_at1': group_recall_at1,
'cirr_group_recall_at2': group_recall_at2,
'cirr_group_recall_at3': group_recall_at3,
}
@torch.no_grad()
def cirr_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: list, pseudo_tokens: torch.Tensor,
preprocess: callable) -> Dict[str, float]:
"""
Compute the retrieval metrics on the CIRR validation set given the pseudo tokens and the reference names
"""
# Load the model
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
#clip_model = clip_model.float().eval().requires_grad_(False)
# Extract the index features
classic_val_dataset = CIRRDataset(dataset_path, 'val', 'classic', preprocess)
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
# Define the relative validation dataset
relative_val_dataset = CIRRDataset(dataset_path, 'val', 'relative', preprocess)
return cirr_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names,
ref_names_list, pseudo_tokens)
@torch.no_grad()
def circo_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
pseudo_tokens: torch.Tensor) -> Tuple[
torch.Tensor, List[str], list]:
"""
Generates features predictions for the validation set of CIRCO
"""
# Create the data loader
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
pin_memory=False, collate_fn=collate_fn, shuffle=False)
predicted_features_list = []
target_names_list = []
gts_img_ids_list = []
# Compute the features
for batch in tqdm(relative_val_loader):
reference_names = batch['reference_name']
target_names = batch['target_name']
relative_captions = batch['relative_caption']
gt_img_ids = batch['gt_img_ids']
gt_img_ids = np.array(gt_img_ids).T.tolist()
input_captions = [f"a photo of $ that {caption}" for caption in relative_captions]
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
predicted_features = F.normalize(text_features)
predicted_features_list.append(predicted_features)
target_names_list.extend(target_names)
gts_img_ids_list.extend(gt_img_ids)
predicted_features = torch.vstack(predicted_features_list)
return predicted_features, target_names_list, gts_img_ids_list
@torch.no_grad()
def circo_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
-> Dict[str, float]:
"""
Compute the retrieval metrics on the CIRCO validation set given the dataset, pseudo tokens and the reference names
"""
# Generate the predicted features
predicted_features, target_names, gts_img_ids = circo_generate_val_predictions(clip_model, relative_val_dataset,
ref_names_list, pseudo_tokens)
ap_at5 = []
ap_at10 = []
ap_at25 = []
ap_at50 = []
recall_at5 = []
recall_at10 = []
recall_at25 = []
recall_at50 = []
# Move the features to the device
index_features = index_features.to(device)
predicted_features = predicted_features.to(device)
# Normalize the features
index_features = F.normalize(index_features.float())
for predicted_feature, target_name, gt_img_ids in tqdm(zip(predicted_features, target_names, gts_img_ids)):
gt_img_ids = np.array(gt_img_ids)[
np.array(gt_img_ids) != ''] # remove trailing empty strings added for collate_fn
similarity = predicted_feature @ index_features.T
sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu()
sorted_index_names = np.array(index_names)[sorted_indices]
map_labels = torch.tensor(np.isin(sorted_index_names, gt_img_ids), dtype=torch.uint8)
precisions = torch.cumsum(map_labels, dim=0) * map_labels # Consider only positions corresponding to GTs
precisions = precisions / torch.arange(1, map_labels.shape[0] + 1) # Compute precision for each position
ap_at5.append(float(torch.sum(precisions[:5]) / min(len(gt_img_ids), 5)))
ap_at10.append(float(torch.sum(precisions[:10]) / min(len(gt_img_ids), 10)))
ap_at25.append(float(torch.sum(precisions[:25]) / min(len(gt_img_ids), 25)))
ap_at50.append(float(torch.sum(precisions[:50]) / min(len(gt_img_ids), 50)))
assert target_name == gt_img_ids[0], f"Target name not in GTs {target_name} {gt_img_ids}"
single_gt_labels = torch.tensor(sorted_index_names == target_name)
recall_at5.append(float(torch.sum(single_gt_labels[:5])))
recall_at10.append(float(torch.sum(single_gt_labels[:10])))
recall_at25.append(float(torch.sum(single_gt_labels[:25])))
recall_at50.append(float(torch.sum(single_gt_labels[:50])))
map_at5 = np.mean(ap_at5) * 100
map_at10 = np.mean(ap_at10) * 100
map_at25 = np.mean(ap_at25) * 100
map_at50 = np.mean(ap_at50) * 100
recall_at5 = np.mean(recall_at5) * 100
recall_at10 = np.mean(recall_at10) * 100
recall_at25 = np.mean(recall_at25) * 100
recall_at50 = np.mean(recall_at50) * 100
return {
'circo_map_at5': map_at5,
'circo_map_at10': map_at10,
'circo_map_at25': map_at25,
'circo_map_at50': map_at50,
'circo_recall_at5': recall_at5,
'circo_recall_at10': recall_at10,
'circo_recall_at25': recall_at25,
'circo_recall_at50': recall_at50,
}
@torch.no_grad()
def circo_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str], pseudo_tokens: torch.Tensor,
preprocess: callable) -> Dict[str, float]:
"""
Compute the retrieval metrics on the CIRCO validation set given the pseudo tokens and the reference names
"""
# Load the model
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
#clip_model = clip_model.float().eval().requires_grad_(False)
# Extract the index features
classic_val_dataset = CIRCODataset(dataset_path, 'val', 'classic', preprocess)
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
# Define the relative validation dataset
relative_val_dataset = CIRCODataset(dataset_path, 'val', 'relative', preprocess)
return circo_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
pseudo_tokens)
def main():
parser = ArgumentParser()
parser.add_argument("--exp-name", type=str, help="Experiment to evaluate")
parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True,
help="If 'oti' evaluate directly using the inverted oti pseudo tokens, "
"if 'phi' predicts the pseudo tokens using the phi network, "
"if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, "
"if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens"
)
parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'fashioniq', 'circo'],
help="Dataset to use")
parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True)
parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'],
help="Preprocess pipeline to use")
parser.add_argument("--phi-checkpoint-name", type=str,
help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'")
parser.add_argument("--clip_model_name", default="giga", type=str)
parser.add_argument("--cache_dir", default="./hf_models", type=str)
parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
args = parser.parse_args()
#if args.eval_type in ['phi', 'oti'] and args.exp_name is None:
# raise ValueError("Experiment name is required when using phi or oti evaluation type")
if args.eval_type == 'phi' and args.phi_checkpoint_name is None:
raise ValueError("Phi checkpoint name is required when using phi evaluation type")
if args.eval_type == 'oti':
experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'val' / args.exp_name
if not experiment_path.exists():
raise ValueError(f"Experiment {args.exp_name} not found")
with open(experiment_path / 'hyperparameters.json') as f:
hyperparameters = json.load(f)
pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device)
with open(experiment_path / 'image_names.pkl', 'rb') as f:
ref_names_list = pickle.load(f)
clip_model_name = hyperparameters['clip_model_name']
clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False)
if args.preprocess_type == 'targetpad':
print('Target pad preprocess pipeline is used')
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
elif args.preprocess_type == 'clip':
print('CLIP preprocess pipeline is used')
preprocess = clip_preprocess
else:
raise ValueError("Preprocess type not supported")
elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']:
if args.eval_type == 'phi':
args.mixed_precision = 'fp16'
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
phi = Phi(input_dim=text_encoder.config.projection_dim,
hidden_dim=text_encoder.config.projection_dim * 4,
output_dim=text_encoder.config.hidden_size, dropout=0.5).to(
device)
phi.load_state_dict(
torch.load(args.phi_checkpoint_name, map_location=device)[
phi.__class__.__name__])
phi = phi.eval()
elif args.eval_type == 'pic2word':
args.mixed_precision = 'fp16'
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim,
output_dim=text_encoder.config.hidden_size,
).to(device)
sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text']
sd = {k[len('module.'):]: v for k, v in sd.items()}
phi.load_state_dict(sd)
phi = phi.eval()
else: # searle or searle-xl
if args.eval_type == 'searle':
clip_model_name = 'ViT-B/32'
else: # args.eval_type == 'searle-xl':
clip_model_name = 'ViT-L/14'
phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
backbone=clip_model_name)
phi = phi.to(device).eval()
clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False)
if args.preprocess_type == 'targetpad':
print('Target pad preprocess pipeline is used')
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
elif args.preprocess_type == 'clip':
print('CLIP preprocess pipeline is used')
preprocess = clip_preprocess
else:
raise ValueError("Preprocess type not supported")
if args.dataset.lower() == 'fashioniq':
relative_val_dataset = FashionIQDataset(args.dataset_path, 'val', ['dress', 'toptee', 'shirt'],
'relative', preprocess, no_duplicates=True)
elif args.dataset.lower() == 'cirr':
relative_val_dataset = CIRRDataset(args.dataset_path, 'val', 'relative', preprocess,
no_duplicates=True)
elif args.dataset.lower() == 'circo':
relative_val_dataset = CIRCODataset(args.dataset_path, 'val', 'relative', preprocess)
else:
raise ValueError("Dataset not supported")
#clip_model = clip_model.float().to(device)
image_encoder = image_encoder.float().to(device)
text_encoder = text_encoder.float().to(device)
pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_val_dataset, args)
pseudo_tokens = pseudo_tokens.to(device)
else:
raise ValueError("Eval type not supported")
print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t")
if args.dataset.lower() == 'fashioniq':
recalls_at10 = []
recalls_at50 = []
for dress_type in ['shirt', 'dress', 'toptee']:
fiq_metrics = fiq_val_retrieval(args.dataset_path, dress_type, image_encoder, text_encoder, ref_names_list,
pseudo_tokens, preprocess)
recalls_at10.append(fiq_metrics['fiq_recall_at10'])
recalls_at50.append(fiq_metrics['fiq_recall_at50'])
for k, v in fiq_metrics.items():
print(f"{dress_type}_{k} = {v:.2f}")
print("\n")
print(f"average_fiq_recall_at10 = {np.mean(recalls_at10):.2f}")
print(f"average_fiq_recall_at50 = {np.mean(recalls_at50):.2f}")
elif args.dataset.lower() == 'cirr':
cirr_metrics = cirr_val_retrieval(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
preprocess)
for k, v in cirr_metrics.items():
print(f"{k} = {v:.2f}")
elif args.dataset.lower() == 'circo':
circo_metrics = circo_val_retrieval(args.dataset_path, clip_model_name, ref_names_list, pseudo_tokens,
preprocess)
for k, v in circo_metrics.items():
print(f"{k} = {v:.2f}")
if __name__ == '__main__':
main()