Spaces:
Runtime error
Runtime error
initial commit
Browse files- README.md +1 -1
- app.py +220 -0
- data_utils.py +67 -0
- encode_with_pseudo_tokens.py +54 -0
- eval_templates.py +70 -0
- generate_test_submission.py +363 -0
- loader.py +632 -0
- models.py +192 -0
- requirements.txt +8 -0
- train_phi.py +317 -0
- utils.py +182 -0
- validate.py +650 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: LinCIR
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: LinCIR
|
3 |
+
emoji: π
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import gradio as gr
|
13 |
+
from clip_retrieval.clip_client import ClipClient
|
14 |
+
|
15 |
+
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
16 |
+
from models import build_text_encoder, Phi, PIC2WORD
|
17 |
+
|
18 |
+
import transformers
|
19 |
+
from huggingface_hub import hf_hub_url, cached_download
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
parser = ArgumentParser()
|
24 |
+
parser.add_argument("--lincir_ckpt_path", default=None, type=str,
|
25 |
+
help="The output directory where the model predictions and checkpoints will be written")
|
26 |
+
parser.add_argument("--pic2word_ckpt_path", default=None, type=str)
|
27 |
+
parser.add_argument("--cache_dir", default="./hf_models", type=str,
|
28 |
+
help="Path to model cache folder")
|
29 |
+
parser.add_argument("--clip_model_name", default="large", type=str,
|
30 |
+
help="CLIP model to use, e.g 'large', 'huge', 'giga'")
|
31 |
+
parser.add_argument("--mixed_precision", default="fp16", type=str)
|
32 |
+
parser.add_argument("--test_fps", action="store_true")
|
33 |
+
args = parser.parse_args()
|
34 |
+
return args
|
35 |
+
|
36 |
+
|
37 |
+
def load_models(args):
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
device = 'cuda:0'
|
40 |
+
dtype = torch.float16
|
41 |
+
else:
|
42 |
+
device = 'cpu'
|
43 |
+
dtype = torch.float32
|
44 |
+
|
45 |
+
clip_vision_model, clip_preprocess, clip_text_model, tokenizer = build_text_encoder(args)
|
46 |
+
|
47 |
+
tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # 49408
|
48 |
+
|
49 |
+
# ours
|
50 |
+
phi = Phi(input_dim=clip_text_model.config.projection_dim,
|
51 |
+
hidden_dim=clip_text_model.config.projection_dim * 4,
|
52 |
+
output_dim=clip_text_model.config.hidden_size, dropout=0.0)
|
53 |
+
phi.eval()
|
54 |
+
|
55 |
+
# searle
|
56 |
+
phi_searle, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
|
57 |
+
backbone='ViT-L/14')
|
58 |
+
phi_searle.eval()
|
59 |
+
|
60 |
+
# pic2word
|
61 |
+
phi_pic2word = PIC2WORD(embed_dim=clip_text_model.config.projection_dim,
|
62 |
+
output_dim=clip_text_model.config.hidden_size)
|
63 |
+
phi_pic2word.eval()
|
64 |
+
|
65 |
+
clip_vision_model.to(device, dtype=dtype)
|
66 |
+
clip_text_model.to(device, dtype=dtype)
|
67 |
+
|
68 |
+
if not args.test_fps:
|
69 |
+
# download and load sd
|
70 |
+
if not os.path.exists('./pretrained_models/lincir_large.pt'):
|
71 |
+
model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='lincir_large.pt')
|
72 |
+
cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='lincir_large.pt')
|
73 |
+
state_dict = torch.load('./pretrained_models/lincir_large.pt', map_location=device)
|
74 |
+
phi.load_state_dict(state_dict['Phi'])
|
75 |
+
|
76 |
+
if not os.path.exists('./pretrained_models/pic2word_large.pt'):
|
77 |
+
model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='pic2word_large.pt')
|
78 |
+
cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='pic2word_large.pt')
|
79 |
+
sd = torch.load('./pretrained_models/pic2word_large.pt', map_location=device)['state_dict_img2text']
|
80 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
81 |
+
phi_pic2word.load_state_dict(sd)
|
82 |
+
|
83 |
+
phi.to(device, dtype=dtype)
|
84 |
+
phi_searle.to(device, dtype=dtype)
|
85 |
+
phi_pic2word.to(device, dtype=dtype)
|
86 |
+
|
87 |
+
decoder = None
|
88 |
+
|
89 |
+
return {'clip_vision_model': clip_vision_model,
|
90 |
+
'clip_preprocess': clip_preprocess,
|
91 |
+
'clip_text_model': clip_text_model,
|
92 |
+
'tokenizer': tokenizer,
|
93 |
+
'phi': phi,
|
94 |
+
'phi_searle': phi_searle,
|
95 |
+
'phi_pic2word': phi_pic2word,
|
96 |
+
'decoder': decoder,
|
97 |
+
'device': device,
|
98 |
+
'dtype': dtype,
|
99 |
+
'clip_model_name': args.clip_model_name,
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def predict(images, input_text, model_name):
|
104 |
+
start_time = time.time()
|
105 |
+
input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
|
106 |
+
input_text = input_text.replace('$', '[$]')
|
107 |
+
input_tokens = model_dict['tokenizer'](text=input_text, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
|
108 |
+
input_tokens = torch.where(input_tokens == 49408,
|
109 |
+
torch.ones_like(input_tokens) * 259,
|
110 |
+
input_tokens)
|
111 |
+
image_features = model_dict['clip_vision_model'](pixel_values=input_images.to(model_dict['dtype'])).image_embeds
|
112 |
+
clip_image_time = time.time() - start_time
|
113 |
+
|
114 |
+
start_time = time.time()
|
115 |
+
if model_name == 'lincir':
|
116 |
+
estimated_token_embeddings = model_dict['phi'](image_features)
|
117 |
+
elif model_name == 'searle':
|
118 |
+
estimated_token_embeddings = model_dict['phi_searle'](image_features)
|
119 |
+
else: # model_name == 'pic2word'
|
120 |
+
estimated_token_embeddings = model_dict['phi_pic2word'](image_features)
|
121 |
+
phi_time = time.time() - start_time
|
122 |
+
|
123 |
+
start_time = time.time()
|
124 |
+
text_embeddings, text_last_hidden_states = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, estimated_token_embeddings, return_last_states=True)
|
125 |
+
clip_text_time = time.time() - start_time
|
126 |
+
|
127 |
+
start_time = time.time()
|
128 |
+
results = client.query(embedding_input=text_embeddings[0].tolist())
|
129 |
+
retrieval_time = time.time() - start_time
|
130 |
+
|
131 |
+
output = ''
|
132 |
+
|
133 |
+
for idx, result in enumerate(results):
|
134 |
+
image_url = result['url']
|
135 |
+
output += f'![image]({image_url})\n'
|
136 |
+
|
137 |
+
time_output = {'CLIP visual extractor': clip_image_time,
|
138 |
+
'CLIP textual extractor': clip_text_time,
|
139 |
+
'Phi projection': phi_time,
|
140 |
+
'CLIP retrieval': retrieval_time,
|
141 |
+
}
|
142 |
+
setup_output = {'device': model_dict['device'],
|
143 |
+
'dtype': model_dict['dtype'],
|
144 |
+
'Phi': model_name,
|
145 |
+
'CLIP': model_dict['clip_model_name'],
|
146 |
+
}
|
147 |
+
|
148 |
+
return {'time': time_output, 'setup': setup_output}, output
|
149 |
+
|
150 |
+
|
151 |
+
def test_fps(batch_size=1):
|
152 |
+
dummy_images = torch.rand([batch_size, 3, 224, 224])
|
153 |
+
|
154 |
+
todo_list = ['phi', 'phi_pic2word']
|
155 |
+
|
156 |
+
input_tokens = model_dict['tokenizer'](text=['a photo of $1 with flowers'] * batch_size, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
|
157 |
+
input_tokens = torch.where(input_tokens == 49409,
|
158 |
+
torch.ones_like(input_tokens) * 259,
|
159 |
+
input_tokens)
|
160 |
+
|
161 |
+
for model_name in todo_list:
|
162 |
+
time_array = []
|
163 |
+
n_repeat = 100
|
164 |
+
for _ in range(n_repeat):
|
165 |
+
start_time = time.time()
|
166 |
+
image_features = model_dict['clip_vision_model'](pixel_values=dummy_images.to(model_dict['clip_vision_model'].device, dtype=model_dict['clip_vision_model'].dtype)).image_embeds
|
167 |
+
token_embeddings = model_dict[model_name](image_features)
|
168 |
+
text_embeddings = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, token_embeddings)
|
169 |
+
end_time = time.time()
|
170 |
+
if _ > 5:
|
171 |
+
time_array.append(end_time - start_time)
|
172 |
+
print(f"{model_name}: {np.mean(time_array):.4f}")
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == '__main__':
|
176 |
+
args = parse_args()
|
177 |
+
|
178 |
+
global model_dict, client
|
179 |
+
|
180 |
+
model_dict = load_models(args)
|
181 |
+
|
182 |
+
if args.test_fps:
|
183 |
+
# check FPS of all models.
|
184 |
+
test_fps(1)
|
185 |
+
exit()
|
186 |
+
|
187 |
+
|
188 |
+
client = ClipClient(url="https://knn.laion.ai/knn-service",
|
189 |
+
indice_name="laion5B-H-14" if args.clip_model_name == "huge" else "laion5B-L-14",
|
190 |
+
)
|
191 |
+
|
192 |
+
title = 'Zeroshot CIR demo'
|
193 |
+
|
194 |
+
md_title = f'''# {title}
|
195 |
+
[LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
|
196 |
+
[SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
|
197 |
+
[Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
|
198 |
+
|
199 |
+
K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. This is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval).
|
200 |
+
'''
|
201 |
+
|
202 |
+
with gr.Blocks(title=title) as demo:
|
203 |
+
gr.Markdown(md_title)
|
204 |
+
with gr.Row():
|
205 |
+
with gr.Column():
|
206 |
+
with gr.Row():
|
207 |
+
image_source = gr.Image(type='pil', label='image1')
|
208 |
+
model_name = gr.Radio(['lincir', 'searle', 'pic2word'], label='Phi model', value='lincir')
|
209 |
+
text_input = gr.Textbox(value='', label='Input text guidance. Special token is $')
|
210 |
+
submit_button = gr.Button('Submit')
|
211 |
+
gr.Examples([["example1.jpg", "$, pencil sketch", 'lincir']], inputs=[image_source, text_input, model_name])
|
212 |
+
with gr.Column():
|
213 |
+
json_output = gr.JSON(label='Processing time')
|
214 |
+
md_output = gr.Markdown(label='Output')
|
215 |
+
|
216 |
+
submit_button.click(predict, inputs=[image_source, text_input, model_name], outputs=[json_output, md_output])
|
217 |
+
|
218 |
+
demo.queue()
|
219 |
+
|
220 |
+
demo.launch()
|
data_utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import PIL
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms.functional as FT
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
|
8 |
+
from torchvision.transforms import InterpolationMode
|
9 |
+
|
10 |
+
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
|
11 |
+
|
12 |
+
|
13 |
+
def _convert_image_to_rgb(image):
|
14 |
+
return image.convert("RGB")
|
15 |
+
|
16 |
+
|
17 |
+
def collate_fn(batch):
|
18 |
+
'''
|
19 |
+
function which discard None images in a batch when using torch DataLoader
|
20 |
+
:param batch: input_batch
|
21 |
+
:return: output_batch = input_batch - None_values
|
22 |
+
'''
|
23 |
+
batch = list(filter(lambda x: x is not None, batch))
|
24 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
25 |
+
|
26 |
+
|
27 |
+
class TargetPad:
|
28 |
+
"""
|
29 |
+
If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
|
30 |
+
For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, target_ratio: float, size: int):
|
34 |
+
"""
|
35 |
+
:param target_ratio: target ratio
|
36 |
+
:param size: preprocessing output dimension
|
37 |
+
"""
|
38 |
+
self.size = size
|
39 |
+
self.target_ratio = target_ratio
|
40 |
+
|
41 |
+
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
42 |
+
w, h = image.size
|
43 |
+
actual_ratio = max(w, h) / min(w, h)
|
44 |
+
if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
|
45 |
+
return image
|
46 |
+
scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
|
47 |
+
hp = max(int((scaled_max_wh - w) / 2), 0)
|
48 |
+
vp = max(int((scaled_max_wh - h) / 2), 0)
|
49 |
+
padding = [hp, vp, hp, vp]
|
50 |
+
return FT.pad(image, padding, 0, 'constant')
|
51 |
+
|
52 |
+
|
53 |
+
def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
|
54 |
+
"""
|
55 |
+
CLIP-like preprocessing transform computed after using TargetPad pad
|
56 |
+
:param target_ratio: target ratio for TargetPad
|
57 |
+
:param dim: image output dimension
|
58 |
+
:return: CLIP-like torchvision Compose transform
|
59 |
+
"""
|
60 |
+
return Compose([
|
61 |
+
TargetPad(target_ratio, dim),
|
62 |
+
Resize(dim, interpolation=InterpolationMode.BICUBIC),
|
63 |
+
CenterCrop(dim),
|
64 |
+
_convert_image_to_rgb,
|
65 |
+
ToTensor(),
|
66 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
67 |
+
])
|
encode_with_pseudo_tokens.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
from clip.model import CLIP
|
8 |
+
from transformers import CLIPTextModelWithProjection
|
9 |
+
|
10 |
+
|
11 |
+
def _make_causal_mask(
|
12 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Make causal mask used for bi-directional self-attention.
|
16 |
+
Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693
|
17 |
+
"""
|
18 |
+
bsz, tgt_len = input_ids_shape
|
19 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
20 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
21 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
22 |
+
mask = mask.to(dtype)
|
23 |
+
|
24 |
+
if past_key_values_length > 0:
|
25 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
26 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
27 |
+
|
28 |
+
|
29 |
+
def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor,
|
30 |
+
num_tokens=1, return_last_states=False) -> torch.Tensor:
|
31 |
+
x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model]
|
32 |
+
x = torch.where(text.unsqueeze(-1) == 259,
|
33 |
+
pseudo_tokens.unsqueeze(1).type(clip_model.dtype),
|
34 |
+
x)
|
35 |
+
x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
|
36 |
+
_causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device)
|
37 |
+
x = clip_model.text_model.encoder(inputs_embeds=x,
|
38 |
+
attention_mask=None,
|
39 |
+
causal_attention_mask=_causal_attention_mask,
|
40 |
+
output_attentions=False,
|
41 |
+
output_hidden_states=False,
|
42 |
+
return_dict=False)
|
43 |
+
x = x[0]
|
44 |
+
x_last = clip_model.text_model.final_layer_norm(x)
|
45 |
+
x = x_last[torch.arange(x_last.shape[0], device=x_last.device),
|
46 |
+
text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1),
|
47 |
+
]
|
48 |
+
if hasattr(clip_model, 'text_projection'):
|
49 |
+
x = clip_model.text_projection(x)
|
50 |
+
|
51 |
+
if return_last_states:
|
52 |
+
return x, x_last
|
53 |
+
else:
|
54 |
+
return x
|
eval_templates.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
templates = [
|
7 |
+
lambda caption: f"a photo of $ that {caption}",
|
8 |
+
lambda caption: f"$ that {caption}",
|
9 |
+
lambda caption: f"$ with {caption}",
|
10 |
+
lambda caption: f"$ , {caption}",
|
11 |
+
lambda caption: f"$ adapted to {caption}",
|
12 |
+
lambda caption: f"$ modified by {caption}",
|
13 |
+
lambda caption: f"$ in response to {caption}",
|
14 |
+
lambda caption: f"$ transformed by {caption}",
|
15 |
+
lambda caption: f"$ influenced by {caption}",
|
16 |
+
lambda caption: f"Retrieval of $ using feedback {caption}",
|
17 |
+
lambda caption: f"$ guided by {caption}",
|
18 |
+
lambda caption: f"$ adjusted to {caption}",
|
19 |
+
lambda caption: f"$ in alignment with {caption}",
|
20 |
+
lambda caption: f"$ in correspondence to {caption}",
|
21 |
+
lambda caption: f"$ refined with {caption}",
|
22 |
+
lambda caption: f"$ as directed by {caption}",
|
23 |
+
lambda caption: f"$ evolved from {caption}",
|
24 |
+
lambda caption: f"$ inspired by {caption}",
|
25 |
+
lambda caption: f"$ with adjustments from {caption}",
|
26 |
+
lambda caption: f"$ in consideration of {caption}",
|
27 |
+
lambda caption: f"$ , taking into account {caption}",
|
28 |
+
lambda caption: f"$ as influenced by the query {caption}",
|
29 |
+
lambda caption: f"$ reshaped by {caption}",
|
30 |
+
lambda caption: f"$ curated based on {caption}",
|
31 |
+
lambda caption: f"$ showcasing {caption}",
|
32 |
+
lambda caption: f"An instance of $ where {caption}",
|
33 |
+
lambda caption: f"$ highlighting {caption}",
|
34 |
+
lambda caption: f"A depiction of $ exhibiting {caption}",
|
35 |
+
lambda caption: f"$ as exemplified by {caption}",
|
36 |
+
lambda caption: f"$ demonstrating {caption}",
|
37 |
+
lambda caption: f"An illustration of $ portraying {caption}",
|
38 |
+
lambda caption: f"$ in the context of {caption}",
|
39 |
+
lambda caption: f"$ as influenced by {caption}",
|
40 |
+
lambda caption: f"$ characterized by {caption}",
|
41 |
+
lambda caption: f"$ : An exploration of {caption}",
|
42 |
+
lambda caption: f"A presentation of $ underlined by {caption}",
|
43 |
+
lambda caption: f"A manifestation of $ reflecting {caption}",
|
44 |
+
lambda caption: f"$ in light of {caption}",
|
45 |
+
lambda caption: f"$ as a testament to {caption}",
|
46 |
+
lambda caption: f"$ intertwined with {caption}",
|
47 |
+
lambda caption: f"$ complemented by {caption}",
|
48 |
+
lambda caption: f"$ juxtaposed with {caption}",
|
49 |
+
lambda caption: f"A representation of $ in relation to {caption}",
|
50 |
+
lambda caption: f"$ that {caption}",
|
51 |
+
lambda caption: f"$ which {caption}",
|
52 |
+
lambda caption: f"$ where it {caption}",
|
53 |
+
lambda caption: f"Discover $ that {caption}",
|
54 |
+
lambda caption: f"Retrieve $ that {caption}",
|
55 |
+
lambda caption: f"Search for $ that {caption}",
|
56 |
+
lambda caption: f"Identify $ which {caption}",
|
57 |
+
lambda caption: f"Highlight $ that {caption}",
|
58 |
+
lambda caption: f"Present $ where it {caption}",
|
59 |
+
lambda caption: f"Showcase $ that {caption}",
|
60 |
+
lambda caption: f"Explore $ which {caption}",
|
61 |
+
lambda caption: f"Find $ that {caption}",
|
62 |
+
lambda caption: f"Source $ which {caption}",
|
63 |
+
lambda caption: f"View $ where it {caption}",
|
64 |
+
lambda caption: f"Examine $ that {caption}",
|
65 |
+
lambda caption: f"Analyze $ which {caption}",
|
66 |
+
lambda caption: f"Observe $ that {caption}",
|
67 |
+
lambda caption: f"Report $ which {caption}",
|
68 |
+
lambda caption: f"See $ where it {caption}",
|
69 |
+
lambda caption: f"Document $ that {caption}"
|
70 |
+
]
|
generate_test_submission.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pickle
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
from typing import List, Tuple, Dict
|
6 |
+
|
7 |
+
import clip
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from clip.model import CLIP
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from data_utils import PROJECT_ROOT, targetpad_transform
|
16 |
+
from loader import CIRRDataset, CIRCODataset
|
17 |
+
from encode_with_pseudo_tokens import encode_with_pseudo_tokens, encode_with_pseudo_tokens_HF
|
18 |
+
from models import build_text_encoder, Phi, PIC2WORD
|
19 |
+
from utils import extract_image_features, device, collate_fn, extract_pseudo_tokens_with_phi
|
20 |
+
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def cirr_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str],
|
24 |
+
pseudo_tokens: torch.Tensor, preprocess: callable, submission_name: str) -> None:
|
25 |
+
"""
|
26 |
+
Generate the test submission file for the CIRR dataset given the pseudo tokens
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Load the CLIP model
|
30 |
+
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
|
31 |
+
#clip_model = clip_model.float().eval()
|
32 |
+
|
33 |
+
# Compute the index features
|
34 |
+
classic_test_dataset = CIRRDataset(dataset_path, 'test1', 'classic', preprocess)
|
35 |
+
index_features, index_names = extract_image_features(classic_test_dataset, image_encoder)
|
36 |
+
|
37 |
+
relative_test_dataset = CIRRDataset(dataset_path, 'test1', 'relative', preprocess)
|
38 |
+
|
39 |
+
# Get the predictions dicts
|
40 |
+
pairid_to_retrieved_images, pairid_to_group_retrieved_images = \
|
41 |
+
cirr_generate_test_dicts(relative_test_dataset, text_encoder, index_features, index_names,
|
42 |
+
ref_names_list, pseudo_tokens)
|
43 |
+
|
44 |
+
submission = {
|
45 |
+
'version': 'rc2',
|
46 |
+
'metric': 'recall'
|
47 |
+
}
|
48 |
+
group_submission = {
|
49 |
+
'version': 'rc2',
|
50 |
+
'metric': 'recall_subset'
|
51 |
+
}
|
52 |
+
|
53 |
+
submission.update(pairid_to_retrieved_images)
|
54 |
+
group_submission.update(pairid_to_group_retrieved_images)
|
55 |
+
|
56 |
+
submissions_folder_path = os.path.join('./submission', 'cirr')
|
57 |
+
os.makedirs(submissions_folder_path, exist_ok=True)
|
58 |
+
|
59 |
+
with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file:
|
60 |
+
json.dump(submission, file, sort_keys=True)
|
61 |
+
|
62 |
+
with open(os.path.join(submissions_folder_path, f"subset_{submission_name}.json"), 'w+') as file:
|
63 |
+
json.dump(group_submission, file, sort_keys=True)
|
64 |
+
|
65 |
+
|
66 |
+
def cirr_generate_test_dicts(relative_test_dataset: CIRRDataset, clip_model, index_features: torch.Tensor,
|
67 |
+
index_names: List[str], ref_names_list: List[str], pseudo_tokens: List[str]) \
|
68 |
+
-> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
|
69 |
+
"""
|
70 |
+
Generate the test submission dicts for the CIRR dataset given the pseudo tokens
|
71 |
+
"""
|
72 |
+
|
73 |
+
# Get the predicted features
|
74 |
+
predicted_features, reference_names, pairs_id, group_members = \
|
75 |
+
cirr_generate_test_predictions(clip_model, relative_test_dataset, ref_names_list, pseudo_tokens)
|
76 |
+
|
77 |
+
print(f"Compute CIRR prediction dicts")
|
78 |
+
|
79 |
+
# Normalize the index features
|
80 |
+
index_features = index_features.to(device)
|
81 |
+
index_features = F.normalize(index_features, dim=-1).float()
|
82 |
+
|
83 |
+
# Compute the distances and sort the results
|
84 |
+
distances = 1 - predicted_features @ index_features.T
|
85 |
+
sorted_indices = torch.argsort(distances, dim=-1).cpu()
|
86 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
87 |
+
|
88 |
+
# Delete the reference image from the results
|
89 |
+
reference_mask = torch.tensor(
|
90 |
+
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names),
|
91 |
+
-1))
|
92 |
+
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
|
93 |
+
sorted_index_names.shape[1] - 1)
|
94 |
+
# Compute the subset predictions
|
95 |
+
group_members = np.array(group_members)
|
96 |
+
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
|
97 |
+
sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1)
|
98 |
+
|
99 |
+
# Generate prediction dicts
|
100 |
+
pairid_to_retrieved_images = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in
|
101 |
+
zip(pairs_id, sorted_index_names)}
|
102 |
+
pairid_to_group_retrieved_images = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in
|
103 |
+
zip(pairs_id, sorted_group_names)}
|
104 |
+
|
105 |
+
return pairid_to_retrieved_images, pairid_to_group_retrieved_images
|
106 |
+
|
107 |
+
|
108 |
+
def cirr_generate_test_predictions(clip_model, relative_test_dataset: CIRRDataset, ref_names_list: List[str],
|
109 |
+
pseudo_tokens: torch.Tensor) -> \
|
110 |
+
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
|
111 |
+
"""
|
112 |
+
Generate the test prediction features for the CIRR dataset given the pseudo tokens
|
113 |
+
"""
|
114 |
+
|
115 |
+
# Create the test dataloader
|
116 |
+
relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10,
|
117 |
+
pin_memory=False)
|
118 |
+
|
119 |
+
predicted_features_list = []
|
120 |
+
reference_names_list = []
|
121 |
+
pair_id_list = []
|
122 |
+
group_members_list = []
|
123 |
+
|
124 |
+
# Compute the predictions
|
125 |
+
for batch in tqdm(relative_test_loader):
|
126 |
+
reference_names = batch['reference_name']
|
127 |
+
pairs_id = batch['pair_id']
|
128 |
+
relative_captions = batch['relative_caption']
|
129 |
+
group_members = batch['group_members']
|
130 |
+
|
131 |
+
group_members = np.array(group_members).T.tolist()
|
132 |
+
|
133 |
+
input_captions = [
|
134 |
+
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
|
135 |
+
|
136 |
+
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
137 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
138 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
139 |
+
|
140 |
+
predicted_features = F.normalize(text_features)
|
141 |
+
|
142 |
+
predicted_features_list.append(predicted_features)
|
143 |
+
reference_names_list.extend(reference_names)
|
144 |
+
pair_id_list.extend(pairs_id)
|
145 |
+
group_members_list.extend(group_members)
|
146 |
+
|
147 |
+
predicted_features = torch.vstack(predicted_features_list)
|
148 |
+
|
149 |
+
return predicted_features, reference_names_list, pair_id_list, group_members_list
|
150 |
+
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def circo_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str],
|
154 |
+
pseudo_tokens: torch.Tensor, preprocess: callable,
|
155 |
+
submission_name: str) -> None:
|
156 |
+
"""
|
157 |
+
Generate the test submission file for the CIRCO dataset given the pseudo tokens
|
158 |
+
"""
|
159 |
+
|
160 |
+
# Load the CLIP model
|
161 |
+
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
|
162 |
+
#clip_model = clip_model.float().eval().requires_grad_(False)
|
163 |
+
|
164 |
+
# Compute the index features
|
165 |
+
classic_test_dataset = CIRCODataset(dataset_path, 'test', 'classic', preprocess)
|
166 |
+
index_features, index_names = extract_image_features(classic_test_dataset, image_encoder)
|
167 |
+
|
168 |
+
relative_test_dataset = CIRCODataset(dataset_path, 'test', 'relative', preprocess)
|
169 |
+
|
170 |
+
# Get the predictions dict
|
171 |
+
queryid_to_retrieved_images = circo_generate_test_dict(relative_test_dataset, text_encoder, index_features,
|
172 |
+
index_names, ref_names_list, pseudo_tokens)
|
173 |
+
|
174 |
+
submissions_folder_path = os.path.join('./submission', 'circo')
|
175 |
+
os.makedirs(submissions_folder_path, exist_ok=True)
|
176 |
+
|
177 |
+
with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file:
|
178 |
+
json.dump(queryid_to_retrieved_images, file, sort_keys=True)
|
179 |
+
|
180 |
+
|
181 |
+
def circo_generate_test_predictions(clip_model, relative_test_dataset: CIRCODataset, ref_names_list: List[str],
|
182 |
+
pseudo_tokens: torch.Tensor) -> [torch.Tensor, List[List[str]]]:
|
183 |
+
"""
|
184 |
+
Generate the test prediction features for the CIRCO dataset given the pseudo tokens
|
185 |
+
"""
|
186 |
+
|
187 |
+
# Create the test dataloader
|
188 |
+
relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10,
|
189 |
+
pin_memory=False, collate_fn=collate_fn, shuffle=False)
|
190 |
+
|
191 |
+
predicted_features_list = []
|
192 |
+
query_ids_list = []
|
193 |
+
|
194 |
+
# Compute the predictions
|
195 |
+
for batch in tqdm(relative_test_loader):
|
196 |
+
reference_names = batch['reference_name']
|
197 |
+
relative_captions = batch['relative_caption']
|
198 |
+
query_ids = batch['query_id']
|
199 |
+
|
200 |
+
input_captions = [f"a photo of $ that {caption}" for caption in relative_captions]
|
201 |
+
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
202 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
203 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
204 |
+
predicted_features = F.normalize(text_features)
|
205 |
+
|
206 |
+
predicted_features_list.append(predicted_features)
|
207 |
+
query_ids_list.extend(query_ids)
|
208 |
+
|
209 |
+
predicted_features = torch.vstack(predicted_features_list)
|
210 |
+
return predicted_features, query_ids_list
|
211 |
+
|
212 |
+
|
213 |
+
def circo_generate_test_dict(relative_test_dataset: CIRCODataset, clip_model, index_features: torch.Tensor,
|
214 |
+
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
|
215 |
+
-> Dict[str, List[str]]:
|
216 |
+
"""
|
217 |
+
Generate the test submission dicts for the CIRCO dataset given the pseudo tokens
|
218 |
+
"""
|
219 |
+
|
220 |
+
# Get the predicted features
|
221 |
+
predicted_features, query_ids = circo_generate_test_predictions(clip_model, relative_test_dataset,
|
222 |
+
ref_names_list, pseudo_tokens)
|
223 |
+
|
224 |
+
# Normalize the features
|
225 |
+
index_features = index_features.float().to(device)
|
226 |
+
index_features = F.normalize(index_features, dim=-1)
|
227 |
+
|
228 |
+
# Compute the similarity
|
229 |
+
similarity = predicted_features @ index_features.T
|
230 |
+
sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu()
|
231 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
232 |
+
|
233 |
+
# Generate prediction dicts
|
234 |
+
queryid_to_retrieved_images = {query_id: query_sorted_names[:50].tolist() for
|
235 |
+
(query_id, query_sorted_names) in zip(query_ids, sorted_index_names)}
|
236 |
+
|
237 |
+
return queryid_to_retrieved_images
|
238 |
+
|
239 |
+
|
240 |
+
def main():
|
241 |
+
parser = ArgumentParser()
|
242 |
+
parser.add_argument("--submission-name", type=str, required=True, help="Filename of the generated submission file")
|
243 |
+
parser.add_argument("--exp-name", type=str, help="Experiment to evaluate")
|
244 |
+
parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'circo'], help="Dataset to use")
|
245 |
+
parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True)
|
246 |
+
parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True,
|
247 |
+
help="If 'oti' evaluate directly using the inverted oti pseudo tokens, "
|
248 |
+
"if 'phi' predicts the pseudo tokens using the phi network, "
|
249 |
+
"if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, "
|
250 |
+
"if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens")
|
251 |
+
|
252 |
+
parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'],
|
253 |
+
help="Preprocess pipeline to use")
|
254 |
+
parser.add_argument("--phi-checkpoint-name", type=str,
|
255 |
+
help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'")
|
256 |
+
|
257 |
+
parser.add_argument("--clip_model_name", default="giga", type=str)
|
258 |
+
parser.add_argument("--cache_dir", default="./hf_models", type=str)
|
259 |
+
|
260 |
+
parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
|
261 |
+
|
262 |
+
args = parser.parse_args()
|
263 |
+
|
264 |
+
if args.eval_type == 'oti':
|
265 |
+
experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'test' / args.exp_name
|
266 |
+
|
267 |
+
with open(experiment_path / 'hyperparameters.json') as f:
|
268 |
+
hyperparameters = json.load(f)
|
269 |
+
|
270 |
+
pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device)
|
271 |
+
with open(experiment_path / 'image_names.pkl', 'rb') as f:
|
272 |
+
ref_names_list = pickle.load(f)
|
273 |
+
|
274 |
+
clip_model_name = hyperparameters['clip_model_name']
|
275 |
+
clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False)
|
276 |
+
|
277 |
+
if args.preprocess_type == 'targetpad':
|
278 |
+
print('Target pad preprocess pipeline is used')
|
279 |
+
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
|
280 |
+
elif args.preprocess_type == 'clip':
|
281 |
+
print('CLIP preprocess pipeline is used')
|
282 |
+
preprocess = clip_preprocess
|
283 |
+
else:
|
284 |
+
raise ValueError("Preprocess type not supported")
|
285 |
+
|
286 |
+
|
287 |
+
elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']:
|
288 |
+
if args.eval_type == 'phi':
|
289 |
+
args.mixed_precision = 'fp16'
|
290 |
+
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
|
291 |
+
|
292 |
+
phi = Phi(input_dim=text_encoder.config.projection_dim,
|
293 |
+
hidden_dim=text_encoder.config.projection_dim * 4,
|
294 |
+
output_dim=text_encoder.config.hidden_size, dropout=0.5).to(
|
295 |
+
device)
|
296 |
+
|
297 |
+
phi.load_state_dict(
|
298 |
+
torch.load(args.phi_checkpoint_name, map_location=device)[
|
299 |
+
phi.__class__.__name__])
|
300 |
+
phi = phi.eval()
|
301 |
+
|
302 |
+
elif args.eval_type == 'pic2word':
|
303 |
+
args.mixed_precision = 'fp16'
|
304 |
+
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
|
305 |
+
|
306 |
+
phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim,
|
307 |
+
output_dim=text_encoder.config.hidden_size,
|
308 |
+
).to(device)
|
309 |
+
sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text']
|
310 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
311 |
+
phi.load_state_dict(sd)
|
312 |
+
phi = phi.eval()
|
313 |
+
|
314 |
+
else: # searle or searle-xl
|
315 |
+
if args.eval_type == 'searle':
|
316 |
+
clip_model_name = 'ViT-B/32'
|
317 |
+
else: # args.eval_type == 'searle-xl':
|
318 |
+
clip_model_name = 'ViT-L/14'
|
319 |
+
phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
|
320 |
+
backbone=clip_model_name)
|
321 |
+
|
322 |
+
phi = phi.to(device).eval()
|
323 |
+
clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False)
|
324 |
+
|
325 |
+
if args.preprocess_type == 'targetpad':
|
326 |
+
print('Target pad preprocess pipeline is used')
|
327 |
+
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
|
328 |
+
elif args.preprocess_type == 'clip':
|
329 |
+
print('CLIP preprocess pipeline is used')
|
330 |
+
preprocess = clip_preprocess
|
331 |
+
else:
|
332 |
+
raise ValueError("Preprocess type not supported")
|
333 |
+
|
334 |
+
if args.dataset.lower() == 'cirr':
|
335 |
+
relative_test_dataset = CIRRDataset(args.dataset_path, 'test', 'relative', preprocess, no_duplicates=True)
|
336 |
+
elif args.dataset.lower() == 'circo':
|
337 |
+
relative_test_dataset = CIRCODataset(args.dataset_path, 'test', 'relative', preprocess)
|
338 |
+
else:
|
339 |
+
raise ValueError("Dataset not supported")
|
340 |
+
|
341 |
+
#clip_model = clip_model.float().to(device)
|
342 |
+
image_encoder = image_encoder.float().to(device)
|
343 |
+
text_encoder = text_encoder.float().to(device)
|
344 |
+
pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_test_dataset, args)
|
345 |
+
pseudo_tokens = pseudo_tokens.to(device)
|
346 |
+
else:
|
347 |
+
raise ValueError("Eval type not supported")
|
348 |
+
|
349 |
+
print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t")
|
350 |
+
|
351 |
+
if args.dataset == 'cirr':
|
352 |
+
cirr_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
|
353 |
+
preprocess, args.submission_name)
|
354 |
+
elif args.dataset == 'circo':
|
355 |
+
circo_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
|
356 |
+
preprocess, args.submission_name)
|
357 |
+
|
358 |
+
else:
|
359 |
+
raise ValueError("Dataset not supported")
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == '__main__':
|
363 |
+
main()
|
loader.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
import os
|
7 |
+
import functools
|
8 |
+
import glob
|
9 |
+
import random
|
10 |
+
import json
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import List, Optional, Union, Dict, Literal
|
13 |
+
import PIL
|
14 |
+
import PIL.Image
|
15 |
+
import torch
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
import webdataset as wds
|
18 |
+
import spacy
|
19 |
+
import numpy as np
|
20 |
+
import sng_parser
|
21 |
+
import datasets
|
22 |
+
|
23 |
+
|
24 |
+
def extract_keywords(spacy_nlp, caption):
|
25 |
+
candidates = []
|
26 |
+
nlp_caption = caption
|
27 |
+
|
28 |
+
doc = spacy_nlp(nlp_caption)
|
29 |
+
|
30 |
+
tmp = ''
|
31 |
+
for word in doc:
|
32 |
+
if word.pos_ == 'ADJ':
|
33 |
+
if tmp == '':
|
34 |
+
tmp += word.text
|
35 |
+
else:
|
36 |
+
tmp += ' ' + word.text
|
37 |
+
elif word.pos_ == 'NOUN' or word.pos_ == 'PROPN':
|
38 |
+
if tmp == '':
|
39 |
+
tmp += word.text
|
40 |
+
else:
|
41 |
+
tmp += ' ' + word.text
|
42 |
+
else:
|
43 |
+
if tmp != '':
|
44 |
+
candidates.append(tmp)
|
45 |
+
tmp = ''
|
46 |
+
if tmp != '':
|
47 |
+
candidates.append(tmp)
|
48 |
+
|
49 |
+
candidates = list(set(candidates))
|
50 |
+
|
51 |
+
return candidates
|
52 |
+
|
53 |
+
|
54 |
+
def extract_keywords_spacy(spacy_nlp, caption):
|
55 |
+
sequences = []
|
56 |
+
current_sequence = []
|
57 |
+
doc = spacy_nlp(caption)
|
58 |
+
for token in doc:
|
59 |
+
# Check if the token is a noun, proper noun, or adjective
|
60 |
+
if token.pos_ in ['NOUN', 'PROPN', 'ADJ', 'DET']:
|
61 |
+
current_sequence.append(token.text)
|
62 |
+
else:
|
63 |
+
# If we encounter a token that's not one of the desired POS and current_sequence is not empty
|
64 |
+
if current_sequence:
|
65 |
+
sequences.append(" ".join(current_sequence))
|
66 |
+
current_sequence = []
|
67 |
+
|
68 |
+
# Adding any remaining sequence after the loop
|
69 |
+
if current_sequence:
|
70 |
+
sequences.append(" ".join(current_sequence))
|
71 |
+
|
72 |
+
return sequences
|
73 |
+
|
74 |
+
|
75 |
+
def extract_sng(caption):
|
76 |
+
graph = sng_parser.parse(caption)
|
77 |
+
entities = [x['head'] for i, x in enumerate(graph['entities'])]
|
78 |
+
relations = [{'subject': entities[x['subject']], 'object': entities[x['object']], 'relation': x['relation']} for x in graph['relations']]
|
79 |
+
return entities, relations
|
80 |
+
|
81 |
+
|
82 |
+
def clean_caption(caption, tokenizer):
|
83 |
+
if caption is None:
|
84 |
+
caption = ''
|
85 |
+
if '<PERSON>' in caption: # to handle with GCC12M
|
86 |
+
caption = caption.replace('<PERSON>', 'person')
|
87 |
+
caption = caption.lower().replace('$', '').strip()
|
88 |
+
tokens = tokenizer.encode(caption, padding='longest', return_tensors='pt')
|
89 |
+
if tokens.shape[1] > 77:
|
90 |
+
caption = tokenizer.batch_decode(tokens[:,1:76])[0]
|
91 |
+
return caption
|
92 |
+
|
93 |
+
|
94 |
+
def preprocess_precomputed_base(sample, spacy_nlp, keywords_list, tokenizer):
|
95 |
+
'''
|
96 |
+
'image_feature.npy','json'
|
97 |
+
'''
|
98 |
+
image_feature, image_feature_giga, meta = sample
|
99 |
+
|
100 |
+
caption = clean_caption(meta['source_caption'], tokenizer)
|
101 |
+
|
102 |
+
keywords = ['']
|
103 |
+
try:
|
104 |
+
keywords = extract_keywords_spacy(spacy_nlp, caption)
|
105 |
+
except Exception as e:
|
106 |
+
#print(e)
|
107 |
+
pass
|
108 |
+
|
109 |
+
# for keywords
|
110 |
+
indicator = 1
|
111 |
+
replaced_caption = caption
|
112 |
+
for keyword in keywords:
|
113 |
+
if keyword != '' and keyword in caption:
|
114 |
+
replaced_caption = replaced_caption.replace(keyword, '[$]')
|
115 |
+
else:
|
116 |
+
tmp_keywords = caption.split(' ')
|
117 |
+
if len(tmp_keywords) > 0:
|
118 |
+
selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1))
|
119 |
+
for selected_keyword in selected_keywords:
|
120 |
+
replaced_caption = replaced_caption.replace(selected_keyword, '[$]')
|
121 |
+
else:
|
122 |
+
replaced_caption = f'a photo of [$] that {caption}'
|
123 |
+
indicator = 0
|
124 |
+
break
|
125 |
+
|
126 |
+
token_dict = tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True)
|
127 |
+
tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0]
|
128 |
+
|
129 |
+
replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
|
130 |
+
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
|
131 |
+
|
132 |
+
replaced_tokens = torch.where(replaced_tokens == 49408,
|
133 |
+
torch.ones_like(replaced_tokens) * 259,
|
134 |
+
replaced_tokens)
|
135 |
+
|
136 |
+
if 259 not in replaced_tokens:
|
137 |
+
replaced_caption = 'a photo of [$]'
|
138 |
+
replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
|
139 |
+
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
|
140 |
+
|
141 |
+
replaced_tokens = torch.where(replaced_tokens == 49408,
|
142 |
+
torch.ones_like(replaced_tokens) * 259,
|
143 |
+
replaced_tokens)
|
144 |
+
indicator = 0
|
145 |
+
|
146 |
+
new_sample = [tokens, replaced_tokens, indicator]
|
147 |
+
|
148 |
+
return tuple(new_sample)
|
149 |
+
|
150 |
+
|
151 |
+
class CaptionDataset(Dataset):
|
152 |
+
def __init__(self, captions, tokenizer, spacy_nlp):
|
153 |
+
self.captions = captions
|
154 |
+
self.tokenizer = tokenizer
|
155 |
+
self.spacy_nlp = spacy_nlp
|
156 |
+
|
157 |
+
def __len__(self):
|
158 |
+
return len(self.captions)
|
159 |
+
|
160 |
+
def __getitem__(self, idx):
|
161 |
+
caption = self.captions[idx]
|
162 |
+
|
163 |
+
caption = clean_caption(caption, self.tokenizer)
|
164 |
+
|
165 |
+
keywords = [""]
|
166 |
+
try:
|
167 |
+
keywords = extract_keywords_spacy(self.spacy_nlp, caption)
|
168 |
+
except Exception as e:
|
169 |
+
#print(e)
|
170 |
+
pass
|
171 |
+
|
172 |
+
# for keywords
|
173 |
+
indicator = 1
|
174 |
+
replaced_caption = caption
|
175 |
+
|
176 |
+
if len(keywords) == 0:
|
177 |
+
keywords = [""]
|
178 |
+
|
179 |
+
for keyword in keywords:
|
180 |
+
if keyword != '' and keyword in caption:
|
181 |
+
replaced_caption = replaced_caption.replace(keyword, '[$]')
|
182 |
+
else:
|
183 |
+
tmp_keywords = caption.split(' ')
|
184 |
+
if len(tmp_keywords) > 0:
|
185 |
+
selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1))
|
186 |
+
for selected_keyword in selected_keywords:
|
187 |
+
replaced_caption = replaced_caption.replace(selected_keyword, '[$]')
|
188 |
+
else:
|
189 |
+
replaced_caption = f'a photo of [$] that {caption}'
|
190 |
+
indicator = 0
|
191 |
+
break
|
192 |
+
|
193 |
+
token_dict = self.tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True)
|
194 |
+
tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0]
|
195 |
+
|
196 |
+
replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
|
197 |
+
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
|
198 |
+
|
199 |
+
replaced_tokens = torch.where(replaced_tokens == 49408,
|
200 |
+
torch.ones_like(replaced_tokens) * 259,
|
201 |
+
replaced_tokens)
|
202 |
+
|
203 |
+
if 259 not in replaced_tokens:
|
204 |
+
replaced_caption = 'a photo of [$]'
|
205 |
+
replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
|
206 |
+
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
|
207 |
+
|
208 |
+
replaced_tokens = torch.where(replaced_tokens == 49408,
|
209 |
+
torch.ones_like(replaced_tokens) * 259,
|
210 |
+
replaced_tokens)
|
211 |
+
indicator = 0
|
212 |
+
|
213 |
+
return tokens, replaced_tokens, indicator
|
214 |
+
|
215 |
+
|
216 |
+
def build_loader(args, tokenizer, accelerator):
|
217 |
+
data_names = {'dataset1': 'dangne/gcc_caption_only',
|
218 |
+
'dataset2': 'FredZhang7/stable-diffusion-prompts-2.47M',
|
219 |
+
'dataset3': 'Geonmo/midjourney-prompts-only',
|
220 |
+
}
|
221 |
+
|
222 |
+
for k, v in data_names.items():
|
223 |
+
if not os.path.exists(os.path.join('./datasets', k)):
|
224 |
+
if accelerator.is_main_process:
|
225 |
+
print('Downloading captions is required')
|
226 |
+
db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k))
|
227 |
+
|
228 |
+
captions = []
|
229 |
+
for k, v in data_names.items():
|
230 |
+
db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k))
|
231 |
+
captions += db['train']['text']
|
232 |
+
|
233 |
+
dataset = CaptionDataset(captions, tokenizer, spacy.load('en_core_web_sm'))
|
234 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True)
|
235 |
+
|
236 |
+
return data_loader
|
237 |
+
|
238 |
+
|
239 |
+
class FashionIQDataset(Dataset):
|
240 |
+
"""
|
241 |
+
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
|
242 |
+
FashionIQ dataset class for PyTorch.
|
243 |
+
The dataset can be used in 'relative' or 'classic' mode:
|
244 |
+
- In 'classic' mode the dataset yield :a dict with keys ['image', 'image_name']
|
245 |
+
- In 'relative' mode the dataset yield dict with keys:
|
246 |
+
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions'] when
|
247 |
+
split in ['train', 'val']
|
248 |
+
- ['reference_image', 'reference_name', 'relative_captions'] when split == test
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'], dress_types: List[str],
|
252 |
+
mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False):
|
253 |
+
"""
|
254 |
+
:param dataset_path: path to the FashionIQ dataset
|
255 |
+
:param split: dataset split, should be in ['train, 'val', 'test']
|
256 |
+
:param dress_types: list of fashionIQ categories, each category should be in ['dress', 'shirt', 'toptee']
|
257 |
+
:param mode: dataset mode, should be in ['relative', 'classic']:
|
258 |
+
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
|
259 |
+
- In 'relative' mode the dataset yield dict with keys:
|
260 |
+
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions']
|
261 |
+
when split in ['train', 'val']
|
262 |
+
- ['reference_image', 'reference_name', 'relative_captions'] when split == test
|
263 |
+
:param preprocess: function which preprocesses the image
|
264 |
+
:param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode
|
265 |
+
"""
|
266 |
+
dataset_path = Path(dataset_path)
|
267 |
+
self.dataset_path = dataset_path
|
268 |
+
self.mode = mode
|
269 |
+
self.dress_types = dress_types
|
270 |
+
self.split = split
|
271 |
+
self.no_duplicates = no_duplicates
|
272 |
+
|
273 |
+
# Validate the inputs
|
274 |
+
if mode not in ['relative', 'classic']:
|
275 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
276 |
+
if split not in ['test', 'train', 'val']:
|
277 |
+
raise ValueError("split should be in ['test', 'train', 'val']")
|
278 |
+
for dress_type in dress_types:
|
279 |
+
if dress_type not in ['dress', 'shirt', 'toptee']:
|
280 |
+
raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']")
|
281 |
+
|
282 |
+
self.preprocess = preprocess
|
283 |
+
|
284 |
+
# get triplets made by (reference_image, target_image, a pair of relative captions)
|
285 |
+
self.triplets: List[dict] = []
|
286 |
+
for dress_type in dress_types:
|
287 |
+
with open(dataset_path / 'captions' / f'cap.{dress_type}.{split}.json') as f:
|
288 |
+
self.triplets.extend(json.load(f))
|
289 |
+
|
290 |
+
# Remove duplicats from
|
291 |
+
if self.no_duplicates:
|
292 |
+
seen = set()
|
293 |
+
new_triplets = []
|
294 |
+
for triplet in self.triplets:
|
295 |
+
if triplet['candidate'] not in seen:
|
296 |
+
seen.add(triplet['candidate'])
|
297 |
+
new_triplets.append(triplet)
|
298 |
+
self.triplets = new_triplets
|
299 |
+
|
300 |
+
# get the image names
|
301 |
+
self.image_names: list = []
|
302 |
+
for dress_type in dress_types:
|
303 |
+
with open(dataset_path / 'image_splits' / f'split.{dress_type}.{split}.json') as f:
|
304 |
+
self.image_names.extend(json.load(f))
|
305 |
+
|
306 |
+
print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized")
|
307 |
+
|
308 |
+
def __getitem__(self, index) -> dict:
|
309 |
+
try:
|
310 |
+
if self.mode == 'relative':
|
311 |
+
relative_captions = self.triplets[index]['captions']
|
312 |
+
reference_name = self.triplets[index]['candidate']
|
313 |
+
|
314 |
+
if self.split in ['train', 'val']:
|
315 |
+
reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg"
|
316 |
+
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
|
317 |
+
target_name = self.triplets[index]['target']
|
318 |
+
target_image_path = self.dataset_path / 'images' / f"{target_name}.jpg"
|
319 |
+
target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0]
|
320 |
+
|
321 |
+
return {
|
322 |
+
'reference_image': reference_image,
|
323 |
+
'reference_name': reference_name,
|
324 |
+
'target_image': target_image,
|
325 |
+
'target_name': target_name,
|
326 |
+
'relative_captions': relative_captions
|
327 |
+
}
|
328 |
+
|
329 |
+
elif self.split == 'test':
|
330 |
+
reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg"
|
331 |
+
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
|
332 |
+
|
333 |
+
return {
|
334 |
+
'reference_image': reference_image,
|
335 |
+
'reference_name': reference_name,
|
336 |
+
'relative_captions': relative_captions
|
337 |
+
}
|
338 |
+
|
339 |
+
elif self.mode == 'classic':
|
340 |
+
image_name = self.image_names[index]
|
341 |
+
image_path = self.dataset_path / 'images' / f"{image_name}.jpg"
|
342 |
+
image = self.preprocess(PIL.Image.open(image_path), return_tensors='pt')['pixel_values'][0]
|
343 |
+
|
344 |
+
return {
|
345 |
+
'image': image,
|
346 |
+
'image_name': image_name
|
347 |
+
}
|
348 |
+
|
349 |
+
else:
|
350 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
351 |
+
except Exception as e:
|
352 |
+
print(f"Exception: {e}")
|
353 |
+
|
354 |
+
def __len__(self):
|
355 |
+
if self.mode == 'relative':
|
356 |
+
return len(self.triplets)
|
357 |
+
elif self.mode == 'classic':
|
358 |
+
return len(self.image_names)
|
359 |
+
else:
|
360 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
361 |
+
|
362 |
+
|
363 |
+
class CIRRDataset(Dataset):
|
364 |
+
"""
|
365 |
+
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
|
366 |
+
CIRR dataset class for PyTorch dataloader.
|
367 |
+
The dataset can be used in 'relative' or 'classic' mode:
|
368 |
+
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
|
369 |
+
- In 'relative' mode the dataset yield dict with keys:
|
370 |
+
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption', 'group_members']
|
371 |
+
when split in ['train', 'val']
|
372 |
+
- ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'],
|
376 |
+
mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False):
|
377 |
+
"""
|
378 |
+
:param dataset_path: path to the CIRR dataset
|
379 |
+
:param split: dataset split, should be in ['train', 'val', 'test']
|
380 |
+
:param mode: dataset mode, should be in ['relative', 'classic']:
|
381 |
+
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
|
382 |
+
- In 'relative' mode the dataset yield dict with keys:
|
383 |
+
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption',
|
384 |
+
'group_members'] when split in ['train', 'val']
|
385 |
+
- ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test
|
386 |
+
:param preprocess: function which preprocesses the image
|
387 |
+
:param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode
|
388 |
+
"""
|
389 |
+
dataset_path = Path(dataset_path)
|
390 |
+
self.dataset_path = dataset_path
|
391 |
+
self.preprocess = preprocess
|
392 |
+
self.mode = mode
|
393 |
+
self.split = split
|
394 |
+
self.no_duplicates = no_duplicates
|
395 |
+
|
396 |
+
if split == "test":
|
397 |
+
split = "test1"
|
398 |
+
self.split = "test1"
|
399 |
+
|
400 |
+
# Validate inputs
|
401 |
+
if split not in ['test1', 'train', 'val']:
|
402 |
+
raise ValueError("split should be in ['test1', 'train', 'val']")
|
403 |
+
if mode not in ['relative', 'classic']:
|
404 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
405 |
+
|
406 |
+
# get triplets made by (reference_image, target_image, relative caption)
|
407 |
+
with open(dataset_path / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f:
|
408 |
+
self.triplets = json.load(f)
|
409 |
+
|
410 |
+
# Remove duplicates from triplets
|
411 |
+
if self.no_duplicates:
|
412 |
+
seen = set()
|
413 |
+
new_triplets = []
|
414 |
+
for triplet in self.triplets:
|
415 |
+
if triplet['reference'] not in seen:
|
416 |
+
seen.add(triplet['reference'])
|
417 |
+
new_triplets.append(triplet)
|
418 |
+
self.triplets = new_triplets
|
419 |
+
|
420 |
+
# get a mapping from image name to relative path
|
421 |
+
with open(dataset_path / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f:
|
422 |
+
self.name_to_relpath = json.load(f)
|
423 |
+
|
424 |
+
print(f"CIRR {split} dataset in {mode} mode initialized")
|
425 |
+
|
426 |
+
def __getitem__(self, index) -> dict:
|
427 |
+
try:
|
428 |
+
if self.mode == 'relative':
|
429 |
+
group_members = self.triplets[index]['img_set']['members']
|
430 |
+
reference_name = self.triplets[index]['reference']
|
431 |
+
relative_caption = self.triplets[index]['caption']
|
432 |
+
|
433 |
+
if self.split in ['train', 'val']:
|
434 |
+
reference_image_path = self.dataset_path / self.name_to_relpath[reference_name]
|
435 |
+
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
|
436 |
+
target_hard_name = self.triplets[index]['target_hard']
|
437 |
+
target_image_path = self.dataset_path / self.name_to_relpath[target_hard_name]
|
438 |
+
target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0]
|
439 |
+
|
440 |
+
return {
|
441 |
+
'reference_image': reference_image,
|
442 |
+
'reference_name': reference_name,
|
443 |
+
'target_image': target_image,
|
444 |
+
'target_name': target_hard_name,
|
445 |
+
'relative_caption': relative_caption,
|
446 |
+
'group_members': group_members
|
447 |
+
}
|
448 |
+
|
449 |
+
elif self.split == 'test1':
|
450 |
+
pair_id = self.triplets[index]['pairid']
|
451 |
+
reference_image_path = self.dataset_path / self.name_to_relpath[reference_name]
|
452 |
+
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
|
453 |
+
return {
|
454 |
+
'reference_image': reference_image,
|
455 |
+
'reference_name': reference_name,
|
456 |
+
'relative_caption': relative_caption,
|
457 |
+
'group_members': group_members,
|
458 |
+
'pair_id': pair_id
|
459 |
+
}
|
460 |
+
|
461 |
+
elif self.mode == 'classic':
|
462 |
+
image_name = list(self.name_to_relpath.keys())[index]
|
463 |
+
image_path = self.dataset_path / self.name_to_relpath[image_name]
|
464 |
+
im = PIL.Image.open(image_path)
|
465 |
+
image = self.preprocess(im, return_tensors='pt')['pixel_values'][0]
|
466 |
+
|
467 |
+
return {
|
468 |
+
'image': image,
|
469 |
+
'image_name': image_name
|
470 |
+
}
|
471 |
+
|
472 |
+
else:
|
473 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
474 |
+
|
475 |
+
except Exception as e:
|
476 |
+
print(f"Exception: {e}")
|
477 |
+
|
478 |
+
def __len__(self):
|
479 |
+
if self.mode == 'relative':
|
480 |
+
return len(self.triplets)
|
481 |
+
elif self.mode == 'classic':
|
482 |
+
return len(self.name_to_relpath)
|
483 |
+
else:
|
484 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
485 |
+
|
486 |
+
|
487 |
+
class CIRCODataset(Dataset):
|
488 |
+
"""
|
489 |
+
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
|
490 |
+
CIRCO dataset class for PyTorch.
|
491 |
+
The dataset can be used in 'relative' or 'classic' mode:
|
492 |
+
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
|
493 |
+
- In 'relative' mode the dataset yield dict with keys:
|
494 |
+
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions', 'shared_concept',
|
495 |
+
'gt_img_ids', 'query_id'] when split == 'val'
|
496 |
+
- ['reference_image', 'reference_name', 'relative_captions', 'shared_concept', 'query_id'] when split == test
|
497 |
+
"""
|
498 |
+
|
499 |
+
def __init__(self, dataset_path: Union[str, Path], split: Literal['val', 'test'],
|
500 |
+
mode: Literal['relative', 'classic'], preprocess: callable):
|
501 |
+
"""
|
502 |
+
Args:
|
503 |
+
dataset_path (Union[str, Path]): path to CIRCO dataset
|
504 |
+
split (str): dataset split, should be in ['test', 'val']
|
505 |
+
mode (str): dataset mode, should be in ['relative', 'classic']
|
506 |
+
preprocess (callable): function which preprocesses the image
|
507 |
+
"""
|
508 |
+
|
509 |
+
# Set dataset paths and configurations
|
510 |
+
dataset_path = Path(dataset_path)
|
511 |
+
self.mode = mode
|
512 |
+
self.split = split
|
513 |
+
self.preprocess = preprocess
|
514 |
+
self.data_path = dataset_path
|
515 |
+
|
516 |
+
# Ensure input arguments are valid
|
517 |
+
if mode not in ['relative', 'classic']:
|
518 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
519 |
+
if split not in ['test', 'val']:
|
520 |
+
raise ValueError("split should be in ['test', 'val']")
|
521 |
+
|
522 |
+
# Load COCO images information
|
523 |
+
with open(dataset_path / 'COCO2017_unlabeled' / "annotations" / "image_info_unlabeled2017.json", "r") as f:
|
524 |
+
imgs_info = json.load(f)
|
525 |
+
|
526 |
+
self.img_paths = [dataset_path / 'COCO2017_unlabeled' / "unlabeled2017" / img_info["file_name"] for img_info in
|
527 |
+
imgs_info["images"]]
|
528 |
+
self.img_ids = [img_info["id"] for img_info in imgs_info["images"]]
|
529 |
+
self.img_ids_indexes_map = {str(img_id): i for i, img_id in enumerate(self.img_ids)}
|
530 |
+
|
531 |
+
# get CIRCO annotations
|
532 |
+
with open(dataset_path / 'annotations' / f'{split}.json', "r") as f:
|
533 |
+
self.annotations: List[dict] = json.load(f)
|
534 |
+
|
535 |
+
# Get maximum number of ground truth images (for padding when loading the images)
|
536 |
+
self.max_num_gts = 23 # Maximum number of ground truth images
|
537 |
+
|
538 |
+
print(f"CIRCODataset {split} dataset in {mode} mode initialized")
|
539 |
+
|
540 |
+
def get_target_img_ids(self, index) -> Dict[str, int]:
|
541 |
+
"""
|
542 |
+
Returns the id of the target image and ground truth images for a given query
|
543 |
+
|
544 |
+
Args:
|
545 |
+
index (int): id of the query
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
Dict[str, int]: dictionary containing target image id and a list of ground truth image ids
|
549 |
+
"""
|
550 |
+
|
551 |
+
return {
|
552 |
+
'target_img_id': self.annotations[index]['target_img_id'],
|
553 |
+
'gt_img_ids': self.annotations[index]['gt_img_ids']
|
554 |
+
}
|
555 |
+
|
556 |
+
def __getitem__(self, index) -> dict:
|
557 |
+
"""
|
558 |
+
Returns a specific item from the dataset based on the index.
|
559 |
+
|
560 |
+
In 'classic' mode, the dataset yields a dictionary with the following keys: [img, img_id]
|
561 |
+
In 'relative' mode, the dataset yields dictionaries with the following keys:
|
562 |
+
- [reference_img, reference_img_id, target_img, target_img_id, relative_caption, shared_concept, gt_img_ids,
|
563 |
+
query_id]
|
564 |
+
if split == val
|
565 |
+
- [reference_img, reference_img_id, relative_caption, shared_concept, query_id] if split == test
|
566 |
+
"""
|
567 |
+
|
568 |
+
if self.mode == 'relative':
|
569 |
+
# Get the query id
|
570 |
+
query_id = str(self.annotations[index]['id'])
|
571 |
+
|
572 |
+
# Get relative caption and shared concept
|
573 |
+
relative_caption = self.annotations[index]['relative_caption']
|
574 |
+
shared_concept = self.annotations[index]['shared_concept']
|
575 |
+
|
576 |
+
# Get the reference image
|
577 |
+
reference_img_id = str(self.annotations[index]['reference_img_id'])
|
578 |
+
reference_img_path = self.img_paths[self.img_ids_indexes_map[reference_img_id]]
|
579 |
+
reference_img = self.preprocess(PIL.Image.open(reference_img_path), return_tensors='pt')['pixel_values'][0]
|
580 |
+
|
581 |
+
if self.split == 'val':
|
582 |
+
# Get the target image and ground truth images
|
583 |
+
target_img_id = str(self.annotations[index]['target_img_id'])
|
584 |
+
gt_img_ids = [str(x) for x in self.annotations[index]['gt_img_ids']]
|
585 |
+
target_img_path = self.img_paths[self.img_ids_indexes_map[target_img_id]]
|
586 |
+
target_img = self.preprocess(PIL.Image.open(target_img_path), return_tensors='pt')['pixel_values'][0]
|
587 |
+
|
588 |
+
# Pad ground truth image IDs with zeros for collate_fn
|
589 |
+
gt_img_ids += [''] * (self.max_num_gts - len(gt_img_ids))
|
590 |
+
|
591 |
+
return {
|
592 |
+
'reference_image': reference_img,
|
593 |
+
'reference_name': reference_img_id,
|
594 |
+
'target_image': target_img,
|
595 |
+
'target_name': target_img_id,
|
596 |
+
'relative_caption': relative_caption,
|
597 |
+
'shared_concept': shared_concept,
|
598 |
+
'gt_img_ids': gt_img_ids,
|
599 |
+
'query_id': query_id,
|
600 |
+
}
|
601 |
+
|
602 |
+
elif self.split == 'test':
|
603 |
+
return {
|
604 |
+
'reference_image': reference_img,
|
605 |
+
'reference_name': reference_img_id,
|
606 |
+
'relative_caption': relative_caption,
|
607 |
+
'shared_concept': shared_concept,
|
608 |
+
'query_id': query_id,
|
609 |
+
}
|
610 |
+
|
611 |
+
elif self.mode == 'classic':
|
612 |
+
# Get image ID and image path
|
613 |
+
img_id = str(self.img_ids[index])
|
614 |
+
img_path = self.img_paths[index]
|
615 |
+
|
616 |
+
# Preprocess image and return
|
617 |
+
img = self.preprocess(PIL.Image.open(img_path), return_tensors='pt')['pixel_values'][0]
|
618 |
+
return {
|
619 |
+
'image': img,
|
620 |
+
'image_name': img_id
|
621 |
+
}
|
622 |
+
|
623 |
+
def __len__(self):
|
624 |
+
"""
|
625 |
+
Returns the length of the dataset.
|
626 |
+
"""
|
627 |
+
if self.mode == 'relative':
|
628 |
+
return len(self.annotations)
|
629 |
+
elif self.mode == 'classic':
|
630 |
+
return len(self.img_ids)
|
631 |
+
else:
|
632 |
+
raise ValueError("mode should be in ['relative', 'classic']")
|
models.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def build_text_encoder(args):
|
13 |
+
clip_model_dict = {'base32': 'openai/clip-vit-base-patch32',
|
14 |
+
'base': 'openai/clip-vit-base-patch16',
|
15 |
+
'large': 'openai/clip-vit-large-patch14',
|
16 |
+
'huge': 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
17 |
+
'giga': 'Geonmo/CLIP-Giga-config-fixed',
|
18 |
+
'meta-large': 'facebook/metaclip-l14-fullcc2.5b',
|
19 |
+
'meta-huge': 'facebook/metaclip-h14-fullcc2.5b',
|
20 |
+
}
|
21 |
+
|
22 |
+
clip_preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224},
|
23 |
+
do_center_crop=True,
|
24 |
+
do_convert_rgb=True,
|
25 |
+
do_normalize=True,
|
26 |
+
do_rescale=True,
|
27 |
+
do_resize=True,
|
28 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
29 |
+
image_std=[0.26862954, 0.26130258, 0.27577711],
|
30 |
+
resample=3,
|
31 |
+
size={'shortest_edge': 224},
|
32 |
+
)
|
33 |
+
|
34 |
+
clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir)
|
35 |
+
|
36 |
+
clip_text_model = CLIPTextModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir)
|
37 |
+
|
38 |
+
tokenizer = CLIPTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer_2', cache_dir=args.cache_dir)
|
39 |
+
tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # NOTE: 49408
|
40 |
+
|
41 |
+
return clip_vision_model, clip_preprocess, clip_text_model, tokenizer
|
42 |
+
|
43 |
+
|
44 |
+
class Phi(nn.Module):
|
45 |
+
"""
|
46 |
+
Textual Inversion Phi network.
|
47 |
+
Takes as input the visual features of an image and outputs the pseudo-work embedding.
|
48 |
+
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/phi.py
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: int):
|
52 |
+
super().__init__()
|
53 |
+
self.layers = nn.Sequential(
|
54 |
+
nn.Linear(input_dim, hidden_dim),
|
55 |
+
nn.GELU(),
|
56 |
+
nn.Dropout(p=dropout),
|
57 |
+
nn.Linear(hidden_dim, hidden_dim),
|
58 |
+
nn.GELU(),
|
59 |
+
nn.Dropout(p=dropout),
|
60 |
+
nn.Linear(hidden_dim, output_dim),
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
#x = F.normalize(x, dim=-1)
|
65 |
+
return self.layers(x)
|
66 |
+
|
67 |
+
|
68 |
+
class EMAModel:
|
69 |
+
"""
|
70 |
+
Exponential Moving Average of models weights
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, parameters, decay=0.9999):
|
74 |
+
parameters = list(parameters)
|
75 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
76 |
+
|
77 |
+
self.collected_params = None
|
78 |
+
|
79 |
+
self.decay = decay
|
80 |
+
self.optimization_step = 0
|
81 |
+
|
82 |
+
@torch.no_grad()
|
83 |
+
def step(self, parameters):
|
84 |
+
parameters = list(parameters)
|
85 |
+
|
86 |
+
self.optimization_step += 1
|
87 |
+
|
88 |
+
# Compute the decay factor for the exponential moving average.
|
89 |
+
value = (1 + self.optimization_step) / (10 + self.optimization_step)
|
90 |
+
one_minus_decay = 1 - min(self.decay, value)
|
91 |
+
|
92 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
93 |
+
if param.requires_grad:
|
94 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
95 |
+
else:
|
96 |
+
s_param.copy_(param)
|
97 |
+
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
|
100 |
+
def copy_to(self, parameters) -> None:
|
101 |
+
"""
|
102 |
+
Copy current averaged parameters into given collection of parameters.
|
103 |
+
Args:
|
104 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
105 |
+
updated with the stored moving averages. If `None`, the
|
106 |
+
parameters with which this `ExponentialMovingAverage` was
|
107 |
+
initialized will be used.
|
108 |
+
"""
|
109 |
+
parameters = list(parameters)
|
110 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
111 |
+
param.data.copy_(s_param.data)
|
112 |
+
|
113 |
+
def to(self, device=None, dtype=None) -> None:
|
114 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
115 |
+
Args:
|
116 |
+
device: like `device` argument to `torch.Tensor.to`
|
117 |
+
"""
|
118 |
+
# .to() on the tensors handles None correctly
|
119 |
+
self.shadow_params = [
|
120 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
121 |
+
for p in self.shadow_params
|
122 |
+
]
|
123 |
+
|
124 |
+
def state_dict(self) -> dict:
|
125 |
+
r"""
|
126 |
+
Returns the state of the ExponentialMovingAverage as a dict.
|
127 |
+
This method is used by accelerate during checkpointing to save the ema state dict.
|
128 |
+
"""
|
129 |
+
# Following PyTorch conventions, references to tensors are returned:
|
130 |
+
# "returns a reference to the state and not its copy!" -
|
131 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
132 |
+
return {
|
133 |
+
"decay": self.decay,
|
134 |
+
"optimization_step": self.optimization_step,
|
135 |
+
"shadow_params": self.shadow_params,
|
136 |
+
"collected_params": self.collected_params,
|
137 |
+
}
|
138 |
+
|
139 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
140 |
+
r"""
|
141 |
+
Loads the ExponentialMovingAverage state.
|
142 |
+
This method is used by accelerate during checkpointing to save the ema state dict.
|
143 |
+
Args:
|
144 |
+
state_dict (dict): EMA state. Should be an object returned
|
145 |
+
from a call to :meth:`state_dict`.
|
146 |
+
"""
|
147 |
+
# deepcopy, to be consistent with module API
|
148 |
+
state_dict = copy.deepcopy(state_dict)
|
149 |
+
|
150 |
+
self.decay = state_dict["decay"]
|
151 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
152 |
+
raise ValueError("Decay must be between 0 and 1")
|
153 |
+
|
154 |
+
self.optimization_step = state_dict["optimization_step"]
|
155 |
+
if not isinstance(self.optimization_step, int):
|
156 |
+
raise ValueError("Invalid optimization_step")
|
157 |
+
|
158 |
+
self.shadow_params = state_dict["shadow_params"]
|
159 |
+
if not isinstance(self.shadow_params, list):
|
160 |
+
raise ValueError("shadow_params must be a list")
|
161 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
162 |
+
raise ValueError("shadow_params must all be Tensors")
|
163 |
+
|
164 |
+
self.collected_params = state_dict["collected_params"]
|
165 |
+
if self.collected_params is not None:
|
166 |
+
if not isinstance(self.collected_params, list):
|
167 |
+
raise ValueError("collected_params must be a list")
|
168 |
+
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
|
169 |
+
raise ValueError("collected_params must all be Tensors")
|
170 |
+
if len(self.collected_params) != len(self.shadow_params):
|
171 |
+
raise ValueError("collected_params and shadow_params must have the same length")
|
172 |
+
|
173 |
+
|
174 |
+
class PIC2WORD(nn.Module):
|
175 |
+
def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1):
|
176 |
+
super().__init__()
|
177 |
+
self.fc_out = nn.Linear(middle_dim, output_dim)
|
178 |
+
layers = []
|
179 |
+
dim = embed_dim
|
180 |
+
for _ in range(n_layer):
|
181 |
+
block = []
|
182 |
+
block.append(nn.Linear(dim, middle_dim))
|
183 |
+
block.append(nn.Dropout(dropout))
|
184 |
+
block.append(nn.ReLU())
|
185 |
+
dim = middle_dim
|
186 |
+
layers.append(nn.Sequential(*block))
|
187 |
+
self.layers = nn.Sequential(*layers)
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
for layer in self.layers:
|
191 |
+
x = layer(x)
|
192 |
+
return self.fc_out(x)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
diffusers
|
5 |
+
accelerate
|
6 |
+
datasets
|
7 |
+
spacy
|
8 |
+
clip-retrieval
|
train_phi.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LinCIR
|
3 |
+
Copyright (c) 2023-present NAVER Corp.
|
4 |
+
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
5 |
+
'''
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import random
|
10 |
+
import math
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Literal, Tuple, Dict, List, Set
|
14 |
+
import logging
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from loader import build_loader, CIRRDataset
|
22 |
+
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
23 |
+
from models import build_text_encoder, Phi, EMAModel
|
24 |
+
from utils import extract_image_features, extract_pseudo_tokens_with_phi
|
25 |
+
from validate import cirr_compute_val_metrics
|
26 |
+
|
27 |
+
import transformers
|
28 |
+
from transformers import get_scheduler
|
29 |
+
from accelerate import Accelerator, DeepSpeedPlugin
|
30 |
+
from accelerate.logging import get_logger
|
31 |
+
from accelerate.utils import set_seed
|
32 |
+
from accelerate.state import AcceleratorState
|
33 |
+
from accelerate.logging import get_logger
|
34 |
+
|
35 |
+
|
36 |
+
logger = get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
def parse_args():
|
40 |
+
parser = ArgumentParser()
|
41 |
+
|
42 |
+
parser.add_argument("--output_dir", default="trained_models", type=str,
|
43 |
+
help="The output directory where the model predictions and checkpoints will be written")
|
44 |
+
parser.add_argument("--logging_dir", default="logs", type=str, help="tensorboard logs will saved here")
|
45 |
+
parser.add_argument("--cache_dir", default="./hf_models", type=str,
|
46 |
+
help="Path to model cache folder")
|
47 |
+
parser.add_argument("--report_to", default="tensorboard", type=str, help="")
|
48 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
49 |
+
|
50 |
+
parser.add_argument("--clip_model_name", default="giga", type=str,
|
51 |
+
help="CLIP model to use, e.g 'large', 'giga'")
|
52 |
+
parser.add_argument("--cirr_dataset_path", type=str, help="Path to CIRR dataset", required=True)
|
53 |
+
parser.add_argument("--keywords_path", type=str, help="Path to keywords json file")
|
54 |
+
parser.add_argument("--resume", default=None, type=str, help="Path to pretrained ckpt")
|
55 |
+
|
56 |
+
parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
57 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
58 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
59 |
+
help="")
|
60 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
|
61 |
+
parser.add_argument("--max_train_steps", type=int, default=50000, help="Total number of training steps to perform")
|
62 |
+
parser.add_argument("--phi_dropout", default=0.5, type=float, help="Dropout probability for the phi network")
|
63 |
+
parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
|
64 |
+
parser.add_argument("--batch_size", default=256, type=int, help="Phi training batch size")
|
65 |
+
parser.add_argument("--num_workers", default=10, type=int, help="Number of workers")
|
66 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float, help="Learning rate")
|
67 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
|
68 |
+
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Number of updates steps to accumulate before performing a backward/update pass")
|
69 |
+
parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.")
|
70 |
+
parser.add_argument("--mixed_precision", default=None, type=str, choices=["no", "fp16", "bf16"], help="mixed precision")
|
71 |
+
parser.add_argument("--validation_steps", default=1, type=int, help="Validation frequency expressed in epochs")
|
72 |
+
parser.add_argument("--checkpointing_steps", default=None, type=int, help="Save a checkpoint of the training state every X updates")
|
73 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
74 |
+
|
75 |
+
parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility")
|
76 |
+
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
80 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
81 |
+
args.local_rank = env_local_rank
|
82 |
+
|
83 |
+
return args
|
84 |
+
|
85 |
+
|
86 |
+
def save_phi(name: str, cur_epoch: int, model_to_save: Phi, training_path: Path) -> None:
|
87 |
+
"""
|
88 |
+
Save the weights of Phi during training
|
89 |
+
"""
|
90 |
+
models_path = os.path.join(training_path, "checkpoints")
|
91 |
+
os.makedirs(models_path, exist_ok=True)
|
92 |
+
model_name = model_to_save.__class__.__name__
|
93 |
+
torch.save({
|
94 |
+
'epoch': cur_epoch,
|
95 |
+
model_name: model_to_save.state_dict(),
|
96 |
+
}, os.path.join(models_path, f'{name}.pt'))
|
97 |
+
|
98 |
+
|
99 |
+
def train_phi(args):
|
100 |
+
# We are going to use the pre-extracted clip image features. so we do not need image_encoder anymore.
|
101 |
+
|
102 |
+
### init accelerator here
|
103 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
104 |
+
accelerator = Accelerator(
|
105 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
106 |
+
mixed_precision=args.mixed_precision,
|
107 |
+
log_with=args.report_to,
|
108 |
+
project_dir=logging_dir,
|
109 |
+
)
|
110 |
+
|
111 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
112 |
+
|
113 |
+
logging.basicConfig(
|
114 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
115 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
116 |
+
level=logging.INFO,
|
117 |
+
)
|
118 |
+
logger.info(accelerator.state, main_process_only=False)
|
119 |
+
|
120 |
+
if accelerator.is_local_main_process:
|
121 |
+
transformers.utils.logging.set_verbosity_info()
|
122 |
+
else:
|
123 |
+
transformers.utils.logging.set_verbosity_error()
|
124 |
+
|
125 |
+
if args.seed is not None:
|
126 |
+
set_seed(args.seed)
|
127 |
+
|
128 |
+
### Define the text encoder from clip
|
129 |
+
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
|
130 |
+
|
131 |
+
### Define the phi model
|
132 |
+
phi = Phi(input_dim=text_encoder.config.projection_dim,
|
133 |
+
hidden_dim=text_encoder.config.projection_dim * 4,
|
134 |
+
output_dim=text_encoder.config.hidden_size, dropout=args.phi_dropout)
|
135 |
+
|
136 |
+
if args.resume:
|
137 |
+
phi.load_state_dict(
|
138 |
+
torch.load(args.resume, map_location=accelerator.device)[
|
139 |
+
phi.__class__.__name__])
|
140 |
+
|
141 |
+
|
142 |
+
### GPU handling
|
143 |
+
weight_dtype = torch.float32
|
144 |
+
if accelerator.mixed_precision == "fp16":
|
145 |
+
weight_dtype = torch.float16
|
146 |
+
elif accelerator.mixed_precision == "bf16":
|
147 |
+
weight_dtype = torch.bfloat16
|
148 |
+
|
149 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
150 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
151 |
+
|
152 |
+
image_encoder.requires_grad_(False)
|
153 |
+
text_encoder.requires_grad_(False)
|
154 |
+
|
155 |
+
if args.use_ema:
|
156 |
+
import copy
|
157 |
+
ema_phi = copy.deepcopy(phi)
|
158 |
+
ema_phi = EMAModel(ema_phi.parameters())
|
159 |
+
ema_phi.to(accelerator.device, dtype=weight_dtype)
|
160 |
+
|
161 |
+
### Define the train datasets
|
162 |
+
print('pytorch loader')
|
163 |
+
train_dataset = build_loader(args, tokenizer, accelerator)
|
164 |
+
|
165 |
+
## evaluator
|
166 |
+
if accelerator.is_main_process:
|
167 |
+
## Define CIRR validation set
|
168 |
+
cirr_relative_val_dataset = CIRRDataset(args.cirr_dataset_path, 'val', 'relative', clip_preprocess)
|
169 |
+
cirr_classic_val_dataset = CIRRDataset(args.cirr_dataset_path, 'val', 'classic', clip_preprocess)
|
170 |
+
|
171 |
+
# Extract the features for the CIRR validation set
|
172 |
+
cirr_val_index_features, cirr_val_index_names = extract_image_features(cirr_classic_val_dataset, image_encoder)
|
173 |
+
|
174 |
+
# Define the optimizer, the loss and the grad scaler
|
175 |
+
if args.use_8bit_adam:
|
176 |
+
try:
|
177 |
+
import bitsandbytes as bnb
|
178 |
+
except ImportError:
|
179 |
+
raise ImportError(
|
180 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
181 |
+
)
|
182 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
183 |
+
else:
|
184 |
+
optimizer_cls = torch.optim.AdamW
|
185 |
+
|
186 |
+
optimizer = optimizer_cls(phi.parameters(),
|
187 |
+
lr=args.learning_rate,
|
188 |
+
weight_decay=args.weight_decay)
|
189 |
+
|
190 |
+
lr_scheduler = get_scheduler(
|
191 |
+
args.lr_scheduler,
|
192 |
+
optimizer=optimizer,
|
193 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
|
194 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps * accelerator.num_processes,
|
195 |
+
)
|
196 |
+
|
197 |
+
phi, optimizer, lr_scheduler, train_dataset = accelerator.prepare(
|
198 |
+
phi, optimizer, lr_scheduler, train_dataset
|
199 |
+
)
|
200 |
+
|
201 |
+
if accelerator.is_main_process:
|
202 |
+
accelerator.init_trackers("zeroshot-cir", config=vars(args))
|
203 |
+
|
204 |
+
# Start with the training loop
|
205 |
+
total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
206 |
+
|
207 |
+
logger.info("***** Running training *****")
|
208 |
+
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
209 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
210 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
211 |
+
logger.info(f" Total steps = {args.max_train_steps}")
|
212 |
+
|
213 |
+
phi.train()
|
214 |
+
|
215 |
+
train_loss = 0.0
|
216 |
+
global_step = 0
|
217 |
+
best_recall = -1
|
218 |
+
|
219 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
220 |
+
progress_bar.set_description("Steps")
|
221 |
+
|
222 |
+
while True:
|
223 |
+
for idx, (original_tokens, replaced_tokens, indicators) in enumerate(train_dataset):
|
224 |
+
original_tokens = original_tokens.to(accelerator.device)
|
225 |
+
replaced_tokens = replaced_tokens.to(accelerator.device)
|
226 |
+
|
227 |
+
org = text_encoder(input_ids=original_tokens)
|
228 |
+
original_text_embeddings, original_last_hidden_states = org.text_embeds, org.last_hidden_state
|
229 |
+
input_features = original_text_embeddings.clone()
|
230 |
+
input_features += 1.0 * torch.rand(input_features.shape[0], device=input_features.device).unsqueeze(-1) * torch.randn(input_features.shape, device=input_features.device)
|
231 |
+
|
232 |
+
# normalize test
|
233 |
+
if args.l2_normalize:
|
234 |
+
input_features = F.normalize(input_features, dim=-1)
|
235 |
+
#################
|
236 |
+
|
237 |
+
estimated_token_embeddings = phi(input_features)
|
238 |
+
|
239 |
+
replaced_text_embeddings, replaced_last_hidden_states = encode_with_pseudo_tokens_HF(text_encoder, replaced_tokens, estimated_token_embeddings, return_last_states=True)
|
240 |
+
|
241 |
+
loss = F.mse_loss(replaced_text_embeddings.float(), original_text_embeddings.float(), reduction="mean")
|
242 |
+
|
243 |
+
avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
|
244 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
245 |
+
|
246 |
+
# Backpropagation
|
247 |
+
accelerator.backward(loss)
|
248 |
+
if accelerator.sync_gradients and args.max_grad_norm is not None:
|
249 |
+
accelerator.clip_grad_norm_(phi.parameters(), arg.max_grad_norm)
|
250 |
+
optimizer.step()
|
251 |
+
lr_scheduler.step()
|
252 |
+
optimizer.zero_grad()
|
253 |
+
|
254 |
+
if accelerator.sync_gradients:
|
255 |
+
if args.use_ema:
|
256 |
+
ema_phi.step(phi.module.parameters())
|
257 |
+
progress_bar.update(1)
|
258 |
+
global_step += 1
|
259 |
+
accelerator.log({"train/train_loss": train_loss}, step=global_step)
|
260 |
+
train_loss = 0.0
|
261 |
+
|
262 |
+
accelerator.log({'train/lr': lr_scheduler.get_last_lr()[0]}, step=global_step)
|
263 |
+
accelerator.log({'train/preproc_rate': torch.sum(indicators).item() / len(indicators)}, step=global_step)
|
264 |
+
if args.checkpointing_steps and global_step % args.checkpointing_steps == 0:
|
265 |
+
if accelerator.is_main_process:
|
266 |
+
logger.info(f"model saving... step: {global_step}")
|
267 |
+
save_phi(f"phi_{global_step:09}", global_step, accelerator.unwrap_model(phi), args.output_dir)
|
268 |
+
save_phi(f"phi_latest", global_step, accelerator.unwrap_model(phi), args.output_dir)
|
269 |
+
if args.use_ema:
|
270 |
+
phi_for_saving = copy.deepcopy(accelerator.unwrap_model(phi))
|
271 |
+
ema_phi.copy_to(phi_for_saving.parameters())
|
272 |
+
save_phi(f"ema_phi_{global_step:09}", global_step, phi_for_saving, args.output_dir)
|
273 |
+
save_phi(f"ema_phi_latest", global_step, phi_for_saving, args.output_dir)
|
274 |
+
|
275 |
+
if global_step % args.validation_steps == 0 or global_step == 50:
|
276 |
+
if accelerator.is_main_process:
|
277 |
+
logger.info(f"evaluate model... step: {global_step}")
|
278 |
+
|
279 |
+
if args.use_ema:
|
280 |
+
phi_for_eval = copy.deepcopy(accelerator.unwrap_model(phi))
|
281 |
+
ema_phi.copy_to(phi_for_eval.parameters())
|
282 |
+
else:
|
283 |
+
phi_for_eval = phi
|
284 |
+
|
285 |
+
phi_for_eval.eval()
|
286 |
+
|
287 |
+
# Extract the pseudo tokens for the CIRR validation set using Phi
|
288 |
+
cirr_val_pseudo_tokens, cirr_val_ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi_for_eval,
|
289 |
+
cirr_relative_val_dataset, args)
|
290 |
+
cirr_val_pseudo_tokens = cirr_val_pseudo_tokens.to(accelerator.device)
|
291 |
+
|
292 |
+
# Compute the CIRR validation metrics
|
293 |
+
cirr_results_dict = cirr_compute_val_metrics(cirr_relative_val_dataset, text_encoder,
|
294 |
+
cirr_val_index_features, cirr_val_index_names,
|
295 |
+
cirr_val_ref_names_list, cirr_val_pseudo_tokens)
|
296 |
+
check_list = ['cirr_recall_at1', 'cirr_recall_at5', 'cirr_recall_at10', 'cirr_recall_at50']
|
297 |
+
for check_key in check_list:
|
298 |
+
accelerator.log({f"validate/{check_key}": cirr_results_dict[check_key]}, step=global_step)
|
299 |
+
print(json.dumps(cirr_results_dict, indent=4))
|
300 |
+
|
301 |
+
# Save the best model.
|
302 |
+
if args.checkpointing_steps:
|
303 |
+
if cirr_results_dict['cirr_recall_at1'] > best_recall:
|
304 |
+
best_recall = cirr_results_dict['cirr_recall_at1']
|
305 |
+
logger.info(f"best model saving... step: {global_step}")
|
306 |
+
save_phi("phi_best", global_step, accelerator.unwrap_model(phi), args.output_dir)
|
307 |
+
|
308 |
+
phi.train()
|
309 |
+
|
310 |
+
if global_step >= args.max_train_steps:
|
311 |
+
break
|
312 |
+
|
313 |
+
|
314 |
+
if __name__ == '__main__':
|
315 |
+
args = parse_args()
|
316 |
+
|
317 |
+
train_phi(args)
|
utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from clip.model import CLIP
|
6 |
+
from transformers import CLIPVisionModelWithProjection
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from data_utils import collate_fn
|
12 |
+
from models import Phi
|
13 |
+
|
14 |
+
|
15 |
+
if torch.cuda.is_available():
|
16 |
+
device = torch.device("cuda")
|
17 |
+
dtype = torch.float16
|
18 |
+
else:
|
19 |
+
device = torch.device("cpu")
|
20 |
+
dtype = torch.float32
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32,
|
25 |
+
num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]:
|
26 |
+
"""
|
27 |
+
Extracts image features from a dataset using a CLIP model.
|
28 |
+
"""
|
29 |
+
# Create data loader
|
30 |
+
loader = DataLoader(dataset=dataset, batch_size=batch_size,
|
31 |
+
num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
|
32 |
+
|
33 |
+
index_features = []
|
34 |
+
index_names = []
|
35 |
+
try:
|
36 |
+
print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}")
|
37 |
+
except Exception as e:
|
38 |
+
pass
|
39 |
+
|
40 |
+
# Extract features
|
41 |
+
for batch in tqdm(loader):
|
42 |
+
images = batch.get('image')
|
43 |
+
names = batch.get('image_name')
|
44 |
+
if images is None:
|
45 |
+
images = batch.get('reference_image')
|
46 |
+
if names is None:
|
47 |
+
names = batch.get('reference_name')
|
48 |
+
|
49 |
+
images = images.to(clip_model.device)
|
50 |
+
with torch.no_grad():
|
51 |
+
batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images)
|
52 |
+
index_features.append(batch_features.cpu())
|
53 |
+
index_names.extend(names)
|
54 |
+
|
55 |
+
index_features = torch.vstack(index_features)
|
56 |
+
return index_features, index_names
|
57 |
+
|
58 |
+
|
59 |
+
def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor:
|
60 |
+
# Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
|
61 |
+
v1 = F.normalize(v1, dim=1)
|
62 |
+
v2 = F.normalize(v2, dim=1)
|
63 |
+
|
64 |
+
numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature)
|
65 |
+
numerator = torch.cat((numerator, numerator), 0)
|
66 |
+
joint_vector = torch.cat((v1, v2), 0)
|
67 |
+
pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
|
68 |
+
denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0)
|
69 |
+
|
70 |
+
loss = -torch.mean(torch.log(numerator / denominator))
|
71 |
+
|
72 |
+
return loss
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]:
|
77 |
+
"""
|
78 |
+
Extracts pseudo tokens from a dataset using a CLIP model and a phi model
|
79 |
+
"""
|
80 |
+
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
|
81 |
+
collate_fn=collate_fn)
|
82 |
+
predicted_tokens = []
|
83 |
+
names_list = []
|
84 |
+
print(f"Extracting tokens using phi model")
|
85 |
+
for batch in tqdm(data_loader):
|
86 |
+
images = batch.get('image')
|
87 |
+
names = batch.get('image_name')
|
88 |
+
if images is None:
|
89 |
+
images = batch.get('reference_image')
|
90 |
+
if names is None:
|
91 |
+
names = batch.get('reference_name')
|
92 |
+
|
93 |
+
images = images.to(device)
|
94 |
+
image_features = clip_model(pixel_values=images.half()).image_embeds
|
95 |
+
if args.l2_normalize:
|
96 |
+
image_features = F.normalize(image_features, dim=-1)
|
97 |
+
batch_predicted_tokens = phi(image_features)
|
98 |
+
predicted_tokens.append(batch_predicted_tokens.cpu())
|
99 |
+
names_list.extend(names)
|
100 |
+
|
101 |
+
predicted_tokens = torch.vstack(predicted_tokens)
|
102 |
+
return predicted_tokens, names_list
|
103 |
+
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]:
|
107 |
+
"""
|
108 |
+
Extracts image features from a dataset using a CLIP model
|
109 |
+
"""
|
110 |
+
data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
|
111 |
+
collate_fn=collate_fn)
|
112 |
+
predicted_tokens = []
|
113 |
+
names_list = []
|
114 |
+
print(f"Extracting tokens using phi model")
|
115 |
+
for batch in tqdm(data_loader):
|
116 |
+
images = batch.get('image')
|
117 |
+
names = batch.get('image_name')
|
118 |
+
if images is None:
|
119 |
+
images = batch.get('reference_image')
|
120 |
+
if names is None:
|
121 |
+
names = batch.get('reference_name')
|
122 |
+
|
123 |
+
images = images.to(device)
|
124 |
+
image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds
|
125 |
+
|
126 |
+
#batch_predicted_tokens = phi(image_features)
|
127 |
+
batch_predicted_tokens = image_features
|
128 |
+
predicted_tokens.append(batch_predicted_tokens.cpu())
|
129 |
+
names_list.extend(names)
|
130 |
+
|
131 |
+
predicted_tokens = torch.vstack(predicted_tokens)
|
132 |
+
return predicted_tokens, names_list
|
133 |
+
|
134 |
+
class CustomTensorDataset(Dataset):
|
135 |
+
"""
|
136 |
+
Custom Tensor Dataset which yields image_features and image_names
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, images: torch.Tensor, names: torch.Tensor):
|
140 |
+
self.images = images
|
141 |
+
self.names = names
|
142 |
+
|
143 |
+
def __getitem__(self, index) -> dict:
|
144 |
+
return {'image': self.images[index],
|
145 |
+
'image_name': self.names[index]
|
146 |
+
}
|
147 |
+
|
148 |
+
def __len__(self):
|
149 |
+
return len(self.images)
|
150 |
+
|
151 |
+
|
152 |
+
def get_templates():
|
153 |
+
"""
|
154 |
+
Return a list of templates
|
155 |
+
Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694
|
156 |
+
"""
|
157 |
+
return [
|
158 |
+
"This is a photo of a {}",
|
159 |
+
"This photo contains a {}",
|
160 |
+
"A photo of a {}",
|
161 |
+
"This is an illustration of a {}",
|
162 |
+
"This illustration contains a {}",
|
163 |
+
"An illustrations of a {}",
|
164 |
+
"This is a sketch of a {}",
|
165 |
+
"This sketch contains a {}",
|
166 |
+
"A sketch of a {}",
|
167 |
+
"This is a diagram of a {}",
|
168 |
+
"This diagram contains a {}",
|
169 |
+
"A diagram of a {}",
|
170 |
+
"A {}",
|
171 |
+
"We see a {}",
|
172 |
+
"{}",
|
173 |
+
"We see a {} in this photo",
|
174 |
+
"We see a {} in this image",
|
175 |
+
"We see a {} in this illustration",
|
176 |
+
"We see a {} photo",
|
177 |
+
"We see a {} image",
|
178 |
+
"We see a {} illustration",
|
179 |
+
"{} photo",
|
180 |
+
"{} image",
|
181 |
+
"{} illustration",
|
182 |
+
]
|
validate.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from typing import List, Dict, Tuple
|
5 |
+
|
6 |
+
import clip
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from clip.model import CLIP
|
11 |
+
from transformers import CLIPTextModelWithProjection
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from data_utils import collate_fn, PROJECT_ROOT, targetpad_transform
|
17 |
+
from loader import FashionIQDataset, CIRRDataset, CIRCODataset
|
18 |
+
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
19 |
+
from models import build_text_encoder, Phi, PIC2WORD
|
20 |
+
from utils import extract_image_features, device, extract_pseudo_tokens_with_phi
|
21 |
+
|
22 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
23 |
+
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def fiq_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
|
27 |
+
pseudo_tokens: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
|
28 |
+
"""
|
29 |
+
Generates features predictions for the validation set of Fashion IQ.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# Create data loader
|
33 |
+
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
|
34 |
+
pin_memory=False, collate_fn=collate_fn, shuffle=False)
|
35 |
+
|
36 |
+
predicted_features_list = []
|
37 |
+
target_names_list = []
|
38 |
+
|
39 |
+
# Compute features
|
40 |
+
for batch in tqdm(relative_val_loader):
|
41 |
+
reference_names = batch['reference_name']
|
42 |
+
target_names = batch['target_name']
|
43 |
+
relative_captions = batch['relative_captions']
|
44 |
+
|
45 |
+
flattened_captions: list = np.array(relative_captions).T.flatten().tolist()
|
46 |
+
input_captions = [
|
47 |
+
f"{flattened_captions[i].strip('.?, ')} and {flattened_captions[i + 1].strip('.?, ')}" for
|
48 |
+
i in range(0, len(flattened_captions), 2)]
|
49 |
+
input_captions_reversed = [
|
50 |
+
f"{flattened_captions[i + 1].strip('.?, ')} and {flattened_captions[i].strip('.?, ')}" for
|
51 |
+
i in range(0, len(flattened_captions), 2)]
|
52 |
+
|
53 |
+
input_captions = [
|
54 |
+
f"a photo of $ that {in_cap}" for in_cap in input_captions]
|
55 |
+
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
56 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
57 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
58 |
+
|
59 |
+
input_captions_reversed = [
|
60 |
+
f"a photo of $ that {in_cap}" for in_cap in input_captions_reversed]
|
61 |
+
tokenized_input_captions_reversed = clip.tokenize(input_captions_reversed, context_length=77).to(device)
|
62 |
+
text_features_reversed = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions_reversed,
|
63 |
+
batch_tokens)
|
64 |
+
|
65 |
+
predicted_features = F.normalize((F.normalize(text_features) + F.normalize(text_features_reversed)) / 2)
|
66 |
+
# predicted_features = F.normalize((text_features + text_features_reversed) / 2)
|
67 |
+
|
68 |
+
predicted_features_list.append(predicted_features)
|
69 |
+
target_names_list.extend(target_names)
|
70 |
+
|
71 |
+
predicted_features = torch.vstack(predicted_features_list)
|
72 |
+
return predicted_features, target_names_list
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def fiq_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
|
77 |
+
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
|
78 |
+
-> Dict[str, float]:
|
79 |
+
"""
|
80 |
+
Compute the retrieval metrics on the FashionIQ validation set given the dataset, pseudo tokens and the reference names
|
81 |
+
"""
|
82 |
+
|
83 |
+
# Generate the predicted features
|
84 |
+
predicted_features, target_names = fiq_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list,
|
85 |
+
pseudo_tokens)
|
86 |
+
|
87 |
+
# Move the features to the device
|
88 |
+
index_features = index_features.to(device)
|
89 |
+
predicted_features = predicted_features.to(device)
|
90 |
+
|
91 |
+
# Normalize the features
|
92 |
+
index_features = F.normalize(index_features.float())
|
93 |
+
|
94 |
+
# Compute the distances
|
95 |
+
distances = 1 - predicted_features @ index_features.T
|
96 |
+
sorted_indices = torch.argsort(distances, dim=-1).cpu()
|
97 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
98 |
+
|
99 |
+
# Check if the target names are in the top 10 and top 50
|
100 |
+
labels = torch.tensor(
|
101 |
+
sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
|
102 |
+
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
|
103 |
+
|
104 |
+
# Compute the metrics
|
105 |
+
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
|
106 |
+
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
|
107 |
+
|
108 |
+
return {'fiq_recall_at10': recall_at10,
|
109 |
+
'fiq_recall_at50': recall_at50}
|
110 |
+
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def fiq_val_retrieval(dataset_path: str, dress_type: str, image_encoder, text_encoder, ref_names_list: List[str],
|
114 |
+
pseudo_tokens: torch.Tensor, preprocess: callable) -> Dict[str, float]:
|
115 |
+
"""
|
116 |
+
Compute the retrieval metrics on the FashionIQ validation set given the pseudo tokens and the reference names
|
117 |
+
"""
|
118 |
+
# Load the model
|
119 |
+
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
|
120 |
+
#clip_model = clip_model.float().eval().requires_grad_(False)
|
121 |
+
|
122 |
+
# Extract the index features
|
123 |
+
classic_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'classic', preprocess)
|
124 |
+
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
|
125 |
+
|
126 |
+
# Define the relative dataset
|
127 |
+
relative_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'relative', preprocess)
|
128 |
+
|
129 |
+
return fiq_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
|
130 |
+
pseudo_tokens)
|
131 |
+
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def cirr_generate_val_predictions(clip_model: CLIPTextModelWithProjection, relative_val_dataset: Dataset, ref_names_list: List[str],
|
135 |
+
pseudo_tokens: torch.Tensor) -> \
|
136 |
+
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
|
137 |
+
"""
|
138 |
+
Generates features predictions for the validation set of CIRR
|
139 |
+
"""
|
140 |
+
|
141 |
+
# Define the dataloader
|
142 |
+
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
|
143 |
+
pin_memory=False, collate_fn=collate_fn)
|
144 |
+
predicted_features_list = []
|
145 |
+
target_names_list = []
|
146 |
+
group_members_list = []
|
147 |
+
reference_names_list = []
|
148 |
+
|
149 |
+
for batch in tqdm(relative_val_loader):
|
150 |
+
reference_names = batch['reference_name']
|
151 |
+
target_names = batch['target_name']
|
152 |
+
relative_captions = batch['relative_caption']
|
153 |
+
group_members = batch['group_members']
|
154 |
+
|
155 |
+
group_members = np.array(group_members).T.tolist()
|
156 |
+
|
157 |
+
input_captions = [
|
158 |
+
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
|
159 |
+
|
160 |
+
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
161 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
162 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
163 |
+
|
164 |
+
predicted_features = F.normalize(text_features)
|
165 |
+
|
166 |
+
predicted_features_list.append(predicted_features)
|
167 |
+
target_names_list.extend(target_names)
|
168 |
+
group_members_list.extend(group_members)
|
169 |
+
reference_names_list.extend(reference_names)
|
170 |
+
|
171 |
+
predicted_features = torch.vstack(predicted_features_list)
|
172 |
+
|
173 |
+
return predicted_features, reference_names_list, target_names_list, group_members_list
|
174 |
+
|
175 |
+
|
176 |
+
@torch.no_grad()
|
177 |
+
def cirr_generate_val_predictions_with_phi(clip_model: CLIPTextModelWithProjection, phi, relative_val_dataset: Dataset, ref_names_list: List[str],
|
178 |
+
image_features: torch.Tensor) -> \
|
179 |
+
Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
|
180 |
+
"""
|
181 |
+
Generates features predictions for the validation set of CIRR
|
182 |
+
"""
|
183 |
+
|
184 |
+
# Define the dataloader
|
185 |
+
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
|
186 |
+
pin_memory=False, collate_fn=collate_fn)
|
187 |
+
predicted_features_list = []
|
188 |
+
target_names_list = []
|
189 |
+
group_members_list = []
|
190 |
+
reference_names_list = []
|
191 |
+
|
192 |
+
for batch in tqdm(relative_val_loader):
|
193 |
+
reference_names = batch['reference_name']
|
194 |
+
target_names = batch['target_name']
|
195 |
+
relative_captions = batch['relative_caption']
|
196 |
+
group_members = batch['group_members']
|
197 |
+
|
198 |
+
group_members = np.array(group_members).T.tolist()
|
199 |
+
|
200 |
+
input_captions = [
|
201 |
+
f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
|
202 |
+
|
203 |
+
# we need to make batch_tokens with selected_image_features
|
204 |
+
selected_image_features = torch.vstack([image_features[ref_names_list.index(ref)] for ref in reference_names])
|
205 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
206 |
+
context = clip_model.text_model.embeddings.token_embedding(tokenized_input_captions) + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
|
207 |
+
batch_tokens = phi(selected_image_features, context)
|
208 |
+
#batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
209 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
210 |
+
|
211 |
+
predicted_features = F.normalize(text_features)
|
212 |
+
|
213 |
+
predicted_features_list.append(predicted_features)
|
214 |
+
target_names_list.extend(target_names)
|
215 |
+
group_members_list.extend(group_members)
|
216 |
+
reference_names_list.extend(reference_names)
|
217 |
+
|
218 |
+
predicted_features = torch.vstack(predicted_features_list)
|
219 |
+
|
220 |
+
return predicted_features, reference_names_list, target_names_list, group_members_list
|
221 |
+
|
222 |
+
|
223 |
+
@torch.no_grad()
|
224 |
+
def cirr_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
|
225 |
+
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
|
226 |
+
-> Dict[str, float]:
|
227 |
+
"""
|
228 |
+
Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
|
229 |
+
"""
|
230 |
+
|
231 |
+
# Generate the predicted features
|
232 |
+
predicted_features, reference_names, target_names, group_members = \
|
233 |
+
cirr_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list, pseudo_tokens)
|
234 |
+
|
235 |
+
index_features = index_features.to(device)
|
236 |
+
predicted_features = predicted_features.to(device)
|
237 |
+
|
238 |
+
# Normalize the index features
|
239 |
+
index_features = F.normalize(index_features, dim=-1).float()
|
240 |
+
predicted_features = predicted_features.float()
|
241 |
+
|
242 |
+
# Compute the distances and sort the results
|
243 |
+
distances = 1 - predicted_features @ index_features.T
|
244 |
+
sorted_indices = torch.argsort(distances, dim=-1).cpu()
|
245 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
246 |
+
|
247 |
+
# Delete the reference image from the results
|
248 |
+
reference_mask = torch.tensor(
|
249 |
+
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
|
250 |
+
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
|
251 |
+
sorted_index_names.shape[1] - 1)
|
252 |
+
# Compute the ground-truth labels wrt the predictions
|
253 |
+
labels = torch.tensor(
|
254 |
+
sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
|
255 |
+
|
256 |
+
# Compute the subset predictions and ground-truth labels
|
257 |
+
group_members = np.array(group_members)
|
258 |
+
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
|
259 |
+
group_labels = labels[group_mask].reshape(labels.shape[0], -1)
|
260 |
+
|
261 |
+
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
|
262 |
+
assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
|
263 |
+
|
264 |
+
# Compute the metrics
|
265 |
+
recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
|
266 |
+
recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
|
267 |
+
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
|
268 |
+
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
|
269 |
+
group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
|
270 |
+
group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
|
271 |
+
group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
|
272 |
+
|
273 |
+
return {
|
274 |
+
'cirr_recall_at1': recall_at1,
|
275 |
+
'cirr_recall_at5': recall_at5,
|
276 |
+
'cirr_recall_at10': recall_at10,
|
277 |
+
'cirr_recall_at50': recall_at50,
|
278 |
+
'cirr_group_recall_at1': group_recall_at1,
|
279 |
+
'cirr_group_recall_at2': group_recall_at2,
|
280 |
+
'cirr_group_recall_at3': group_recall_at3,
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def cirr_compute_val_metrics_with_phi(relative_val_dataset: Dataset, clip_model: CLIPTextModelWithProjection, phi, index_features: torch.Tensor,
|
286 |
+
index_names: List[str], ref_names_list: List[str], image_features: torch.Tensor) \
|
287 |
+
-> Dict[str, float]:
|
288 |
+
"""
|
289 |
+
Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
|
290 |
+
"""
|
291 |
+
|
292 |
+
# Generate the predicted features
|
293 |
+
predicted_features, reference_names, target_names, group_members = \
|
294 |
+
cirr_generate_val_predictions_with_phi(clip_model, phi, relative_val_dataset, ref_names_list, image_features)
|
295 |
+
|
296 |
+
index_features = index_features.to(device)
|
297 |
+
predicted_features = predicted_features.to(device)
|
298 |
+
|
299 |
+
# Normalize the index features
|
300 |
+
index_features = F.normalize(index_features, dim=-1).float()
|
301 |
+
predicted_features = predicted_features.float()
|
302 |
+
|
303 |
+
# Compute the distances and sort the results
|
304 |
+
distances = 1 - predicted_features @ index_features.T
|
305 |
+
sorted_indices = torch.argsort(distances, dim=-1).cpu()
|
306 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
307 |
+
|
308 |
+
# Delete the reference image from the results
|
309 |
+
reference_mask = torch.tensor(
|
310 |
+
sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
|
311 |
+
sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
|
312 |
+
sorted_index_names.shape[1] - 1)
|
313 |
+
# Compute the ground-truth labels wrt the predictions
|
314 |
+
labels = torch.tensor(
|
315 |
+
sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
|
316 |
+
|
317 |
+
# Compute the subset predictions and ground-truth labels
|
318 |
+
group_members = np.array(group_members)
|
319 |
+
group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
|
320 |
+
group_labels = labels[group_mask].reshape(labels.shape[0], -1)
|
321 |
+
|
322 |
+
assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
|
323 |
+
assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
|
324 |
+
|
325 |
+
# Compute the metrics
|
326 |
+
recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
|
327 |
+
recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
|
328 |
+
recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
|
329 |
+
recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
|
330 |
+
group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
|
331 |
+
group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
|
332 |
+
group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
|
333 |
+
|
334 |
+
return {
|
335 |
+
'cirr_recall_at1': recall_at1,
|
336 |
+
'cirr_recall_at5': recall_at5,
|
337 |
+
'cirr_recall_at10': recall_at10,
|
338 |
+
'cirr_recall_at50': recall_at50,
|
339 |
+
'cirr_group_recall_at1': group_recall_at1,
|
340 |
+
'cirr_group_recall_at2': group_recall_at2,
|
341 |
+
'cirr_group_recall_at3': group_recall_at3,
|
342 |
+
}
|
343 |
+
|
344 |
+
|
345 |
+
@torch.no_grad()
|
346 |
+
def cirr_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: list, pseudo_tokens: torch.Tensor,
|
347 |
+
preprocess: callable) -> Dict[str, float]:
|
348 |
+
"""
|
349 |
+
Compute the retrieval metrics on the CIRR validation set given the pseudo tokens and the reference names
|
350 |
+
"""
|
351 |
+
|
352 |
+
# Load the model
|
353 |
+
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
|
354 |
+
#clip_model = clip_model.float().eval().requires_grad_(False)
|
355 |
+
|
356 |
+
# Extract the index features
|
357 |
+
classic_val_dataset = CIRRDataset(dataset_path, 'val', 'classic', preprocess)
|
358 |
+
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
|
359 |
+
|
360 |
+
# Define the relative validation dataset
|
361 |
+
relative_val_dataset = CIRRDataset(dataset_path, 'val', 'relative', preprocess)
|
362 |
+
|
363 |
+
return cirr_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names,
|
364 |
+
ref_names_list, pseudo_tokens)
|
365 |
+
|
366 |
+
|
367 |
+
@torch.no_grad()
|
368 |
+
def circo_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
|
369 |
+
pseudo_tokens: torch.Tensor) -> Tuple[
|
370 |
+
torch.Tensor, List[str], list]:
|
371 |
+
"""
|
372 |
+
Generates features predictions for the validation set of CIRCO
|
373 |
+
"""
|
374 |
+
|
375 |
+
# Create the data loader
|
376 |
+
relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
|
377 |
+
pin_memory=False, collate_fn=collate_fn, shuffle=False)
|
378 |
+
|
379 |
+
predicted_features_list = []
|
380 |
+
target_names_list = []
|
381 |
+
gts_img_ids_list = []
|
382 |
+
|
383 |
+
# Compute the features
|
384 |
+
for batch in tqdm(relative_val_loader):
|
385 |
+
reference_names = batch['reference_name']
|
386 |
+
target_names = batch['target_name']
|
387 |
+
relative_captions = batch['relative_caption']
|
388 |
+
gt_img_ids = batch['gt_img_ids']
|
389 |
+
|
390 |
+
gt_img_ids = np.array(gt_img_ids).T.tolist()
|
391 |
+
input_captions = [f"a photo of $ that {caption}" for caption in relative_captions]
|
392 |
+
batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
|
393 |
+
tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
|
394 |
+
text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
|
395 |
+
predicted_features = F.normalize(text_features)
|
396 |
+
|
397 |
+
predicted_features_list.append(predicted_features)
|
398 |
+
target_names_list.extend(target_names)
|
399 |
+
gts_img_ids_list.extend(gt_img_ids)
|
400 |
+
|
401 |
+
predicted_features = torch.vstack(predicted_features_list)
|
402 |
+
|
403 |
+
return predicted_features, target_names_list, gts_img_ids_list
|
404 |
+
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def circo_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
|
408 |
+
index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
|
409 |
+
-> Dict[str, float]:
|
410 |
+
"""
|
411 |
+
Compute the retrieval metrics on the CIRCO validation set given the dataset, pseudo tokens and the reference names
|
412 |
+
"""
|
413 |
+
|
414 |
+
# Generate the predicted features
|
415 |
+
predicted_features, target_names, gts_img_ids = circo_generate_val_predictions(clip_model, relative_val_dataset,
|
416 |
+
ref_names_list, pseudo_tokens)
|
417 |
+
ap_at5 = []
|
418 |
+
ap_at10 = []
|
419 |
+
ap_at25 = []
|
420 |
+
ap_at50 = []
|
421 |
+
|
422 |
+
recall_at5 = []
|
423 |
+
recall_at10 = []
|
424 |
+
recall_at25 = []
|
425 |
+
recall_at50 = []
|
426 |
+
|
427 |
+
# Move the features to the device
|
428 |
+
index_features = index_features.to(device)
|
429 |
+
predicted_features = predicted_features.to(device)
|
430 |
+
|
431 |
+
# Normalize the features
|
432 |
+
index_features = F.normalize(index_features.float())
|
433 |
+
|
434 |
+
for predicted_feature, target_name, gt_img_ids in tqdm(zip(predicted_features, target_names, gts_img_ids)):
|
435 |
+
gt_img_ids = np.array(gt_img_ids)[
|
436 |
+
np.array(gt_img_ids) != ''] # remove trailing empty strings added for collate_fn
|
437 |
+
similarity = predicted_feature @ index_features.T
|
438 |
+
sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu()
|
439 |
+
sorted_index_names = np.array(index_names)[sorted_indices]
|
440 |
+
map_labels = torch.tensor(np.isin(sorted_index_names, gt_img_ids), dtype=torch.uint8)
|
441 |
+
precisions = torch.cumsum(map_labels, dim=0) * map_labels # Consider only positions corresponding to GTs
|
442 |
+
precisions = precisions / torch.arange(1, map_labels.shape[0] + 1) # Compute precision for each position
|
443 |
+
|
444 |
+
ap_at5.append(float(torch.sum(precisions[:5]) / min(len(gt_img_ids), 5)))
|
445 |
+
ap_at10.append(float(torch.sum(precisions[:10]) / min(len(gt_img_ids), 10)))
|
446 |
+
ap_at25.append(float(torch.sum(precisions[:25]) / min(len(gt_img_ids), 25)))
|
447 |
+
ap_at50.append(float(torch.sum(precisions[:50]) / min(len(gt_img_ids), 50)))
|
448 |
+
|
449 |
+
assert target_name == gt_img_ids[0], f"Target name not in GTs {target_name} {gt_img_ids}"
|
450 |
+
single_gt_labels = torch.tensor(sorted_index_names == target_name)
|
451 |
+
recall_at5.append(float(torch.sum(single_gt_labels[:5])))
|
452 |
+
recall_at10.append(float(torch.sum(single_gt_labels[:10])))
|
453 |
+
recall_at25.append(float(torch.sum(single_gt_labels[:25])))
|
454 |
+
recall_at50.append(float(torch.sum(single_gt_labels[:50])))
|
455 |
+
|
456 |
+
map_at5 = np.mean(ap_at5) * 100
|
457 |
+
map_at10 = np.mean(ap_at10) * 100
|
458 |
+
map_at25 = np.mean(ap_at25) * 100
|
459 |
+
map_at50 = np.mean(ap_at50) * 100
|
460 |
+
recall_at5 = np.mean(recall_at5) * 100
|
461 |
+
recall_at10 = np.mean(recall_at10) * 100
|
462 |
+
recall_at25 = np.mean(recall_at25) * 100
|
463 |
+
recall_at50 = np.mean(recall_at50) * 100
|
464 |
+
|
465 |
+
return {
|
466 |
+
'circo_map_at5': map_at5,
|
467 |
+
'circo_map_at10': map_at10,
|
468 |
+
'circo_map_at25': map_at25,
|
469 |
+
'circo_map_at50': map_at50,
|
470 |
+
'circo_recall_at5': recall_at5,
|
471 |
+
'circo_recall_at10': recall_at10,
|
472 |
+
'circo_recall_at25': recall_at25,
|
473 |
+
'circo_recall_at50': recall_at50,
|
474 |
+
}
|
475 |
+
|
476 |
+
|
477 |
+
@torch.no_grad()
|
478 |
+
def circo_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str], pseudo_tokens: torch.Tensor,
|
479 |
+
preprocess: callable) -> Dict[str, float]:
|
480 |
+
"""
|
481 |
+
Compute the retrieval metrics on the CIRCO validation set given the pseudo tokens and the reference names
|
482 |
+
"""
|
483 |
+
# Load the model
|
484 |
+
#clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
|
485 |
+
#clip_model = clip_model.float().eval().requires_grad_(False)
|
486 |
+
|
487 |
+
# Extract the index features
|
488 |
+
classic_val_dataset = CIRCODataset(dataset_path, 'val', 'classic', preprocess)
|
489 |
+
index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
|
490 |
+
|
491 |
+
# Define the relative validation dataset
|
492 |
+
relative_val_dataset = CIRCODataset(dataset_path, 'val', 'relative', preprocess)
|
493 |
+
|
494 |
+
return circo_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
|
495 |
+
pseudo_tokens)
|
496 |
+
|
497 |
+
|
498 |
+
def main():
|
499 |
+
parser = ArgumentParser()
|
500 |
+
parser.add_argument("--exp-name", type=str, help="Experiment to evaluate")
|
501 |
+
parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True,
|
502 |
+
help="If 'oti' evaluate directly using the inverted oti pseudo tokens, "
|
503 |
+
"if 'phi' predicts the pseudo tokens using the phi network, "
|
504 |
+
"if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, "
|
505 |
+
"if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens"
|
506 |
+
)
|
507 |
+
parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'fashioniq', 'circo'],
|
508 |
+
help="Dataset to use")
|
509 |
+
parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True)
|
510 |
+
|
511 |
+
parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'],
|
512 |
+
help="Preprocess pipeline to use")
|
513 |
+
parser.add_argument("--phi-checkpoint-name", type=str,
|
514 |
+
help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'")
|
515 |
+
parser.add_argument("--clip_model_name", default="giga", type=str)
|
516 |
+
parser.add_argument("--cache_dir", default="./hf_models", type=str)
|
517 |
+
|
518 |
+
parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
|
519 |
+
|
520 |
+
args = parser.parse_args()
|
521 |
+
|
522 |
+
#if args.eval_type in ['phi', 'oti'] and args.exp_name is None:
|
523 |
+
# raise ValueError("Experiment name is required when using phi or oti evaluation type")
|
524 |
+
if args.eval_type == 'phi' and args.phi_checkpoint_name is None:
|
525 |
+
raise ValueError("Phi checkpoint name is required when using phi evaluation type")
|
526 |
+
|
527 |
+
if args.eval_type == 'oti':
|
528 |
+
experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'val' / args.exp_name
|
529 |
+
if not experiment_path.exists():
|
530 |
+
raise ValueError(f"Experiment {args.exp_name} not found")
|
531 |
+
|
532 |
+
with open(experiment_path / 'hyperparameters.json') as f:
|
533 |
+
hyperparameters = json.load(f)
|
534 |
+
|
535 |
+
pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device)
|
536 |
+
with open(experiment_path / 'image_names.pkl', 'rb') as f:
|
537 |
+
ref_names_list = pickle.load(f)
|
538 |
+
|
539 |
+
clip_model_name = hyperparameters['clip_model_name']
|
540 |
+
clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False)
|
541 |
+
|
542 |
+
if args.preprocess_type == 'targetpad':
|
543 |
+
print('Target pad preprocess pipeline is used')
|
544 |
+
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
|
545 |
+
elif args.preprocess_type == 'clip':
|
546 |
+
print('CLIP preprocess pipeline is used')
|
547 |
+
preprocess = clip_preprocess
|
548 |
+
else:
|
549 |
+
raise ValueError("Preprocess type not supported")
|
550 |
+
|
551 |
+
|
552 |
+
elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']:
|
553 |
+
if args.eval_type == 'phi':
|
554 |
+
args.mixed_precision = 'fp16'
|
555 |
+
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
|
556 |
+
|
557 |
+
phi = Phi(input_dim=text_encoder.config.projection_dim,
|
558 |
+
hidden_dim=text_encoder.config.projection_dim * 4,
|
559 |
+
output_dim=text_encoder.config.hidden_size, dropout=0.5).to(
|
560 |
+
device)
|
561 |
+
|
562 |
+
phi.load_state_dict(
|
563 |
+
torch.load(args.phi_checkpoint_name, map_location=device)[
|
564 |
+
phi.__class__.__name__])
|
565 |
+
|
566 |
+
phi = phi.eval()
|
567 |
+
|
568 |
+
elif args.eval_type == 'pic2word':
|
569 |
+
args.mixed_precision = 'fp16'
|
570 |
+
image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
|
571 |
+
phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim,
|
572 |
+
output_dim=text_encoder.config.hidden_size,
|
573 |
+
).to(device)
|
574 |
+
sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text']
|
575 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
576 |
+
phi.load_state_dict(sd)
|
577 |
+
phi = phi.eval()
|
578 |
+
|
579 |
+
else: # searle or searle-xl
|
580 |
+
if args.eval_type == 'searle':
|
581 |
+
clip_model_name = 'ViT-B/32'
|
582 |
+
else: # args.eval_type == 'searle-xl':
|
583 |
+
clip_model_name = 'ViT-L/14'
|
584 |
+
phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
|
585 |
+
backbone=clip_model_name)
|
586 |
+
phi = phi.to(device).eval()
|
587 |
+
clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False)
|
588 |
+
|
589 |
+
if args.preprocess_type == 'targetpad':
|
590 |
+
print('Target pad preprocess pipeline is used')
|
591 |
+
preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
|
592 |
+
elif args.preprocess_type == 'clip':
|
593 |
+
print('CLIP preprocess pipeline is used')
|
594 |
+
preprocess = clip_preprocess
|
595 |
+
else:
|
596 |
+
raise ValueError("Preprocess type not supported")
|
597 |
+
|
598 |
+
if args.dataset.lower() == 'fashioniq':
|
599 |
+
relative_val_dataset = FashionIQDataset(args.dataset_path, 'val', ['dress', 'toptee', 'shirt'],
|
600 |
+
'relative', preprocess, no_duplicates=True)
|
601 |
+
elif args.dataset.lower() == 'cirr':
|
602 |
+
relative_val_dataset = CIRRDataset(args.dataset_path, 'val', 'relative', preprocess,
|
603 |
+
no_duplicates=True)
|
604 |
+
elif args.dataset.lower() == 'circo':
|
605 |
+
relative_val_dataset = CIRCODataset(args.dataset_path, 'val', 'relative', preprocess)
|
606 |
+
else:
|
607 |
+
raise ValueError("Dataset not supported")
|
608 |
+
|
609 |
+
#clip_model = clip_model.float().to(device)
|
610 |
+
image_encoder = image_encoder.float().to(device)
|
611 |
+
text_encoder = text_encoder.float().to(device)
|
612 |
+
pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_val_dataset, args)
|
613 |
+
pseudo_tokens = pseudo_tokens.to(device)
|
614 |
+
else:
|
615 |
+
raise ValueError("Eval type not supported")
|
616 |
+
|
617 |
+
print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t")
|
618 |
+
if args.dataset.lower() == 'fashioniq':
|
619 |
+
recalls_at10 = []
|
620 |
+
recalls_at50 = []
|
621 |
+
for dress_type in ['shirt', 'dress', 'toptee']:
|
622 |
+
fiq_metrics = fiq_val_retrieval(args.dataset_path, dress_type, image_encoder, text_encoder, ref_names_list,
|
623 |
+
pseudo_tokens, preprocess)
|
624 |
+
recalls_at10.append(fiq_metrics['fiq_recall_at10'])
|
625 |
+
recalls_at50.append(fiq_metrics['fiq_recall_at50'])
|
626 |
+
|
627 |
+
for k, v in fiq_metrics.items():
|
628 |
+
print(f"{dress_type}_{k} = {v:.2f}")
|
629 |
+
print("\n")
|
630 |
+
|
631 |
+
print(f"average_fiq_recall_at10 = {np.mean(recalls_at10):.2f}")
|
632 |
+
print(f"average_fiq_recall_at50 = {np.mean(recalls_at50):.2f}")
|
633 |
+
|
634 |
+
elif args.dataset.lower() == 'cirr':
|
635 |
+
cirr_metrics = cirr_val_retrieval(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
|
636 |
+
preprocess)
|
637 |
+
|
638 |
+
for k, v in cirr_metrics.items():
|
639 |
+
print(f"{k} = {v:.2f}")
|
640 |
+
|
641 |
+
elif args.dataset.lower() == 'circo':
|
642 |
+
circo_metrics = circo_val_retrieval(args.dataset_path, clip_model_name, ref_names_list, pseudo_tokens,
|
643 |
+
preprocess)
|
644 |
+
|
645 |
+
for k, v in circo_metrics.items():
|
646 |
+
print(f"{k} = {v:.2f}")
|
647 |
+
|
648 |
+
|
649 |
+
if __name__ == '__main__':
|
650 |
+
main()
|