Spaces:
Sleeping
Sleeping
OmkarThawakar
commited on
Commit
•
ed00004
1
Parent(s):
7baf9f3
initail commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +13 -0
- app.py +208 -0
- configs/data/cirr.yaml +22 -0
- configs/data/fashioniq-base.yaml +28 -0
- configs/data/fashioniq-dress.yaml +4 -0
- configs/data/fashioniq-shirt.yaml +4 -0
- configs/data/fashioniq-toptee.yaml +4 -0
- configs/data/webvid-covr.yaml +26 -0
- configs/data/webvid-covr_rule-based.yaml +26 -0
- configs/experiment/cirr.yaml +13 -0
- configs/experiment/covr_hard-negatives.yaml +6 -0
- configs/experiment/covr_iterate-triplets.yaml +14 -0
- configs/experiment/covr_late-fusion.yaml +12 -0
- configs/experiment/covr_middle-emb.yaml +14 -0
- configs/experiment/covr_only-text.yaml +8 -0
- configs/experiment/covr_only-visual.yaml +20 -0
- configs/experiment/covr_random-frame.yaml +10 -0
- configs/experiment/covr_rule-based.yaml +8 -0
- configs/experiment/fiq-dress.yaml +17 -0
- configs/experiment/fiq-shirt.yaml +17 -0
- configs/experiment/fiq-toptee.yaml +17 -0
- configs/machine/default.yaml +16 -0
- configs/machine/server.yaml +8 -0
- configs/med_config.json +21 -0
- configs/model/blip-large.yaml +15 -0
- configs/model/blip-large_text.yaml +15 -0
- configs/model/blip-large_visual.yaml +15 -0
- configs/model/ckpt/blip-l-coco.yaml +3 -0
- configs/model/ckpt/cirr-gt.yaml +3 -0
- configs/model/ckpt/cirr_ft-covr+gt.yaml +3 -0
- configs/model/ckpt/webvid-covr.yaml +3 -0
- configs/model/loss/cross_entropy.yaml +2 -0
- configs/model/loss/hn_nce.yaml +5 -0
- configs/model/optimizer/adamw.yaml +5 -0
- configs/model/scheduler/cosine.yaml +6 -0
- configs/model/scheduler/step.yaml +5 -0
- configs/test.yaml +27 -0
- configs/test/all.yaml +6 -0
- configs/test/cirr.yaml +15 -0
- configs/test/fashioniq-dress.yaml +18 -0
- configs/test/fashioniq-shirt.yaml +18 -0
- configs/test/fashioniq-toptee.yaml +18 -0
- configs/test/fashioniq.yaml +4 -0
- configs/test/main.yaml +3 -0
- configs/test/webvid-covr.yaml +20 -0
- configs/test/webvid-covr_text.yaml +20 -0
- configs/test/webvid-covr_visual.yaml +20 -0
- configs/train.yaml +33 -0
- configs/trainer/cpu.yaml +5 -0
- configs/trainer/ddp.yaml +12 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
outputs/
|
2 |
+
datasets/sidechef/images
|
3 |
+
datasets/sidechef/sample_images
|
4 |
+
datasets/sidechef/my_tags.json
|
5 |
+
datasets/sidechef/tag_categories.json
|
6 |
+
datasets/sidechef/tags.json
|
7 |
+
launching
|
8 |
+
annotation/
|
9 |
+
.vscode/
|
10 |
+
bert-base-uncased/
|
11 |
+
delete*
|
12 |
+
__pycache__/
|
13 |
+
env/
|
app.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# from src.data.embs import ImageDataset
|
13 |
+
from src.model.blip_embs import blip_embs
|
14 |
+
from src.data.transforms import transform_test
|
15 |
+
|
16 |
+
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
from langchain_core.output_parsers import StrOutputParser
|
20 |
+
from langchain_core.prompts import ChatPromptTemplate
|
21 |
+
from langchain_groq import ChatGroq
|
22 |
+
|
23 |
+
|
24 |
+
# GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
25 |
+
GROQ_API_KEY = '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'
|
26 |
+
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
27 |
+
|
28 |
+
# Initialize LLM
|
29 |
+
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2)
|
30 |
+
|
31 |
+
# QA system prompt and chain
|
32 |
+
qa_system_prompt = """
|
33 |
+
Prompt:
|
34 |
+
You are a highly intelligent assistant. Use the following context to answer user questions. Analyze the data carefully and generate a clear, concise, and informative response to the user's question based on this data.
|
35 |
+
|
36 |
+
Response Guidelines:
|
37 |
+
- Use only the information provided in the data to answer the question.
|
38 |
+
- Ensure the answer is accurate and directly related to the question.
|
39 |
+
- If the data is insufficient to answer the question, politey apologise and tell the user that there is insufficient data available to answer their question.
|
40 |
+
- Provide the response in a conversational yet professional tone.
|
41 |
+
|
42 |
+
Context:
|
43 |
+
{context}
|
44 |
+
"""
|
45 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
46 |
+
[
|
47 |
+
("system", qa_system_prompt),
|
48 |
+
("human", "{input}")
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
question_answer_chain = qa_prompt | llm | StrOutputParser()
|
53 |
+
|
54 |
+
|
55 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
56 |
+
|
57 |
+
def __init__(self, stops=[], encounters=1):
|
58 |
+
super().__init__()
|
59 |
+
self.stops = stops
|
60 |
+
|
61 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
62 |
+
for stop in self.stops:
|
63 |
+
if torch.all(input_ids[:, -len(stop):] == stop).item():
|
64 |
+
return True
|
65 |
+
|
66 |
+
return False
|
67 |
+
|
68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
69 |
+
|
70 |
+
def get_blip_config(model="base"):
|
71 |
+
config = dict()
|
72 |
+
if model == "base":
|
73 |
+
config[
|
74 |
+
"pretrained"
|
75 |
+
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth "
|
76 |
+
config["vit"] = "base"
|
77 |
+
config["batch_size_train"] = 32
|
78 |
+
config["batch_size_test"] = 16
|
79 |
+
config["vit_grad_ckpt"] = True
|
80 |
+
config["vit_ckpt_layer"] = 4
|
81 |
+
config["init_lr"] = 1e-5
|
82 |
+
elif model == "large":
|
83 |
+
config[
|
84 |
+
"pretrained"
|
85 |
+
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
|
86 |
+
config["vit"] = "large"
|
87 |
+
config["batch_size_train"] = 16
|
88 |
+
config["batch_size_test"] = 32
|
89 |
+
config["vit_grad_ckpt"] = True
|
90 |
+
config["vit_ckpt_layer"] = 12
|
91 |
+
config["init_lr"] = 5e-6
|
92 |
+
|
93 |
+
config["image_size"] = 384
|
94 |
+
config["queue_size"] = 57600
|
95 |
+
config["alpha"] = 0.4
|
96 |
+
config["k_test"] = 256
|
97 |
+
config["negative_all_rank"] = True
|
98 |
+
|
99 |
+
return config
|
100 |
+
|
101 |
+
|
102 |
+
print("Creating model")
|
103 |
+
config = get_blip_config("large")
|
104 |
+
|
105 |
+
model = blip_embs(
|
106 |
+
pretrained=config["pretrained"],
|
107 |
+
image_size=config["image_size"],
|
108 |
+
vit=config["vit"],
|
109 |
+
vit_grad_ckpt=config["vit_grad_ckpt"],
|
110 |
+
vit_ckpt_layer=config["vit_ckpt_layer"],
|
111 |
+
queue_size=config["queue_size"],
|
112 |
+
negative_all_rank=config["negative_all_rank"],
|
113 |
+
)
|
114 |
+
|
115 |
+
model = model.to(device)
|
116 |
+
model.eval()
|
117 |
+
print("Model Loaded !")
|
118 |
+
print("="*50)
|
119 |
+
|
120 |
+
transform = transform_test(384)
|
121 |
+
|
122 |
+
print("Loading Data")
|
123 |
+
df = pd.read_json("datasets/sidechef/my_recipes.json")
|
124 |
+
|
125 |
+
print("Loading Target Embedding")
|
126 |
+
tar_img_feats = []
|
127 |
+
for _id in df["id_"].tolist():
|
128 |
+
tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))
|
129 |
+
|
130 |
+
tar_img_feats = torch.cat(tar_img_feats, dim=0)
|
131 |
+
|
132 |
+
|
133 |
+
class Chat:
|
134 |
+
|
135 |
+
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None):
|
136 |
+
self.device = device
|
137 |
+
self.model = model
|
138 |
+
self.transform = transform
|
139 |
+
self.df = dataframe
|
140 |
+
self.tar_img_feats = tar_img_feats
|
141 |
+
self.img_feats = None
|
142 |
+
self.target_recipe = None
|
143 |
+
self.messages = []
|
144 |
+
|
145 |
+
if stopping_criteria is not None:
|
146 |
+
self.stopping_criteria = stopping_criteria
|
147 |
+
else:
|
148 |
+
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
149 |
+
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
150 |
+
|
151 |
+
def encode_image(self, image_path):
|
152 |
+
img = Image.fromarray(image_path).convert("RGB")
|
153 |
+
img = self.transform(img).unsqueeze(0)
|
154 |
+
img = img.to(self.device)
|
155 |
+
img_embs = model.visual_encoder(img)
|
156 |
+
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()
|
157 |
+
|
158 |
+
self.img_feats = img_feats
|
159 |
+
|
160 |
+
self.get_target(self.img_feats, self.tar_img_feats)
|
161 |
+
|
162 |
+
def get_target(self, img_feats, tar_img_feats) :
|
163 |
+
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()
|
164 |
+
index = np.argsort(score)[::-1][0] + 1
|
165 |
+
self.target_recipe = df.iloc[index]
|
166 |
+
|
167 |
+
def ask(self):
|
168 |
+
return json.dumps(self.target_recipe.to_json())
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
chat = Chat(model,transform,df,tar_img_feats)
|
173 |
+
print("Chat Initialized !")
|
174 |
+
|
175 |
+
|
176 |
+
custom_css = """
|
177 |
+
.primary{
|
178 |
+
background-color: #4CAF50; /* Green */
|
179 |
+
}
|
180 |
+
"""
|
181 |
+
|
182 |
+
|
183 |
+
def respond_to_user(image, message):
|
184 |
+
# Process the image and message here
|
185 |
+
# For demonstration, I'll just return a simple text response
|
186 |
+
chat = Chat(model,transform,df,tar_img_feats)
|
187 |
+
chat.encode_image(image)
|
188 |
+
data = chat.ask()
|
189 |
+
formated_input = {
|
190 |
+
'input': message,
|
191 |
+
'context': data
|
192 |
+
}
|
193 |
+
try:
|
194 |
+
response = question_answer_chain.invoke(formated_input)
|
195 |
+
except Exception as e:
|
196 |
+
response = {'content':"An error occurred while processing your request."}
|
197 |
+
return response
|
198 |
+
|
199 |
+
iface = gr.Interface(
|
200 |
+
fn=respond_to_user,
|
201 |
+
inputs=[gr.Image(), gr.Textbox(label="Ask Query")],
|
202 |
+
outputs=gr.Textbox(label="Nutrition-GPT"),
|
203 |
+
title="Nutrition-GPT Demo",
|
204 |
+
description="Upload an food image and ask queries!",
|
205 |
+
css=".component-12 {background-color: red}",
|
206 |
+
)
|
207 |
+
|
208 |
+
iface.launch()
|
configs/data/cirr.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataname: cirr
|
2 |
+
_target_: src.data.cirr.CIRRDataModule
|
3 |
+
|
4 |
+
# Paths
|
5 |
+
dataset_dir: ${paths.datasets_dir}/CIRR
|
6 |
+
|
7 |
+
batch_size: ${machine.batch_size}
|
8 |
+
num_workers: ${machine.num_workers}
|
9 |
+
|
10 |
+
annotation:
|
11 |
+
train: ${paths.work_dir}/annotation/cirr/cap.rc2.train.json
|
12 |
+
val: ${paths.work_dir}/annotation/cirr/cap.rc2.val.json
|
13 |
+
|
14 |
+
img_dirs:
|
15 |
+
train: ${data.dataset_dir}/images/train
|
16 |
+
val: ${data.dataset_dir}/images/dev
|
17 |
+
|
18 |
+
emb_dirs:
|
19 |
+
train: ${data.dataset_dir}/blip-embs-large/train
|
20 |
+
val: ${data.dataset_dir}/blip-embs-large/dev
|
21 |
+
|
22 |
+
image_size: 384
|
configs/data/fashioniq-base.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataname: fashioniq-${data.category}
|
2 |
+
_target_: src.data.fashioniq.FashionIQDataModule
|
3 |
+
|
4 |
+
# Paths
|
5 |
+
dataset_dir: ${paths.datasets_dir}/fashion-iq
|
6 |
+
|
7 |
+
batch_size: ${machine.batch_size}
|
8 |
+
num_workers: ${machine.num_workers}
|
9 |
+
|
10 |
+
annotation:
|
11 |
+
train: ${paths.work_dir}/annotation/fashion-iq/cap.${data.category}.train.json
|
12 |
+
val: ${paths.work_dir}/annotation/fashion-iq/cap.${data.category}.val.json
|
13 |
+
|
14 |
+
targets:
|
15 |
+
train: ${paths.work_dir}/annotation/fashion-iq/split.${data.category}.train.json
|
16 |
+
val: ${paths.work_dir}/annotation/fashion-iq/split.${data.category}.val.json
|
17 |
+
|
18 |
+
img_dirs:
|
19 |
+
train: ${data.dataset_dir}/images/
|
20 |
+
val: ${data.dataset_dir}/images/
|
21 |
+
|
22 |
+
emb_dirs:
|
23 |
+
train: ${data.dataset_dir}/blip-embs-large/
|
24 |
+
val: ${data.dataset_dir}/blip-embs-large/
|
25 |
+
|
26 |
+
image_size: 384
|
27 |
+
|
28 |
+
category: ???
|
configs/data/fashioniq-dress.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- fashioniq-base.yaml
|
3 |
+
|
4 |
+
category: dress
|
configs/data/fashioniq-shirt.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- fashioniq-base.yaml
|
3 |
+
|
4 |
+
category: shirt
|
configs/data/fashioniq-toptee.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- fashioniq-base.yaml
|
3 |
+
|
4 |
+
category: toptee
|
configs/data/webvid-covr.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataname: webvid-covr
|
2 |
+
_target_: src.data.webvid_covr.WebVidCoVRDataModule
|
3 |
+
|
4 |
+
image_size: 384
|
5 |
+
iterate: "pth2"
|
6 |
+
vid_query_method: middle
|
7 |
+
vid_frames: 1
|
8 |
+
emb_pool: query
|
9 |
+
|
10 |
+
# Paths
|
11 |
+
dataset_dir: ${paths.datasets_dir}/WebVid
|
12 |
+
|
13 |
+
batch_size: ${machine.batch_size}
|
14 |
+
num_workers: ${machine.num_workers}
|
15 |
+
|
16 |
+
annotation:
|
17 |
+
train: ${paths.work_dir}/annotation/webvid-covr/webvid2m-covr_train.csv
|
18 |
+
val: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_val.csv
|
19 |
+
|
20 |
+
vid_dirs:
|
21 |
+
train: ${data.dataset_dir}/2M/train
|
22 |
+
val: ${data.dataset_dir}/8M/train
|
23 |
+
|
24 |
+
emb_dirs:
|
25 |
+
train: ${data.dataset_dir}/2M/blip-vid-embs-${model.model.vit}-all
|
26 |
+
val: ${data.dataset_dir}/8M/blip-vid-embs-${model.model.vit}-all
|
configs/data/webvid-covr_rule-based.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataname: webvid-covr-rule-based
|
2 |
+
_target_: src.data.webvid_covr_rulebased.WebVidCoVRDataModuleRuleBased
|
3 |
+
|
4 |
+
image_size: 384
|
5 |
+
iterate: "pth2"
|
6 |
+
vid_query_method: middle
|
7 |
+
vid_frames: 1
|
8 |
+
emb_pool: query
|
9 |
+
|
10 |
+
# Paths
|
11 |
+
dataset_dir: ${paths.datasets_dir}/WebVid
|
12 |
+
|
13 |
+
batch_size: ${machine.batch_size}
|
14 |
+
num_workers: ${machine.num_workers}
|
15 |
+
|
16 |
+
annotation:
|
17 |
+
train: ${paths.work_dir}/annotation/webvid-covr/webvid2m-covr_train.csv
|
18 |
+
val: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_val.csv
|
19 |
+
|
20 |
+
vid_dirs:
|
21 |
+
train: ${data.dataset_dir}/2M/train
|
22 |
+
val: ${data.dataset_dir}/8M/train
|
23 |
+
|
24 |
+
emb_dirs:
|
25 |
+
train: ${data.dataset_dir}/2M/blip-vid-embs-${model.model.vit}-all
|
26 |
+
val: ${data.dataset_dir}/8M/blip-vid-embs-${model.model.vit}-all
|
configs/experiment/cirr.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: cirr.yaml
|
5 |
+
- override /test: cirr.yaml
|
6 |
+
# - override /model/ckpt: webvid-covr.yaml
|
7 |
+
|
8 |
+
model:
|
9 |
+
optimizer:
|
10 |
+
lr: 1e-4
|
11 |
+
|
12 |
+
trainer:
|
13 |
+
max_epochs: 6
|
configs/experiment/covr_hard-negatives.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr.yaml
|
5 |
+
- override /test: main.yaml
|
6 |
+
- override /model/loss: cross_entropy
|
configs/experiment/covr_iterate-triplets.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr.yaml
|
5 |
+
- override /test: main.yaml
|
6 |
+
|
7 |
+
run_name: "iterate-triplets"
|
8 |
+
|
9 |
+
data:
|
10 |
+
iterate: "triplets"
|
11 |
+
|
12 |
+
trainer:
|
13 |
+
max_epochs: 2
|
14 |
+
print_interval: 1
|
configs/experiment/covr_late-fusion.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr.yaml
|
5 |
+
- override /model: blip-large_late-fusion.yaml
|
6 |
+
- override /test: webvid-covr_late-fusion.yaml
|
7 |
+
|
8 |
+
val: False
|
9 |
+
|
10 |
+
model:
|
11 |
+
optimizer:
|
12 |
+
lr: 1e-4
|
configs/experiment/covr_middle-emb.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
run_name: "middle_emb"
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- override /data: webvid-covr.yaml
|
7 |
+
- override /test: webvid-covr.yaml
|
8 |
+
|
9 |
+
data:
|
10 |
+
emb_pool: "middle"
|
11 |
+
|
12 |
+
test:
|
13 |
+
webvid_covr:
|
14 |
+
emb_pool: "middle"
|
configs/experiment/covr_only-text.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr.yaml
|
5 |
+
- override /test: webvid-covr_text.yaml
|
6 |
+
- override /model: blip-large_text.yaml
|
7 |
+
|
8 |
+
val: False
|
configs/experiment/covr_only-visual.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr.yaml
|
5 |
+
- override /test: webvid-covr_visual.yaml
|
6 |
+
- override /model: blip-large_visual.yaml
|
7 |
+
|
8 |
+
val: False
|
9 |
+
|
10 |
+
run_name: only-visual
|
11 |
+
|
12 |
+
machine:
|
13 |
+
batch_size: 64 # We have to reduce the learning rate because we are training the ViT
|
14 |
+
|
15 |
+
model:
|
16 |
+
optimizer:
|
17 |
+
lr: 0.125e-4 # We have to reduce the learning rate because we are reducing the batch size
|
18 |
+
|
19 |
+
data:
|
20 |
+
emb_pool: mean
|
configs/experiment/covr_random-frame.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
run_name: "random-frame"
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- override /data: webvid-covr.yaml
|
7 |
+
- override /test: webvid-covr.yaml
|
8 |
+
|
9 |
+
data:
|
10 |
+
vid_query_method: "random"
|
configs/experiment/covr_rule-based.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: webvid-covr_rule-based.yaml
|
5 |
+
- override /test: main.yaml
|
6 |
+
|
7 |
+
trainer:
|
8 |
+
print_interval: 2
|
configs/experiment/fiq-dress.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: fashioniq-dress.yaml
|
5 |
+
- override /test: fashioniq-dress.yaml
|
6 |
+
- override /model/ckpt: webvid-covr.yaml
|
7 |
+
|
8 |
+
model:
|
9 |
+
optimizer:
|
10 |
+
lr: 1e-4
|
11 |
+
|
12 |
+
machine:
|
13 |
+
batch_size: 256
|
14 |
+
|
15 |
+
trainer:
|
16 |
+
max_epochs: 6
|
17 |
+
print_interval: 2
|
configs/experiment/fiq-shirt.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: fashioniq-shirt.yaml
|
5 |
+
- override /test: fashioniq-shirt.yaml
|
6 |
+
- override /model/ckpt: webvid-covr.yaml
|
7 |
+
|
8 |
+
model:
|
9 |
+
optimizer:
|
10 |
+
lr: 1e-4
|
11 |
+
|
12 |
+
machine:
|
13 |
+
batch_size: 256
|
14 |
+
|
15 |
+
trainer:
|
16 |
+
max_epochs: 6
|
17 |
+
print_interval: 2
|
configs/experiment/fiq-toptee.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: fashioniq-toptee.yaml
|
5 |
+
- override /test: fashioniq-toptee.yaml
|
6 |
+
- override /model/ckpt: webvid-covr.yaml
|
7 |
+
|
8 |
+
model:
|
9 |
+
optimizer:
|
10 |
+
lr: 1e-4
|
11 |
+
|
12 |
+
machine:
|
13 |
+
batch_size: 256
|
14 |
+
|
15 |
+
trainer:
|
16 |
+
max_epochs: 6
|
17 |
+
print_interval: 2
|
configs/machine/default.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path to root directory
|
2 |
+
root_dir: "."
|
3 |
+
|
4 |
+
# path to working directory
|
5 |
+
work_dir: ${hydra:runtime.cwd}
|
6 |
+
|
7 |
+
# path to output directory, created dynamically by hydra
|
8 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
9 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
10 |
+
output_dir: ${hydra:runtime.output_dir}
|
11 |
+
|
12 |
+
# path to dataset directory
|
13 |
+
datasets_dir: ${hydra:runtime.cwd}/datasets/
|
14 |
+
|
15 |
+
# path to logging directory
|
16 |
+
log_dir: ${paths.root_dir}/logs/
|
configs/machine/server.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: server
|
2 |
+
|
3 |
+
# specific attributes to this machine
|
4 |
+
batch_size: 512
|
5 |
+
num_workers: 8
|
6 |
+
|
7 |
+
defaults:
|
8 |
+
- default@paths
|
configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
configs/model/blip-large.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
modelname: blip-large
|
2 |
+
_target_: src.model.blip_cir.blip_cir
|
3 |
+
|
4 |
+
ckpt_path: ${model.ckpt.path}
|
5 |
+
|
6 |
+
model:
|
7 |
+
_target_: src.model.blip_cir.BLIPCir
|
8 |
+
med_config: ${paths.work_dir}/configs/med_config.json
|
9 |
+
image_size: ${data.image_size}
|
10 |
+
vit: "large"
|
11 |
+
vit_grad_ckpt: True
|
12 |
+
vit_ckpt_layer: 12
|
13 |
+
embed_dim: 256
|
14 |
+
train_vit: False
|
15 |
+
loss: ${model.loss}
|
configs/model/blip-large_text.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
modelname: blip-large-text
|
2 |
+
_target_: src.model.blip_cir_text.blip_cir_text
|
3 |
+
|
4 |
+
ckpt_path: ${model.ckpt.path}
|
5 |
+
|
6 |
+
model:
|
7 |
+
_target_: src.model.blip_cir_text.BLIPCirTextOnly
|
8 |
+
med_config: ${paths.work_dir}/configs/med_config.json
|
9 |
+
image_size: ${data.image_size}
|
10 |
+
vit: "large"
|
11 |
+
vit_grad_ckpt: True
|
12 |
+
vit_ckpt_layer: 12
|
13 |
+
embed_dim: 256
|
14 |
+
train_vit: False
|
15 |
+
loss: ${model.loss}
|
configs/model/blip-large_visual.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
modelname: blip-large-visual
|
2 |
+
_target_: src.model.blip_cir_visual.blip_cir_visual
|
3 |
+
|
4 |
+
ckpt_path: ${model.ckpt.path}
|
5 |
+
|
6 |
+
model:
|
7 |
+
_target_: src.model.blip_cir_visual.BLIPCirVisualOnly
|
8 |
+
med_config: ${paths.work_dir}/configs/med_config.json
|
9 |
+
image_size: ${data.image_size}
|
10 |
+
vit: "large"
|
11 |
+
vit_grad_ckpt: True
|
12 |
+
vit_ckpt_layer: 12
|
13 |
+
embed_dim: 256
|
14 |
+
train_vit: True
|
15 |
+
loss: ${model.loss}
|
configs/model/ckpt/blip-l-coco.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: blip-l-coco
|
2 |
+
|
3 |
+
path: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
|
configs/model/ckpt/cirr-gt.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: cirr-gt
|
2 |
+
|
3 |
+
path: ${paths.work_dir}/outputs/cirr/blip-large/blip-l-coco/tv-False_loss-hnnce_lr-1e-05/base/ckpt_4.ckpt
|
configs/model/ckpt/cirr_ft-covr+gt.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: cirr_ft-covr+gt
|
2 |
+
|
3 |
+
path: ${paths.work_dir}/outputs/cirr/blip-large/webvid-covr/tv-False_loss-hnnce_lr-0.0001/base/ckpt_5.ckpt
|
configs/model/ckpt/webvid-covr.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: webvid-covr
|
2 |
+
|
3 |
+
path: ${paths.work_dir}/outputs/webvid-covr/blip-large/blip-l-coco/tv-False_loss-hnnce_lr-1e-05/good/ckpt_4.ckpt
|
configs/model/loss/cross_entropy.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: src.model.loss.CrossEntropyLoss
|
2 |
+
name: ce
|
configs/model/loss/hn_nce.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.model.loss.HardNegativeNCE
|
2 |
+
name: hnnce
|
3 |
+
|
4 |
+
alpha: 1
|
5 |
+
beta: 0.5
|
configs/model/optimizer/adamw.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torch.optim.AdamW
|
2 |
+
_partial_: true
|
3 |
+
|
4 |
+
lr: 1e-05
|
5 |
+
weight_decay: 0.05
|
configs/model/scheduler/cosine.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.tools.scheduler.CosineSchedule
|
2 |
+
|
3 |
+
init_lr: ${model.optimizer.lr}
|
4 |
+
min_lr: 0
|
5 |
+
decay_rate: ${model.optimizer.weight_decay}
|
6 |
+
max_epochs: ${trainer.max_epochs}
|
configs/model/scheduler/step.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.tools.scheduler.StepSchedule
|
2 |
+
|
3 |
+
init_lr: ${model.optimizer.lr}
|
4 |
+
decay_rate: ${model.optimizer.weight_decay}
|
5 |
+
min_lr: 0
|
configs/test.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: outputs/test/${model.modelname}/${model.ckpt.name}/${run_name}
|
4 |
+
job: # automatically go to the job folder (needed for hydra > 1.2 with new behavior)
|
5 |
+
chdir: true
|
6 |
+
|
7 |
+
# Global configurations shared between different modules
|
8 |
+
run_name: base
|
9 |
+
|
10 |
+
seed: 1234
|
11 |
+
logger_level: INFO
|
12 |
+
|
13 |
+
# Composing nested config with default
|
14 |
+
defaults:
|
15 |
+
- _self_
|
16 |
+
- data: cirr
|
17 |
+
- test: all
|
18 |
+
- machine: server
|
19 |
+
- trainer: gpu
|
20 |
+
- model: blip-large
|
21 |
+
- model/ckpt: blip-l-coco
|
22 |
+
- model/loss: hn_nce
|
23 |
+
- trainer/logger: none
|
24 |
+
|
25 |
+
- experiment: null
|
26 |
+
|
27 |
+
paths: ${machine.paths}
|
configs/test/all.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- cirr.yaml
|
3 |
+
- webvid-covr.yaml
|
4 |
+
- fashioniq-dress.yaml
|
5 |
+
- fashioniq-shirt.yaml
|
6 |
+
- fashioniq-toptee.yaml
|
configs/test/cirr.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cirr:
|
2 |
+
dataname: cirr
|
3 |
+
_target_: src.data.cirr.CIRRTestDataModule
|
4 |
+
|
5 |
+
test:
|
6 |
+
_target_: src.test.cirr.TestCirr
|
7 |
+
|
8 |
+
batch_size: ${machine.batch_size}
|
9 |
+
num_workers: ${machine.num_workers}
|
10 |
+
|
11 |
+
annotation: ${paths.work_dir}/annotation/cirr/cap.rc2.test1.json
|
12 |
+
img_dirs: ${paths.datasets_dir}/CIRR/images/test1
|
13 |
+
emb_dirs: ${paths.datasets_dir}/CIRR/blip-embs-large/test1
|
14 |
+
|
15 |
+
image_size: 384
|
configs/test/fashioniq-dress.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fashioniq-dress:
|
2 |
+
dataname: fashioniq-dress
|
3 |
+
_target_: src.data.fashioniq.FashionIQTestDataModule
|
4 |
+
|
5 |
+
batch_size: ${machine.batch_size}
|
6 |
+
num_workers: ${machine.num_workers}
|
7 |
+
|
8 |
+
annotation: ${paths.work_dir}/annotation/fashion-iq/cap.dress.val.json
|
9 |
+
targets: ${paths.work_dir}/annotation/fashion-iq/split.dress.val.json
|
10 |
+
|
11 |
+
img_dirs: ${paths.datasets_dir}/fashion-iq/images/
|
12 |
+
emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
|
13 |
+
|
14 |
+
image_size: 384
|
15 |
+
|
16 |
+
test:
|
17 |
+
_target_: src.test.fashioniq.TestFashionIQ
|
18 |
+
category: dress
|
configs/test/fashioniq-shirt.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fashioniq-shirt:
|
2 |
+
dataname: fashioniq-shirt
|
3 |
+
_target_: src.data.fashioniq.FashionIQTestDataModule
|
4 |
+
|
5 |
+
batch_size: ${machine.batch_size}
|
6 |
+
num_workers: ${machine.num_workers}
|
7 |
+
|
8 |
+
annotation: ${paths.work_dir}/annotation/fashion-iq/cap.shirt.val.json
|
9 |
+
targets: ${paths.work_dir}/annotation/fashion-iq/split.shirt.val.json
|
10 |
+
|
11 |
+
img_dirs: ${paths.datasets_dir}/fashion-iq/images/
|
12 |
+
emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
|
13 |
+
|
14 |
+
image_size: 384
|
15 |
+
|
16 |
+
test:
|
17 |
+
_target_: src.test.fashioniq.TestFashionIQ
|
18 |
+
category: shirt
|
configs/test/fashioniq-toptee.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fashioniq-toptee:
|
2 |
+
dataname: fashioniq-toptee
|
3 |
+
_target_: src.data.fashioniq.FashionIQTestDataModule
|
4 |
+
|
5 |
+
batch_size: ${machine.batch_size}
|
6 |
+
num_workers: ${machine.num_workers}
|
7 |
+
|
8 |
+
annotation: ${paths.work_dir}/annotation/fashion-iq/cap.toptee.val.json
|
9 |
+
targets: ${paths.work_dir}/annotation/fashion-iq/split.toptee.val.json
|
10 |
+
|
11 |
+
img_dirs: ${paths.datasets_dir}/fashion-iq/images/
|
12 |
+
emb_dirs: ${paths.datasets_dir}/fashion-iq/blip-embs-large/
|
13 |
+
|
14 |
+
image_size: 384
|
15 |
+
|
16 |
+
test:
|
17 |
+
_target_: src.test.fashioniq.TestFashionIQ
|
18 |
+
category: toptee
|
configs/test/fashioniq.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- fashioniq-dress.yaml
|
3 |
+
- fashioniq-shirt.yaml
|
4 |
+
- fashioniq-toptee.yaml
|
configs/test/main.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- cirr.yaml
|
3 |
+
- webvid-covr.yaml
|
configs/test/webvid-covr.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
webvid_covr:
|
2 |
+
dataname: webvid-covr
|
3 |
+
_target_: src.data.webvid_covr.WebVidCoVRTestDataModule
|
4 |
+
|
5 |
+
image_size: 384
|
6 |
+
|
7 |
+
vid_query_method: middle
|
8 |
+
vid_frames: 1
|
9 |
+
emb_pool: query
|
10 |
+
|
11 |
+
batch_size: ${machine.batch_size}
|
12 |
+
num_workers: ${machine.num_workers}
|
13 |
+
|
14 |
+
# Paths
|
15 |
+
annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
|
16 |
+
vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
|
17 |
+
emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
|
18 |
+
|
19 |
+
test:
|
20 |
+
_target_: src.test.webvid_covr.TestWebVidCoVR
|
configs/test/webvid-covr_text.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
webvid_covr_text:
|
2 |
+
dataname: webvid-covr_text
|
3 |
+
_target_: src.data.webvid_covr.WebVidCoVRTestDataModule
|
4 |
+
|
5 |
+
image_size: 384
|
6 |
+
|
7 |
+
vid_query_method: middle
|
8 |
+
vid_frames: 1
|
9 |
+
emb_pool: query
|
10 |
+
|
11 |
+
batch_size: ${machine.batch_size}
|
12 |
+
num_workers: ${machine.num_workers}
|
13 |
+
|
14 |
+
# Paths
|
15 |
+
annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
|
16 |
+
vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
|
17 |
+
emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
|
18 |
+
|
19 |
+
test:
|
20 |
+
_target_: src.test.webvid_covr_exp.TestWebVidCoVRTextOnly
|
configs/test/webvid-covr_visual.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
webvid_covr_visual:
|
2 |
+
dataname: webvid-covr_visual
|
3 |
+
_target_: src.data.webvid_covr.WebVidCoVRTestDataModule
|
4 |
+
|
5 |
+
image_size: 384
|
6 |
+
|
7 |
+
vid_query_method: middle
|
8 |
+
vid_frames: 1
|
9 |
+
emb_pool: mean
|
10 |
+
|
11 |
+
batch_size: ${machine.batch_size}
|
12 |
+
num_workers: ${machine.num_workers}
|
13 |
+
|
14 |
+
# Paths
|
15 |
+
annotation: ${paths.work_dir}/annotation/webvid-covr/webvid8m-covr_test.csv
|
16 |
+
vid_dirs: ${paths.datasets_dir}/WebVid/8M/train
|
17 |
+
emb_dirs: ${paths.datasets_dir}/WebVid/8M/blip-vid-embs-${model.model.vit}-all
|
18 |
+
|
19 |
+
test:
|
20 |
+
_target_: src.test.webvid_covr_exp.TestWebVidCoVRVisualOnly
|
configs/train.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: outputs/${data.dataname}/${model.modelname}/${model.ckpt.name}/${experiment}/${run_name}
|
4 |
+
job: # automatically go to the job folder (needed for hydra > 1.2 with new behavior)
|
5 |
+
chdir: true
|
6 |
+
|
7 |
+
# Global configurations shared between different modules
|
8 |
+
experiment: tv-${model.model.train_vit}_loss-${model.model.loss.name}_lr-${model.optimizer.lr}
|
9 |
+
run_name: base
|
10 |
+
|
11 |
+
seed: 1234
|
12 |
+
logger_level: INFO
|
13 |
+
|
14 |
+
# Composing nested config with default
|
15 |
+
defaults:
|
16 |
+
- _self_
|
17 |
+
- data: webvid-covr
|
18 |
+
- machine: server
|
19 |
+
- trainer: gpu
|
20 |
+
- test: all
|
21 |
+
- trainer/logger: csv
|
22 |
+
- model: blip-large
|
23 |
+
- model/optimizer: adamw
|
24 |
+
- model/scheduler: cosine
|
25 |
+
- model/loss: hn_nce
|
26 |
+
- model/ckpt: blip-l-coco
|
27 |
+
|
28 |
+
- experiment: null
|
29 |
+
|
30 |
+
paths: ${machine.paths}
|
31 |
+
|
32 |
+
# Flag to validate at the end of every epoch
|
33 |
+
val: True
|
configs/trainer/cpu.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- default.yaml
|
3 |
+
|
4 |
+
accelerator: cpu
|
5 |
+
devices: 1
|
configs/trainer/ddp.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- default.yaml
|
3 |
+
|
4 |
+
strategy: ddp
|
5 |
+
|
6 |
+
accelerator: gpu
|
7 |
+
devices: 4
|
8 |
+
num_nodes: 1
|
9 |
+
|
10 |
+
fabric:
|
11 |
+
num_nodes: ${trainer.num_nodes}
|
12 |
+
strategy: ${trainer.strategy}
|