Geonmo commited on
Commit
cacafc1
β€’
1 Parent(s): 3f0fd05

initial commit

Browse files
Files changed (12) hide show
  1. README.md +1 -1
  2. app.py +220 -0
  3. data_utils.py +67 -0
  4. encode_with_pseudo_tokens.py +54 -0
  5. eval_templates.py +70 -0
  6. generate_test_submission.py +363 -0
  7. loader.py +632 -0
  8. models.py +192 -0
  9. requirements.txt +8 -0
  10. train_phi.py +317 -0
  11. utils.py +182 -0
  12. validate.py +650 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: LinCIR
3
- emoji: 🐨
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
 
1
  ---
2
  title: LinCIR
3
+ emoji: πŸ“š
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ import os
7
+ import time
8
+ from argparse import ArgumentParser
9
+
10
+ import numpy as np
11
+ import torch
12
+ import gradio as gr
13
+ from clip_retrieval.clip_client import ClipClient
14
+
15
+ from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
16
+ from models import build_text_encoder, Phi, PIC2WORD
17
+
18
+ import transformers
19
+ from huggingface_hub import hf_hub_url, cached_download
20
+
21
+
22
+ def parse_args():
23
+ parser = ArgumentParser()
24
+ parser.add_argument("--lincir_ckpt_path", default=None, type=str,
25
+ help="The output directory where the model predictions and checkpoints will be written")
26
+ parser.add_argument("--pic2word_ckpt_path", default=None, type=str)
27
+ parser.add_argument("--cache_dir", default="./hf_models", type=str,
28
+ help="Path to model cache folder")
29
+ parser.add_argument("--clip_model_name", default="large", type=str,
30
+ help="CLIP model to use, e.g 'large', 'huge', 'giga'")
31
+ parser.add_argument("--mixed_precision", default="fp16", type=str)
32
+ parser.add_argument("--test_fps", action="store_true")
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def load_models(args):
38
+ if torch.cuda.is_available():
39
+ device = 'cuda:0'
40
+ dtype = torch.float16
41
+ else:
42
+ device = 'cpu'
43
+ dtype = torch.float32
44
+
45
+ clip_vision_model, clip_preprocess, clip_text_model, tokenizer = build_text_encoder(args)
46
+
47
+ tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # 49408
48
+
49
+ # ours
50
+ phi = Phi(input_dim=clip_text_model.config.projection_dim,
51
+ hidden_dim=clip_text_model.config.projection_dim * 4,
52
+ output_dim=clip_text_model.config.hidden_size, dropout=0.0)
53
+ phi.eval()
54
+
55
+ # searle
56
+ phi_searle, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
57
+ backbone='ViT-L/14')
58
+ phi_searle.eval()
59
+
60
+ # pic2word
61
+ phi_pic2word = PIC2WORD(embed_dim=clip_text_model.config.projection_dim,
62
+ output_dim=clip_text_model.config.hidden_size)
63
+ phi_pic2word.eval()
64
+
65
+ clip_vision_model.to(device, dtype=dtype)
66
+ clip_text_model.to(device, dtype=dtype)
67
+
68
+ if not args.test_fps:
69
+ # download and load sd
70
+ if not os.path.exists('./pretrained_models/lincir_large.pt'):
71
+ model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='lincir_large.pt')
72
+ cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='lincir_large.pt')
73
+ state_dict = torch.load('./pretrained_models/lincir_large.pt', map_location=device)
74
+ phi.load_state_dict(state_dict['Phi'])
75
+
76
+ if not os.path.exists('./pretrained_models/pic2word_large.pt'):
77
+ model_file_url = hf_hub_url(repo_id='navervision/zeroshot-cir-models', filename='pic2word_large.pt')
78
+ cached_download(model_file_url, cache_dir='./pretrained_models', force_filename='pic2word_large.pt')
79
+ sd = torch.load('./pretrained_models/pic2word_large.pt', map_location=device)['state_dict_img2text']
80
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
81
+ phi_pic2word.load_state_dict(sd)
82
+
83
+ phi.to(device, dtype=dtype)
84
+ phi_searle.to(device, dtype=dtype)
85
+ phi_pic2word.to(device, dtype=dtype)
86
+
87
+ decoder = None
88
+
89
+ return {'clip_vision_model': clip_vision_model,
90
+ 'clip_preprocess': clip_preprocess,
91
+ 'clip_text_model': clip_text_model,
92
+ 'tokenizer': tokenizer,
93
+ 'phi': phi,
94
+ 'phi_searle': phi_searle,
95
+ 'phi_pic2word': phi_pic2word,
96
+ 'decoder': decoder,
97
+ 'device': device,
98
+ 'dtype': dtype,
99
+ 'clip_model_name': args.clip_model_name,
100
+ }
101
+
102
+
103
+ def predict(images, input_text, model_name):
104
+ start_time = time.time()
105
+ input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
106
+ input_text = input_text.replace('$', '[$]')
107
+ input_tokens = model_dict['tokenizer'](text=input_text, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
108
+ input_tokens = torch.where(input_tokens == 49408,
109
+ torch.ones_like(input_tokens) * 259,
110
+ input_tokens)
111
+ image_features = model_dict['clip_vision_model'](pixel_values=input_images.to(model_dict['dtype'])).image_embeds
112
+ clip_image_time = time.time() - start_time
113
+
114
+ start_time = time.time()
115
+ if model_name == 'lincir':
116
+ estimated_token_embeddings = model_dict['phi'](image_features)
117
+ elif model_name == 'searle':
118
+ estimated_token_embeddings = model_dict['phi_searle'](image_features)
119
+ else: # model_name == 'pic2word'
120
+ estimated_token_embeddings = model_dict['phi_pic2word'](image_features)
121
+ phi_time = time.time() - start_time
122
+
123
+ start_time = time.time()
124
+ text_embeddings, text_last_hidden_states = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, estimated_token_embeddings, return_last_states=True)
125
+ clip_text_time = time.time() - start_time
126
+
127
+ start_time = time.time()
128
+ results = client.query(embedding_input=text_embeddings[0].tolist())
129
+ retrieval_time = time.time() - start_time
130
+
131
+ output = ''
132
+
133
+ for idx, result in enumerate(results):
134
+ image_url = result['url']
135
+ output += f'![image]({image_url})\n'
136
+
137
+ time_output = {'CLIP visual extractor': clip_image_time,
138
+ 'CLIP textual extractor': clip_text_time,
139
+ 'Phi projection': phi_time,
140
+ 'CLIP retrieval': retrieval_time,
141
+ }
142
+ setup_output = {'device': model_dict['device'],
143
+ 'dtype': model_dict['dtype'],
144
+ 'Phi': model_name,
145
+ 'CLIP': model_dict['clip_model_name'],
146
+ }
147
+
148
+ return {'time': time_output, 'setup': setup_output}, output
149
+
150
+
151
+ def test_fps(batch_size=1):
152
+ dummy_images = torch.rand([batch_size, 3, 224, 224])
153
+
154
+ todo_list = ['phi', 'phi_pic2word']
155
+
156
+ input_tokens = model_dict['tokenizer'](text=['a photo of $1 with flowers'] * batch_size, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(model_dict['device'])
157
+ input_tokens = torch.where(input_tokens == 49409,
158
+ torch.ones_like(input_tokens) * 259,
159
+ input_tokens)
160
+
161
+ for model_name in todo_list:
162
+ time_array = []
163
+ n_repeat = 100
164
+ for _ in range(n_repeat):
165
+ start_time = time.time()
166
+ image_features = model_dict['clip_vision_model'](pixel_values=dummy_images.to(model_dict['clip_vision_model'].device, dtype=model_dict['clip_vision_model'].dtype)).image_embeds
167
+ token_embeddings = model_dict[model_name](image_features)
168
+ text_embeddings = encode_with_pseudo_tokens_HF(model_dict['clip_text_model'], input_tokens, token_embeddings)
169
+ end_time = time.time()
170
+ if _ > 5:
171
+ time_array.append(end_time - start_time)
172
+ print(f"{model_name}: {np.mean(time_array):.4f}")
173
+
174
+
175
+ if __name__ == '__main__':
176
+ args = parse_args()
177
+
178
+ global model_dict, client
179
+
180
+ model_dict = load_models(args)
181
+
182
+ if args.test_fps:
183
+ # check FPS of all models.
184
+ test_fps(1)
185
+ exit()
186
+
187
+
188
+ client = ClipClient(url="https://knn.laion.ai/knn-service",
189
+ indice_name="laion5B-H-14" if args.clip_model_name == "huge" else "laion5B-L-14",
190
+ )
191
+
192
+ title = 'Zeroshot CIR demo'
193
+
194
+ md_title = f'''# {title}
195
+ [LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
196
+ [SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
197
+ [Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
198
+
199
+ K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. This is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval).
200
+ '''
201
+
202
+ with gr.Blocks(title=title) as demo:
203
+ gr.Markdown(md_title)
204
+ with gr.Row():
205
+ with gr.Column():
206
+ with gr.Row():
207
+ image_source = gr.Image(type='pil', label='image1')
208
+ model_name = gr.Radio(['lincir', 'searle', 'pic2word'], label='Phi model', value='lincir')
209
+ text_input = gr.Textbox(value='', label='Input text guidance. Special token is $')
210
+ submit_button = gr.Button('Submit')
211
+ gr.Examples([["example1.jpg", "$, pencil sketch", 'lincir']], inputs=[image_source, text_input, model_name])
212
+ with gr.Column():
213
+ json_output = gr.JSON(label='Processing time')
214
+ md_output = gr.Markdown(label='Output')
215
+
216
+ submit_button.click(predict, inputs=[image_source, text_input, model_name], outputs=[json_output, md_output])
217
+
218
+ demo.queue()
219
+
220
+ demo.launch()
data_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import PIL
4
+ import torch
5
+ import torchvision.transforms.functional as FT
6
+ from torch.utils.data import Dataset
7
+ from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
8
+ from torchvision.transforms import InterpolationMode
9
+
10
+ PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
+
12
+
13
+ def _convert_image_to_rgb(image):
14
+ return image.convert("RGB")
15
+
16
+
17
+ def collate_fn(batch):
18
+ '''
19
+ function which discard None images in a batch when using torch DataLoader
20
+ :param batch: input_batch
21
+ :return: output_batch = input_batch - None_values
22
+ '''
23
+ batch = list(filter(lambda x: x is not None, batch))
24
+ return torch.utils.data.dataloader.default_collate(batch)
25
+
26
+
27
+ class TargetPad:
28
+ """
29
+ If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
30
+ For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
31
+ """
32
+
33
+ def __init__(self, target_ratio: float, size: int):
34
+ """
35
+ :param target_ratio: target ratio
36
+ :param size: preprocessing output dimension
37
+ """
38
+ self.size = size
39
+ self.target_ratio = target_ratio
40
+
41
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
42
+ w, h = image.size
43
+ actual_ratio = max(w, h) / min(w, h)
44
+ if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
45
+ return image
46
+ scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
47
+ hp = max(int((scaled_max_wh - w) / 2), 0)
48
+ vp = max(int((scaled_max_wh - h) / 2), 0)
49
+ padding = [hp, vp, hp, vp]
50
+ return FT.pad(image, padding, 0, 'constant')
51
+
52
+
53
+ def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
54
+ """
55
+ CLIP-like preprocessing transform computed after using TargetPad pad
56
+ :param target_ratio: target ratio for TargetPad
57
+ :param dim: image output dimension
58
+ :return: CLIP-like torchvision Compose transform
59
+ """
60
+ return Compose([
61
+ TargetPad(target_ratio, dim),
62
+ Resize(dim, interpolation=InterpolationMode.BICUBIC),
63
+ CenterCrop(dim),
64
+ _convert_image_to_rgb,
65
+ ToTensor(),
66
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
67
+ ])
encode_with_pseudo_tokens.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ import torch
7
+ from clip.model import CLIP
8
+ from transformers import CLIPTextModelWithProjection
9
+
10
+
11
+ def _make_causal_mask(
12
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
13
+ ):
14
+ """
15
+ Make causal mask used for bi-directional self-attention.
16
+ Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693
17
+ """
18
+ bsz, tgt_len = input_ids_shape
19
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
20
+ mask_cond = torch.arange(mask.size(-1), device=device)
21
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
22
+ mask = mask.to(dtype)
23
+
24
+ if past_key_values_length > 0:
25
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
26
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
27
+
28
+
29
+ def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor,
30
+ num_tokens=1, return_last_states=False) -> torch.Tensor:
31
+ x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model]
32
+ x = torch.where(text.unsqueeze(-1) == 259,
33
+ pseudo_tokens.unsqueeze(1).type(clip_model.dtype),
34
+ x)
35
+ x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
36
+ _causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device)
37
+ x = clip_model.text_model.encoder(inputs_embeds=x,
38
+ attention_mask=None,
39
+ causal_attention_mask=_causal_attention_mask,
40
+ output_attentions=False,
41
+ output_hidden_states=False,
42
+ return_dict=False)
43
+ x = x[0]
44
+ x_last = clip_model.text_model.final_layer_norm(x)
45
+ x = x_last[torch.arange(x_last.shape[0], device=x_last.device),
46
+ text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1),
47
+ ]
48
+ if hasattr(clip_model, 'text_projection'):
49
+ x = clip_model.text_projection(x)
50
+
51
+ if return_last_states:
52
+ return x, x_last
53
+ else:
54
+ return x
eval_templates.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ templates = [
7
+ lambda caption: f"a photo of $ that {caption}",
8
+ lambda caption: f"$ that {caption}",
9
+ lambda caption: f"$ with {caption}",
10
+ lambda caption: f"$ , {caption}",
11
+ lambda caption: f"$ adapted to {caption}",
12
+ lambda caption: f"$ modified by {caption}",
13
+ lambda caption: f"$ in response to {caption}",
14
+ lambda caption: f"$ transformed by {caption}",
15
+ lambda caption: f"$ influenced by {caption}",
16
+ lambda caption: f"Retrieval of $ using feedback {caption}",
17
+ lambda caption: f"$ guided by {caption}",
18
+ lambda caption: f"$ adjusted to {caption}",
19
+ lambda caption: f"$ in alignment with {caption}",
20
+ lambda caption: f"$ in correspondence to {caption}",
21
+ lambda caption: f"$ refined with {caption}",
22
+ lambda caption: f"$ as directed by {caption}",
23
+ lambda caption: f"$ evolved from {caption}",
24
+ lambda caption: f"$ inspired by {caption}",
25
+ lambda caption: f"$ with adjustments from {caption}",
26
+ lambda caption: f"$ in consideration of {caption}",
27
+ lambda caption: f"$ , taking into account {caption}",
28
+ lambda caption: f"$ as influenced by the query {caption}",
29
+ lambda caption: f"$ reshaped by {caption}",
30
+ lambda caption: f"$ curated based on {caption}",
31
+ lambda caption: f"$ showcasing {caption}",
32
+ lambda caption: f"An instance of $ where {caption}",
33
+ lambda caption: f"$ highlighting {caption}",
34
+ lambda caption: f"A depiction of $ exhibiting {caption}",
35
+ lambda caption: f"$ as exemplified by {caption}",
36
+ lambda caption: f"$ demonstrating {caption}",
37
+ lambda caption: f"An illustration of $ portraying {caption}",
38
+ lambda caption: f"$ in the context of {caption}",
39
+ lambda caption: f"$ as influenced by {caption}",
40
+ lambda caption: f"$ characterized by {caption}",
41
+ lambda caption: f"$ : An exploration of {caption}",
42
+ lambda caption: f"A presentation of $ underlined by {caption}",
43
+ lambda caption: f"A manifestation of $ reflecting {caption}",
44
+ lambda caption: f"$ in light of {caption}",
45
+ lambda caption: f"$ as a testament to {caption}",
46
+ lambda caption: f"$ intertwined with {caption}",
47
+ lambda caption: f"$ complemented by {caption}",
48
+ lambda caption: f"$ juxtaposed with {caption}",
49
+ lambda caption: f"A representation of $ in relation to {caption}",
50
+ lambda caption: f"$ that {caption}",
51
+ lambda caption: f"$ which {caption}",
52
+ lambda caption: f"$ where it {caption}",
53
+ lambda caption: f"Discover $ that {caption}",
54
+ lambda caption: f"Retrieve $ that {caption}",
55
+ lambda caption: f"Search for $ that {caption}",
56
+ lambda caption: f"Identify $ which {caption}",
57
+ lambda caption: f"Highlight $ that {caption}",
58
+ lambda caption: f"Present $ where it {caption}",
59
+ lambda caption: f"Showcase $ that {caption}",
60
+ lambda caption: f"Explore $ which {caption}",
61
+ lambda caption: f"Find $ that {caption}",
62
+ lambda caption: f"Source $ which {caption}",
63
+ lambda caption: f"View $ where it {caption}",
64
+ lambda caption: f"Examine $ that {caption}",
65
+ lambda caption: f"Analyze $ which {caption}",
66
+ lambda caption: f"Observe $ that {caption}",
67
+ lambda caption: f"Report $ which {caption}",
68
+ lambda caption: f"See $ where it {caption}",
69
+ lambda caption: f"Document $ that {caption}"
70
+ ]
generate_test_submission.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ from argparse import ArgumentParser
5
+ from typing import List, Tuple, Dict
6
+
7
+ import clip
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from clip.model import CLIP
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+
15
+ from data_utils import PROJECT_ROOT, targetpad_transform
16
+ from loader import CIRRDataset, CIRCODataset
17
+ from encode_with_pseudo_tokens import encode_with_pseudo_tokens, encode_with_pseudo_tokens_HF
18
+ from models import build_text_encoder, Phi, PIC2WORD
19
+ from utils import extract_image_features, device, collate_fn, extract_pseudo_tokens_with_phi
20
+
21
+
22
+ @torch.no_grad()
23
+ def cirr_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str],
24
+ pseudo_tokens: torch.Tensor, preprocess: callable, submission_name: str) -> None:
25
+ """
26
+ Generate the test submission file for the CIRR dataset given the pseudo tokens
27
+ """
28
+
29
+ # Load the CLIP model
30
+ #clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
31
+ #clip_model = clip_model.float().eval()
32
+
33
+ # Compute the index features
34
+ classic_test_dataset = CIRRDataset(dataset_path, 'test1', 'classic', preprocess)
35
+ index_features, index_names = extract_image_features(classic_test_dataset, image_encoder)
36
+
37
+ relative_test_dataset = CIRRDataset(dataset_path, 'test1', 'relative', preprocess)
38
+
39
+ # Get the predictions dicts
40
+ pairid_to_retrieved_images, pairid_to_group_retrieved_images = \
41
+ cirr_generate_test_dicts(relative_test_dataset, text_encoder, index_features, index_names,
42
+ ref_names_list, pseudo_tokens)
43
+
44
+ submission = {
45
+ 'version': 'rc2',
46
+ 'metric': 'recall'
47
+ }
48
+ group_submission = {
49
+ 'version': 'rc2',
50
+ 'metric': 'recall_subset'
51
+ }
52
+
53
+ submission.update(pairid_to_retrieved_images)
54
+ group_submission.update(pairid_to_group_retrieved_images)
55
+
56
+ submissions_folder_path = os.path.join('./submission', 'cirr')
57
+ os.makedirs(submissions_folder_path, exist_ok=True)
58
+
59
+ with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file:
60
+ json.dump(submission, file, sort_keys=True)
61
+
62
+ with open(os.path.join(submissions_folder_path, f"subset_{submission_name}.json"), 'w+') as file:
63
+ json.dump(group_submission, file, sort_keys=True)
64
+
65
+
66
+ def cirr_generate_test_dicts(relative_test_dataset: CIRRDataset, clip_model, index_features: torch.Tensor,
67
+ index_names: List[str], ref_names_list: List[str], pseudo_tokens: List[str]) \
68
+ -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
69
+ """
70
+ Generate the test submission dicts for the CIRR dataset given the pseudo tokens
71
+ """
72
+
73
+ # Get the predicted features
74
+ predicted_features, reference_names, pairs_id, group_members = \
75
+ cirr_generate_test_predictions(clip_model, relative_test_dataset, ref_names_list, pseudo_tokens)
76
+
77
+ print(f"Compute CIRR prediction dicts")
78
+
79
+ # Normalize the index features
80
+ index_features = index_features.to(device)
81
+ index_features = F.normalize(index_features, dim=-1).float()
82
+
83
+ # Compute the distances and sort the results
84
+ distances = 1 - predicted_features @ index_features.T
85
+ sorted_indices = torch.argsort(distances, dim=-1).cpu()
86
+ sorted_index_names = np.array(index_names)[sorted_indices]
87
+
88
+ # Delete the reference image from the results
89
+ reference_mask = torch.tensor(
90
+ sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(sorted_index_names),
91
+ -1))
92
+ sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
93
+ sorted_index_names.shape[1] - 1)
94
+ # Compute the subset predictions
95
+ group_members = np.array(group_members)
96
+ group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
97
+ sorted_group_names = sorted_index_names[group_mask].reshape(sorted_index_names.shape[0], -1)
98
+
99
+ # Generate prediction dicts
100
+ pairid_to_retrieved_images = {str(int(pair_id)): prediction[:50].tolist() for (pair_id, prediction) in
101
+ zip(pairs_id, sorted_index_names)}
102
+ pairid_to_group_retrieved_images = {str(int(pair_id)): prediction[:3].tolist() for (pair_id, prediction) in
103
+ zip(pairs_id, sorted_group_names)}
104
+
105
+ return pairid_to_retrieved_images, pairid_to_group_retrieved_images
106
+
107
+
108
+ def cirr_generate_test_predictions(clip_model, relative_test_dataset: CIRRDataset, ref_names_list: List[str],
109
+ pseudo_tokens: torch.Tensor) -> \
110
+ Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
111
+ """
112
+ Generate the test prediction features for the CIRR dataset given the pseudo tokens
113
+ """
114
+
115
+ # Create the test dataloader
116
+ relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10,
117
+ pin_memory=False)
118
+
119
+ predicted_features_list = []
120
+ reference_names_list = []
121
+ pair_id_list = []
122
+ group_members_list = []
123
+
124
+ # Compute the predictions
125
+ for batch in tqdm(relative_test_loader):
126
+ reference_names = batch['reference_name']
127
+ pairs_id = batch['pair_id']
128
+ relative_captions = batch['relative_caption']
129
+ group_members = batch['group_members']
130
+
131
+ group_members = np.array(group_members).T.tolist()
132
+
133
+ input_captions = [
134
+ f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
135
+
136
+ batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
137
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
138
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
139
+
140
+ predicted_features = F.normalize(text_features)
141
+
142
+ predicted_features_list.append(predicted_features)
143
+ reference_names_list.extend(reference_names)
144
+ pair_id_list.extend(pairs_id)
145
+ group_members_list.extend(group_members)
146
+
147
+ predicted_features = torch.vstack(predicted_features_list)
148
+
149
+ return predicted_features, reference_names_list, pair_id_list, group_members_list
150
+
151
+
152
+ @torch.no_grad()
153
+ def circo_generate_test_submission_file(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str],
154
+ pseudo_tokens: torch.Tensor, preprocess: callable,
155
+ submission_name: str) -> None:
156
+ """
157
+ Generate the test submission file for the CIRCO dataset given the pseudo tokens
158
+ """
159
+
160
+ # Load the CLIP model
161
+ #clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
162
+ #clip_model = clip_model.float().eval().requires_grad_(False)
163
+
164
+ # Compute the index features
165
+ classic_test_dataset = CIRCODataset(dataset_path, 'test', 'classic', preprocess)
166
+ index_features, index_names = extract_image_features(classic_test_dataset, image_encoder)
167
+
168
+ relative_test_dataset = CIRCODataset(dataset_path, 'test', 'relative', preprocess)
169
+
170
+ # Get the predictions dict
171
+ queryid_to_retrieved_images = circo_generate_test_dict(relative_test_dataset, text_encoder, index_features,
172
+ index_names, ref_names_list, pseudo_tokens)
173
+
174
+ submissions_folder_path = os.path.join('./submission', 'circo')
175
+ os.makedirs(submissions_folder_path, exist_ok=True)
176
+
177
+ with open(os.path.join(submissions_folder_path, f"{submission_name}.json"), 'w+') as file:
178
+ json.dump(queryid_to_retrieved_images, file, sort_keys=True)
179
+
180
+
181
+ def circo_generate_test_predictions(clip_model, relative_test_dataset: CIRCODataset, ref_names_list: List[str],
182
+ pseudo_tokens: torch.Tensor) -> [torch.Tensor, List[List[str]]]:
183
+ """
184
+ Generate the test prediction features for the CIRCO dataset given the pseudo tokens
185
+ """
186
+
187
+ # Create the test dataloader
188
+ relative_test_loader = DataLoader(dataset=relative_test_dataset, batch_size=32, num_workers=10,
189
+ pin_memory=False, collate_fn=collate_fn, shuffle=False)
190
+
191
+ predicted_features_list = []
192
+ query_ids_list = []
193
+
194
+ # Compute the predictions
195
+ for batch in tqdm(relative_test_loader):
196
+ reference_names = batch['reference_name']
197
+ relative_captions = batch['relative_caption']
198
+ query_ids = batch['query_id']
199
+
200
+ input_captions = [f"a photo of $ that {caption}" for caption in relative_captions]
201
+ batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
202
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
203
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
204
+ predicted_features = F.normalize(text_features)
205
+
206
+ predicted_features_list.append(predicted_features)
207
+ query_ids_list.extend(query_ids)
208
+
209
+ predicted_features = torch.vstack(predicted_features_list)
210
+ return predicted_features, query_ids_list
211
+
212
+
213
+ def circo_generate_test_dict(relative_test_dataset: CIRCODataset, clip_model, index_features: torch.Tensor,
214
+ index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
215
+ -> Dict[str, List[str]]:
216
+ """
217
+ Generate the test submission dicts for the CIRCO dataset given the pseudo tokens
218
+ """
219
+
220
+ # Get the predicted features
221
+ predicted_features, query_ids = circo_generate_test_predictions(clip_model, relative_test_dataset,
222
+ ref_names_list, pseudo_tokens)
223
+
224
+ # Normalize the features
225
+ index_features = index_features.float().to(device)
226
+ index_features = F.normalize(index_features, dim=-1)
227
+
228
+ # Compute the similarity
229
+ similarity = predicted_features @ index_features.T
230
+ sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu()
231
+ sorted_index_names = np.array(index_names)[sorted_indices]
232
+
233
+ # Generate prediction dicts
234
+ queryid_to_retrieved_images = {query_id: query_sorted_names[:50].tolist() for
235
+ (query_id, query_sorted_names) in zip(query_ids, sorted_index_names)}
236
+
237
+ return queryid_to_retrieved_images
238
+
239
+
240
+ def main():
241
+ parser = ArgumentParser()
242
+ parser.add_argument("--submission-name", type=str, required=True, help="Filename of the generated submission file")
243
+ parser.add_argument("--exp-name", type=str, help="Experiment to evaluate")
244
+ parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'circo'], help="Dataset to use")
245
+ parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True)
246
+ parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True,
247
+ help="If 'oti' evaluate directly using the inverted oti pseudo tokens, "
248
+ "if 'phi' predicts the pseudo tokens using the phi network, "
249
+ "if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, "
250
+ "if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens")
251
+
252
+ parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'],
253
+ help="Preprocess pipeline to use")
254
+ parser.add_argument("--phi-checkpoint-name", type=str,
255
+ help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'")
256
+
257
+ parser.add_argument("--clip_model_name", default="giga", type=str)
258
+ parser.add_argument("--cache_dir", default="./hf_models", type=str)
259
+
260
+ parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
261
+
262
+ args = parser.parse_args()
263
+
264
+ if args.eval_type == 'oti':
265
+ experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'test' / args.exp_name
266
+
267
+ with open(experiment_path / 'hyperparameters.json') as f:
268
+ hyperparameters = json.load(f)
269
+
270
+ pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device)
271
+ with open(experiment_path / 'image_names.pkl', 'rb') as f:
272
+ ref_names_list = pickle.load(f)
273
+
274
+ clip_model_name = hyperparameters['clip_model_name']
275
+ clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False)
276
+
277
+ if args.preprocess_type == 'targetpad':
278
+ print('Target pad preprocess pipeline is used')
279
+ preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
280
+ elif args.preprocess_type == 'clip':
281
+ print('CLIP preprocess pipeline is used')
282
+ preprocess = clip_preprocess
283
+ else:
284
+ raise ValueError("Preprocess type not supported")
285
+
286
+
287
+ elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']:
288
+ if args.eval_type == 'phi':
289
+ args.mixed_precision = 'fp16'
290
+ image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
291
+
292
+ phi = Phi(input_dim=text_encoder.config.projection_dim,
293
+ hidden_dim=text_encoder.config.projection_dim * 4,
294
+ output_dim=text_encoder.config.hidden_size, dropout=0.5).to(
295
+ device)
296
+
297
+ phi.load_state_dict(
298
+ torch.load(args.phi_checkpoint_name, map_location=device)[
299
+ phi.__class__.__name__])
300
+ phi = phi.eval()
301
+
302
+ elif args.eval_type == 'pic2word':
303
+ args.mixed_precision = 'fp16'
304
+ image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
305
+
306
+ phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim,
307
+ output_dim=text_encoder.config.hidden_size,
308
+ ).to(device)
309
+ sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text']
310
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
311
+ phi.load_state_dict(sd)
312
+ phi = phi.eval()
313
+
314
+ else: # searle or searle-xl
315
+ if args.eval_type == 'searle':
316
+ clip_model_name = 'ViT-B/32'
317
+ else: # args.eval_type == 'searle-xl':
318
+ clip_model_name = 'ViT-L/14'
319
+ phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
320
+ backbone=clip_model_name)
321
+
322
+ phi = phi.to(device).eval()
323
+ clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False)
324
+
325
+ if args.preprocess_type == 'targetpad':
326
+ print('Target pad preprocess pipeline is used')
327
+ preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
328
+ elif args.preprocess_type == 'clip':
329
+ print('CLIP preprocess pipeline is used')
330
+ preprocess = clip_preprocess
331
+ else:
332
+ raise ValueError("Preprocess type not supported")
333
+
334
+ if args.dataset.lower() == 'cirr':
335
+ relative_test_dataset = CIRRDataset(args.dataset_path, 'test', 'relative', preprocess, no_duplicates=True)
336
+ elif args.dataset.lower() == 'circo':
337
+ relative_test_dataset = CIRCODataset(args.dataset_path, 'test', 'relative', preprocess)
338
+ else:
339
+ raise ValueError("Dataset not supported")
340
+
341
+ #clip_model = clip_model.float().to(device)
342
+ image_encoder = image_encoder.float().to(device)
343
+ text_encoder = text_encoder.float().to(device)
344
+ pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_test_dataset, args)
345
+ pseudo_tokens = pseudo_tokens.to(device)
346
+ else:
347
+ raise ValueError("Eval type not supported")
348
+
349
+ print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t")
350
+
351
+ if args.dataset == 'cirr':
352
+ cirr_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
353
+ preprocess, args.submission_name)
354
+ elif args.dataset == 'circo':
355
+ circo_generate_test_submission_file(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
356
+ preprocess, args.submission_name)
357
+
358
+ else:
359
+ raise ValueError("Dataset not supported")
360
+
361
+
362
+ if __name__ == '__main__':
363
+ main()
loader.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ import os
7
+ import functools
8
+ import glob
9
+ import random
10
+ import json
11
+ from pathlib import Path
12
+ from typing import List, Optional, Union, Dict, Literal
13
+ import PIL
14
+ import PIL.Image
15
+ import torch
16
+ from torch.utils.data import Dataset
17
+ import webdataset as wds
18
+ import spacy
19
+ import numpy as np
20
+ import sng_parser
21
+ import datasets
22
+
23
+
24
+ def extract_keywords(spacy_nlp, caption):
25
+ candidates = []
26
+ nlp_caption = caption
27
+
28
+ doc = spacy_nlp(nlp_caption)
29
+
30
+ tmp = ''
31
+ for word in doc:
32
+ if word.pos_ == 'ADJ':
33
+ if tmp == '':
34
+ tmp += word.text
35
+ else:
36
+ tmp += ' ' + word.text
37
+ elif word.pos_ == 'NOUN' or word.pos_ == 'PROPN':
38
+ if tmp == '':
39
+ tmp += word.text
40
+ else:
41
+ tmp += ' ' + word.text
42
+ else:
43
+ if tmp != '':
44
+ candidates.append(tmp)
45
+ tmp = ''
46
+ if tmp != '':
47
+ candidates.append(tmp)
48
+
49
+ candidates = list(set(candidates))
50
+
51
+ return candidates
52
+
53
+
54
+ def extract_keywords_spacy(spacy_nlp, caption):
55
+ sequences = []
56
+ current_sequence = []
57
+ doc = spacy_nlp(caption)
58
+ for token in doc:
59
+ # Check if the token is a noun, proper noun, or adjective
60
+ if token.pos_ in ['NOUN', 'PROPN', 'ADJ', 'DET']:
61
+ current_sequence.append(token.text)
62
+ else:
63
+ # If we encounter a token that's not one of the desired POS and current_sequence is not empty
64
+ if current_sequence:
65
+ sequences.append(" ".join(current_sequence))
66
+ current_sequence = []
67
+
68
+ # Adding any remaining sequence after the loop
69
+ if current_sequence:
70
+ sequences.append(" ".join(current_sequence))
71
+
72
+ return sequences
73
+
74
+
75
+ def extract_sng(caption):
76
+ graph = sng_parser.parse(caption)
77
+ entities = [x['head'] for i, x in enumerate(graph['entities'])]
78
+ relations = [{'subject': entities[x['subject']], 'object': entities[x['object']], 'relation': x['relation']} for x in graph['relations']]
79
+ return entities, relations
80
+
81
+
82
+ def clean_caption(caption, tokenizer):
83
+ if caption is None:
84
+ caption = ''
85
+ if '<PERSON>' in caption: # to handle with GCC12M
86
+ caption = caption.replace('<PERSON>', 'person')
87
+ caption = caption.lower().replace('$', '').strip()
88
+ tokens = tokenizer.encode(caption, padding='longest', return_tensors='pt')
89
+ if tokens.shape[1] > 77:
90
+ caption = tokenizer.batch_decode(tokens[:,1:76])[0]
91
+ return caption
92
+
93
+
94
+ def preprocess_precomputed_base(sample, spacy_nlp, keywords_list, tokenizer):
95
+ '''
96
+ 'image_feature.npy','json'
97
+ '''
98
+ image_feature, image_feature_giga, meta = sample
99
+
100
+ caption = clean_caption(meta['source_caption'], tokenizer)
101
+
102
+ keywords = ['']
103
+ try:
104
+ keywords = extract_keywords_spacy(spacy_nlp, caption)
105
+ except Exception as e:
106
+ #print(e)
107
+ pass
108
+
109
+ # for keywords
110
+ indicator = 1
111
+ replaced_caption = caption
112
+ for keyword in keywords:
113
+ if keyword != '' and keyword in caption:
114
+ replaced_caption = replaced_caption.replace(keyword, '[$]')
115
+ else:
116
+ tmp_keywords = caption.split(' ')
117
+ if len(tmp_keywords) > 0:
118
+ selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1))
119
+ for selected_keyword in selected_keywords:
120
+ replaced_caption = replaced_caption.replace(selected_keyword, '[$]')
121
+ else:
122
+ replaced_caption = f'a photo of [$] that {caption}'
123
+ indicator = 0
124
+ break
125
+
126
+ token_dict = tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True)
127
+ tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0]
128
+
129
+ replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
130
+ replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
131
+
132
+ replaced_tokens = torch.where(replaced_tokens == 49408,
133
+ torch.ones_like(replaced_tokens) * 259,
134
+ replaced_tokens)
135
+
136
+ if 259 not in replaced_tokens:
137
+ replaced_caption = 'a photo of [$]'
138
+ replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
139
+ replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
140
+
141
+ replaced_tokens = torch.where(replaced_tokens == 49408,
142
+ torch.ones_like(replaced_tokens) * 259,
143
+ replaced_tokens)
144
+ indicator = 0
145
+
146
+ new_sample = [tokens, replaced_tokens, indicator]
147
+
148
+ return tuple(new_sample)
149
+
150
+
151
+ class CaptionDataset(Dataset):
152
+ def __init__(self, captions, tokenizer, spacy_nlp):
153
+ self.captions = captions
154
+ self.tokenizer = tokenizer
155
+ self.spacy_nlp = spacy_nlp
156
+
157
+ def __len__(self):
158
+ return len(self.captions)
159
+
160
+ def __getitem__(self, idx):
161
+ caption = self.captions[idx]
162
+
163
+ caption = clean_caption(caption, self.tokenizer)
164
+
165
+ keywords = [""]
166
+ try:
167
+ keywords = extract_keywords_spacy(self.spacy_nlp, caption)
168
+ except Exception as e:
169
+ #print(e)
170
+ pass
171
+
172
+ # for keywords
173
+ indicator = 1
174
+ replaced_caption = caption
175
+
176
+ if len(keywords) == 0:
177
+ keywords = [""]
178
+
179
+ for keyword in keywords:
180
+ if keyword != '' and keyword in caption:
181
+ replaced_caption = replaced_caption.replace(keyword, '[$]')
182
+ else:
183
+ tmp_keywords = caption.split(' ')
184
+ if len(tmp_keywords) > 0:
185
+ selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1))
186
+ for selected_keyword in selected_keywords:
187
+ replaced_caption = replaced_caption.replace(selected_keyword, '[$]')
188
+ else:
189
+ replaced_caption = f'a photo of [$] that {caption}'
190
+ indicator = 0
191
+ break
192
+
193
+ token_dict = self.tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True)
194
+ tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0]
195
+
196
+ replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
197
+ replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
198
+
199
+ replaced_tokens = torch.where(replaced_tokens == 49408,
200
+ torch.ones_like(replaced_tokens) * 259,
201
+ replaced_tokens)
202
+
203
+ if 259 not in replaced_tokens:
204
+ replaced_caption = 'a photo of [$]'
205
+ replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True)
206
+ replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0]
207
+
208
+ replaced_tokens = torch.where(replaced_tokens == 49408,
209
+ torch.ones_like(replaced_tokens) * 259,
210
+ replaced_tokens)
211
+ indicator = 0
212
+
213
+ return tokens, replaced_tokens, indicator
214
+
215
+
216
+ def build_loader(args, tokenizer, accelerator):
217
+ data_names = {'dataset1': 'dangne/gcc_caption_only',
218
+ 'dataset2': 'FredZhang7/stable-diffusion-prompts-2.47M',
219
+ 'dataset3': 'Geonmo/midjourney-prompts-only',
220
+ }
221
+
222
+ for k, v in data_names.items():
223
+ if not os.path.exists(os.path.join('./datasets', k)):
224
+ if accelerator.is_main_process:
225
+ print('Downloading captions is required')
226
+ db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k))
227
+
228
+ captions = []
229
+ for k, v in data_names.items():
230
+ db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k))
231
+ captions += db['train']['text']
232
+
233
+ dataset = CaptionDataset(captions, tokenizer, spacy.load('en_core_web_sm'))
234
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True)
235
+
236
+ return data_loader
237
+
238
+
239
+ class FashionIQDataset(Dataset):
240
+ """
241
+ Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
242
+ FashionIQ dataset class for PyTorch.
243
+ The dataset can be used in 'relative' or 'classic' mode:
244
+ - In 'classic' mode the dataset yield :a dict with keys ['image', 'image_name']
245
+ - In 'relative' mode the dataset yield dict with keys:
246
+ - ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions'] when
247
+ split in ['train', 'val']
248
+ - ['reference_image', 'reference_name', 'relative_captions'] when split == test
249
+ """
250
+
251
+ def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'], dress_types: List[str],
252
+ mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False):
253
+ """
254
+ :param dataset_path: path to the FashionIQ dataset
255
+ :param split: dataset split, should be in ['train, 'val', 'test']
256
+ :param dress_types: list of fashionIQ categories, each category should be in ['dress', 'shirt', 'toptee']
257
+ :param mode: dataset mode, should be in ['relative', 'classic']:
258
+ - In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
259
+ - In 'relative' mode the dataset yield dict with keys:
260
+ - ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions']
261
+ when split in ['train', 'val']
262
+ - ['reference_image', 'reference_name', 'relative_captions'] when split == test
263
+ :param preprocess: function which preprocesses the image
264
+ :param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode
265
+ """
266
+ dataset_path = Path(dataset_path)
267
+ self.dataset_path = dataset_path
268
+ self.mode = mode
269
+ self.dress_types = dress_types
270
+ self.split = split
271
+ self.no_duplicates = no_duplicates
272
+
273
+ # Validate the inputs
274
+ if mode not in ['relative', 'classic']:
275
+ raise ValueError("mode should be in ['relative', 'classic']")
276
+ if split not in ['test', 'train', 'val']:
277
+ raise ValueError("split should be in ['test', 'train', 'val']")
278
+ for dress_type in dress_types:
279
+ if dress_type not in ['dress', 'shirt', 'toptee']:
280
+ raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']")
281
+
282
+ self.preprocess = preprocess
283
+
284
+ # get triplets made by (reference_image, target_image, a pair of relative captions)
285
+ self.triplets: List[dict] = []
286
+ for dress_type in dress_types:
287
+ with open(dataset_path / 'captions' / f'cap.{dress_type}.{split}.json') as f:
288
+ self.triplets.extend(json.load(f))
289
+
290
+ # Remove duplicats from
291
+ if self.no_duplicates:
292
+ seen = set()
293
+ new_triplets = []
294
+ for triplet in self.triplets:
295
+ if triplet['candidate'] not in seen:
296
+ seen.add(triplet['candidate'])
297
+ new_triplets.append(triplet)
298
+ self.triplets = new_triplets
299
+
300
+ # get the image names
301
+ self.image_names: list = []
302
+ for dress_type in dress_types:
303
+ with open(dataset_path / 'image_splits' / f'split.{dress_type}.{split}.json') as f:
304
+ self.image_names.extend(json.load(f))
305
+
306
+ print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized")
307
+
308
+ def __getitem__(self, index) -> dict:
309
+ try:
310
+ if self.mode == 'relative':
311
+ relative_captions = self.triplets[index]['captions']
312
+ reference_name = self.triplets[index]['candidate']
313
+
314
+ if self.split in ['train', 'val']:
315
+ reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg"
316
+ reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
317
+ target_name = self.triplets[index]['target']
318
+ target_image_path = self.dataset_path / 'images' / f"{target_name}.jpg"
319
+ target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0]
320
+
321
+ return {
322
+ 'reference_image': reference_image,
323
+ 'reference_name': reference_name,
324
+ 'target_image': target_image,
325
+ 'target_name': target_name,
326
+ 'relative_captions': relative_captions
327
+ }
328
+
329
+ elif self.split == 'test':
330
+ reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg"
331
+ reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
332
+
333
+ return {
334
+ 'reference_image': reference_image,
335
+ 'reference_name': reference_name,
336
+ 'relative_captions': relative_captions
337
+ }
338
+
339
+ elif self.mode == 'classic':
340
+ image_name = self.image_names[index]
341
+ image_path = self.dataset_path / 'images' / f"{image_name}.jpg"
342
+ image = self.preprocess(PIL.Image.open(image_path), return_tensors='pt')['pixel_values'][0]
343
+
344
+ return {
345
+ 'image': image,
346
+ 'image_name': image_name
347
+ }
348
+
349
+ else:
350
+ raise ValueError("mode should be in ['relative', 'classic']")
351
+ except Exception as e:
352
+ print(f"Exception: {e}")
353
+
354
+ def __len__(self):
355
+ if self.mode == 'relative':
356
+ return len(self.triplets)
357
+ elif self.mode == 'classic':
358
+ return len(self.image_names)
359
+ else:
360
+ raise ValueError("mode should be in ['relative', 'classic']")
361
+
362
+
363
+ class CIRRDataset(Dataset):
364
+ """
365
+ Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
366
+ CIRR dataset class for PyTorch dataloader.
367
+ The dataset can be used in 'relative' or 'classic' mode:
368
+ - In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
369
+ - In 'relative' mode the dataset yield dict with keys:
370
+ - ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption', 'group_members']
371
+ when split in ['train', 'val']
372
+ - ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test
373
+ """
374
+
375
+ def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'],
376
+ mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False):
377
+ """
378
+ :param dataset_path: path to the CIRR dataset
379
+ :param split: dataset split, should be in ['train', 'val', 'test']
380
+ :param mode: dataset mode, should be in ['relative', 'classic']:
381
+ - In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
382
+ - In 'relative' mode the dataset yield dict with keys:
383
+ - ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption',
384
+ 'group_members'] when split in ['train', 'val']
385
+ - ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test
386
+ :param preprocess: function which preprocesses the image
387
+ :param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode
388
+ """
389
+ dataset_path = Path(dataset_path)
390
+ self.dataset_path = dataset_path
391
+ self.preprocess = preprocess
392
+ self.mode = mode
393
+ self.split = split
394
+ self.no_duplicates = no_duplicates
395
+
396
+ if split == "test":
397
+ split = "test1"
398
+ self.split = "test1"
399
+
400
+ # Validate inputs
401
+ if split not in ['test1', 'train', 'val']:
402
+ raise ValueError("split should be in ['test1', 'train', 'val']")
403
+ if mode not in ['relative', 'classic']:
404
+ raise ValueError("mode should be in ['relative', 'classic']")
405
+
406
+ # get triplets made by (reference_image, target_image, relative caption)
407
+ with open(dataset_path / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f:
408
+ self.triplets = json.load(f)
409
+
410
+ # Remove duplicates from triplets
411
+ if self.no_duplicates:
412
+ seen = set()
413
+ new_triplets = []
414
+ for triplet in self.triplets:
415
+ if triplet['reference'] not in seen:
416
+ seen.add(triplet['reference'])
417
+ new_triplets.append(triplet)
418
+ self.triplets = new_triplets
419
+
420
+ # get a mapping from image name to relative path
421
+ with open(dataset_path / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f:
422
+ self.name_to_relpath = json.load(f)
423
+
424
+ print(f"CIRR {split} dataset in {mode} mode initialized")
425
+
426
+ def __getitem__(self, index) -> dict:
427
+ try:
428
+ if self.mode == 'relative':
429
+ group_members = self.triplets[index]['img_set']['members']
430
+ reference_name = self.triplets[index]['reference']
431
+ relative_caption = self.triplets[index]['caption']
432
+
433
+ if self.split in ['train', 'val']:
434
+ reference_image_path = self.dataset_path / self.name_to_relpath[reference_name]
435
+ reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
436
+ target_hard_name = self.triplets[index]['target_hard']
437
+ target_image_path = self.dataset_path / self.name_to_relpath[target_hard_name]
438
+ target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0]
439
+
440
+ return {
441
+ 'reference_image': reference_image,
442
+ 'reference_name': reference_name,
443
+ 'target_image': target_image,
444
+ 'target_name': target_hard_name,
445
+ 'relative_caption': relative_caption,
446
+ 'group_members': group_members
447
+ }
448
+
449
+ elif self.split == 'test1':
450
+ pair_id = self.triplets[index]['pairid']
451
+ reference_image_path = self.dataset_path / self.name_to_relpath[reference_name]
452
+ reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0]
453
+ return {
454
+ 'reference_image': reference_image,
455
+ 'reference_name': reference_name,
456
+ 'relative_caption': relative_caption,
457
+ 'group_members': group_members,
458
+ 'pair_id': pair_id
459
+ }
460
+
461
+ elif self.mode == 'classic':
462
+ image_name = list(self.name_to_relpath.keys())[index]
463
+ image_path = self.dataset_path / self.name_to_relpath[image_name]
464
+ im = PIL.Image.open(image_path)
465
+ image = self.preprocess(im, return_tensors='pt')['pixel_values'][0]
466
+
467
+ return {
468
+ 'image': image,
469
+ 'image_name': image_name
470
+ }
471
+
472
+ else:
473
+ raise ValueError("mode should be in ['relative', 'classic']")
474
+
475
+ except Exception as e:
476
+ print(f"Exception: {e}")
477
+
478
+ def __len__(self):
479
+ if self.mode == 'relative':
480
+ return len(self.triplets)
481
+ elif self.mode == 'classic':
482
+ return len(self.name_to_relpath)
483
+ else:
484
+ raise ValueError("mode should be in ['relative', 'classic']")
485
+
486
+
487
+ class CIRCODataset(Dataset):
488
+ """
489
+ Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py
490
+ CIRCO dataset class for PyTorch.
491
+ The dataset can be used in 'relative' or 'classic' mode:
492
+ - In 'classic' mode the dataset yield a dict with keys ['image', 'image_name']
493
+ - In 'relative' mode the dataset yield dict with keys:
494
+ - ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions', 'shared_concept',
495
+ 'gt_img_ids', 'query_id'] when split == 'val'
496
+ - ['reference_image', 'reference_name', 'relative_captions', 'shared_concept', 'query_id'] when split == test
497
+ """
498
+
499
+ def __init__(self, dataset_path: Union[str, Path], split: Literal['val', 'test'],
500
+ mode: Literal['relative', 'classic'], preprocess: callable):
501
+ """
502
+ Args:
503
+ dataset_path (Union[str, Path]): path to CIRCO dataset
504
+ split (str): dataset split, should be in ['test', 'val']
505
+ mode (str): dataset mode, should be in ['relative', 'classic']
506
+ preprocess (callable): function which preprocesses the image
507
+ """
508
+
509
+ # Set dataset paths and configurations
510
+ dataset_path = Path(dataset_path)
511
+ self.mode = mode
512
+ self.split = split
513
+ self.preprocess = preprocess
514
+ self.data_path = dataset_path
515
+
516
+ # Ensure input arguments are valid
517
+ if mode not in ['relative', 'classic']:
518
+ raise ValueError("mode should be in ['relative', 'classic']")
519
+ if split not in ['test', 'val']:
520
+ raise ValueError("split should be in ['test', 'val']")
521
+
522
+ # Load COCO images information
523
+ with open(dataset_path / 'COCO2017_unlabeled' / "annotations" / "image_info_unlabeled2017.json", "r") as f:
524
+ imgs_info = json.load(f)
525
+
526
+ self.img_paths = [dataset_path / 'COCO2017_unlabeled' / "unlabeled2017" / img_info["file_name"] for img_info in
527
+ imgs_info["images"]]
528
+ self.img_ids = [img_info["id"] for img_info in imgs_info["images"]]
529
+ self.img_ids_indexes_map = {str(img_id): i for i, img_id in enumerate(self.img_ids)}
530
+
531
+ # get CIRCO annotations
532
+ with open(dataset_path / 'annotations' / f'{split}.json', "r") as f:
533
+ self.annotations: List[dict] = json.load(f)
534
+
535
+ # Get maximum number of ground truth images (for padding when loading the images)
536
+ self.max_num_gts = 23 # Maximum number of ground truth images
537
+
538
+ print(f"CIRCODataset {split} dataset in {mode} mode initialized")
539
+
540
+ def get_target_img_ids(self, index) -> Dict[str, int]:
541
+ """
542
+ Returns the id of the target image and ground truth images for a given query
543
+
544
+ Args:
545
+ index (int): id of the query
546
+
547
+ Returns:
548
+ Dict[str, int]: dictionary containing target image id and a list of ground truth image ids
549
+ """
550
+
551
+ return {
552
+ 'target_img_id': self.annotations[index]['target_img_id'],
553
+ 'gt_img_ids': self.annotations[index]['gt_img_ids']
554
+ }
555
+
556
+ def __getitem__(self, index) -> dict:
557
+ """
558
+ Returns a specific item from the dataset based on the index.
559
+
560
+ In 'classic' mode, the dataset yields a dictionary with the following keys: [img, img_id]
561
+ In 'relative' mode, the dataset yields dictionaries with the following keys:
562
+ - [reference_img, reference_img_id, target_img, target_img_id, relative_caption, shared_concept, gt_img_ids,
563
+ query_id]
564
+ if split == val
565
+ - [reference_img, reference_img_id, relative_caption, shared_concept, query_id] if split == test
566
+ """
567
+
568
+ if self.mode == 'relative':
569
+ # Get the query id
570
+ query_id = str(self.annotations[index]['id'])
571
+
572
+ # Get relative caption and shared concept
573
+ relative_caption = self.annotations[index]['relative_caption']
574
+ shared_concept = self.annotations[index]['shared_concept']
575
+
576
+ # Get the reference image
577
+ reference_img_id = str(self.annotations[index]['reference_img_id'])
578
+ reference_img_path = self.img_paths[self.img_ids_indexes_map[reference_img_id]]
579
+ reference_img = self.preprocess(PIL.Image.open(reference_img_path), return_tensors='pt')['pixel_values'][0]
580
+
581
+ if self.split == 'val':
582
+ # Get the target image and ground truth images
583
+ target_img_id = str(self.annotations[index]['target_img_id'])
584
+ gt_img_ids = [str(x) for x in self.annotations[index]['gt_img_ids']]
585
+ target_img_path = self.img_paths[self.img_ids_indexes_map[target_img_id]]
586
+ target_img = self.preprocess(PIL.Image.open(target_img_path), return_tensors='pt')['pixel_values'][0]
587
+
588
+ # Pad ground truth image IDs with zeros for collate_fn
589
+ gt_img_ids += [''] * (self.max_num_gts - len(gt_img_ids))
590
+
591
+ return {
592
+ 'reference_image': reference_img,
593
+ 'reference_name': reference_img_id,
594
+ 'target_image': target_img,
595
+ 'target_name': target_img_id,
596
+ 'relative_caption': relative_caption,
597
+ 'shared_concept': shared_concept,
598
+ 'gt_img_ids': gt_img_ids,
599
+ 'query_id': query_id,
600
+ }
601
+
602
+ elif self.split == 'test':
603
+ return {
604
+ 'reference_image': reference_img,
605
+ 'reference_name': reference_img_id,
606
+ 'relative_caption': relative_caption,
607
+ 'shared_concept': shared_concept,
608
+ 'query_id': query_id,
609
+ }
610
+
611
+ elif self.mode == 'classic':
612
+ # Get image ID and image path
613
+ img_id = str(self.img_ids[index])
614
+ img_path = self.img_paths[index]
615
+
616
+ # Preprocess image and return
617
+ img = self.preprocess(PIL.Image.open(img_path), return_tensors='pt')['pixel_values'][0]
618
+ return {
619
+ 'image': img,
620
+ 'image_name': img_id
621
+ }
622
+
623
+ def __len__(self):
624
+ """
625
+ Returns the length of the dataset.
626
+ """
627
+ if self.mode == 'relative':
628
+ return len(self.annotations)
629
+ elif self.mode == 'classic':
630
+ return len(self.img_ids)
631
+ else:
632
+ raise ValueError("mode should be in ['relative', 'classic']")
models.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPTokenizer
10
+
11
+
12
+ def build_text_encoder(args):
13
+ clip_model_dict = {'base32': 'openai/clip-vit-base-patch32',
14
+ 'base': 'openai/clip-vit-base-patch16',
15
+ 'large': 'openai/clip-vit-large-patch14',
16
+ 'huge': 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
17
+ 'giga': 'Geonmo/CLIP-Giga-config-fixed',
18
+ 'meta-large': 'facebook/metaclip-l14-fullcc2.5b',
19
+ 'meta-huge': 'facebook/metaclip-h14-fullcc2.5b',
20
+ }
21
+
22
+ clip_preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224},
23
+ do_center_crop=True,
24
+ do_convert_rgb=True,
25
+ do_normalize=True,
26
+ do_rescale=True,
27
+ do_resize=True,
28
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
29
+ image_std=[0.26862954, 0.26130258, 0.27577711],
30
+ resample=3,
31
+ size={'shortest_edge': 224},
32
+ )
33
+
34
+ clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir)
35
+
36
+ clip_text_model = CLIPTextModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir)
37
+
38
+ tokenizer = CLIPTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer_2', cache_dir=args.cache_dir)
39
+ tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # NOTE: 49408
40
+
41
+ return clip_vision_model, clip_preprocess, clip_text_model, tokenizer
42
+
43
+
44
+ class Phi(nn.Module):
45
+ """
46
+ Textual Inversion Phi network.
47
+ Takes as input the visual features of an image and outputs the pseudo-work embedding.
48
+ Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/phi.py
49
+ """
50
+
51
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: int):
52
+ super().__init__()
53
+ self.layers = nn.Sequential(
54
+ nn.Linear(input_dim, hidden_dim),
55
+ nn.GELU(),
56
+ nn.Dropout(p=dropout),
57
+ nn.Linear(hidden_dim, hidden_dim),
58
+ nn.GELU(),
59
+ nn.Dropout(p=dropout),
60
+ nn.Linear(hidden_dim, output_dim),
61
+ )
62
+
63
+ def forward(self, x):
64
+ #x = F.normalize(x, dim=-1)
65
+ return self.layers(x)
66
+
67
+
68
+ class EMAModel:
69
+ """
70
+ Exponential Moving Average of models weights
71
+ """
72
+
73
+ def __init__(self, parameters, decay=0.9999):
74
+ parameters = list(parameters)
75
+ self.shadow_params = [p.clone().detach() for p in parameters]
76
+
77
+ self.collected_params = None
78
+
79
+ self.decay = decay
80
+ self.optimization_step = 0
81
+
82
+ @torch.no_grad()
83
+ def step(self, parameters):
84
+ parameters = list(parameters)
85
+
86
+ self.optimization_step += 1
87
+
88
+ # Compute the decay factor for the exponential moving average.
89
+ value = (1 + self.optimization_step) / (10 + self.optimization_step)
90
+ one_minus_decay = 1 - min(self.decay, value)
91
+
92
+ for s_param, param in zip(self.shadow_params, parameters):
93
+ if param.requires_grad:
94
+ s_param.sub_(one_minus_decay * (s_param - param))
95
+ else:
96
+ s_param.copy_(param)
97
+
98
+ torch.cuda.empty_cache()
99
+
100
+ def copy_to(self, parameters) -> None:
101
+ """
102
+ Copy current averaged parameters into given collection of parameters.
103
+ Args:
104
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
105
+ updated with the stored moving averages. If `None`, the
106
+ parameters with which this `ExponentialMovingAverage` was
107
+ initialized will be used.
108
+ """
109
+ parameters = list(parameters)
110
+ for s_param, param in zip(self.shadow_params, parameters):
111
+ param.data.copy_(s_param.data)
112
+
113
+ def to(self, device=None, dtype=None) -> None:
114
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
115
+ Args:
116
+ device: like `device` argument to `torch.Tensor.to`
117
+ """
118
+ # .to() on the tensors handles None correctly
119
+ self.shadow_params = [
120
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
121
+ for p in self.shadow_params
122
+ ]
123
+
124
+ def state_dict(self) -> dict:
125
+ r"""
126
+ Returns the state of the ExponentialMovingAverage as a dict.
127
+ This method is used by accelerate during checkpointing to save the ema state dict.
128
+ """
129
+ # Following PyTorch conventions, references to tensors are returned:
130
+ # "returns a reference to the state and not its copy!" -
131
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
132
+ return {
133
+ "decay": self.decay,
134
+ "optimization_step": self.optimization_step,
135
+ "shadow_params": self.shadow_params,
136
+ "collected_params": self.collected_params,
137
+ }
138
+
139
+ def load_state_dict(self, state_dict: dict) -> None:
140
+ r"""
141
+ Loads the ExponentialMovingAverage state.
142
+ This method is used by accelerate during checkpointing to save the ema state dict.
143
+ Args:
144
+ state_dict (dict): EMA state. Should be an object returned
145
+ from a call to :meth:`state_dict`.
146
+ """
147
+ # deepcopy, to be consistent with module API
148
+ state_dict = copy.deepcopy(state_dict)
149
+
150
+ self.decay = state_dict["decay"]
151
+ if self.decay < 0.0 or self.decay > 1.0:
152
+ raise ValueError("Decay must be between 0 and 1")
153
+
154
+ self.optimization_step = state_dict["optimization_step"]
155
+ if not isinstance(self.optimization_step, int):
156
+ raise ValueError("Invalid optimization_step")
157
+
158
+ self.shadow_params = state_dict["shadow_params"]
159
+ if not isinstance(self.shadow_params, list):
160
+ raise ValueError("shadow_params must be a list")
161
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
162
+ raise ValueError("shadow_params must all be Tensors")
163
+
164
+ self.collected_params = state_dict["collected_params"]
165
+ if self.collected_params is not None:
166
+ if not isinstance(self.collected_params, list):
167
+ raise ValueError("collected_params must be a list")
168
+ if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
169
+ raise ValueError("collected_params must all be Tensors")
170
+ if len(self.collected_params) != len(self.shadow_params):
171
+ raise ValueError("collected_params and shadow_params must have the same length")
172
+
173
+
174
+ class PIC2WORD(nn.Module):
175
+ def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1):
176
+ super().__init__()
177
+ self.fc_out = nn.Linear(middle_dim, output_dim)
178
+ layers = []
179
+ dim = embed_dim
180
+ for _ in range(n_layer):
181
+ block = []
182
+ block.append(nn.Linear(dim, middle_dim))
183
+ block.append(nn.Dropout(dropout))
184
+ block.append(nn.ReLU())
185
+ dim = middle_dim
186
+ layers.append(nn.Sequential(*block))
187
+ self.layers = nn.Sequential(*layers)
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ for layer in self.layers:
191
+ x = layer(x)
192
+ return self.fc_out(x)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ transformers
4
+ diffusers
5
+ accelerate
6
+ datasets
7
+ spacy
8
+ clip-retrieval
train_phi.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LinCIR
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
5
+ '''
6
+ import json
7
+ import os
8
+ import pickle
9
+ import random
10
+ import math
11
+ from argparse import ArgumentParser
12
+ from pathlib import Path
13
+ from typing import Literal, Tuple, Dict, List, Set
14
+ import logging
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from tqdm import tqdm
20
+
21
+ from loader import build_loader, CIRRDataset
22
+ from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
23
+ from models import build_text_encoder, Phi, EMAModel
24
+ from utils import extract_image_features, extract_pseudo_tokens_with_phi
25
+ from validate import cirr_compute_val_metrics
26
+
27
+ import transformers
28
+ from transformers import get_scheduler
29
+ from accelerate import Accelerator, DeepSpeedPlugin
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import set_seed
32
+ from accelerate.state import AcceleratorState
33
+ from accelerate.logging import get_logger
34
+
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ def parse_args():
40
+ parser = ArgumentParser()
41
+
42
+ parser.add_argument("--output_dir", default="trained_models", type=str,
43
+ help="The output directory where the model predictions and checkpoints will be written")
44
+ parser.add_argument("--logging_dir", default="logs", type=str, help="tensorboard logs will saved here")
45
+ parser.add_argument("--cache_dir", default="./hf_models", type=str,
46
+ help="Path to model cache folder")
47
+ parser.add_argument("--report_to", default="tensorboard", type=str, help="")
48
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
49
+
50
+ parser.add_argument("--clip_model_name", default="giga", type=str,
51
+ help="CLIP model to use, e.g 'large', 'giga'")
52
+ parser.add_argument("--cirr_dataset_path", type=str, help="Path to CIRR dataset", required=True)
53
+ parser.add_argument("--keywords_path", type=str, help="Path to keywords json file")
54
+ parser.add_argument("--resume", default=None, type=str, help="Path to pretrained ckpt")
55
+
56
+ parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.")
57
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
58
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
59
+ help="")
60
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
61
+ parser.add_argument("--max_train_steps", type=int, default=50000, help="Total number of training steps to perform")
62
+ parser.add_argument("--phi_dropout", default=0.5, type=float, help="Dropout probability for the phi network")
63
+ parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
64
+ parser.add_argument("--batch_size", default=256, type=int, help="Phi training batch size")
65
+ parser.add_argument("--num_workers", default=10, type=int, help="Number of workers")
66
+ parser.add_argument("--learning_rate", default=1e-4, type=float, help="Learning rate")
67
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
68
+ parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Number of updates steps to accumulate before performing a backward/update pass")
69
+ parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.")
70
+ parser.add_argument("--mixed_precision", default=None, type=str, choices=["no", "fp16", "bf16"], help="mixed precision")
71
+ parser.add_argument("--validation_steps", default=1, type=int, help="Validation frequency expressed in epochs")
72
+ parser.add_argument("--checkpointing_steps", default=None, type=int, help="Save a checkpoint of the training state every X updates")
73
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
74
+
75
+ parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility")
76
+
77
+ args = parser.parse_args()
78
+
79
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
80
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
81
+ args.local_rank = env_local_rank
82
+
83
+ return args
84
+
85
+
86
+ def save_phi(name: str, cur_epoch: int, model_to_save: Phi, training_path: Path) -> None:
87
+ """
88
+ Save the weights of Phi during training
89
+ """
90
+ models_path = os.path.join(training_path, "checkpoints")
91
+ os.makedirs(models_path, exist_ok=True)
92
+ model_name = model_to_save.__class__.__name__
93
+ torch.save({
94
+ 'epoch': cur_epoch,
95
+ model_name: model_to_save.state_dict(),
96
+ }, os.path.join(models_path, f'{name}.pt'))
97
+
98
+
99
+ def train_phi(args):
100
+ # We are going to use the pre-extracted clip image features. so we do not need image_encoder anymore.
101
+
102
+ ### init accelerator here
103
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
104
+ accelerator = Accelerator(
105
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
106
+ mixed_precision=args.mixed_precision,
107
+ log_with=args.report_to,
108
+ project_dir=logging_dir,
109
+ )
110
+
111
+ os.makedirs(args.output_dir, exist_ok=True)
112
+
113
+ logging.basicConfig(
114
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
115
+ datefmt="%m/%d/%Y %H:%M:%S",
116
+ level=logging.INFO,
117
+ )
118
+ logger.info(accelerator.state, main_process_only=False)
119
+
120
+ if accelerator.is_local_main_process:
121
+ transformers.utils.logging.set_verbosity_info()
122
+ else:
123
+ transformers.utils.logging.set_verbosity_error()
124
+
125
+ if args.seed is not None:
126
+ set_seed(args.seed)
127
+
128
+ ### Define the text encoder from clip
129
+ image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
130
+
131
+ ### Define the phi model
132
+ phi = Phi(input_dim=text_encoder.config.projection_dim,
133
+ hidden_dim=text_encoder.config.projection_dim * 4,
134
+ output_dim=text_encoder.config.hidden_size, dropout=args.phi_dropout)
135
+
136
+ if args.resume:
137
+ phi.load_state_dict(
138
+ torch.load(args.resume, map_location=accelerator.device)[
139
+ phi.__class__.__name__])
140
+
141
+
142
+ ### GPU handling
143
+ weight_dtype = torch.float32
144
+ if accelerator.mixed_precision == "fp16":
145
+ weight_dtype = torch.float16
146
+ elif accelerator.mixed_precision == "bf16":
147
+ weight_dtype = torch.bfloat16
148
+
149
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
150
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
151
+
152
+ image_encoder.requires_grad_(False)
153
+ text_encoder.requires_grad_(False)
154
+
155
+ if args.use_ema:
156
+ import copy
157
+ ema_phi = copy.deepcopy(phi)
158
+ ema_phi = EMAModel(ema_phi.parameters())
159
+ ema_phi.to(accelerator.device, dtype=weight_dtype)
160
+
161
+ ### Define the train datasets
162
+ print('pytorch loader')
163
+ train_dataset = build_loader(args, tokenizer, accelerator)
164
+
165
+ ## evaluator
166
+ if accelerator.is_main_process:
167
+ ## Define CIRR validation set
168
+ cirr_relative_val_dataset = CIRRDataset(args.cirr_dataset_path, 'val', 'relative', clip_preprocess)
169
+ cirr_classic_val_dataset = CIRRDataset(args.cirr_dataset_path, 'val', 'classic', clip_preprocess)
170
+
171
+ # Extract the features for the CIRR validation set
172
+ cirr_val_index_features, cirr_val_index_names = extract_image_features(cirr_classic_val_dataset, image_encoder)
173
+
174
+ # Define the optimizer, the loss and the grad scaler
175
+ if args.use_8bit_adam:
176
+ try:
177
+ import bitsandbytes as bnb
178
+ except ImportError:
179
+ raise ImportError(
180
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
181
+ )
182
+ optimizer_cls = bnb.optim.AdamW8bit
183
+ else:
184
+ optimizer_cls = torch.optim.AdamW
185
+
186
+ optimizer = optimizer_cls(phi.parameters(),
187
+ lr=args.learning_rate,
188
+ weight_decay=args.weight_decay)
189
+
190
+ lr_scheduler = get_scheduler(
191
+ args.lr_scheduler,
192
+ optimizer=optimizer,
193
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
194
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps * accelerator.num_processes,
195
+ )
196
+
197
+ phi, optimizer, lr_scheduler, train_dataset = accelerator.prepare(
198
+ phi, optimizer, lr_scheduler, train_dataset
199
+ )
200
+
201
+ if accelerator.is_main_process:
202
+ accelerator.init_trackers("zeroshot-cir", config=vars(args))
203
+
204
+ # Start with the training loop
205
+ total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
206
+
207
+ logger.info("***** Running training *****")
208
+ logger.info(f" Instantaneous batch size per device = {args.batch_size}")
209
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
210
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
211
+ logger.info(f" Total steps = {args.max_train_steps}")
212
+
213
+ phi.train()
214
+
215
+ train_loss = 0.0
216
+ global_step = 0
217
+ best_recall = -1
218
+
219
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
220
+ progress_bar.set_description("Steps")
221
+
222
+ while True:
223
+ for idx, (original_tokens, replaced_tokens, indicators) in enumerate(train_dataset):
224
+ original_tokens = original_tokens.to(accelerator.device)
225
+ replaced_tokens = replaced_tokens.to(accelerator.device)
226
+
227
+ org = text_encoder(input_ids=original_tokens)
228
+ original_text_embeddings, original_last_hidden_states = org.text_embeds, org.last_hidden_state
229
+ input_features = original_text_embeddings.clone()
230
+ input_features += 1.0 * torch.rand(input_features.shape[0], device=input_features.device).unsqueeze(-1) * torch.randn(input_features.shape, device=input_features.device)
231
+
232
+ # normalize test
233
+ if args.l2_normalize:
234
+ input_features = F.normalize(input_features, dim=-1)
235
+ #################
236
+
237
+ estimated_token_embeddings = phi(input_features)
238
+
239
+ replaced_text_embeddings, replaced_last_hidden_states = encode_with_pseudo_tokens_HF(text_encoder, replaced_tokens, estimated_token_embeddings, return_last_states=True)
240
+
241
+ loss = F.mse_loss(replaced_text_embeddings.float(), original_text_embeddings.float(), reduction="mean")
242
+
243
+ avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
244
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
245
+
246
+ # Backpropagation
247
+ accelerator.backward(loss)
248
+ if accelerator.sync_gradients and args.max_grad_norm is not None:
249
+ accelerator.clip_grad_norm_(phi.parameters(), arg.max_grad_norm)
250
+ optimizer.step()
251
+ lr_scheduler.step()
252
+ optimizer.zero_grad()
253
+
254
+ if accelerator.sync_gradients:
255
+ if args.use_ema:
256
+ ema_phi.step(phi.module.parameters())
257
+ progress_bar.update(1)
258
+ global_step += 1
259
+ accelerator.log({"train/train_loss": train_loss}, step=global_step)
260
+ train_loss = 0.0
261
+
262
+ accelerator.log({'train/lr': lr_scheduler.get_last_lr()[0]}, step=global_step)
263
+ accelerator.log({'train/preproc_rate': torch.sum(indicators).item() / len(indicators)}, step=global_step)
264
+ if args.checkpointing_steps and global_step % args.checkpointing_steps == 0:
265
+ if accelerator.is_main_process:
266
+ logger.info(f"model saving... step: {global_step}")
267
+ save_phi(f"phi_{global_step:09}", global_step, accelerator.unwrap_model(phi), args.output_dir)
268
+ save_phi(f"phi_latest", global_step, accelerator.unwrap_model(phi), args.output_dir)
269
+ if args.use_ema:
270
+ phi_for_saving = copy.deepcopy(accelerator.unwrap_model(phi))
271
+ ema_phi.copy_to(phi_for_saving.parameters())
272
+ save_phi(f"ema_phi_{global_step:09}", global_step, phi_for_saving, args.output_dir)
273
+ save_phi(f"ema_phi_latest", global_step, phi_for_saving, args.output_dir)
274
+
275
+ if global_step % args.validation_steps == 0 or global_step == 50:
276
+ if accelerator.is_main_process:
277
+ logger.info(f"evaluate model... step: {global_step}")
278
+
279
+ if args.use_ema:
280
+ phi_for_eval = copy.deepcopy(accelerator.unwrap_model(phi))
281
+ ema_phi.copy_to(phi_for_eval.parameters())
282
+ else:
283
+ phi_for_eval = phi
284
+
285
+ phi_for_eval.eval()
286
+
287
+ # Extract the pseudo tokens for the CIRR validation set using Phi
288
+ cirr_val_pseudo_tokens, cirr_val_ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi_for_eval,
289
+ cirr_relative_val_dataset, args)
290
+ cirr_val_pseudo_tokens = cirr_val_pseudo_tokens.to(accelerator.device)
291
+
292
+ # Compute the CIRR validation metrics
293
+ cirr_results_dict = cirr_compute_val_metrics(cirr_relative_val_dataset, text_encoder,
294
+ cirr_val_index_features, cirr_val_index_names,
295
+ cirr_val_ref_names_list, cirr_val_pseudo_tokens)
296
+ check_list = ['cirr_recall_at1', 'cirr_recall_at5', 'cirr_recall_at10', 'cirr_recall_at50']
297
+ for check_key in check_list:
298
+ accelerator.log({f"validate/{check_key}": cirr_results_dict[check_key]}, step=global_step)
299
+ print(json.dumps(cirr_results_dict, indent=4))
300
+
301
+ # Save the best model.
302
+ if args.checkpointing_steps:
303
+ if cirr_results_dict['cirr_recall_at1'] > best_recall:
304
+ best_recall = cirr_results_dict['cirr_recall_at1']
305
+ logger.info(f"best model saving... step: {global_step}")
306
+ save_phi("phi_best", global_step, accelerator.unwrap_model(phi), args.output_dir)
307
+
308
+ phi.train()
309
+
310
+ if global_step >= args.max_train_steps:
311
+ break
312
+
313
+
314
+ if __name__ == '__main__':
315
+ args = parse_args()
316
+
317
+ train_phi(args)
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from clip.model import CLIP
6
+ from transformers import CLIPVisionModelWithProjection
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data import Dataset
9
+ from tqdm import tqdm
10
+
11
+ from data_utils import collate_fn
12
+ from models import Phi
13
+
14
+
15
+ if torch.cuda.is_available():
16
+ device = torch.device("cuda")
17
+ dtype = torch.float16
18
+ else:
19
+ device = torch.device("cpu")
20
+ dtype = torch.float32
21
+
22
+
23
+ @torch.no_grad()
24
+ def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32,
25
+ num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]:
26
+ """
27
+ Extracts image features from a dataset using a CLIP model.
28
+ """
29
+ # Create data loader
30
+ loader = DataLoader(dataset=dataset, batch_size=batch_size,
31
+ num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
32
+
33
+ index_features = []
34
+ index_names = []
35
+ try:
36
+ print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}")
37
+ except Exception as e:
38
+ pass
39
+
40
+ # Extract features
41
+ for batch in tqdm(loader):
42
+ images = batch.get('image')
43
+ names = batch.get('image_name')
44
+ if images is None:
45
+ images = batch.get('reference_image')
46
+ if names is None:
47
+ names = batch.get('reference_name')
48
+
49
+ images = images.to(clip_model.device)
50
+ with torch.no_grad():
51
+ batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images)
52
+ index_features.append(batch_features.cpu())
53
+ index_names.extend(names)
54
+
55
+ index_features = torch.vstack(index_features)
56
+ return index_features, index_names
57
+
58
+
59
+ def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor:
60
+ # Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
61
+ v1 = F.normalize(v1, dim=1)
62
+ v2 = F.normalize(v2, dim=1)
63
+
64
+ numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature)
65
+ numerator = torch.cat((numerator, numerator), 0)
66
+ joint_vector = torch.cat((v1, v2), 0)
67
+ pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
68
+ denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0)
69
+
70
+ loss = -torch.mean(torch.log(numerator / denominator))
71
+
72
+ return loss
73
+
74
+
75
+ @torch.no_grad()
76
+ def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]:
77
+ """
78
+ Extracts pseudo tokens from a dataset using a CLIP model and a phi model
79
+ """
80
+ data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
81
+ collate_fn=collate_fn)
82
+ predicted_tokens = []
83
+ names_list = []
84
+ print(f"Extracting tokens using phi model")
85
+ for batch in tqdm(data_loader):
86
+ images = batch.get('image')
87
+ names = batch.get('image_name')
88
+ if images is None:
89
+ images = batch.get('reference_image')
90
+ if names is None:
91
+ names = batch.get('reference_name')
92
+
93
+ images = images.to(device)
94
+ image_features = clip_model(pixel_values=images.half()).image_embeds
95
+ if args.l2_normalize:
96
+ image_features = F.normalize(image_features, dim=-1)
97
+ batch_predicted_tokens = phi(image_features)
98
+ predicted_tokens.append(batch_predicted_tokens.cpu())
99
+ names_list.extend(names)
100
+
101
+ predicted_tokens = torch.vstack(predicted_tokens)
102
+ return predicted_tokens, names_list
103
+
104
+
105
+ @torch.no_grad()
106
+ def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]:
107
+ """
108
+ Extracts image features from a dataset using a CLIP model
109
+ """
110
+ data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
111
+ collate_fn=collate_fn)
112
+ predicted_tokens = []
113
+ names_list = []
114
+ print(f"Extracting tokens using phi model")
115
+ for batch in tqdm(data_loader):
116
+ images = batch.get('image')
117
+ names = batch.get('image_name')
118
+ if images is None:
119
+ images = batch.get('reference_image')
120
+ if names is None:
121
+ names = batch.get('reference_name')
122
+
123
+ images = images.to(device)
124
+ image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds
125
+
126
+ #batch_predicted_tokens = phi(image_features)
127
+ batch_predicted_tokens = image_features
128
+ predicted_tokens.append(batch_predicted_tokens.cpu())
129
+ names_list.extend(names)
130
+
131
+ predicted_tokens = torch.vstack(predicted_tokens)
132
+ return predicted_tokens, names_list
133
+
134
+ class CustomTensorDataset(Dataset):
135
+ """
136
+ Custom Tensor Dataset which yields image_features and image_names
137
+ """
138
+
139
+ def __init__(self, images: torch.Tensor, names: torch.Tensor):
140
+ self.images = images
141
+ self.names = names
142
+
143
+ def __getitem__(self, index) -> dict:
144
+ return {'image': self.images[index],
145
+ 'image_name': self.names[index]
146
+ }
147
+
148
+ def __len__(self):
149
+ return len(self.images)
150
+
151
+
152
+ def get_templates():
153
+ """
154
+ Return a list of templates
155
+ Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694
156
+ """
157
+ return [
158
+ "This is a photo of a {}",
159
+ "This photo contains a {}",
160
+ "A photo of a {}",
161
+ "This is an illustration of a {}",
162
+ "This illustration contains a {}",
163
+ "An illustrations of a {}",
164
+ "This is a sketch of a {}",
165
+ "This sketch contains a {}",
166
+ "A sketch of a {}",
167
+ "This is a diagram of a {}",
168
+ "This diagram contains a {}",
169
+ "A diagram of a {}",
170
+ "A {}",
171
+ "We see a {}",
172
+ "{}",
173
+ "We see a {} in this photo",
174
+ "We see a {} in this image",
175
+ "We see a {} in this illustration",
176
+ "We see a {} photo",
177
+ "We see a {} image",
178
+ "We see a {} illustration",
179
+ "{} photo",
180
+ "{} image",
181
+ "{} illustration",
182
+ ]
validate.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ from argparse import ArgumentParser
4
+ from typing import List, Dict, Tuple
5
+
6
+ import clip
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from clip.model import CLIP
11
+ from transformers import CLIPTextModelWithProjection
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data import Dataset
14
+ from tqdm import tqdm
15
+
16
+ from data_utils import collate_fn, PROJECT_ROOT, targetpad_transform
17
+ from loader import FashionIQDataset, CIRRDataset, CIRCODataset
18
+ from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
19
+ from models import build_text_encoder, Phi, PIC2WORD
20
+ from utils import extract_image_features, device, extract_pseudo_tokens_with_phi
21
+
22
+ torch.multiprocessing.set_sharing_strategy('file_system')
23
+
24
+
25
+ @torch.no_grad()
26
+ def fiq_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
27
+ pseudo_tokens: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
28
+ """
29
+ Generates features predictions for the validation set of Fashion IQ.
30
+ """
31
+
32
+ # Create data loader
33
+ relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
34
+ pin_memory=False, collate_fn=collate_fn, shuffle=False)
35
+
36
+ predicted_features_list = []
37
+ target_names_list = []
38
+
39
+ # Compute features
40
+ for batch in tqdm(relative_val_loader):
41
+ reference_names = batch['reference_name']
42
+ target_names = batch['target_name']
43
+ relative_captions = batch['relative_captions']
44
+
45
+ flattened_captions: list = np.array(relative_captions).T.flatten().tolist()
46
+ input_captions = [
47
+ f"{flattened_captions[i].strip('.?, ')} and {flattened_captions[i + 1].strip('.?, ')}" for
48
+ i in range(0, len(flattened_captions), 2)]
49
+ input_captions_reversed = [
50
+ f"{flattened_captions[i + 1].strip('.?, ')} and {flattened_captions[i].strip('.?, ')}" for
51
+ i in range(0, len(flattened_captions), 2)]
52
+
53
+ input_captions = [
54
+ f"a photo of $ that {in_cap}" for in_cap in input_captions]
55
+ batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
56
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
57
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
58
+
59
+ input_captions_reversed = [
60
+ f"a photo of $ that {in_cap}" for in_cap in input_captions_reversed]
61
+ tokenized_input_captions_reversed = clip.tokenize(input_captions_reversed, context_length=77).to(device)
62
+ text_features_reversed = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions_reversed,
63
+ batch_tokens)
64
+
65
+ predicted_features = F.normalize((F.normalize(text_features) + F.normalize(text_features_reversed)) / 2)
66
+ # predicted_features = F.normalize((text_features + text_features_reversed) / 2)
67
+
68
+ predicted_features_list.append(predicted_features)
69
+ target_names_list.extend(target_names)
70
+
71
+ predicted_features = torch.vstack(predicted_features_list)
72
+ return predicted_features, target_names_list
73
+
74
+
75
+ @torch.no_grad()
76
+ def fiq_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
77
+ index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
78
+ -> Dict[str, float]:
79
+ """
80
+ Compute the retrieval metrics on the FashionIQ validation set given the dataset, pseudo tokens and the reference names
81
+ """
82
+
83
+ # Generate the predicted features
84
+ predicted_features, target_names = fiq_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list,
85
+ pseudo_tokens)
86
+
87
+ # Move the features to the device
88
+ index_features = index_features.to(device)
89
+ predicted_features = predicted_features.to(device)
90
+
91
+ # Normalize the features
92
+ index_features = F.normalize(index_features.float())
93
+
94
+ # Compute the distances
95
+ distances = 1 - predicted_features @ index_features.T
96
+ sorted_indices = torch.argsort(distances, dim=-1).cpu()
97
+ sorted_index_names = np.array(index_names)[sorted_indices]
98
+
99
+ # Check if the target names are in the top 10 and top 50
100
+ labels = torch.tensor(
101
+ sorted_index_names == np.repeat(np.array(target_names), len(index_names)).reshape(len(target_names), -1))
102
+ assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
103
+
104
+ # Compute the metrics
105
+ recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
106
+ recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
107
+
108
+ return {'fiq_recall_at10': recall_at10,
109
+ 'fiq_recall_at50': recall_at50}
110
+
111
+
112
+ @torch.no_grad()
113
+ def fiq_val_retrieval(dataset_path: str, dress_type: str, image_encoder, text_encoder, ref_names_list: List[str],
114
+ pseudo_tokens: torch.Tensor, preprocess: callable) -> Dict[str, float]:
115
+ """
116
+ Compute the retrieval metrics on the FashionIQ validation set given the pseudo tokens and the reference names
117
+ """
118
+ # Load the model
119
+ #clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
120
+ #clip_model = clip_model.float().eval().requires_grad_(False)
121
+
122
+ # Extract the index features
123
+ classic_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'classic', preprocess)
124
+ index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
125
+
126
+ # Define the relative dataset
127
+ relative_val_dataset = FashionIQDataset(dataset_path, 'val', [dress_type], 'relative', preprocess)
128
+
129
+ return fiq_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
130
+ pseudo_tokens)
131
+
132
+
133
+ @torch.no_grad()
134
+ def cirr_generate_val_predictions(clip_model: CLIPTextModelWithProjection, relative_val_dataset: Dataset, ref_names_list: List[str],
135
+ pseudo_tokens: torch.Tensor) -> \
136
+ Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
137
+ """
138
+ Generates features predictions for the validation set of CIRR
139
+ """
140
+
141
+ # Define the dataloader
142
+ relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
143
+ pin_memory=False, collate_fn=collate_fn)
144
+ predicted_features_list = []
145
+ target_names_list = []
146
+ group_members_list = []
147
+ reference_names_list = []
148
+
149
+ for batch in tqdm(relative_val_loader):
150
+ reference_names = batch['reference_name']
151
+ target_names = batch['target_name']
152
+ relative_captions = batch['relative_caption']
153
+ group_members = batch['group_members']
154
+
155
+ group_members = np.array(group_members).T.tolist()
156
+
157
+ input_captions = [
158
+ f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
159
+
160
+ batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
161
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
162
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
163
+
164
+ predicted_features = F.normalize(text_features)
165
+
166
+ predicted_features_list.append(predicted_features)
167
+ target_names_list.extend(target_names)
168
+ group_members_list.extend(group_members)
169
+ reference_names_list.extend(reference_names)
170
+
171
+ predicted_features = torch.vstack(predicted_features_list)
172
+
173
+ return predicted_features, reference_names_list, target_names_list, group_members_list
174
+
175
+
176
+ @torch.no_grad()
177
+ def cirr_generate_val_predictions_with_phi(clip_model: CLIPTextModelWithProjection, phi, relative_val_dataset: Dataset, ref_names_list: List[str],
178
+ image_features: torch.Tensor) -> \
179
+ Tuple[torch.Tensor, List[str], List[str], List[List[str]]]:
180
+ """
181
+ Generates features predictions for the validation set of CIRR
182
+ """
183
+
184
+ # Define the dataloader
185
+ relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
186
+ pin_memory=False, collate_fn=collate_fn)
187
+ predicted_features_list = []
188
+ target_names_list = []
189
+ group_members_list = []
190
+ reference_names_list = []
191
+
192
+ for batch in tqdm(relative_val_loader):
193
+ reference_names = batch['reference_name']
194
+ target_names = batch['target_name']
195
+ relative_captions = batch['relative_caption']
196
+ group_members = batch['group_members']
197
+
198
+ group_members = np.array(group_members).T.tolist()
199
+
200
+ input_captions = [
201
+ f"a photo of $ that {rel_caption}" for rel_caption in relative_captions]
202
+
203
+ # we need to make batch_tokens with selected_image_features
204
+ selected_image_features = torch.vstack([image_features[ref_names_list.index(ref)] for ref in reference_names])
205
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
206
+ context = clip_model.text_model.embeddings.token_embedding(tokenized_input_captions) + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
207
+ batch_tokens = phi(selected_image_features, context)
208
+ #batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
209
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
210
+
211
+ predicted_features = F.normalize(text_features)
212
+
213
+ predicted_features_list.append(predicted_features)
214
+ target_names_list.extend(target_names)
215
+ group_members_list.extend(group_members)
216
+ reference_names_list.extend(reference_names)
217
+
218
+ predicted_features = torch.vstack(predicted_features_list)
219
+
220
+ return predicted_features, reference_names_list, target_names_list, group_members_list
221
+
222
+
223
+ @torch.no_grad()
224
+ def cirr_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
225
+ index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
226
+ -> Dict[str, float]:
227
+ """
228
+ Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
229
+ """
230
+
231
+ # Generate the predicted features
232
+ predicted_features, reference_names, target_names, group_members = \
233
+ cirr_generate_val_predictions(clip_model, relative_val_dataset, ref_names_list, pseudo_tokens)
234
+
235
+ index_features = index_features.to(device)
236
+ predicted_features = predicted_features.to(device)
237
+
238
+ # Normalize the index features
239
+ index_features = F.normalize(index_features, dim=-1).float()
240
+ predicted_features = predicted_features.float()
241
+
242
+ # Compute the distances and sort the results
243
+ distances = 1 - predicted_features @ index_features.T
244
+ sorted_indices = torch.argsort(distances, dim=-1).cpu()
245
+ sorted_index_names = np.array(index_names)[sorted_indices]
246
+
247
+ # Delete the reference image from the results
248
+ reference_mask = torch.tensor(
249
+ sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
250
+ sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
251
+ sorted_index_names.shape[1] - 1)
252
+ # Compute the ground-truth labels wrt the predictions
253
+ labels = torch.tensor(
254
+ sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
255
+
256
+ # Compute the subset predictions and ground-truth labels
257
+ group_members = np.array(group_members)
258
+ group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
259
+ group_labels = labels[group_mask].reshape(labels.shape[0], -1)
260
+
261
+ assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
262
+ assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
263
+
264
+ # Compute the metrics
265
+ recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
266
+ recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
267
+ recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
268
+ recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
269
+ group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
270
+ group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
271
+ group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
272
+
273
+ return {
274
+ 'cirr_recall_at1': recall_at1,
275
+ 'cirr_recall_at5': recall_at5,
276
+ 'cirr_recall_at10': recall_at10,
277
+ 'cirr_recall_at50': recall_at50,
278
+ 'cirr_group_recall_at1': group_recall_at1,
279
+ 'cirr_group_recall_at2': group_recall_at2,
280
+ 'cirr_group_recall_at3': group_recall_at3,
281
+ }
282
+
283
+
284
+ @torch.no_grad()
285
+ def cirr_compute_val_metrics_with_phi(relative_val_dataset: Dataset, clip_model: CLIPTextModelWithProjection, phi, index_features: torch.Tensor,
286
+ index_names: List[str], ref_names_list: List[str], image_features: torch.Tensor) \
287
+ -> Dict[str, float]:
288
+ """
289
+ Compute the retrieval metrics on the CIRR validation set given the dataset, pseudo tokens and the reference names
290
+ """
291
+
292
+ # Generate the predicted features
293
+ predicted_features, reference_names, target_names, group_members = \
294
+ cirr_generate_val_predictions_with_phi(clip_model, phi, relative_val_dataset, ref_names_list, image_features)
295
+
296
+ index_features = index_features.to(device)
297
+ predicted_features = predicted_features.to(device)
298
+
299
+ # Normalize the index features
300
+ index_features = F.normalize(index_features, dim=-1).float()
301
+ predicted_features = predicted_features.float()
302
+
303
+ # Compute the distances and sort the results
304
+ distances = 1 - predicted_features @ index_features.T
305
+ sorted_indices = torch.argsort(distances, dim=-1).cpu()
306
+ sorted_index_names = np.array(index_names)[sorted_indices]
307
+
308
+ # Delete the reference image from the results
309
+ reference_mask = torch.tensor(
310
+ sorted_index_names != np.repeat(np.array(reference_names), len(index_names)).reshape(len(target_names), -1))
311
+ sorted_index_names = sorted_index_names[reference_mask].reshape(sorted_index_names.shape[0],
312
+ sorted_index_names.shape[1] - 1)
313
+ # Compute the ground-truth labels wrt the predictions
314
+ labels = torch.tensor(
315
+ sorted_index_names == np.repeat(np.array(target_names), len(index_names) - 1).reshape(len(target_names), -1))
316
+
317
+ # Compute the subset predictions and ground-truth labels
318
+ group_members = np.array(group_members)
319
+ group_mask = (sorted_index_names[..., None] == group_members[:, None, :]).sum(-1).astype(bool)
320
+ group_labels = labels[group_mask].reshape(labels.shape[0], -1)
321
+
322
+ assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
323
+ assert torch.equal(torch.sum(group_labels, dim=-1).int(), torch.ones(len(target_names)).int())
324
+
325
+ # Compute the metrics
326
+ recall_at1 = (torch.sum(labels[:, :1]) / len(labels)).item() * 100
327
+ recall_at5 = (torch.sum(labels[:, :5]) / len(labels)).item() * 100
328
+ recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
329
+ recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
330
+ group_recall_at1 = (torch.sum(group_labels[:, :1]) / len(group_labels)).item() * 100
331
+ group_recall_at2 = (torch.sum(group_labels[:, :2]) / len(group_labels)).item() * 100
332
+ group_recall_at3 = (torch.sum(group_labels[:, :3]) / len(group_labels)).item() * 100
333
+
334
+ return {
335
+ 'cirr_recall_at1': recall_at1,
336
+ 'cirr_recall_at5': recall_at5,
337
+ 'cirr_recall_at10': recall_at10,
338
+ 'cirr_recall_at50': recall_at50,
339
+ 'cirr_group_recall_at1': group_recall_at1,
340
+ 'cirr_group_recall_at2': group_recall_at2,
341
+ 'cirr_group_recall_at3': group_recall_at3,
342
+ }
343
+
344
+
345
+ @torch.no_grad()
346
+ def cirr_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: list, pseudo_tokens: torch.Tensor,
347
+ preprocess: callable) -> Dict[str, float]:
348
+ """
349
+ Compute the retrieval metrics on the CIRR validation set given the pseudo tokens and the reference names
350
+ """
351
+
352
+ # Load the model
353
+ #clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
354
+ #clip_model = clip_model.float().eval().requires_grad_(False)
355
+
356
+ # Extract the index features
357
+ classic_val_dataset = CIRRDataset(dataset_path, 'val', 'classic', preprocess)
358
+ index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
359
+
360
+ # Define the relative validation dataset
361
+ relative_val_dataset = CIRRDataset(dataset_path, 'val', 'relative', preprocess)
362
+
363
+ return cirr_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names,
364
+ ref_names_list, pseudo_tokens)
365
+
366
+
367
+ @torch.no_grad()
368
+ def circo_generate_val_predictions(clip_model, relative_val_dataset: Dataset, ref_names_list: List[str],
369
+ pseudo_tokens: torch.Tensor) -> Tuple[
370
+ torch.Tensor, List[str], list]:
371
+ """
372
+ Generates features predictions for the validation set of CIRCO
373
+ """
374
+
375
+ # Create the data loader
376
+ relative_val_loader = DataLoader(dataset=relative_val_dataset, batch_size=32, num_workers=10,
377
+ pin_memory=False, collate_fn=collate_fn, shuffle=False)
378
+
379
+ predicted_features_list = []
380
+ target_names_list = []
381
+ gts_img_ids_list = []
382
+
383
+ # Compute the features
384
+ for batch in tqdm(relative_val_loader):
385
+ reference_names = batch['reference_name']
386
+ target_names = batch['target_name']
387
+ relative_captions = batch['relative_caption']
388
+ gt_img_ids = batch['gt_img_ids']
389
+
390
+ gt_img_ids = np.array(gt_img_ids).T.tolist()
391
+ input_captions = [f"a photo of $ that {caption}" for caption in relative_captions]
392
+ batch_tokens = torch.vstack([pseudo_tokens[ref_names_list.index(ref)].unsqueeze(0) for ref in reference_names])
393
+ tokenized_input_captions = clip.tokenize(input_captions, context_length=77).to(device)
394
+ text_features = encode_with_pseudo_tokens_HF(clip_model, tokenized_input_captions, batch_tokens)
395
+ predicted_features = F.normalize(text_features)
396
+
397
+ predicted_features_list.append(predicted_features)
398
+ target_names_list.extend(target_names)
399
+ gts_img_ids_list.extend(gt_img_ids)
400
+
401
+ predicted_features = torch.vstack(predicted_features_list)
402
+
403
+ return predicted_features, target_names_list, gts_img_ids_list
404
+
405
+
406
+ @torch.no_grad()
407
+ def circo_compute_val_metrics(relative_val_dataset: Dataset, clip_model, index_features: torch.Tensor,
408
+ index_names: List[str], ref_names_list: List[str], pseudo_tokens: torch.Tensor) \
409
+ -> Dict[str, float]:
410
+ """
411
+ Compute the retrieval metrics on the CIRCO validation set given the dataset, pseudo tokens and the reference names
412
+ """
413
+
414
+ # Generate the predicted features
415
+ predicted_features, target_names, gts_img_ids = circo_generate_val_predictions(clip_model, relative_val_dataset,
416
+ ref_names_list, pseudo_tokens)
417
+ ap_at5 = []
418
+ ap_at10 = []
419
+ ap_at25 = []
420
+ ap_at50 = []
421
+
422
+ recall_at5 = []
423
+ recall_at10 = []
424
+ recall_at25 = []
425
+ recall_at50 = []
426
+
427
+ # Move the features to the device
428
+ index_features = index_features.to(device)
429
+ predicted_features = predicted_features.to(device)
430
+
431
+ # Normalize the features
432
+ index_features = F.normalize(index_features.float())
433
+
434
+ for predicted_feature, target_name, gt_img_ids in tqdm(zip(predicted_features, target_names, gts_img_ids)):
435
+ gt_img_ids = np.array(gt_img_ids)[
436
+ np.array(gt_img_ids) != ''] # remove trailing empty strings added for collate_fn
437
+ similarity = predicted_feature @ index_features.T
438
+ sorted_indices = torch.topk(similarity, dim=-1, k=50).indices.cpu()
439
+ sorted_index_names = np.array(index_names)[sorted_indices]
440
+ map_labels = torch.tensor(np.isin(sorted_index_names, gt_img_ids), dtype=torch.uint8)
441
+ precisions = torch.cumsum(map_labels, dim=0) * map_labels # Consider only positions corresponding to GTs
442
+ precisions = precisions / torch.arange(1, map_labels.shape[0] + 1) # Compute precision for each position
443
+
444
+ ap_at5.append(float(torch.sum(precisions[:5]) / min(len(gt_img_ids), 5)))
445
+ ap_at10.append(float(torch.sum(precisions[:10]) / min(len(gt_img_ids), 10)))
446
+ ap_at25.append(float(torch.sum(precisions[:25]) / min(len(gt_img_ids), 25)))
447
+ ap_at50.append(float(torch.sum(precisions[:50]) / min(len(gt_img_ids), 50)))
448
+
449
+ assert target_name == gt_img_ids[0], f"Target name not in GTs {target_name} {gt_img_ids}"
450
+ single_gt_labels = torch.tensor(sorted_index_names == target_name)
451
+ recall_at5.append(float(torch.sum(single_gt_labels[:5])))
452
+ recall_at10.append(float(torch.sum(single_gt_labels[:10])))
453
+ recall_at25.append(float(torch.sum(single_gt_labels[:25])))
454
+ recall_at50.append(float(torch.sum(single_gt_labels[:50])))
455
+
456
+ map_at5 = np.mean(ap_at5) * 100
457
+ map_at10 = np.mean(ap_at10) * 100
458
+ map_at25 = np.mean(ap_at25) * 100
459
+ map_at50 = np.mean(ap_at50) * 100
460
+ recall_at5 = np.mean(recall_at5) * 100
461
+ recall_at10 = np.mean(recall_at10) * 100
462
+ recall_at25 = np.mean(recall_at25) * 100
463
+ recall_at50 = np.mean(recall_at50) * 100
464
+
465
+ return {
466
+ 'circo_map_at5': map_at5,
467
+ 'circo_map_at10': map_at10,
468
+ 'circo_map_at25': map_at25,
469
+ 'circo_map_at50': map_at50,
470
+ 'circo_recall_at5': recall_at5,
471
+ 'circo_recall_at10': recall_at10,
472
+ 'circo_recall_at25': recall_at25,
473
+ 'circo_recall_at50': recall_at50,
474
+ }
475
+
476
+
477
+ @torch.no_grad()
478
+ def circo_val_retrieval(dataset_path: str, image_encoder, text_encoder, ref_names_list: List[str], pseudo_tokens: torch.Tensor,
479
+ preprocess: callable) -> Dict[str, float]:
480
+ """
481
+ Compute the retrieval metrics on the CIRCO validation set given the pseudo tokens and the reference names
482
+ """
483
+ # Load the model
484
+ #clip_model, _ = clip.load(clip_model_name, device=device, jit=False)
485
+ #clip_model = clip_model.float().eval().requires_grad_(False)
486
+
487
+ # Extract the index features
488
+ classic_val_dataset = CIRCODataset(dataset_path, 'val', 'classic', preprocess)
489
+ index_features, index_names = extract_image_features(classic_val_dataset, image_encoder)
490
+
491
+ # Define the relative validation dataset
492
+ relative_val_dataset = CIRCODataset(dataset_path, 'val', 'relative', preprocess)
493
+
494
+ return circo_compute_val_metrics(relative_val_dataset, text_encoder, index_features, index_names, ref_names_list,
495
+ pseudo_tokens)
496
+
497
+
498
+ def main():
499
+ parser = ArgumentParser()
500
+ parser.add_argument("--exp-name", type=str, help="Experiment to evaluate")
501
+ parser.add_argument("--eval-type", type=str, choices=['oti', 'phi', 'searle', 'searle-xl', 'pic2word'], required=True,
502
+ help="If 'oti' evaluate directly using the inverted oti pseudo tokens, "
503
+ "if 'phi' predicts the pseudo tokens using the phi network, "
504
+ "if 'searle' uses the pre-trained SEARLE model to predict the pseudo tokens, "
505
+ "if 'searle-xl' uses the pre-trained SEARLE-XL model to predict the pseudo tokens"
506
+ )
507
+ parser.add_argument("--dataset", type=str, required=True, choices=['cirr', 'fashioniq', 'circo'],
508
+ help="Dataset to use")
509
+ parser.add_argument("--dataset-path", type=str, help="Path to the dataset", required=True)
510
+
511
+ parser.add_argument("--preprocess-type", default="clip", type=str, choices=['clip', 'targetpad'],
512
+ help="Preprocess pipeline to use")
513
+ parser.add_argument("--phi-checkpoint-name", type=str,
514
+ help="Phi checkpoint to use, needed when using phi, e.g. 'phi_20.pt'")
515
+ parser.add_argument("--clip_model_name", default="giga", type=str)
516
+ parser.add_argument("--cache_dir", default="./hf_models", type=str)
517
+
518
+ parser.add_argument("--l2_normalize", action="store_true", help="Whether or not to use l2 normalization")
519
+
520
+ args = parser.parse_args()
521
+
522
+ #if args.eval_type in ['phi', 'oti'] and args.exp_name is None:
523
+ # raise ValueError("Experiment name is required when using phi or oti evaluation type")
524
+ if args.eval_type == 'phi' and args.phi_checkpoint_name is None:
525
+ raise ValueError("Phi checkpoint name is required when using phi evaluation type")
526
+
527
+ if args.eval_type == 'oti':
528
+ experiment_path = PROJECT_ROOT / 'data' / "oti_pseudo_tokens" / args.dataset.lower() / 'val' / args.exp_name
529
+ if not experiment_path.exists():
530
+ raise ValueError(f"Experiment {args.exp_name} not found")
531
+
532
+ with open(experiment_path / 'hyperparameters.json') as f:
533
+ hyperparameters = json.load(f)
534
+
535
+ pseudo_tokens = torch.load(experiment_path / 'ema_oti_pseudo_tokens.pt', map_location=device)
536
+ with open(experiment_path / 'image_names.pkl', 'rb') as f:
537
+ ref_names_list = pickle.load(f)
538
+
539
+ clip_model_name = hyperparameters['clip_model_name']
540
+ clip_model, clip_preprocess = clip.load(clip_model_name, device='cpu', jit=False)
541
+
542
+ if args.preprocess_type == 'targetpad':
543
+ print('Target pad preprocess pipeline is used')
544
+ preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
545
+ elif args.preprocess_type == 'clip':
546
+ print('CLIP preprocess pipeline is used')
547
+ preprocess = clip_preprocess
548
+ else:
549
+ raise ValueError("Preprocess type not supported")
550
+
551
+
552
+ elif args.eval_type in ['phi', 'searle', 'searle-xl', 'pic2word']:
553
+ if args.eval_type == 'phi':
554
+ args.mixed_precision = 'fp16'
555
+ image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
556
+
557
+ phi = Phi(input_dim=text_encoder.config.projection_dim,
558
+ hidden_dim=text_encoder.config.projection_dim * 4,
559
+ output_dim=text_encoder.config.hidden_size, dropout=0.5).to(
560
+ device)
561
+
562
+ phi.load_state_dict(
563
+ torch.load(args.phi_checkpoint_name, map_location=device)[
564
+ phi.__class__.__name__])
565
+
566
+ phi = phi.eval()
567
+
568
+ elif args.eval_type == 'pic2word':
569
+ args.mixed_precision = 'fp16'
570
+ image_encoder, clip_preprocess, text_encoder, tokenizer = build_text_encoder(args)
571
+ phi = PIC2WORD(embed_dim=text_encoder.config.projection_dim,
572
+ output_dim=text_encoder.config.hidden_size,
573
+ ).to(device)
574
+ sd = torch.load(args.phi_checkpoint_name, map_location=device)['state_dict_img2text']
575
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
576
+ phi.load_state_dict(sd)
577
+ phi = phi.eval()
578
+
579
+ else: # searle or searle-xl
580
+ if args.eval_type == 'searle':
581
+ clip_model_name = 'ViT-B/32'
582
+ else: # args.eval_type == 'searle-xl':
583
+ clip_model_name = 'ViT-L/14'
584
+ phi, _ = torch.hub.load(repo_or_dir='miccunifi/SEARLE', model='searle', source='github',
585
+ backbone=clip_model_name)
586
+ phi = phi.to(device).eval()
587
+ clip_model, clip_preprocess = clip.load(clip_model_name, device=device, jit=False)
588
+
589
+ if args.preprocess_type == 'targetpad':
590
+ print('Target pad preprocess pipeline is used')
591
+ preprocess = targetpad_transform(1.25, clip_model.visual.input_resolution)
592
+ elif args.preprocess_type == 'clip':
593
+ print('CLIP preprocess pipeline is used')
594
+ preprocess = clip_preprocess
595
+ else:
596
+ raise ValueError("Preprocess type not supported")
597
+
598
+ if args.dataset.lower() == 'fashioniq':
599
+ relative_val_dataset = FashionIQDataset(args.dataset_path, 'val', ['dress', 'toptee', 'shirt'],
600
+ 'relative', preprocess, no_duplicates=True)
601
+ elif args.dataset.lower() == 'cirr':
602
+ relative_val_dataset = CIRRDataset(args.dataset_path, 'val', 'relative', preprocess,
603
+ no_duplicates=True)
604
+ elif args.dataset.lower() == 'circo':
605
+ relative_val_dataset = CIRCODataset(args.dataset_path, 'val', 'relative', preprocess)
606
+ else:
607
+ raise ValueError("Dataset not supported")
608
+
609
+ #clip_model = clip_model.float().to(device)
610
+ image_encoder = image_encoder.float().to(device)
611
+ text_encoder = text_encoder.float().to(device)
612
+ pseudo_tokens, ref_names_list = extract_pseudo_tokens_with_phi(image_encoder, phi, relative_val_dataset, args)
613
+ pseudo_tokens = pseudo_tokens.to(device)
614
+ else:
615
+ raise ValueError("Eval type not supported")
616
+
617
+ print(f"Eval type = {args.eval_type} \t exp name = {args.exp_name} \t")
618
+ if args.dataset.lower() == 'fashioniq':
619
+ recalls_at10 = []
620
+ recalls_at50 = []
621
+ for dress_type in ['shirt', 'dress', 'toptee']:
622
+ fiq_metrics = fiq_val_retrieval(args.dataset_path, dress_type, image_encoder, text_encoder, ref_names_list,
623
+ pseudo_tokens, preprocess)
624
+ recalls_at10.append(fiq_metrics['fiq_recall_at10'])
625
+ recalls_at50.append(fiq_metrics['fiq_recall_at50'])
626
+
627
+ for k, v in fiq_metrics.items():
628
+ print(f"{dress_type}_{k} = {v:.2f}")
629
+ print("\n")
630
+
631
+ print(f"average_fiq_recall_at10 = {np.mean(recalls_at10):.2f}")
632
+ print(f"average_fiq_recall_at50 = {np.mean(recalls_at50):.2f}")
633
+
634
+ elif args.dataset.lower() == 'cirr':
635
+ cirr_metrics = cirr_val_retrieval(args.dataset_path, image_encoder, text_encoder, ref_names_list, pseudo_tokens,
636
+ preprocess)
637
+
638
+ for k, v in cirr_metrics.items():
639
+ print(f"{k} = {v:.2f}")
640
+
641
+ elif args.dataset.lower() == 'circo':
642
+ circo_metrics = circo_val_retrieval(args.dataset_path, clip_model_name, ref_names_list, pseudo_tokens,
643
+ preprocess)
644
+
645
+ for k, v in circo_metrics.items():
646
+ print(f"{k} = {v:.2f}")
647
+
648
+
649
+ if __name__ == '__main__':
650
+ main()