Spaces:
Runtime error
Runtime error
import os | |
import json | |
import pickle | |
from argparse import ArgumentParser | |
from typing import List, Tuple, Dict | |
import clip | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from clip.model import CLIP | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from data_utils import PROJECT_ROOT, targetpad_transform | |
from loader import CIRRDataset, CIRCODataset | |
from encode_with_pseudo_tokens import encode_with_pseudo_tokens, encode_with_pseudo_tokens_HF | |
from models import build_text_encoder, Phi, PIC2WORD | |
from utils import extract_image_features, device, collate_fn, extract_pseudo_tokens_with_phi | |
def cirr_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str], | |
pseudo_tokens: torch.Tensor, preprocess: callable, submission_name: str) -> None: | |
""" | |
Generate the test submission file for the CIRR dataset given the pseudo tokens | |
""" | |
# Load the CLIP model | |
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False) | |
#clip_model = clip_model.float().eval() | |
# Compute the index features | |
classic_test_dataset = CIRRDataset(dataset_path, 'test1', 'classic', preprocess) | |
index_features, index_names = extract_image_features(classic_test_dataset, image_encoder) | |
relative_test_dataset = CIRRDataset(dataset_path, 'test1', 'relative', preprocess) | |
# Get the predictions dicts | |
pairid_to_retrieved_images, pairid_to_group_retrieved_images = \ | |
cirr_generate_test_dicts(relative_test_dataset, text_encoder, index_features, index_names, | |
ref_names_list, pseudo_tokens) | |
submission = { | |
'version': 'rc2', | |
'metric': 'recall' | |
} | |
group_submission = { | |
'version': 'rc2', | |
'metric': 'recall_subset' | |
} | |
submission.update(pairid_to_retrieved_images) | |
group_submission.update(pairid_to_group_retrieved_images) | |
submissions_folder_path = os.path.join('./submission', 'cirr') | |
os.makedirs(submissions_folder_path, exist_ok=True) | |
with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file: | |
json.dump(submission, file, sort_keys=True) | |
with open(os.path.join(submissions_folder_path, f"subset_{submission_name}.json"), 'w+') as file: | |
json.dump(group_submission, file, sort_keys=True) | |
def cirr_generate_test_dicts(relative_test_dataset: CIRRDataset, clip_model, index_features: torch.Tensor, | |
index_names: List[str], ref_names_list: List[str], pseudo_tokens: List[str]) \ | |
-> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: | |
""" | |
Generate the test submission dicts for the CIRR dataset given the pseudo tokens | |
""" | |
# Get the predicted features | |
predicted_features, reference_names, pairs_id, group_members = \ | |
cirr_generate_test_predictions(clip_model, relative_test_dataset, ref_names_list, pseudo_tokens) | |
print(f"Compute CIRR prediction dicts") | |
# Normalize the index features | |
index_features = index_features.to(device) | |
index_features = F.normalize(index_features, dim=-1).float() | |
# Compute the distances and sort the results | |
distances = 1 - predicted_features @ index_features.T | |
sorted_indices = torch.argsort(distances, dim=-1).cpu() | |
sorted_index_names = np.array(index_names)[sorted_indices] | |
# Delete the reference image from the results | |
reference_mask = torch.tensor( | |
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names), | |
-1)) | |
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0], | |
sorted_index_names.shape[1] - 1) | |
# Compute the subset predictions | |
group_members = np.array(group_members) | |
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool) | |
sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1) | |
# Generate prediction dicts | |
pairid_to_retrieved_images = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in | |
zip(pairs_id, sorted_index_names)} | |
pairid_to_group_retrieved_images = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in | |
zip(pairs_id, sorted_group_names)} | |
return pairid_to_retrieved_images, pairid_to_group_retrieved_images | |
def cirr_generate_test_predictions(clip_model, relative_test_dataset: CIRRDataset, ref_names_list: List[str], | |
pseudo_tokens: torch.Tensor) -> \ | |
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]: | |
""" | |
Generate the test prediction features for the CIRR dataset given the pseudo tokens | |
""" | |
# Create the test dataloader | |
relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10, | |
pin_memory=False) | |
predicted_features_list = [] | |
reference_names_list = [] | |
pair_id_list = [] | |
group_members_list = [] | |
# Compute the predictions | |
for batch in tqdm(relative_test_loader): | |
reference_names = batch['reference_name'] | |
pairs_id = batch['pair_id'] | |
relative_captions = batch['relative_caption'] | |
group_members = batch['group_members'] | |
group_members = np.array(group_members).T.tolist() | |
input_captions = [ | |
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions] | |
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names]) | |
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device) | |
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens) | |
predicted_features = F.normalize(text_features) | |
predicted_features_list.append(predicted_features) | |
reference_names_list.extend(reference_names) | |
pair_id_list.extend(pairs_id) | |
group_members_list.extend(group_members) | |
predicted_features = torch.vstack(predicted_features_list) | |
return predicted_features, reference_names_list, pair_id_list, group_members_list | |
def circo_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str], | |
pseudo_tokens: torch.Tensor, preprocess: callable, | |
submission_name: str) -> None: | |
""" | |
Generate the test submission file for the CIRCO dataset given the pseudo tokens | |
""" | |
# Load the CLIP model | |
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False) | |
#clip_model = clip_model.float().eval().requires_grad_(False) | |
# Compute the index features | |
classic_test_dataset = CIRCODataset(dataset_path, 'test', 'classic', preprocess) | |
index_features, index_names = extract_image_features(classic_test_dataset, image_encoder) | |
relative_test_dataset = CIRCODataset(dataset_path, 'test', 'relative', preprocess) | |
# Get the predictions dict | |
queryid_to_retrieved_images = circo_generate_test_dict(relative_test_dataset, text_encoder, index_features, | |
index_names, ref_names_list, pseudo_tokens) | |
submissions_folder_path = os.path.join('./submission', 'circo') | |
os.makedirs(submissions_folder_path, exist_ok=True) | |
with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file: | |
json.dump(queryid_to_retrieved_images, file, sort_keys=True) | |
def circo_generate_test_predictions(clip_model, relative_test_dataset: CIRCODataset, ref_names_list: List[str], | |
pseudo_tokens: torch.Tensor) -> [torch.Tensor, List[List[str]]]: | |
""" | |
Generate the test prediction features for the CIRCO dataset given the pseudo tokens | |
""" | |
# Create the test dataloader | |
relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10, | |
pin_memory=False, collate_fn=collate_fn, shuffle=False) | |
predicted_features_list = [] | |
query_ids_list = [] | |
# Compute the predictions | |
for batch in tqdm(relative_test_loader): | |
reference_names = batch['reference_name'] | |
relative_captions = batch['relative_caption'] | |
query_ids = batch['query_id'] | |
input_captions = [f"a photo of $ that {caption}" for caption in relative_captions] | |
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names]) | |
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device) | |
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens) | |
predicted_features = F.normalize(text_features) | |
predicted_features_list.append(predicted_features) | |
query_ids_list.extend(query_ids) | |
predicted_features = torch.vstack(predicted_features_list) | |
return predicted_features, query_ids_list | |
def circo_generate_test_dict(relative_test_dataset: CIRCODataset, clip_model, index_features: torch.Tensor, | |
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \ | |
-> Dict[str, List[str]]: | |
""" | |
Generate the test submission dicts for the CIRCO dataset given the pseudo tokens | |
""" | |
# Get the predicted features | |
predicted_features, query_ids = circo_generate_test_predictions(clip_model, relative_test_dataset, | |
ref_names_list, pseudo_tokens) | |
# Normalize the features | |
index_features = index_features.float().to(device) | |
index_features = F.normalize(index_features, dim=-1) | |
# Compute the similarity | |
similarity = predicted_features @ index_features.T | |
sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu() | |
sorted_index_names = np.array(index_names)[sorted_indices] | |
# Generate prediction dicts | |
queryid_to_retrieved_images = {query_id: query_sorted_names[:50].tolist() for | |
(query_id, query_sorted_names) in zip(query_ids, sorted_index_names)} | |
return queryid_to_retrieved_images | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument("--submission-name", type=str, required=True, help="Filename of the generated submission file") | |
parser.add_argument("--exp-name", type=str, help="Experiment to evaluate") | |
parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'circo'], help="Dataset to use") | |
parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True) | |
parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True, | |
help="If 'oti' evaluate directly using the inverted oti pseudo tokens, " | |
"if 'phi' predicts the pseudo tokens using the phi network, " | |
"if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, " | |
"if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens") | |
parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'], | |
help="Preprocess pipeline to use") | |
parser.add_argument("--phi-checkpoint-name", type=str, | |
help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'") | |
parser.add_argument("--clip_model_name", default="giga", type=str) | |
parser.add_argument("--cache_dir", default="./hf_models", type=str) | |
parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization") | |
args = parser.parse_args() | |
if args.eval_type == 'oti': | |
experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'test' / args.exp_name | |
with open(experiment_path / 'hyperparameters.json') as f: | |
hyperparameters = json.load(f) | |
pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device) | |
with open(experiment_path / 'image_names.pkl', 'rb') as f: | |
ref_names_list = pickle.load(f) | |
clip_model_name = hyperparameters['clip_model_name'] | |
clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False) | |
if args.preprocess_type == 'targetpad': | |
print('Target pad preprocess pipeline is used') | |
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution) | |
elif args.preprocess_type == 'clip': | |
print('CLIP preprocess pipeline is used') | |
preprocess = clip_preprocess | |
else: | |
raise ValueError("Preprocess type not supported") | |
elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']: | |
if args.eval_type == 'phi': | |
args.mixed_precision = 'fp16' | |
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args) | |
phi = Phi(input_dim=text_encoder.config.projection_dim, | |
hidden_dim=text_encoder.config.projection_dim * 4, | |
output_dim=text_encoder.config.hidden_size, dropout=0.5).to( | |
device) | |
phi.load_state_dict( | |
torch.load(args.phi_checkpoint_name, map_location=device)[ | |
phi.__class__.__name__]) | |
phi = phi.eval() | |
elif args.eval_type == 'pic2word': | |
args.mixed_precision = 'fp16' | |
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args) | |
phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim, | |
output_dim=text_encoder.config.hidden_size, | |
).to(device) | |
sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text'] | |
sd = {k[len('module.'):]: v for k, v in sd.items()} | |
phi.load_state_dict(sd) | |
phi = phi.eval() | |
else: # searle or searle-xl | |
if args.eval_type == 'searle': | |
clip_model_name = 'ViT-B/32' | |
else: # args.eval_type == 'searle-xl': | |
clip_model_name = 'ViT-L/14' | |
phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github', | |
backbone=clip_model_name) | |
phi = phi.to(device).eval() | |
clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False) | |
if args.preprocess_type == 'targetpad': | |
print('Target pad preprocess pipeline is used') | |
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution) | |
elif args.preprocess_type == 'clip': | |
print('CLIP preprocess pipeline is used') | |
preprocess = clip_preprocess | |
else: | |
raise ValueError("Preprocess type not supported") | |
if args.dataset.lower() == 'cirr': | |
relative_test_dataset = CIRRDataset(args.dataset_path, 'test', 'relative', preprocess, no_duplicates=True) | |
elif args.dataset.lower() == 'circo': | |
relative_test_dataset = CIRCODataset(args.dataset_path, 'test', 'relative', preprocess) | |
else: | |
raise ValueError("Dataset not supported") | |
#clip_model = clip_model.float().to(device) | |
image_encoder = image_encoder.float().to(device) | |
text_encoder = text_encoder.float().to(device) | |
pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_test_dataset, args) | |
pseudo_tokens = pseudo_tokens.to(device) | |
else: | |
raise ValueError("Eval type not supported") | |
print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t") | |
if args.dataset == 'cirr': | |
cirr_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens, | |
preprocess, args.submission_name) | |
elif args.dataset == 'circo': | |
circo_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens, | |
preprocess, args.submission_name) | |
else: | |
raise ValueError("Dataset not supported") | |
if __name__ == '__main__': | |
main() | |