OmkarThawakar commited on
Commit
ed00004
1 Parent(s): 7baf9f3

initail commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +13 -0
  2. app.py +208 -0
  3. configs/data/cirr.yaml +22 -0
  4. configs/data/fashioniq-base.yaml +28 -0
  5. configs/data/fashioniq-dress.yaml +4 -0
  6. configs/data/fashioniq-shirt.yaml +4 -0
  7. configs/data/fashioniq-toptee.yaml +4 -0
  8. configs/data/webvid-covr.yaml +26 -0
  9. configs/data/webvid-covr_rule-based.yaml +26 -0
  10. configs/experiment/cirr.yaml +13 -0
  11. configs/experiment/covr_hard-negatives.yaml +6 -0
  12. configs/experiment/covr_iterate-triplets.yaml +14 -0
  13. configs/experiment/covr_late-fusion.yaml +12 -0
  14. configs/experiment/covr_middle-emb.yaml +14 -0
  15. configs/experiment/covr_only-text.yaml +8 -0
  16. configs/experiment/covr_only-visual.yaml +20 -0
  17. configs/experiment/covr_random-frame.yaml +10 -0
  18. configs/experiment/covr_rule-based.yaml +8 -0
  19. configs/experiment/fiq-dress.yaml +17 -0
  20. configs/experiment/fiq-shirt.yaml +17 -0
  21. configs/experiment/fiq-toptee.yaml +17 -0
  22. configs/machine/default.yaml +16 -0
  23. configs/machine/server.yaml +8 -0
  24. configs/med_config.json +21 -0
  25. configs/model/blip-large.yaml +15 -0
  26. configs/model/blip-large_text.yaml +15 -0
  27. configs/model/blip-large_visual.yaml +15 -0
  28. configs/model/ckpt/blip-l-coco.yaml +3 -0
  29. configs/model/ckpt/cirr-gt.yaml +3 -0
  30. configs/model/ckpt/cirr_ft-covr+gt.yaml +3 -0
  31. configs/model/ckpt/webvid-covr.yaml +3 -0
  32. configs/model/loss/cross_entropy.yaml +2 -0
  33. configs/model/loss/hn_nce.yaml +5 -0
  34. configs/model/optimizer/adamw.yaml +5 -0
  35. configs/model/scheduler/cosine.yaml +6 -0
  36. configs/model/scheduler/step.yaml +5 -0
  37. configs/test.yaml +27 -0
  38. configs/test/all.yaml +6 -0
  39. configs/test/cirr.yaml +15 -0
  40. configs/test/fashioniq-dress.yaml +18 -0
  41. configs/test/fashioniq-shirt.yaml +18 -0
  42. configs/test/fashioniq-toptee.yaml +18 -0
  43. configs/test/fashioniq.yaml +4 -0
  44. configs/test/main.yaml +3 -0
  45. configs/test/webvid-covr.yaml +20 -0
  46. configs/test/webvid-covr_text.yaml +20 -0
  47. configs/test/webvid-covr_visual.yaml +20 -0
  48. configs/train.yaml +33 -0
  49. configs/trainer/cpu.yaml +5 -0
  50. configs/trainer/ddp.yaml +12 -0
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ outputs/
2
+ datasets/sidechef/images
3
+ datasets/sidechef/sample_images
4
+ datasets/sidechef/my_tags.json
5
+ datasets/sidechef/tag_categories.json
6
+ datasets/sidechef/tags.json
7
+ launching
8
+ annotation/
9
+ .vscode/
10
+ bert-base-uncased/
11
+ delete*
12
+ __pycache__/
13
+ env/
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ # from src.data.embs import ImageDataset
13
+ from src.model.blip_embs import blip_embs
14
+ from src.data.transforms import transform_test
15
+
16
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
17
+ import gradio as gr
18
+
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_core.prompts import ChatPromptTemplate
21
+ from langchain_groq import ChatGroq
22
+
23
+
24
+ # GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
+ GROQ_API_KEY = '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'
26
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
27
+
28
+ # Initialize LLM
29
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2)
30
+
31
+ # QA system prompt and chain
32
+ qa_system_prompt = """
33
+ Prompt:
34
+ You are a highly intelligent assistant. Use the following context to answer user questions. Analyze the data carefully and generate a clear, concise, and informative response to the user's question based on this data.
35
+
36
+ Response Guidelines:
37
+ - Use only the information provided in the data to answer the question.
38
+ - Ensure the answer is accurate and directly related to the question.
39
+ - If the data is insufficient to answer the question, politey apologise and tell the user that there is insufficient data available to answer their question.
40
+ - Provide the response in a conversational yet professional tone.
41
+
42
+ Context:
43
+ {context}
44
+ """
45
+ qa_prompt = ChatPromptTemplate.from_messages(
46
+ [
47
+ ("system", qa_system_prompt),
48
+ ("human", "{input}")
49
+ ]
50
+ )
51
+
52
+ question_answer_chain = qa_prompt | llm | StrOutputParser()
53
+
54
+
55
+ class StoppingCriteriaSub(StoppingCriteria):
56
+
57
+ def __init__(self, stops=[], encounters=1):
58
+ super().__init__()
59
+ self.stops = stops
60
+
61
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
62
+ for stop in self.stops:
63
+ if torch.all(input_ids[:, -len(stop):] == stop).item():
64
+ return True
65
+
66
+ return False
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+
70
+ def get_blip_config(model="base"):
71
+ config = dict()
72
+ if model == "base":
73
+ config[
74
+ "pretrained"
75
+ ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth "
76
+ config["vit"] = "base"
77
+ config["batch_size_train"] = 32
78
+ config["batch_size_test"] = 16
79
+ config["vit_grad_ckpt"] = True
80
+ config["vit_ckpt_layer"] = 4
81
+ config["init_lr"] = 1e-5
82
+ elif model == "large":
83
+ config[
84
+ "pretrained"
85
+ ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
86
+ config["vit"] = "large"
87
+ config["batch_size_train"] = 16
88
+ config["batch_size_test"] = 32
89
+ config["vit_grad_ckpt"] = True
90
+ config["vit_ckpt_layer"] = 12
91
+ config["init_lr"] = 5e-6
92
+
93
+ config["image_size"] = 384
94
+ config["queue_size"] = 57600
95
+ config["alpha"] = 0.4
96
+ config["k_test"] = 256
97
+ config["negative_all_rank"] = True
98
+
99
+ return config
100
+
101
+
102
+ print("Creating model")
103
+ config = get_blip_config("large")
104
+
105
+ model = blip_embs(
106
+ pretrained=config["pretrained"],
107
+ image_size=config["image_size"],
108
+ vit=config["vit"],
109
+ vit_grad_ckpt=config["vit_grad_ckpt"],
110
+ vit_ckpt_layer=config["vit_ckpt_layer"],
111
+ queue_size=config["queue_size"],
112
+ negative_all_rank=config["negative_all_rank"],
113
+ )
114
+
115
+ model = model.to(device)
116
+ model.eval()
117
+ print("Model Loaded !")
118
+ print("="*50)
119
+
120
+ transform = transform_test(384)
121
+
122
+ print("Loading Data")
123
+ df = pd.read_json("datasets/sidechef/my_recipes.json")
124
+
125
+ print("Loading Target Embedding")
126
+ tar_img_feats = []
127
+ for _id in df["id_"].tolist():
128
+ tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))
129
+
130
+ tar_img_feats = torch.cat(tar_img_feats, dim=0)
131
+
132
+
133
+ class Chat:
134
+
135
+ def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None):
136
+ self.device = device
137
+ self.model = model
138
+ self.transform = transform
139
+ self.df = dataframe
140
+ self.tar_img_feats = tar_img_feats
141
+ self.img_feats = None
142
+ self.target_recipe = None
143
+ self.messages = []
144
+
145
+ if stopping_criteria is not None:
146
+ self.stopping_criteria = stopping_criteria
147
+ else:
148
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
149
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
150
+
151
+ def encode_image(self, image_path):
152
+ img = Image.fromarray(image_path).convert("RGB")
153
+ img = self.transform(img).unsqueeze(0)
154
+ img = img.to(self.device)
155
+ img_embs = model.visual_encoder(img)
156
+ img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()
157
+
158
+ self.img_feats = img_feats
159
+
160
+ self.get_target(self.img_feats, self.tar_img_feats)
161
+
162
+ def get_target(self, img_feats, tar_img_feats) :
163
+ score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()
164
+ index = np.argsort(score)[::-1][0] + 1
165
+ self.target_recipe = df.iloc[index]
166
+
167
+ def ask(self):
168
+ return json.dumps(self.target_recipe.to_json())
169
+
170
+
171
+
172
+ chat = Chat(model,transform,df,tar_img_feats)
173
+ print("Chat Initialized !")
174
+
175
+
176
+ custom_css = """
177
+ .primary{
178
+ background-color: #4CAF50; /* Green */
179
+ }
180
+ """
181
+
182
+
183
+ def respond_to_user(image, message):
184
+ # Process the image and message here
185
+ # For demonstration, I'll just return a simple text response
186
+ chat = Chat(model,transform,df,tar_img_feats)
187
+ chat.encode_image(image)
188
+ data = chat.ask()
189
+ formated_input = {
190
+ 'input': message,
191
+ 'context': data
192
+ }
193
+ try:
194
+ response = question_answer_chain.invoke(formated_input)
195
+ except Exception as e:
196
+ response = {'content':"An error occurred while processing your request."}
197
+ return response
198
+
199
+ iface = gr.Interface(
200
+ fn=respond_to_user,
201
+ inputs=[gr.Image(), gr.Textbox(label="Ask Query")],
202
+ outputs=gr.Textbox(label="Nutrition-GPT"),
203
+ title="Nutrition-GPT Demo",
204
+ description="Upload an food image and ask queries!",
205
+ css=".component-12 {background-color: red}",
206
+ )
207
+
208
+ iface.launch()
configs/data/cirr.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataname: cirr
2
+ _target_: src.data.cirr.CIRRDataModule
3
+
4
+ # Paths
5
+ dataset_dir: ${paths.datasets_dir}/CIRR
6
+
7
+ batch_size: ${machine.batch_size}
8
+ num_workers: ${machine.num_workers}
9
+
10
+ annotation:
11
+ train: ${paths.work_dir}/annotation/cirr/cap.rc2.train.json
12
+ val: ${paths.work_dir}/annotation/cirr/cap.rc2.val.json
13
+
14
+ img_dirs:
15
+ train: ${data.dataset_dir}/images/train
16
+ val: ${data.dataset_dir}/images/dev
17
+
18
+ emb_dirs:
19
+ train: ${data.dataset_dir}/blip-embs-large/train
20
+ val: ${data.dataset_dir}/blip-embs-large/dev
21
+
22
+ image_size: 384
configs/data/fashioniq-base.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataname: fashioniq-${data.category}
2
+ _target_: src.data.fashioniq.FashionIQDataModule
3
+
4
+ # Paths
5
+ dataset_dir: ${paths.datasets_dir}/fashion-iq
6
+
7
+ batch_size: ${machine.batch_size}
8
+ num_workers: ${machine.num_workers}
9
+
10
+ annotation:
11
+ train: ${paths.work_dir}/annotation/fashion-iq/cap.${data.category}.train.json
12
+ val: ${paths.work_dir}/annotation/fashion-iq/cap.${data.category}.val.json
13
+
14
+ targets:
15
+ train: ${paths.work_dir}/annotation/fashion-iq/split.${data.category}.train.json
16
+ val: ${paths.work_dir}/annotation/fashion-iq/split.${data.category}.val.json
17
+
18
+ img_dirs:
19
+ train: ${data.dataset_dir}/images/
20
+ val: ${data.dataset_dir}/images/
21
+
22
+ emb_dirs:
23
+ train: ${data.dataset_dir}/blip-embs-large/
24
+ val: ${data.dataset_dir}/blip-embs-large/
25
+
26
+ image_size: 384
27
+
28
+ category: ???
configs/data/fashioniq-dress.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - fashioniq-base.yaml
3
+
4
+ category: dress
configs/data/fashioniq-shirt.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - fashioniq-base.yaml
3
+
4
+ category: shirt
configs/data/fashioniq-toptee.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - fashioniq-base.yaml
3
+
4
+ category: toptee
configs/data/webvid-covr.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataname: webvid-covr
2
+ _target_: src.data.webvid_covr.WebVidCoVRDataModule
3
+
4
+ image_size: 384
5
+ iterate: "pth2"
6
+ vid_query_method: middle
7
+ vid_frames: 1
8
+ emb_pool: query
9
+
10
+ # Paths
11
+ dataset_dir: ${paths.datasets_dir}/WebVid
12
+
13
+ batch_size: ${machine.batch_size}
14
+ num_workers: ${machine.num_workers}
15
+
16
+ annotation:
17
+ train: ${paths.work_dir}/annotation/webvid-covr/webvid2m-covr_train.csv
18
+ val: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_val.csv
19
+
20
+ vid_dirs:
21
+ train: ${data.dataset_dir}/2M/train
22
+ val: ${data.dataset_dir}/8M/train
23
+
24
+ emb_dirs:
25
+ train: ${data.dataset_dir}/2M/blip-vid-embs-${model.model.vit}-all
26
+ val: ${data.dataset_dir}/8M/blip-vid-embs-${model.model.vit}-all
configs/data/webvid-covr_rule-based.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataname: webvid-covr-rule-based
2
+ _target_: src.data.webvid_covr_rulebased.WebVidCoVRDataModuleRuleBased
3
+
4
+ image_size: 384
5
+ iterate: "pth2"
6
+ vid_query_method: middle
7
+ vid_frames: 1
8
+ emb_pool: query
9
+
10
+ # Paths
11
+ dataset_dir: ${paths.datasets_dir}/WebVid
12
+
13
+ batch_size: ${machine.batch_size}
14
+ num_workers: ${machine.num_workers}
15
+
16
+ annotation:
17
+ train: ${paths.work_dir}/annotation/webvid-covr/webvid2m-covr_train.csv
18
+ val: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_val.csv
19
+
20
+ vid_dirs:
21
+ train: ${data.dataset_dir}/2M/train
22
+ val: ${data.dataset_dir}/8M/train
23
+
24
+ emb_dirs:
25
+ train: ${data.dataset_dir}/2M/blip-vid-embs-${model.model.vit}-all
26
+ val: ${data.dataset_dir}/8M/blip-vid-embs-${model.model.vit}-all
configs/experiment/cirr.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: cirr.yaml
5
+ - override /test: cirr.yaml
6
+ # - override /model/ckpt: webvid-covr.yaml
7
+
8
+ model:
9
+ optimizer:
10
+ lr: 1e-4
11
+
12
+ trainer:
13
+ max_epochs: 6
configs/experiment/covr_hard-negatives.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr.yaml
5
+ - override /test: main.yaml
6
+ - override /model/loss: cross_entropy
configs/experiment/covr_iterate-triplets.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr.yaml
5
+ - override /test: main.yaml
6
+
7
+ run_name: "iterate-triplets"
8
+
9
+ data:
10
+ iterate: "triplets"
11
+
12
+ trainer:
13
+ max_epochs: 2
14
+ print_interval: 1
configs/experiment/covr_late-fusion.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr.yaml
5
+ - override /model: blip-large_late-fusion.yaml
6
+ - override /test: webvid-covr_late-fusion.yaml
7
+
8
+ val: False
9
+
10
+ model:
11
+ optimizer:
12
+ lr: 1e-4
configs/experiment/covr_middle-emb.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ run_name: "middle_emb"
4
+
5
+ defaults:
6
+ - override /data: webvid-covr.yaml
7
+ - override /test: webvid-covr.yaml
8
+
9
+ data:
10
+ emb_pool: "middle"
11
+
12
+ test:
13
+ webvid_covr:
14
+ emb_pool: "middle"
configs/experiment/covr_only-text.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr.yaml
5
+ - override /test: webvid-covr_text.yaml
6
+ - override /model: blip-large_text.yaml
7
+
8
+ val: False
configs/experiment/covr_only-visual.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr.yaml
5
+ - override /test: webvid-covr_visual.yaml
6
+ - override /model: blip-large_visual.yaml
7
+
8
+ val: False
9
+
10
+ run_name: only-visual
11
+
12
+ machine:
13
+ batch_size: 64 # We have to reduce the learning rate because we are training the ViT
14
+
15
+ model:
16
+ optimizer:
17
+ lr: 0.125e-4 # We have to reduce the learning rate because we are reducing the batch size
18
+
19
+ data:
20
+ emb_pool: mean
configs/experiment/covr_random-frame.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ run_name: "random-frame"
4
+
5
+ defaults:
6
+ - override /data: webvid-covr.yaml
7
+ - override /test: webvid-covr.yaml
8
+
9
+ data:
10
+ vid_query_method: "random"
configs/experiment/covr_rule-based.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: webvid-covr_rule-based.yaml
5
+ - override /test: main.yaml
6
+
7
+ trainer:
8
+ print_interval: 2
configs/experiment/fiq-dress.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: fashioniq-dress.yaml
5
+ - override /test: fashioniq-dress.yaml
6
+ - override /model/ckpt: webvid-covr.yaml
7
+
8
+ model:
9
+ optimizer:
10
+ lr: 1e-4
11
+
12
+ machine:
13
+ batch_size: 256
14
+
15
+ trainer:
16
+ max_epochs: 6
17
+ print_interval: 2
configs/experiment/fiq-shirt.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: fashioniq-shirt.yaml
5
+ - override /test: fashioniq-shirt.yaml
6
+ - override /model/ckpt: webvid-covr.yaml
7
+
8
+ model:
9
+ optimizer:
10
+ lr: 1e-4
11
+
12
+ machine:
13
+ batch_size: 256
14
+
15
+ trainer:
16
+ max_epochs: 6
17
+ print_interval: 2
configs/experiment/fiq-toptee.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: fashioniq-toptee.yaml
5
+ - override /test: fashioniq-toptee.yaml
6
+ - override /model/ckpt: webvid-covr.yaml
7
+
8
+ model:
9
+ optimizer:
10
+ lr: 1e-4
11
+
12
+ machine:
13
+ batch_size: 256
14
+
15
+ trainer:
16
+ max_epochs: 6
17
+ print_interval: 2
configs/machine/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ root_dir: "."
3
+
4
+ # path to working directory
5
+ work_dir: ${hydra:runtime.cwd}
6
+
7
+ # path to output directory, created dynamically by hydra
8
+ # path generation pattern is specified in `configs/hydra/default.yaml`
9
+ # use it to store all files generated during the run, like ckpts and metrics
10
+ output_dir: ${hydra:runtime.output_dir}
11
+
12
+ # path to dataset directory
13
+ datasets_dir: ${hydra:runtime.cwd}/datasets/
14
+
15
+ # path to logging directory
16
+ log_dir: ${paths.root_dir}/logs/
configs/machine/server.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: server
2
+
3
+ # specific attributes to this machine
4
+ batch_size: 512
5
+ num_workers: 8
6
+
7
+ defaults:
8
+ - default@paths
configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/model/blip-large.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modelname: blip-large
2
+ _target_: src.model.blip_cir.blip_cir
3
+
4
+ ckpt_path: ${model.ckpt.path}
5
+
6
+ model:
7
+ _target_: src.model.blip_cir.BLIPCir
8
+ med_config: ${paths.work_dir}/configs/med_config.json
9
+ image_size: ${data.image_size}
10
+ vit: "large"
11
+ vit_grad_ckpt: True
12
+ vit_ckpt_layer: 12
13
+ embed_dim: 256
14
+ train_vit: False
15
+ loss: ${model.loss}
configs/model/blip-large_text.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modelname: blip-large-text
2
+ _target_: src.model.blip_cir_text.blip_cir_text
3
+
4
+ ckpt_path: ${model.ckpt.path}
5
+
6
+ model:
7
+ _target_: src.model.blip_cir_text.BLIPCirTextOnly
8
+ med_config: ${paths.work_dir}/configs/med_config.json
9
+ image_size: ${data.image_size}
10
+ vit: "large"
11
+ vit_grad_ckpt: True
12
+ vit_ckpt_layer: 12
13
+ embed_dim: 256
14
+ train_vit: False
15
+ loss: ${model.loss}
configs/model/blip-large_visual.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modelname: blip-large-visual
2
+ _target_: src.model.blip_cir_visual.blip_cir_visual
3
+
4
+ ckpt_path: ${model.ckpt.path}
5
+
6
+ model:
7
+ _target_: src.model.blip_cir_visual.BLIPCirVisualOnly
8
+ med_config: ${paths.work_dir}/configs/med_config.json
9
+ image_size: ${data.image_size}
10
+ vit: "large"
11
+ vit_grad_ckpt: True
12
+ vit_ckpt_layer: 12
13
+ embed_dim: 256
14
+ train_vit: True
15
+ loss: ${model.loss}
configs/model/ckpt/blip-l-coco.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: blip-l-coco
2
+
3
+ path: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
configs/model/ckpt/cirr-gt.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: cirr-gt
2
+
3
+ path: ${paths.work_dir}/outputs/cirr/blip-large/blip-l-coco/tv-False_loss-hnnce_lr-1e-05/base/ckpt_4.ckpt
configs/model/ckpt/cirr_ft-covr+gt.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: cirr_ft-covr+gt
2
+
3
+ path: ${paths.work_dir}/outputs/cirr/blip-large/webvid-covr/tv-False_loss-hnnce_lr-0.0001/base/ckpt_5.ckpt
configs/model/ckpt/webvid-covr.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: webvid-covr
2
+
3
+ path: ${paths.work_dir}/outputs/webvid-covr/blip-large/blip-l-coco/tv-False_loss-hnnce_lr-1e-05/good/ckpt_4.ckpt
configs/model/loss/cross_entropy.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: src.model.loss.CrossEntropyLoss
2
+ name: ce
configs/model/loss/hn_nce.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: src.model.loss.HardNegativeNCE
2
+ name: hnnce
3
+
4
+ alpha: 1
5
+ beta: 0.5
configs/model/optimizer/adamw.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: torch.optim.AdamW
2
+ _partial_: true
3
+
4
+ lr: 1e-05
5
+ weight_decay: 0.05
configs/model/scheduler/cosine.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: src.tools.scheduler.CosineSchedule
2
+
3
+ init_lr: ${model.optimizer.lr}
4
+ min_lr: 0
5
+ decay_rate: ${model.optimizer.weight_decay}
6
+ max_epochs: ${trainer.max_epochs}
configs/model/scheduler/step.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: src.tools.scheduler.StepSchedule
2
+
3
+ init_lr: ${model.optimizer.lr}
4
+ decay_rate: ${model.optimizer.weight_decay}
5
+ min_lr: 0
configs/test.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/test/${model.modelname}/${model.ckpt.name}/${run_name}
4
+ job: # automatically go to the job folder (needed for hydra > 1.2 with new behavior)
5
+ chdir: true
6
+
7
+ # Global configurations shared between different modules
8
+ run_name: base
9
+
10
+ seed: 1234
11
+ logger_level: INFO
12
+
13
+ # Composing nested config with default
14
+ defaults:
15
+ - _self_
16
+ - data: cirr
17
+ - test: all
18
+ - machine: server
19
+ - trainer: gpu
20
+ - model: blip-large
21
+ - model/ckpt: blip-l-coco
22
+ - model/loss: hn_nce
23
+ - trainer/logger: none
24
+
25
+ - experiment: null
26
+
27
+ paths: ${machine.paths}
configs/test/all.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - cirr.yaml
3
+ - webvid-covr.yaml
4
+ - fashioniq-dress.yaml
5
+ - fashioniq-shirt.yaml
6
+ - fashioniq-toptee.yaml
configs/test/cirr.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cirr:
2
+ dataname: cirr
3
+ _target_: src.data.cirr.CIRRTestDataModule
4
+
5
+ test:
6
+ _target_: src.test.cirr.TestCirr
7
+
8
+ batch_size: ${machine.batch_size}
9
+ num_workers: ${machine.num_workers}
10
+
11
+ annotation: ${paths.work_dir}/annotation/cirr/cap.rc2.test1.json
12
+ img_dirs: ${paths.datasets_dir}/CIRR/images/test1
13
+ emb_dirs: ${paths.datasets_dir}/CIRR/blip-embs-large/test1
14
+
15
+ image_size: 384
configs/test/fashioniq-dress.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fashioniq-dress:
2
+ dataname: fashioniq-dress
3
+ _target_: src.data.fashioniq.FashionIQTestDataModule
4
+
5
+ batch_size: ${machine.batch_size}
6
+ num_workers: ${machine.num_workers}
7
+
8
+ annotation: ${paths.work_dir}/annotation/fashion-iq/cap.dress.val.json
9
+ targets: ${paths.work_dir}/annotation/fashion-iq/split.dress.val.json
10
+
11
+ img_dirs: ${paths.datasets_dir}/fashion-iq/images/
12
+ emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
13
+
14
+ image_size: 384
15
+
16
+ test:
17
+ _target_: src.test.fashioniq.TestFashionIQ
18
+ category: dress
configs/test/fashioniq-shirt.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fashioniq-shirt:
2
+ dataname: fashioniq-shirt
3
+ _target_: src.data.fashioniq.FashionIQTestDataModule
4
+
5
+ batch_size: ${machine.batch_size}
6
+ num_workers: ${machine.num_workers}
7
+
8
+ annotation: ${paths.work_dir}/annotation/fashion-iq/cap.shirt.val.json
9
+ targets: ${paths.work_dir}/annotation/fashion-iq/split.shirt.val.json
10
+
11
+ img_dirs: ${paths.datasets_dir}/fashion-iq/images/
12
+ emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
13
+
14
+ image_size: 384
15
+
16
+ test:
17
+ _target_: src.test.fashioniq.TestFashionIQ
18
+ category: shirt
configs/test/fashioniq-toptee.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fashioniq-toptee:
2
+ dataname: fashioniq-toptee
3
+ _target_: src.data.fashioniq.FashionIQTestDataModule
4
+
5
+ batch_size: ${machine.batch_size}
6
+ num_workers: ${machine.num_workers}
7
+
8
+ annotation: ${paths.work_dir}/annotation/fashion-iq/cap.toptee.val.json
9
+ targets: ${paths.work_dir}/annotation/fashion-iq/split.toptee.val.json
10
+
11
+ img_dirs: ${paths.datasets_dir}/fashion-iq/images/
12
+ emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
13
+
14
+ image_size: 384
15
+
16
+ test:
17
+ _target_: src.test.fashioniq.TestFashionIQ
18
+ category: toptee
configs/test/fashioniq.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - fashioniq-dress.yaml
3
+ - fashioniq-shirt.yaml
4
+ - fashioniq-toptee.yaml
configs/test/main.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ defaults:
2
+ - cirr.yaml
3
+ - webvid-covr.yaml
configs/test/webvid-covr.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ webvid_covr:
2
+ dataname: webvid-covr
3
+ _target_: src.data.webvid_covr.WebVidCoVRTestDataModule
4
+
5
+ image_size: 384
6
+
7
+ vid_query_method: middle
8
+ vid_frames: 1
9
+ emb_pool: query
10
+
11
+ batch_size: ${machine.batch_size}
12
+ num_workers: ${machine.num_workers}
13
+
14
+ # Paths
15
+ annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
16
+ vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
17
+ emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
18
+
19
+ test:
20
+ _target_: src.test.webvid_covr.TestWebVidCoVR
configs/test/webvid-covr_text.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ webvid_covr_text:
2
+ dataname: webvid-covr_text
3
+ _target_: src.data.webvid_covr.WebVidCoVRTestDataModule
4
+
5
+ image_size: 384
6
+
7
+ vid_query_method: middle
8
+ vid_frames: 1
9
+ emb_pool: query
10
+
11
+ batch_size: ${machine.batch_size}
12
+ num_workers: ${machine.num_workers}
13
+
14
+ # Paths
15
+ annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
16
+ vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
17
+ emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
18
+
19
+ test:
20
+ _target_: src.test.webvid_covr_exp.TestWebVidCoVRTextOnly
configs/test/webvid-covr_visual.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ webvid_covr_visual:
2
+ dataname: webvid-covr_visual
3
+ _target_: src.data.webvid_covr.WebVidCoVRTestDataModule
4
+
5
+ image_size: 384
6
+
7
+ vid_query_method: middle
8
+ vid_frames: 1
9
+ emb_pool: mean
10
+
11
+ batch_size: ${machine.batch_size}
12
+ num_workers: ${machine.num_workers}
13
+
14
+ # Paths
15
+ annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
16
+ vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
17
+ emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
18
+
19
+ test:
20
+ _target_: src.test.webvid_covr_exp.TestWebVidCoVRVisualOnly
configs/train.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${data.dataname}/${model.modelname}/${model.ckpt.name}/${experiment}/${run_name}
4
+ job: # automatically go to the job folder (needed for hydra > 1.2 with new behavior)
5
+ chdir: true
6
+
7
+ # Global configurations shared between different modules
8
+ experiment: tv-${model.model.train_vit}_loss-${model.model.loss.name}_lr-${model.optimizer.lr}
9
+ run_name: base
10
+
11
+ seed: 1234
12
+ logger_level: INFO
13
+
14
+ # Composing nested config with default
15
+ defaults:
16
+ - _self_
17
+ - data: webvid-covr
18
+ - machine: server
19
+ - trainer: gpu
20
+ - test: all
21
+ - trainer/logger: csv
22
+ - model: blip-large
23
+ - model/optimizer: adamw
24
+ - model/scheduler: cosine
25
+ - model/loss: hn_nce
26
+ - model/ckpt: blip-l-coco
27
+
28
+ - experiment: null
29
+
30
+ paths: ${machine.paths}
31
+
32
+ # Flag to validate at the end of every epoch
33
+ val: True
configs/trainer/cpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+
4
+ accelerator: cpu
5
+ devices: 1
configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+
4
+ strategy: ddp
5
+
6
+ accelerator: gpu
7
+ devices: 4
8
+ num_nodes: 1
9
+
10
+ fabric:
11
+ num_nodes: ${trainer.num_nodes}
12
+ strategy: ${trainer.strategy}