Spaces:
Runtime error
Runtime error
Li
commited on
Commit
•
5282eae
1
Parent(s):
6eb2745
init
Browse files- app.py +225 -0
- open_flamingo/.github/workflows/black.yml +10 -0
- open_flamingo/.gitignore +149 -0
- open_flamingo/HISTORY.md +3 -0
- open_flamingo/LICENSE +21 -0
- open_flamingo/MODEL_CARD.md +44 -0
- open_flamingo/Makefile +19 -0
- open_flamingo/README.md +233 -0
- open_flamingo/TERMS_AND_CONDITIONS.md +15 -0
- open_flamingo/docs/flamingo.png +0 -0
- open_flamingo/environment.yml +10 -0
- open_flamingo/open_flamingo/__init__.py +2 -0
- open_flamingo/open_flamingo/eval/__init__.py +1 -0
- open_flamingo/open_flamingo/eval/classification.py +147 -0
- open_flamingo/open_flamingo/eval/coco_metric.py +23 -0
- open_flamingo/open_flamingo/eval/eval_datasets.py +138 -0
- open_flamingo/open_flamingo/eval/evaluate.py +1094 -0
- open_flamingo/open_flamingo/eval/evaluate2.py +1113 -0
- open_flamingo/open_flamingo/eval/imagenet_utils.py +1007 -0
- open_flamingo/open_flamingo/eval/ok_vqa_utils.py +213 -0
- open_flamingo/open_flamingo/eval/vqa_metric.py +594 -0
- open_flamingo/open_flamingo/src/__init__.py +0 -0
- open_flamingo/open_flamingo/src/factory.py +278 -0
- open_flamingo/open_flamingo/src/flamingo.py +236 -0
- open_flamingo/open_flamingo/src/flamingo_lm.py +203 -0
- open_flamingo/open_flamingo/src/helpers.py +263 -0
- open_flamingo/open_flamingo/src/utils.py +31 -0
- open_flamingo/open_flamingo/train/__init__.py +1 -0
- open_flamingo/open_flamingo/train/data.deprecated.py +812 -0
- open_flamingo/open_flamingo/train/data2.py +573 -0
- open_flamingo/open_flamingo/train/distributed.py +128 -0
- open_flamingo/open_flamingo/train/train.py +587 -0
- open_flamingo/open_flamingo/train/train_utils.py +371 -0
- open_flamingo/requirements-dev.txt +5 -0
- open_flamingo/requirements.txt +16 -0
- open_flamingo/setup.py +58 -0
- open_flamingo/tests/test_flamingo_model.py +77 -0
- open_flamingo/tools/check_refcoco.py +14 -0
- open_flamingo/tools/convert_mmc4_to_wds.py +124 -0
- open_flamingo/tools/make_gqa_val.py +0 -0
- open_flamingo/tools/make_mmc4_global_table.py +31 -0
- open_flamingo/tools/make_soft_link.py +26 -0
- open_flamingo/tools/make_soft_link_blip2_data.py +30 -0
- open_flamingo/tools/make_soft_link_laion.py +23 -0
- open_flamingo/tools/make_vqav2_ft_dataset.py +24 -0
- open_flamingo/tools/prepare_mini_blip2_dataset.py +178 -0
- open_flamingo/tools/prepare_pile.py +31 -0
app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system("cd open_flamingo && pip install .")
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
8 |
+
import string
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import torch
|
14 |
+
from PIL import Image
|
15 |
+
# from huggingface_hub import hf_hub_download, login
|
16 |
+
|
17 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
18 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
19 |
+
"ViT-L-14",
|
20 |
+
"datacomp_xl_s13b_b90k",
|
21 |
+
"facebook/opt-350m",
|
22 |
+
"facebook/opt-350m",
|
23 |
+
add_visual_grounding=True,
|
24 |
+
location_token_num=1000,
|
25 |
+
add_visual_token = True,
|
26 |
+
use_format_v2 = True,
|
27 |
+
)
|
28 |
+
|
29 |
+
checkpoint_path = hf_hub_download("chendl/mm", "checkpoint_opt350m.pt")
|
30 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
31 |
+
model_state_dict = {}
|
32 |
+
for key in checkpoint["model_state_dict"].keys():
|
33 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
34 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
35 |
+
# previous checkpoint has some unnecessary weights
|
36 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
37 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
38 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
39 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
40 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
41 |
+
|
42 |
+
|
43 |
+
def generate(
|
44 |
+
idx,
|
45 |
+
image,
|
46 |
+
text,
|
47 |
+
tsvfile,
|
48 |
+
vis_embed_size=256,
|
49 |
+
rank=0,
|
50 |
+
world_size=1,
|
51 |
+
):
|
52 |
+
if image is None:
|
53 |
+
raise gr.Error("Please upload an image.")
|
54 |
+
flamingo.eval().cuda()
|
55 |
+
loc_token_ids = []
|
56 |
+
for i in range(1000):
|
57 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
58 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
59 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
60 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
61 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
62 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
63 |
+
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
|
64 |
+
bad_words_ids = list(all_ids - set(loc_token_ids))
|
65 |
+
bad_words_ids = [[b] for b in bad_words_ids]
|
66 |
+
min_loc_token_id = min(loc_token_ids)
|
67 |
+
max_loc_token_id = max(loc_token_ids)
|
68 |
+
image = Image.open(image).convert("RGB")
|
69 |
+
width = image.width
|
70 |
+
height = image.height
|
71 |
+
image = image.resize((224, 224))
|
72 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
73 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}"]
|
74 |
+
encodings = tokenizer(
|
75 |
+
prompt,
|
76 |
+
padding="longest",
|
77 |
+
truncation=True,
|
78 |
+
return_tensors="pt",
|
79 |
+
max_length=2000,
|
80 |
+
)
|
81 |
+
input_ids = encodings["input_ids"]
|
82 |
+
attention_mask = encodings["attention_mask"]
|
83 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
84 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
85 |
+
image_nums = [1] * len(input_ids)
|
86 |
+
outputs = get_outputs(
|
87 |
+
model=flamingo,
|
88 |
+
batch_images=batch_images.cuda(),
|
89 |
+
attention_mask=attention_mask.cuda(),
|
90 |
+
max_generation_length=5,
|
91 |
+
min_generation_length=4,
|
92 |
+
num_beams=1,
|
93 |
+
length_penalty=1.0,
|
94 |
+
input_ids=input_ids.cuda(),
|
95 |
+
bad_words_ids=bad_words_ids,
|
96 |
+
image_start_index_list=image_start_index_list,
|
97 |
+
image_nums=image_nums,
|
98 |
+
)
|
99 |
+
box = []
|
100 |
+
for o in outputs[0]:
|
101 |
+
if o >= min_loc_token_id and o <= max_loc_token_id:
|
102 |
+
box.append(o.item() - min_loc_token_id)
|
103 |
+
if len(box) == 4:
|
104 |
+
break
|
105 |
+
# else:
|
106 |
+
# tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
|
107 |
+
# tqdm.write(f"prompt: {prompt}")
|
108 |
+
|
109 |
+
gen_text = tokenizer.batch_decode(outputs)
|
110 |
+
return (
|
111 |
+
f"Output:{gen_text}"
|
112 |
+
if idx != 2
|
113 |
+
else f"Question: {text.strip()} Answer: {gen_text}"
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
with gr.Blocks() as demo:
|
118 |
+
gr.Markdown(
|
119 |
+
"""
|
120 |
+
# 🦩 OpenFlamingo Demo
|
121 |
+
|
122 |
+
Blog posts: #1 [Announcing OpenFlamingo: An open-source framework for training vision-language models with in-context learning](https://laion.ai/blog/open-flamingo/) // #2 [OpenFlamingo v2: New Models and Enhanced Training Setup](https://laion.ai/blog/open-flamingo-v2/)
|
123 |
+
|
124 |
+
GitHub: [open_flamingo](https://github.com/mlfoundations/open_flamingo)
|
125 |
+
|
126 |
+
In this demo we showcase the in-context learning capabilities of the OpenFlamingo-9B model, a large multimodal model trained on top of mpt-7b. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
|
127 |
+
The model is trained on an interleaved mixture of text and images and is able to generate text conditioned on sequences of images/text. To safeguard against harmful generations, we detect toxic text in the model output and reject it. However, we understand that this is not a perfect solution and we encourage you to use this demo responsibly. If you find that the model is generating harmful text, please report it using this [form](https://forms.gle/StbcPvyyW2p3Pc7z6).
|
128 |
+
"""
|
129 |
+
)
|
130 |
+
|
131 |
+
with gr.Accordion("See terms and conditions"):
|
132 |
+
gr.Markdown("""**Please read the following information carefully before proceeding.**
|
133 |
+
|
134 |
+
[OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b) is a **research prototype** that aims to enable users to interact with AI through both language and images. AI agents equipped with both language and visual understanding can be useful on a larger variety of tasks compared to models that communicate solely via language. By releasing an open-source research prototype, we hope to help the research community better understand the risks and limitations of modern visual-language AI models and accelerate the development of safer and more reliable methods.
|
135 |
+
**Limitations.** OpenFlamingo-9B is built on top of the [MPT-7B](https://huggingface.co/mosaicml/mpt-7b) large language model developed by Together.xyz. Large language models are trained on mostly unfiltered internet data, and have been shown to be able to produce toxic, unethical, inaccurate, and harmful content. On top of this, OpenFlamingo’s ability to support visual inputs creates additional risks, since it can be used in a wider variety of applications; image+text models may carry additional risks specific to multimodality. Please use discretion when assessing the accuracy or appropriateness of the model’s outputs, and be mindful before sharing its results.
|
136 |
+
**Privacy and data collection.** This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
|
137 |
+
|
138 |
+
|
139 |
+
with gr.Tab("📷 Image Captioning"):
|
140 |
+
with gr.Row():
|
141 |
+
|
142 |
+
|
143 |
+
query_image = gr.Image(type="pil")
|
144 |
+
with gr.Row():
|
145 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
146 |
+
text_output = gr.Textbox(value="Output:", label="Model output")
|
147 |
+
|
148 |
+
run_btn = gr.Button("Run model")
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def on_click_fn(img,text): return generate(0, img, text)
|
153 |
+
|
154 |
+
run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
|
155 |
+
|
156 |
+
with gr.Tab("🦓 Grounding"):
|
157 |
+
with gr.Row():
|
158 |
+
query_image = gr.Image(type="pil")
|
159 |
+
with gr.Row():
|
160 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
161 |
+
text_output = gr.Textbox(value="Output:", label="Model output")
|
162 |
+
|
163 |
+
run_btn = gr.Button("Run model")
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
def on_click_fn(img,text): return generate(0, img, text)
|
168 |
+
|
169 |
+
run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
|
170 |
+
|
171 |
+
with gr.Tab("🔢 Counting objects"):
|
172 |
+
with gr.Row():
|
173 |
+
query_image = gr.Image(type="pil")
|
174 |
+
with gr.Row():
|
175 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
176 |
+
text_output = gr.Textbox(value="Output:", label="Model output")
|
177 |
+
|
178 |
+
run_btn = gr.Button("Run model")
|
179 |
+
|
180 |
+
|
181 |
+
def on_click_fn(img,text): return generate(0, img, text)
|
182 |
+
|
183 |
+
|
184 |
+
run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
|
185 |
+
|
186 |
+
with gr.Tab("🕵️ Visual Question Answering"):
|
187 |
+
with gr.Row():
|
188 |
+
query_image = gr.Image(type="pil")
|
189 |
+
with gr.Row():
|
190 |
+
question = gr.Textbox(lines=1, label="Question")
|
191 |
+
text_output = gr.Textbox(value="Output:", label="Model output")
|
192 |
+
|
193 |
+
run_btn = gr.Button("Run model")
|
194 |
+
|
195 |
+
|
196 |
+
def on_click_fn(img, txt): return generate(2, img, txt)
|
197 |
+
|
198 |
+
|
199 |
+
run_btn.click(
|
200 |
+
on_click_fn, inputs=[query_image, question], outputs=[text_output]
|
201 |
+
)
|
202 |
+
|
203 |
+
with gr.Tab("🌎 Custom"):
|
204 |
+
gr.Markdown(
|
205 |
+
"""### Customize the demonstration by uploading your own images and text samples.
|
206 |
+
### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
|
207 |
+
)
|
208 |
+
with gr.Row():
|
209 |
+
query_image = gr.Image(type="pil")
|
210 |
+
with gr.Row():
|
211 |
+
question = gr.Textbox(lines=1, label="Question")
|
212 |
+
text_output = gr.Textbox(value="Output:", label="Model output")
|
213 |
+
|
214 |
+
run_btn = gr.Button("Run model")
|
215 |
+
|
216 |
+
|
217 |
+
def on_click_fn(img, txt): return generate(2, img, txt)
|
218 |
+
|
219 |
+
|
220 |
+
run_btn.click(
|
221 |
+
on_click_fn, inputs=[query_image, question], outputs=[text_output]
|
222 |
+
)
|
223 |
+
|
224 |
+
demo.queue(concurrency_count=1)
|
225 |
+
demo.launch()
|
open_flamingo/.github/workflows/black.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint
|
2 |
+
|
3 |
+
on: [push, pull_request]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
lint:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
steps:
|
9 |
+
- uses: actions/checkout@v2
|
10 |
+
- uses: psf/black@stable
|
open_flamingo/.gitignore
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pt
|
2 |
+
GRiT/
|
3 |
+
temp/
|
4 |
+
eval_*.sh
|
5 |
+
*.json
|
6 |
+
eval_results/
|
7 |
+
pycocoevalcap/
|
8 |
+
checkpoints_align_*
|
9 |
+
segment-anything/
|
10 |
+
|
11 |
+
wandb/
|
12 |
+
checkpoints*/
|
13 |
+
|
14 |
+
# Byte-compiled / optimized / DLL files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
|
19 |
+
# C extensions
|
20 |
+
*.so
|
21 |
+
|
22 |
+
# Distribution / packaging
|
23 |
+
.Python
|
24 |
+
build/
|
25 |
+
develop-eggs/
|
26 |
+
dist/
|
27 |
+
downloads/
|
28 |
+
eggs/
|
29 |
+
.eggs/
|
30 |
+
lib/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
pip-wheel-metadata/
|
37 |
+
share/python-wheels/
|
38 |
+
*.egg-info/
|
39 |
+
.installed.cfg
|
40 |
+
*.egg
|
41 |
+
MANIFEST
|
42 |
+
|
43 |
+
# PyInstaller
|
44 |
+
# Usually these files are written by a python script from a template
|
45 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
46 |
+
*.manifest
|
47 |
+
*.spec
|
48 |
+
|
49 |
+
# Installer logs
|
50 |
+
pip-log.txt
|
51 |
+
pip-delete-this-directory.txt
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.nox/
|
57 |
+
.coverage
|
58 |
+
.coverage.*
|
59 |
+
.cache
|
60 |
+
nosetests.xml
|
61 |
+
coverage.xml
|
62 |
+
*.cover
|
63 |
+
*.py,cover
|
64 |
+
.hypothesis/
|
65 |
+
.pytest_cache/
|
66 |
+
|
67 |
+
# Translations
|
68 |
+
*.mo
|
69 |
+
*.pot
|
70 |
+
|
71 |
+
# Django stuff:
|
72 |
+
*.log
|
73 |
+
local_settings.py
|
74 |
+
db.sqlite3
|
75 |
+
db.sqlite3-journal
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
target/
|
89 |
+
|
90 |
+
# Jupyter Notebook
|
91 |
+
.ipynb_checkpoints
|
92 |
+
|
93 |
+
# IPython
|
94 |
+
profile_default/
|
95 |
+
ipython_config.py
|
96 |
+
|
97 |
+
# pyenv
|
98 |
+
.python-version
|
99 |
+
|
100 |
+
# pipenv
|
101 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
102 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
103 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
104 |
+
# install all needed dependencies.
|
105 |
+
#Pipfile.lock
|
106 |
+
|
107 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
108 |
+
__pypackages__/
|
109 |
+
|
110 |
+
# Celery stuff
|
111 |
+
celerybeat-schedule
|
112 |
+
celerybeat.pid
|
113 |
+
|
114 |
+
# SageMath parsed files
|
115 |
+
*.sage.py
|
116 |
+
|
117 |
+
# Environments
|
118 |
+
.env
|
119 |
+
.venv
|
120 |
+
env/
|
121 |
+
venv/
|
122 |
+
ENV/
|
123 |
+
env.bak/
|
124 |
+
venv.bak/
|
125 |
+
|
126 |
+
# Pycharm project settings
|
127 |
+
.idea
|
128 |
+
|
129 |
+
# Spyder project settings
|
130 |
+
.spyderproject
|
131 |
+
.spyproject
|
132 |
+
|
133 |
+
# Rope project settings
|
134 |
+
.ropeproject
|
135 |
+
|
136 |
+
# mkdocs documentation
|
137 |
+
/site
|
138 |
+
|
139 |
+
# mypy
|
140 |
+
.mypy_cache/
|
141 |
+
.dmypy.json
|
142 |
+
dmypy.json
|
143 |
+
|
144 |
+
*.out
|
145 |
+
src/wandb
|
146 |
+
wandb
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
open_flamingo/HISTORY.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## 1.0.0
|
2 |
+
|
3 |
+
* it works
|
open_flamingo/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
open_flamingo/MODEL_CARD.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
datasets:
|
4 |
+
- laion2b
|
5 |
+
---
|
6 |
+
|
7 |
+
# OpenFlamingo-9B
|
8 |
+
|
9 |
+
[Blog post]() | [Code](https://github.com/mlfoundations/open_flamingo) | [Demo](https://7164d2142d11.ngrok.app)
|
10 |
+
|
11 |
+
OpenFlamingo is an open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) models.
|
12 |
+
OpenFlamingo-9B is built off of [CLIP ViT-L/14](https://huggingface.co/openai/clip-vit-large-patch14) and [LLaMA-7B](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/).
|
13 |
+
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
We freeze the pretrained vision encoder and language model, and then we train connecting Perceiver modules and cross-attention layers, following the original Flamingo paper.
|
17 |
+
|
18 |
+
Our training data is a mixture of [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and a large interleaved image-text dataset called Multimodal C4, which will be released soon.
|
19 |
+
|
20 |
+
The current model is an early checkpoint of an ongoing effort. This checkpoint has seen 5 million interleaved image-text examples from Multimodal C4 and 10 million samples from LAION 2B.
|
21 |
+
|
22 |
+
## Uses
|
23 |
+
OpenFlamingo-9B is intended to be used **for academic research purposes only.** Commercial use is prohibited, in line with LLaMA's non-commercial license.
|
24 |
+
|
25 |
+
### Bias, Risks, and Limitations
|
26 |
+
This model may generate inaccurate or offensive outputs, reflecting biases in its training data and pretrained priors.
|
27 |
+
|
28 |
+
In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety.
|
29 |
+
|
30 |
+
## Evaluation
|
31 |
+
We've evaluated this checkpoint on the validation sets for two vision-language tasks: COCO captioning and VQAv2. Results are displayed below.
|
32 |
+
|
33 |
+
**COCO (CIDEr)**
|
34 |
+
|
35 |
+
|0-shot|4-shot|8-shot|16-shot|32-shot|
|
36 |
+
|--|--|--|--|--|
|
37 |
+
|65.52|74.28|79.26|81.84|84.52|
|
38 |
+
|
39 |
+
|
40 |
+
**VQAv2 (VQA accuracy)**
|
41 |
+
|
42 |
+
|0-shot|4-shot|8-shot|16-shot|32-shot|
|
43 |
+
|---|---|---|---|---|
|
44 |
+
|43.55|44.05|47.5|48.87|50.34|
|
open_flamingo/Makefile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
install: ## [Local development] Upgrade pip, install requirements, install package.
|
2 |
+
python -m pip install -U pip
|
3 |
+
python -m pip install -e .
|
4 |
+
|
5 |
+
install-dev: ## [Local development] Install test requirements
|
6 |
+
python -m pip install -r requirements-test.txt
|
7 |
+
|
8 |
+
lint: ## [Local development] Run mypy, pylint and black
|
9 |
+
python -m mypy open_flamingo
|
10 |
+
python -m pylint open_flamingo
|
11 |
+
python -m black --check -l 120 open_flamingo
|
12 |
+
|
13 |
+
black: ## [Local development] Auto-format python code using black
|
14 |
+
python -m black -l 120 .
|
15 |
+
|
16 |
+
.PHONY: help
|
17 |
+
|
18 |
+
help: # Run `make help` to get help on the make commands
|
19 |
+
@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
open_flamingo/README.md
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🦩 OpenFlamingo
|
2 |
+
|
3 |
+
[![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo)
|
4 |
+
|
5 |
+
[Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon)
|
6 |
+
|
7 |
+
Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new Multimodal C4 dataset (coming soon). Please refer to our blog post for more details.
|
8 |
+
|
9 |
+
This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions!
|
10 |
+
|
11 |
+
# Table of Contents
|
12 |
+
- [Installation](#installation)
|
13 |
+
- [Approach](#approach)
|
14 |
+
* [Model architecture](#model-architecture)
|
15 |
+
- [Usage](#usage)
|
16 |
+
* [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model)
|
17 |
+
* [Generating text](#generating-text)
|
18 |
+
- [Training](#training)
|
19 |
+
* [Dataset](#dataset)
|
20 |
+
- [Evaluation](#evaluation)
|
21 |
+
- [Future plans](#future-plans)
|
22 |
+
- [Team](#team)
|
23 |
+
- [Acknowledgments](#acknowledgments)
|
24 |
+
- [Citing](#citing)
|
25 |
+
|
26 |
+
# Installation
|
27 |
+
|
28 |
+
To install the package in an existing environment, run
|
29 |
+
```
|
30 |
+
pip install open-flamingo
|
31 |
+
```
|
32 |
+
|
33 |
+
or to create a conda environment for running OpenFlamingo, run
|
34 |
+
```
|
35 |
+
conda env create -f environment.yml
|
36 |
+
```
|
37 |
+
|
38 |
+
# Usage
|
39 |
+
We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.
|
40 |
+
|
41 |
+
#### NOTE: To use LLaMA models, you will need to install the latest version of transformers via
|
42 |
+
```
|
43 |
+
pip install git+https://github.com/huggingface/transformers
|
44 |
+
```
|
45 |
+
Use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format.
|
46 |
+
|
47 |
+
## Initializing an OpenFlamingo model
|
48 |
+
``` python
|
49 |
+
from open_flamingo import create_model_and_transforms
|
50 |
+
|
51 |
+
model, image_processor, tokenizer = create_model_and_transforms(
|
52 |
+
clip_vision_encoder_path="ViT-L-14",
|
53 |
+
clip_vision_encoder_pretrained="openai",
|
54 |
+
lang_encoder_path="<path to llama weights in HuggingFace format>",
|
55 |
+
tokenizer_path="<path to llama tokenizer in HuggingFace format>",
|
56 |
+
cross_attn_every_n_layers=4
|
57 |
+
)
|
58 |
+
|
59 |
+
# grab model checkpoint from huggingface hub
|
60 |
+
from huggingface_hub import hf_hub_download
|
61 |
+
import torch
|
62 |
+
|
63 |
+
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
|
64 |
+
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
65 |
+
```
|
66 |
+
|
67 |
+
## Generating text
|
68 |
+
Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning.
|
69 |
+
|
70 |
+
``` python
|
71 |
+
from PIL import Image
|
72 |
+
import requests
|
73 |
+
|
74 |
+
"""
|
75 |
+
Step 1: Load images
|
76 |
+
"""
|
77 |
+
demo_image_one = Image.open(
|
78 |
+
requests.get(
|
79 |
+
"http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
|
80 |
+
).raw
|
81 |
+
)
|
82 |
+
|
83 |
+
demo_image_two = Image.open(
|
84 |
+
requests.get(
|
85 |
+
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
|
86 |
+
stream=True
|
87 |
+
).raw
|
88 |
+
)
|
89 |
+
|
90 |
+
query_image = Image.open(
|
91 |
+
requests.get(
|
92 |
+
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
|
93 |
+
stream=True
|
94 |
+
).raw
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
"""
|
99 |
+
Step 2: Preprocessing images
|
100 |
+
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
|
101 |
+
batch_size x num_media x num_frames x channels x height x width.
|
102 |
+
In this case batch_size = 1, num_media = 3, num_frames = 1
|
103 |
+
(this will always be one expect for video which we don't support yet),
|
104 |
+
channels = 3, height = 224, width = 224.
|
105 |
+
"""
|
106 |
+
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
|
107 |
+
vision_x = torch.cat(vision_x, dim=0)
|
108 |
+
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
109 |
+
|
110 |
+
"""
|
111 |
+
Step 3: Preprocessing text
|
112 |
+
Details: In the text we expect an <|#image#|> special token to indicate where an image is.
|
113 |
+
We also expect an <|endofchunk|> special token to indicate the end of the text
|
114 |
+
portion associated with an image.
|
115 |
+
"""
|
116 |
+
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
|
117 |
+
lang_x = tokenizer(
|
118 |
+
["<|#image#|>An image of two cats.<|endofchunk|><|#image#|>An image of a bathroom sink.<|endofchunk|><|#image#|>An image of"],
|
119 |
+
return_tensors="pt",
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
"""
|
124 |
+
Step 4: Generate text
|
125 |
+
"""
|
126 |
+
generated_text = model.generate(
|
127 |
+
vision_x=vision_x,
|
128 |
+
lang_x=lang_x["input_ids"],
|
129 |
+
attention_mask=lang_x["attention_mask"],
|
130 |
+
max_new_tokens=20,
|
131 |
+
num_beams=3,
|
132 |
+
)
|
133 |
+
|
134 |
+
print("Generated text: ", tokenizer.decode(generated_text[0]))
|
135 |
+
```
|
136 |
+
|
137 |
+
# Approach
|
138 |
+
OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training.
|
139 |
+
|
140 |
+
## Model architecture
|
141 |
+
OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
|
142 |
+
|
143 |
+
![OpenFlamingo architecture](docs/flamingo.png)
|
144 |
+
Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)
|
145 |
+
|
146 |
+
# Training
|
147 |
+
To train a model, modify the following example command, which uses OPT 1.3B as an example LM:
|
148 |
+
```
|
149 |
+
torchrun --nnodes=1 --nproc_per_node=4 train.py \
|
150 |
+
--run_name flamingo3B \
|
151 |
+
--lm_path facebook/opt-1.3b \
|
152 |
+
--tokenizer_path facebook/opt-1.3b \
|
153 |
+
--dataset_resampled \
|
154 |
+
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
155 |
+
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
156 |
+
--batch_size_mmc4 4 \
|
157 |
+
--batch_size_laion 8 \
|
158 |
+
--train_num_samples_mmc4 125000 \
|
159 |
+
--train_num_samples_laion 250000 \
|
160 |
+
--loss_multiplier_laion 0.2 \
|
161 |
+
--workers=6 \
|
162 |
+
--num_epochs 250 \
|
163 |
+
--lr_scheduler constant \
|
164 |
+
--warmup_steps 5000 \
|
165 |
+
--use_media_placement_augmentation \
|
166 |
+
--mmc4_textsim_threshold 30
|
167 |
+
```
|
168 |
+
|
169 |
+
## Dataset
|
170 |
+
We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards.
|
171 |
+
We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and Multimodal C4 (coming soon) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 comes packaged in the WebDataset format.
|
172 |
+
|
173 |
+
|
174 |
+
# Evaluation
|
175 |
+
We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future.
|
176 |
+
|
177 |
+
Before evaluating the model, you will need to install the coco evaluation package by running the following command:
|
178 |
+
```
|
179 |
+
pip install pycocoevalcap
|
180 |
+
```
|
181 |
+
|
182 |
+
To run evaluations on OKVQA you will need to run the following command:
|
183 |
+
```
|
184 |
+
import nltk
|
185 |
+
nltk.download('wordnet')
|
186 |
+
```
|
187 |
+
|
188 |
+
To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh`
|
189 |
+
|
190 |
+
# Future plans
|
191 |
+
- [ ] Add support for video input
|
192 |
+
- [ ] Release better performing and larger OpenFlamingo models
|
193 |
+
- [ ] Expand our evaluation suite
|
194 |
+
- [ ] Add support for FSDP training
|
195 |
+
|
196 |
+
# Team
|
197 |
+
|
198 |
+
OpenFlamingo is developed by:
|
199 |
+
|
200 |
+
[Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).
|
201 |
+
|
202 |
+
The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google.
|
203 |
+
|
204 |
+
# Acknowledgments
|
205 |
+
This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design.
|
206 |
+
|
207 |
+
We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models.
|
208 |
+
|
209 |
+
# Citing
|
210 |
+
If you found this repository useful, please consider citing:
|
211 |
+
|
212 |
+
```
|
213 |
+
@software{anas_awadalla_2023_7733589,
|
214 |
+
author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig},
|
215 |
+
title = {OpenFlamingo},
|
216 |
+
month = mar,
|
217 |
+
year = 2023,
|
218 |
+
publisher = {Zenodo},
|
219 |
+
version = {v0.1.1},
|
220 |
+
doi = {10.5281/zenodo.7733589},
|
221 |
+
url = {https://doi.org/10.5281/zenodo.7733589}
|
222 |
+
}
|
223 |
+
```
|
224 |
+
|
225 |
+
```
|
226 |
+
@article{Alayrac2022FlamingoAV,
|
227 |
+
title={Flamingo: a Visual Language Model for Few-Shot Learning},
|
228 |
+
author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
|
229 |
+
journal={ArXiv},
|
230 |
+
year={2022},
|
231 |
+
volume={abs/2204.14198}
|
232 |
+
}
|
233 |
+
```
|
open_flamingo/TERMS_AND_CONDITIONS.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**Please read the following information carefully before proceeding.**
|
2 |
+
|
3 |
+
OpenFlamingo is a **research prototype** that aims to enable users to interact with AI through both language and images. AI agents equipped with both language and visual understanding can be useful on a larger variety of tasks compared to models that communicate solely via language. By releasing an open-source research prototype, we hope to help the research community better understand the risks and limitations of modern visual-language AI models and accelerate the development of safer and more reliable methods.
|
4 |
+
|
5 |
+
- [ ] I understand that OpenFlamingo is a research prototype and I will only use it for non-commercial research purposes.
|
6 |
+
|
7 |
+
**Limitations.** OpenFlamingo is built on top of the LLaMA large language model developed by Meta AI. Large language models, including LLaMA, are trained on mostly unfiltered internet data, and have been shown to be able to produce toxic, unethical, inaccurate, and harmful content. On top of this, OpenFlamingo’s ability to support visual inputs creates additional risks, since it can be used in a wider variety of applications; image+text models may carry additional risks specific to multimodality. Please use discretion when assessing the accuracy or appropriateness of the model’s outputs, and be mindful before sharing its results.
|
8 |
+
|
9 |
+
- [ ] I understand that OpenFlamingo may produce unintended, inappropriate, offensive, and/or inaccurate results. I agree to take full responsibility for any use of the OpenFlamingo outputs that I generate.
|
10 |
+
|
11 |
+
**Privacy and data collection.** This demo does NOT store any personal information on its users, and it does NOT store user queries.
|
12 |
+
|
13 |
+
**Licensing.** As OpenFlamingo is built on top of the LLaMA large language model from Meta AI, the LLaMA license agreement (as documented in the Meta request form) also applies.
|
14 |
+
|
15 |
+
- [ ] I have read and agree to the terms of the LLaMA license agreement.
|
open_flamingo/docs/flamingo.png
ADDED
open_flamingo/environment.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mm
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- python=3.9
|
6 |
+
- conda-forge::openjdk
|
7 |
+
- pip
|
8 |
+
- pip:
|
9 |
+
- -r requirements.txt
|
10 |
+
- -e .
|
open_flamingo/open_flamingo/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .src.flamingo import Flamingo
|
2 |
+
from .src.factory import create_model_and_transforms
|
open_flamingo/open_flamingo/eval/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
open_flamingo/open_flamingo/eval/classification.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Sequence, Tuple
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def postprocess_classification_generation(predictions) -> str:
|
8 |
+
return re.split("Prompt|Completion", predictions, 1)[0]
|
9 |
+
|
10 |
+
|
11 |
+
def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
|
12 |
+
"""Compute the accuracy of a sequence of predictions."""
|
13 |
+
|
14 |
+
def _preprocess_fn(s):
|
15 |
+
"""Function to preprocess both targets and predictions."""
|
16 |
+
return s.lower()
|
17 |
+
|
18 |
+
is_correct = [
|
19 |
+
_preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
|
20 |
+
for x in predictions
|
21 |
+
]
|
22 |
+
|
23 |
+
return np.mean(is_correct).item()
|
24 |
+
|
25 |
+
|
26 |
+
def compute_shifted_logits_and_labels(
|
27 |
+
logits: torch.Tensor, encodings, tokenizer, eoc_token_id
|
28 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
29 |
+
"""Helper function to compute shifted logits and labels.
|
30 |
+
|
31 |
+
This allows for straightforward computation of the loss on shift_logits
|
32 |
+
and shift_labels such that the nth element of logits computes the n-1th
|
33 |
+
element of the original labels (in the outputs, the nth element of logits
|
34 |
+
corresponds to the nth element of the labels).
|
35 |
+
|
36 |
+
Elements in shift_labels that correspond to inputs are masked with values
|
37 |
+
of -100 (by default in hf, loss is only computed on token IDs >= 0).
|
38 |
+
|
39 |
+
Returns: tuple containing two elements:
|
40 |
+
shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
|
41 |
+
shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
|
42 |
+
"""
|
43 |
+
|
44 |
+
labels = encodings["input_ids"].clone()
|
45 |
+
|
46 |
+
# convert padding and EOC tokens to -100 so they are ignored in loss
|
47 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
48 |
+
labels[labels == eoc_token_id] = -100
|
49 |
+
|
50 |
+
# Convert all tokens in prefix until separator to -100 so they are
|
51 |
+
# ignored in loss
|
52 |
+
for idx in range(len(labels)):
|
53 |
+
# Find the location of the last token of prefix *from right*,
|
54 |
+
# since the first non-padding token of the sequence will also be
|
55 |
+
# eos_token (because bos_token and eos_token are the same for
|
56 |
+
# the tokenizer).
|
57 |
+
end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
|
58 |
+
labels[idx, : end_of_prefix + 1] = -100
|
59 |
+
|
60 |
+
# Shift so that tokens < n predict n. The shifted tensors both have
|
61 |
+
# shape [batch_size, seq_len - 1].
|
62 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
63 |
+
shift_labels = labels[..., 1:].contiguous()
|
64 |
+
|
65 |
+
return shift_logits, shift_labels
|
66 |
+
|
67 |
+
|
68 |
+
def compute_per_sample_probs(
|
69 |
+
encodings, tokenizer, logits: torch.Tensor, eoc_token_id
|
70 |
+
) -> torch.Tensor:
|
71 |
+
"""Helper function to compute per-sample probability of the input sequence.
|
72 |
+
|
73 |
+
Assumes <eos token> is used to separate inputs from targets in the
|
74 |
+
prompt text
|
75 |
+
"""
|
76 |
+
shift_logits, shift_labels = compute_shifted_logits_and_labels(
|
77 |
+
logits, encodings, tokenizer, eoc_token_id
|
78 |
+
)
|
79 |
+
|
80 |
+
# Tuple of tensors for unmasked label tokens. The first element of the
|
81 |
+
# tuple contains the batch indices; the second element contains the
|
82 |
+
# sequence indices.
|
83 |
+
unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
|
84 |
+
# Tensor where the i^th element is the token_id corresponding to the i^th
|
85 |
+
# element of unmasked_indices
|
86 |
+
unmasked_token_ids = shift_labels[unmasked_indices]
|
87 |
+
|
88 |
+
# 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
|
89 |
+
target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
|
90 |
+
target_idxs = target_idxs.to(shift_logits.device)
|
91 |
+
|
92 |
+
# Sanity check that every element in batch has at least one unmasked
|
93 |
+
# target token
|
94 |
+
assert torch.all(
|
95 |
+
torch.bincount(target_idxs[:, 0]) != 0
|
96 |
+
), "At least one element in batch has no unmasked target tokens."
|
97 |
+
|
98 |
+
# Renormalize over tokens to make sure they are proper probabilities via
|
99 |
+
# softmax over the token dimension.
|
100 |
+
shift_probs = torch.nn.functional.softmax(shift_logits, 2)
|
101 |
+
|
102 |
+
# Compute the probability of the target sequence (as the product of the
|
103 |
+
# probability of the individual tokens in the sequence).
|
104 |
+
target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
|
105 |
+
for i, j, k in target_idxs:
|
106 |
+
target_probs[i] *= shift_probs[i, j, k]
|
107 |
+
|
108 |
+
return target_probs
|
109 |
+
|
110 |
+
|
111 |
+
def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
|
112 |
+
"""Helper function to compute per-sample classification loss.
|
113 |
+
|
114 |
+
Assumes <eos token> is used to separate inputs from targets in the
|
115 |
+
prompt text
|
116 |
+
"""
|
117 |
+
shift_logits, shift_labels = compute_shifted_logits_and_labels(
|
118 |
+
logits, encodings, tokenizer, eoc_token_id
|
119 |
+
)
|
120 |
+
|
121 |
+
device = shift_logits.device
|
122 |
+
|
123 |
+
# Loss is computed token-wise, on Tensors of shape
|
124 |
+
# [batch_size * (seq_len - 1), vocab_size]
|
125 |
+
# and returns a loss tensor of shape
|
126 |
+
# [batch_size * (seq_len - 1)]. Most of the tokens will be masked
|
127 |
+
# in this computation.
|
128 |
+
loss = torch.nn.functional.cross_entropy(
|
129 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
130 |
+
shift_labels.view(-1).to(device),
|
131 |
+
reduction="none",
|
132 |
+
)
|
133 |
+
|
134 |
+
# Reshape to [batch_size, seq_len - 1]
|
135 |
+
loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
|
136 |
+
|
137 |
+
# loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
|
138 |
+
# that should be ignored in the loss.
|
139 |
+
loss_mask = (shift_labels != -100).int().cpu()
|
140 |
+
|
141 |
+
loss *= loss_mask
|
142 |
+
|
143 |
+
# Compute per-element loss : sum loss over all (unmasked) tokens and
|
144 |
+
# divide by number of variable tokens to obtain tensor of
|
145 |
+
# shape [batch_size,]
|
146 |
+
loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
|
147 |
+
return loss
|
open_flamingo/open_flamingo/eval/coco_metric.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycocoevalcap.eval import COCOEvalCap
|
2 |
+
from pycocotools.coco import COCO
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def compute_cider(
|
7 |
+
result_path,
|
8 |
+
annotations_path="/data/yfcc-tmp/data/mscoco/annotations/captions_train2017.json",
|
9 |
+
):
|
10 |
+
# create coco object and coco_result object
|
11 |
+
coco = COCO(annotations_path)
|
12 |
+
coco_result = coco.loadRes(result_path)
|
13 |
+
|
14 |
+
# create coco_eval object by taking coco and coco_result
|
15 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
16 |
+
coco_eval.params["image_id"] = coco_result.getImgIds()
|
17 |
+
coco_eval.evaluate()
|
18 |
+
|
19 |
+
return coco_eval.eval
|
20 |
+
|
21 |
+
|
22 |
+
def postprocess_captioning_generation(predictions):
|
23 |
+
return predictions
|
open_flamingo/open_flamingo/eval/eval_datasets.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.datasets import ImageFolder
|
7 |
+
|
8 |
+
from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
|
9 |
+
|
10 |
+
|
11 |
+
class COCOFlickrDataset(Dataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
image_dir_path,
|
15 |
+
annotations_path,
|
16 |
+
is_flickr=False,
|
17 |
+
):
|
18 |
+
self.image_dir_path = image_dir_path
|
19 |
+
self.annotations = json.load(open(annotations_path))["annotations"]
|
20 |
+
self.is_flickr = is_flickr
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.annotations)
|
24 |
+
|
25 |
+
def get_img_path(self, idx):
|
26 |
+
if self.is_flickr:
|
27 |
+
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
|
28 |
+
else:
|
29 |
+
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg"
|
30 |
+
|
31 |
+
def __getitem__(self, idx):
|
32 |
+
image = Image.open(self.get_img_path(idx))
|
33 |
+
caption = self.annotations[idx]["caption"]
|
34 |
+
return {
|
35 |
+
"image": image,
|
36 |
+
"caption": caption,
|
37 |
+
"image_id": self.annotations[idx]["image_id"],
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class VQADataset(Dataset):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
|
45 |
+
question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
|
46 |
+
annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
|
47 |
+
vqa_dataset="vqa",
|
48 |
+
):
|
49 |
+
self.questions = json.load(open(question_path, "r"))["questions"]
|
50 |
+
self.answers = json.load(open(annotations_path, "r"))["annotations"]
|
51 |
+
self.image_dir_path = image_dir_path
|
52 |
+
self.vqa_dataset = vqa_dataset
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.questions)
|
56 |
+
|
57 |
+
def get_img_path(self, question):
|
58 |
+
if self.vqa_dataset == "vqa":
|
59 |
+
return os.path.join(
|
60 |
+
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
|
61 |
+
)
|
62 |
+
elif self.vqa_dataset == "ok_vqa":
|
63 |
+
return os.path.join(
|
64 |
+
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
question = self.questions[idx]
|
71 |
+
answers = self.answers[idx]
|
72 |
+
img_path = self.get_img_path(question)
|
73 |
+
image = Image.open(img_path)
|
74 |
+
return {
|
75 |
+
"image": image,
|
76 |
+
"question": question["question"],
|
77 |
+
"answers": [a["answer"] for a in answers["answers"]],
|
78 |
+
"question_id": question["question_id"],
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
class ImageNetDataset(ImageFolder):
|
83 |
+
"""Class to represent the ImageNet1k dataset."""
|
84 |
+
|
85 |
+
def __init__(self, root, **kwargs):
|
86 |
+
super().__init__(root=root, **kwargs)
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
sample, target = super().__getitem__(idx)
|
90 |
+
target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
|
91 |
+
return {
|
92 |
+
"image": sample,
|
93 |
+
"class_id": target, # numeric ID of the ImageNet class
|
94 |
+
"class_name": target_label, # human-readable name of ImageNet class
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
class GQADataset(Dataset):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
image_dir_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/images",
|
102 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/val_balanced_questions.json",
|
103 |
+
):
|
104 |
+
annotations = json.load(open(annotations_path))
|
105 |
+
self.questions = []
|
106 |
+
self.answers = []
|
107 |
+
self.image_paths = []
|
108 |
+
self.question_ids = []
|
109 |
+
for anno_id in annotations:
|
110 |
+
question = annotations[anno_id]["question"]
|
111 |
+
imageId = annotations[anno_id]["imageId"]
|
112 |
+
answer = annotations[anno_id]["answer"]
|
113 |
+
self.questions.append(question)
|
114 |
+
self.answers.append(answer)
|
115 |
+
self.image_paths.append(os.path.join(image_dir_path, "{}.jpg".format(imageId)))
|
116 |
+
self.question_ids.append(anno_id)
|
117 |
+
self.vqa_dataset = "gqa"
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
return len(self.questions)
|
121 |
+
|
122 |
+
def __getitem__(self, idx):
|
123 |
+
question = self.questions[idx]
|
124 |
+
question_id = self.question_ids[idx]
|
125 |
+
answer = self.answers[idx]
|
126 |
+
img_path = self.image_paths[idx]
|
127 |
+
image = Image.open(img_path)
|
128 |
+
return {
|
129 |
+
"image": image,
|
130 |
+
"question": question,
|
131 |
+
"answers": answer,
|
132 |
+
"question_id": question_id,
|
133 |
+
}
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
gqa_dataset = GQADataset()
|
137 |
+
for sample in gqa_dataset:
|
138 |
+
print(sample)
|
open_flamingo/open_flamingo/eval/evaluate.py
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
|
11 |
+
import more_itertools
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
15 |
+
from eval_datasets import VQADataset, GQADataset
|
16 |
+
from tqdm import tqdm
|
17 |
+
from collections import Counter
|
18 |
+
|
19 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
20 |
+
from open_flamingo.eval.classification import (
|
21 |
+
compute_per_sample_probs,
|
22 |
+
compute_per_sample_loss,
|
23 |
+
)
|
24 |
+
from open_flamingo.eval.imagenet_utils import (
|
25 |
+
openai_imagenet_classnames,
|
26 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
27 |
+
)
|
28 |
+
|
29 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
30 |
+
from PIL import Image
|
31 |
+
from io import BytesIO
|
32 |
+
import base64
|
33 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
34 |
+
import string
|
35 |
+
from lavis.datasets.builders import load_dataset
|
36 |
+
|
37 |
+
|
38 |
+
def get_iou(box1, box2):
|
39 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
40 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
41 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
42 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
43 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
44 |
+
union = area_box1 + area_box2 - intersection
|
45 |
+
iou = intersection / union if union > 0 else 0
|
46 |
+
return iou
|
47 |
+
|
48 |
+
def expand2square(pil_img, background_color):
|
49 |
+
width, height = pil_img.size
|
50 |
+
if width == height:
|
51 |
+
return pil_img
|
52 |
+
elif width > height:
|
53 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
54 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
55 |
+
return result
|
56 |
+
else:
|
57 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
58 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
59 |
+
return result
|
60 |
+
|
61 |
+
parser = argparse.ArgumentParser()
|
62 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
63 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
64 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
65 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
66 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
67 |
+
parser.add_argument(
|
68 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
69 |
+
)
|
70 |
+
|
71 |
+
# Trial arguments
|
72 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
73 |
+
parser.add_argument(
|
74 |
+
"--num_trials",
|
75 |
+
type=int,
|
76 |
+
default=1,
|
77 |
+
help="Number of trials to run for each shot using different demonstrations",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--trial_seeds",
|
81 |
+
nargs="+",
|
82 |
+
default=[0],
|
83 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
87 |
+
)
|
88 |
+
|
89 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
90 |
+
|
91 |
+
# Per-dataset evaluation flags
|
92 |
+
parser.add_argument(
|
93 |
+
"--eval_coco",
|
94 |
+
action="store_true",
|
95 |
+
default=False,
|
96 |
+
help="Whether to evaluate on COCO.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--eval_vqav2",
|
100 |
+
action="store_true",
|
101 |
+
default=False,
|
102 |
+
help="Whether to evaluate on VQAV2.",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--eval_ok_vqa",
|
106 |
+
action="store_true",
|
107 |
+
default=False,
|
108 |
+
help="Whether to evaluate on OK-VQA.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--eval_imagenet",
|
112 |
+
action="store_true",
|
113 |
+
default=False,
|
114 |
+
help="Whether to evaluate on ImageNet.",
|
115 |
+
)
|
116 |
+
|
117 |
+
parser.add_argument(
|
118 |
+
"--eval_flickr30",
|
119 |
+
action="store_true",
|
120 |
+
default=False,
|
121 |
+
help="Whether to evaluate on Flickr30.",
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--eval_refcoco",
|
126 |
+
action="store_true",
|
127 |
+
default=False,
|
128 |
+
help="Whether to evaluate on RefCOCO.",
|
129 |
+
)
|
130 |
+
|
131 |
+
# Dataset arguments
|
132 |
+
|
133 |
+
## Flickr30 Dataset
|
134 |
+
parser.add_argument(
|
135 |
+
"--flickr_image_dir_path",
|
136 |
+
type=str,
|
137 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
138 |
+
default=None,
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--flickr_annotations_json_path",
|
142 |
+
type=str,
|
143 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
144 |
+
default=None,
|
145 |
+
)
|
146 |
+
|
147 |
+
## COCO Dataset
|
148 |
+
parser.add_argument(
|
149 |
+
"--coco_image_dir_path",
|
150 |
+
type=str,
|
151 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
152 |
+
default=None,
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--coco_annotations_json_path",
|
156 |
+
type=str,
|
157 |
+
default=None,
|
158 |
+
)
|
159 |
+
|
160 |
+
## VQAV2 Dataset
|
161 |
+
parser.add_argument(
|
162 |
+
"--vqav2_image_dir_path",
|
163 |
+
type=str,
|
164 |
+
default=None,
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--vqav2_questions_json_path",
|
168 |
+
type=str,
|
169 |
+
default=None,
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--vqav2_annotations_json_path",
|
173 |
+
type=str,
|
174 |
+
default=None,
|
175 |
+
)
|
176 |
+
|
177 |
+
## OK-VQA Dataset
|
178 |
+
parser.add_argument(
|
179 |
+
"--ok_vqa_image_dir_path",
|
180 |
+
type=str,
|
181 |
+
help="Path to the vqav2/train2014 directory.",
|
182 |
+
default=None,
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--ok_vqa_questions_json_path",
|
186 |
+
type=str,
|
187 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
188 |
+
default=None,
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--ok_vqa_annotations_json_path",
|
192 |
+
type=str,
|
193 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
194 |
+
default=None,
|
195 |
+
)
|
196 |
+
|
197 |
+
## Imagenet dataset
|
198 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
199 |
+
|
200 |
+
## RefCOCO dataset
|
201 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
202 |
+
|
203 |
+
parser.add_argument(
|
204 |
+
"--add_visual_grounding",
|
205 |
+
default=False,
|
206 |
+
action="store_true",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--location_token_num",
|
210 |
+
default=1000,
|
211 |
+
type=int,
|
212 |
+
)
|
213 |
+
# distributed training
|
214 |
+
parser.add_argument(
|
215 |
+
"--dist-url",
|
216 |
+
default="env://",
|
217 |
+
type=str,
|
218 |
+
help="url used to set up distributed training",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--horovod",
|
225 |
+
default=False,
|
226 |
+
action="store_true",
|
227 |
+
help="Use horovod for distributed training.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--no-set-device-rank",
|
231 |
+
default=False,
|
232 |
+
action="store_true",
|
233 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--dist",
|
237 |
+
default=False,
|
238 |
+
action="store_true",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--lora",
|
242 |
+
default=False,
|
243 |
+
action="store_true",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--lora_r",
|
247 |
+
default=16,
|
248 |
+
type=int,
|
249 |
+
required=False,
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--legacy",
|
253 |
+
default=False,
|
254 |
+
action="store_true",
|
255 |
+
)
|
256 |
+
parser.add_argument(
|
257 |
+
"--special",
|
258 |
+
default=False,
|
259 |
+
action="store_true",
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--id",
|
263 |
+
default=0,
|
264 |
+
type=int,
|
265 |
+
required=False,
|
266 |
+
)
|
267 |
+
|
268 |
+
parser.add_argument(
|
269 |
+
"--eval_gqa",
|
270 |
+
default=False,
|
271 |
+
action="store_true",
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--use_sam",
|
275 |
+
default=None,
|
276 |
+
type=str,
|
277 |
+
required=False,
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--add_visual_token",
|
281 |
+
default=False,
|
282 |
+
action="store_true",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--use_format_v2",
|
286 |
+
default=False,
|
287 |
+
action="store_true",
|
288 |
+
)
|
289 |
+
|
290 |
+
|
291 |
+
class OKVQAPostProcess():
|
292 |
+
def __init__(self):
|
293 |
+
self._lemmatizer = None
|
294 |
+
|
295 |
+
def _lemmatize(self, answers):
|
296 |
+
def apply(answer):
|
297 |
+
doc = self.lemmatizer(answer)
|
298 |
+
|
299 |
+
words = []
|
300 |
+
for token in doc:
|
301 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
302 |
+
words.append(token.lemma_)
|
303 |
+
else:
|
304 |
+
words.append(token.text)
|
305 |
+
answer = " ".join(words)
|
306 |
+
|
307 |
+
return answer
|
308 |
+
|
309 |
+
return [apply(answer) for answer in answers]
|
310 |
+
|
311 |
+
@property
|
312 |
+
def lemmatizer(self):
|
313 |
+
if self._lemmatizer is None:
|
314 |
+
try:
|
315 |
+
import spacy
|
316 |
+
|
317 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
318 |
+
except ImportError:
|
319 |
+
logging.error(
|
320 |
+
"""
|
321 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
322 |
+
python -m spacy download en_core_web_sm
|
323 |
+
OR
|
324 |
+
import spacy.cli
|
325 |
+
spacy.cli.download("en_core_web_sm")
|
326 |
+
"""
|
327 |
+
)
|
328 |
+
exit(1)
|
329 |
+
|
330 |
+
return self._lemmatizer
|
331 |
+
|
332 |
+
|
333 |
+
def main():
|
334 |
+
args = parser.parse_args()
|
335 |
+
if args.dist:
|
336 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
337 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
338 |
+
device_id = init_distributed_device(args)
|
339 |
+
else:
|
340 |
+
args.rank = 0
|
341 |
+
args.world_size = 1
|
342 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
343 |
+
|
344 |
+
if "sam" in args.checkpoint_path:
|
345 |
+
args.use_sam = "vit_l"
|
346 |
+
if (
|
347 |
+
"ground" in args.checkpoint_path or
|
348 |
+
"all" in args.checkpoint_path or
|
349 |
+
"sam" in args.checkpoint_path
|
350 |
+
):
|
351 |
+
args.add_visual_grounding = True
|
352 |
+
|
353 |
+
if "visual" in args.checkpoint_path:
|
354 |
+
args.add_visual_token = True
|
355 |
+
if "formatV2" in args.checkpoint_path:
|
356 |
+
args.use_format_v2 = True
|
357 |
+
|
358 |
+
# load model
|
359 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
360 |
+
args.vision_encoder_path,
|
361 |
+
args.vision_encoder_pretrained,
|
362 |
+
args.lm_path,
|
363 |
+
args.lm_tokenizer_path,
|
364 |
+
add_visual_grounding=args.add_visual_grounding,
|
365 |
+
location_token_num=args.location_token_num,
|
366 |
+
lora=args.lora,
|
367 |
+
lora_r=args.lora_r,
|
368 |
+
use_sam=args.use_sam,
|
369 |
+
add_visual_token=args.add_visual_token,
|
370 |
+
use_format_v2=args.use_format_v2,
|
371 |
+
)
|
372 |
+
flamingo.use_format_v2 = args.use_format_v2
|
373 |
+
if args.special:
|
374 |
+
flamingo.special = True
|
375 |
+
else:
|
376 |
+
flamingo.special = False
|
377 |
+
if args.legacy:
|
378 |
+
flamingo.legacy = True
|
379 |
+
print("use legacy evaluation")
|
380 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
381 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
382 |
+
if args.rank == 0:
|
383 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
384 |
+
print("step:", flamingo.step_num)
|
385 |
+
print("expr:", flamingo.expr_name)
|
386 |
+
print("use format v2:", flamingo.use_format_v2)
|
387 |
+
print(args)
|
388 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
389 |
+
model_state_dict = {}
|
390 |
+
for key in checkpoint["model_state_dict"].keys():
|
391 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
392 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
393 |
+
# previous checkpoint has some unnecessary weights
|
394 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
395 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
396 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
397 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
398 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
399 |
+
results = defaultdict(list)
|
400 |
+
if args.eval_coco:
|
401 |
+
print("Evaluating on COCO...")
|
402 |
+
for shot in args.shots:
|
403 |
+
scores = []
|
404 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
405 |
+
cider_score = evaluate_coco_flickr(
|
406 |
+
model=flamingo,
|
407 |
+
tokenizer=tokenizer,
|
408 |
+
image_processor=image_processor,
|
409 |
+
batch_size=args.batch_size,
|
410 |
+
image_dir_path=args.coco_image_dir_path,
|
411 |
+
annotations_json_path=args.coco_annotations_json_path,
|
412 |
+
device=args.device,
|
413 |
+
seed=seed,
|
414 |
+
vis_embed_size=vis_embed_size,
|
415 |
+
rank=args.rank,
|
416 |
+
world_size=args.world_size,
|
417 |
+
id=args.id,
|
418 |
+
)
|
419 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
420 |
+
scores.append(cider_score)
|
421 |
+
print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
|
422 |
+
results["coco"].append(
|
423 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
424 |
+
)
|
425 |
+
|
426 |
+
if args.eval_ok_vqa:
|
427 |
+
print("Evaluating on OK-VQA...")
|
428 |
+
for shot in args.shots:
|
429 |
+
scores = []
|
430 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
431 |
+
ok_vqa_score = evaluate_vqa(
|
432 |
+
model=flamingo,
|
433 |
+
tokenizer=tokenizer,
|
434 |
+
image_processor=image_processor,
|
435 |
+
batch_size=args.batch_size,
|
436 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
437 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
438 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
439 |
+
vqa_dataset="ok_vqa",
|
440 |
+
vis_embed_size=vis_embed_size,
|
441 |
+
rank=args.rank,
|
442 |
+
world_size=args.world_size,
|
443 |
+
id=args.id,
|
444 |
+
)
|
445 |
+
results["ok_vqa"].append(
|
446 |
+
{"shots": shot, "score": ok_vqa_score}
|
447 |
+
)
|
448 |
+
|
449 |
+
if args.eval_vqav2:
|
450 |
+
print("Evaluating on VQAv2...")
|
451 |
+
for shot in args.shots:
|
452 |
+
scores = []
|
453 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
454 |
+
vqa_score = evaluate_vqa(
|
455 |
+
model=flamingo,
|
456 |
+
tokenizer=tokenizer,
|
457 |
+
image_processor=image_processor,
|
458 |
+
batch_size=args.batch_size,
|
459 |
+
image_dir_path=args.vqav2_image_dir_path,
|
460 |
+
questions_json_path=args.vqav2_questions_json_path,
|
461 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
462 |
+
vqa_dataset="vqa",
|
463 |
+
vis_embed_size=vis_embed_size,
|
464 |
+
rank=args.rank,
|
465 |
+
world_size=args.world_size,
|
466 |
+
id=args.id,
|
467 |
+
)
|
468 |
+
results["vqav2"].append(
|
469 |
+
{"shots": shot, "score": vqa_score}
|
470 |
+
)
|
471 |
+
|
472 |
+
if args.eval_gqa:
|
473 |
+
print("Evaluating on GQA...")
|
474 |
+
for shot in args.shots:
|
475 |
+
scores = []
|
476 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
477 |
+
vqa_score = evaluate_vqa(
|
478 |
+
model=flamingo,
|
479 |
+
tokenizer=tokenizer,
|
480 |
+
image_processor=image_processor,
|
481 |
+
batch_size=args.batch_size,
|
482 |
+
vqa_dataset="gqa",
|
483 |
+
vis_embed_size=vis_embed_size,
|
484 |
+
rank=args.rank,
|
485 |
+
world_size=args.world_size,
|
486 |
+
id=args.id,
|
487 |
+
)
|
488 |
+
results["gqa"].append(
|
489 |
+
{"shots": shot, "score": vqa_score}
|
490 |
+
)
|
491 |
+
|
492 |
+
if args.eval_imagenet:
|
493 |
+
print("Evaluating on ImageNet...")
|
494 |
+
for shot in args.shots:
|
495 |
+
scores = []
|
496 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
497 |
+
imagenet_score = evaluate_imagenet(
|
498 |
+
model=flamingo,
|
499 |
+
tokenizer=tokenizer,
|
500 |
+
image_processor=image_processor,
|
501 |
+
batch_size=args.batch_size,
|
502 |
+
num_samples=args.num_samples,
|
503 |
+
num_shots=shot,
|
504 |
+
device=args.device,
|
505 |
+
seed=seed,
|
506 |
+
imagenet_root=args.imagenet_root,
|
507 |
+
)
|
508 |
+
print(
|
509 |
+
f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
|
510 |
+
)
|
511 |
+
scores.append(imagenet_score)
|
512 |
+
print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
|
513 |
+
results["imagenet"].append(
|
514 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
515 |
+
)
|
516 |
+
|
517 |
+
if args.eval_refcoco:
|
518 |
+
print("Evaluating on RefCOCO...")
|
519 |
+
refcoco_score = evaluate_refcoco(
|
520 |
+
model=flamingo,
|
521 |
+
tokenizer=tokenizer,
|
522 |
+
image_processor=image_processor,
|
523 |
+
batch_size=args.batch_size,
|
524 |
+
device=args.device,
|
525 |
+
tsvfile=args.refcoco_tsvfile,
|
526 |
+
vis_embed_size=vis_embed_size,
|
527 |
+
rank=args.rank,
|
528 |
+
world_size=args.world_size,
|
529 |
+
id=args.id,
|
530 |
+
)
|
531 |
+
results["refcoco"].append(
|
532 |
+
{"score": refcoco_score}
|
533 |
+
)
|
534 |
+
|
535 |
+
def prepare_batch_images(batch, image_processor):
|
536 |
+
batch_images = None
|
537 |
+
for b in batch:
|
538 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
539 |
+
if batch_images is None:
|
540 |
+
batch_images = b_image
|
541 |
+
else:
|
542 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
543 |
+
return batch_images
|
544 |
+
|
545 |
+
def get_outputs(
|
546 |
+
model,
|
547 |
+
batch_images,
|
548 |
+
attention_mask,
|
549 |
+
max_generation_length,
|
550 |
+
min_generation_length,
|
551 |
+
num_beams,
|
552 |
+
length_penalty,
|
553 |
+
input_ids,
|
554 |
+
image_start_index_list=None,
|
555 |
+
image_nums=None,
|
556 |
+
bad_words_ids=None,
|
557 |
+
):
|
558 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
559 |
+
outputs = model.generate(
|
560 |
+
batch_images,
|
561 |
+
input_ids,
|
562 |
+
attention_mask=attention_mask,
|
563 |
+
max_new_tokens=max_generation_length,
|
564 |
+
min_length=min_generation_length,
|
565 |
+
num_beams=num_beams,
|
566 |
+
length_penalty=length_penalty,
|
567 |
+
image_start_index_list=image_start_index_list,
|
568 |
+
image_nums=image_nums,
|
569 |
+
bad_words_ids=bad_words_ids,
|
570 |
+
)
|
571 |
+
|
572 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
573 |
+
return outputs
|
574 |
+
|
575 |
+
|
576 |
+
def evaluate_coco_flickr(
|
577 |
+
model,
|
578 |
+
tokenizer,
|
579 |
+
image_processor,
|
580 |
+
batch_size,
|
581 |
+
image_dir_path,
|
582 |
+
annotations_json_path,
|
583 |
+
seed=42,
|
584 |
+
max_generation_length=20,
|
585 |
+
num_beams=1,
|
586 |
+
length_penalty=-2.0,
|
587 |
+
device=-1,
|
588 |
+
is_flickr=False,
|
589 |
+
vis_embed_size=None,
|
590 |
+
rank=0,
|
591 |
+
world_size=1,
|
592 |
+
id=0,
|
593 |
+
):
|
594 |
+
"""Evaluate a model on COCO dataset.
|
595 |
+
|
596 |
+
Args:
|
597 |
+
model (nn.Module): model to evaluate
|
598 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
599 |
+
image_processor : image processor for the model
|
600 |
+
batch_size (int): batch size
|
601 |
+
image_dir_path (str, optional): path to the directory containing the images.
|
602 |
+
annotations_json_path (str, optional): path to the json file containing the annotations.
|
603 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
604 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
|
605 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
606 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
607 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
|
608 |
+
query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
|
609 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
610 |
+
device (int, optional): device to use. Defaults to -1.
|
611 |
+
num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
|
612 |
+
is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
float: CIDEr score
|
616 |
+
|
617 |
+
"""
|
618 |
+
# eval_dataset = COCOFlickrDataset(
|
619 |
+
# image_dir_path=image_dir_path,
|
620 |
+
# annotations_path=annotations_json_path,
|
621 |
+
# is_flickr=is_flickr,
|
622 |
+
# )
|
623 |
+
coco_dataset = load_dataset("coco_caption")
|
624 |
+
eval_dataset = coco_dataset["test"]
|
625 |
+
|
626 |
+
|
627 |
+
model.eval().cuda()
|
628 |
+
predictions = defaultdict()
|
629 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
630 |
+
# if "peft" in lang_encoder_name:
|
631 |
+
# lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
632 |
+
try:
|
633 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
634 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
635 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
636 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
637 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
638 |
+
except:
|
639 |
+
pass
|
640 |
+
|
641 |
+
def get_prompt(sample):
|
642 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
643 |
+
|
644 |
+
tokenizer.padding_side = "left"
|
645 |
+
cnt = 0
|
646 |
+
if world_size > 1:
|
647 |
+
torch.distributed.barrier()
|
648 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
649 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
650 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
651 |
+
)):
|
652 |
+
if ii % world_size != rank:
|
653 |
+
continue
|
654 |
+
cnt += len(batch)
|
655 |
+
batch_images = prepare_batch_images(
|
656 |
+
batch=batch,
|
657 |
+
image_processor=image_processor,
|
658 |
+
).cuda()
|
659 |
+
batch_text = [get_prompt(s) for s in batch]
|
660 |
+
encodings = tokenizer(
|
661 |
+
batch_text,
|
662 |
+
padding="longest",
|
663 |
+
truncation=True,
|
664 |
+
return_tensors="pt",
|
665 |
+
max_length=2000,
|
666 |
+
)
|
667 |
+
input_ids = encodings["input_ids"].cuda()
|
668 |
+
attention_mask = encodings["attention_mask"].cuda()
|
669 |
+
skip_special_tokens = False
|
670 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
671 |
+
if rank == 0:
|
672 |
+
tqdm.write("use legacy model")
|
673 |
+
skip_special_tokens = True
|
674 |
+
for i in range(len(input_ids)):
|
675 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
676 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
677 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
678 |
+
input_ids[i, media_token_index] = pad_token_id
|
679 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
680 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
681 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
682 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
683 |
+
image_nums = [1] * len(input_ids)
|
684 |
+
if "llama" in lang_encoder_name:
|
685 |
+
attention_mask[input_ids == 0] = 0
|
686 |
+
outputs = get_outputs(
|
687 |
+
model=model,
|
688 |
+
batch_images=batch_images,
|
689 |
+
attention_mask=attention_mask,
|
690 |
+
max_generation_length=30,
|
691 |
+
min_generation_length=8,
|
692 |
+
num_beams=5,
|
693 |
+
length_penalty=0,
|
694 |
+
input_ids=input_ids,
|
695 |
+
image_start_index_list=image_start_index_list,
|
696 |
+
image_nums=image_nums,
|
697 |
+
)
|
698 |
+
new_predictions = [
|
699 |
+
postprocess_captioning_generation(out).replace('"', "")
|
700 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
701 |
+
]
|
702 |
+
# if rank == 0:
|
703 |
+
# tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
|
704 |
+
|
705 |
+
for i, sample in enumerate(batch):
|
706 |
+
predictions[int(sample["image_id"])] = {
|
707 |
+
"caption": new_predictions[i],
|
708 |
+
}
|
709 |
+
results_path = (
|
710 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
711 |
+
if is_flickr
|
712 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
713 |
+
)
|
714 |
+
with open(results_path, "w") as f:
|
715 |
+
f.write(
|
716 |
+
json.dumps(
|
717 |
+
[
|
718 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
719 |
+
for k in predictions
|
720 |
+
],
|
721 |
+
indent=2,
|
722 |
+
)
|
723 |
+
)
|
724 |
+
print("save to", results_path)
|
725 |
+
del predictions
|
726 |
+
time.sleep(10)
|
727 |
+
if world_size > 1:
|
728 |
+
torch.distributed.barrier()
|
729 |
+
if rank == 0:
|
730 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
731 |
+
predictions = []
|
732 |
+
for rank_i in range(world_size):
|
733 |
+
part_results_path = (
|
734 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
735 |
+
if is_flickr
|
736 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
737 |
+
)
|
738 |
+
print("load", part_results_path)
|
739 |
+
predictions.extend(json.load(open(part_results_path)))
|
740 |
+
os.remove(part_results_path)
|
741 |
+
print("num:", len(predictions))
|
742 |
+
results_path = (
|
743 |
+
f"flickrresults_{lang_encoder_name}.json"
|
744 |
+
if is_flickr
|
745 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
746 |
+
)
|
747 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
748 |
+
|
749 |
+
metrics = compute_cider(
|
750 |
+
result_path=results_path,
|
751 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
752 |
+
)
|
753 |
+
os.makedirs("eval_results", exist_ok=True)
|
754 |
+
acc = metrics["CIDEr"]
|
755 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
756 |
+
f.write(json.dumps(predictions, indent=2))
|
757 |
+
|
758 |
+
# delete the temporary file
|
759 |
+
os.remove(results_path)
|
760 |
+
else:
|
761 |
+
metrics = {}
|
762 |
+
metrics["CIDEr"] = 0.0
|
763 |
+
|
764 |
+
return metrics["CIDEr"]
|
765 |
+
|
766 |
+
|
767 |
+
def evaluate_vqa(
|
768 |
+
model,
|
769 |
+
tokenizer,
|
770 |
+
image_processor,
|
771 |
+
batch_size,
|
772 |
+
image_dir_path=None,
|
773 |
+
questions_json_path=None,
|
774 |
+
annotations_json_path=None,
|
775 |
+
vqa_dataset="vqa",
|
776 |
+
vis_embed_size=None,
|
777 |
+
rank=0,
|
778 |
+
world_size=1,
|
779 |
+
id=0,
|
780 |
+
):
|
781 |
+
"""
|
782 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
783 |
+
|
784 |
+
Args:
|
785 |
+
model (nn.Module): model to evaluate
|
786 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
787 |
+
image_processor : image processor for the model
|
788 |
+
batch_size (int): batch size
|
789 |
+
image_dir_path (str): path to image directory
|
790 |
+
questions_json_path (str): path to questions json file
|
791 |
+
annotations_json_path (str): path to annotations json file
|
792 |
+
seed (int, optional): random seed. Defaults to 42.
|
793 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
794 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
795 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
796 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
797 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
798 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
799 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
800 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
801 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
802 |
+
Returns:
|
803 |
+
float: accuracy score
|
804 |
+
"""
|
805 |
+
if world_size > 1:
|
806 |
+
torch.distributed.barrier()
|
807 |
+
if vqa_dataset == "gqa":
|
808 |
+
eval_dataset = GQADataset()
|
809 |
+
else:
|
810 |
+
eval_dataset = VQADataset(
|
811 |
+
image_dir_path=image_dir_path,
|
812 |
+
question_path=questions_json_path,
|
813 |
+
annotations_path=annotations_json_path,
|
814 |
+
vqa_dataset=vqa_dataset,
|
815 |
+
)
|
816 |
+
postprocessor = OKVQAPostProcess()
|
817 |
+
try:
|
818 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
819 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
820 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
821 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
822 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
823 |
+
except:
|
824 |
+
pass
|
825 |
+
def get_prompt(sample):
|
826 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
827 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
828 |
+
|
829 |
+
model.eval().cuda()
|
830 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
831 |
+
if "peft" in lang_encoder_name:
|
832 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
833 |
+
predictions = []
|
834 |
+
tokenizer.padding_side = "left"
|
835 |
+
if world_size > 1:
|
836 |
+
torch.distributed.barrier()
|
837 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
838 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
839 |
+
)):
|
840 |
+
if ii % world_size != rank:
|
841 |
+
continue
|
842 |
+
batch_images = prepare_batch_images(
|
843 |
+
batch=batch,
|
844 |
+
image_processor=image_processor,
|
845 |
+
).cuda()
|
846 |
+
batch_text = [get_prompt(s) for s in batch]
|
847 |
+
encodings = tokenizer(
|
848 |
+
batch_text,
|
849 |
+
return_tensors="pt",
|
850 |
+
padding="longest",
|
851 |
+
truncation=True,
|
852 |
+
max_length=2000,
|
853 |
+
)
|
854 |
+
input_ids = encodings["input_ids"].cuda()
|
855 |
+
attention_mask = encodings["attention_mask"].cuda()
|
856 |
+
skip_special_tokens = True
|
857 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
858 |
+
if rank == 0:
|
859 |
+
tqdm.write("use legacy model")
|
860 |
+
for i in range(len(input_ids)):
|
861 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
862 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
863 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
864 |
+
input_ids[i, media_token_index] = pad_token_id
|
865 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
866 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
867 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
868 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
869 |
+
image_nums = [1] * len(input_ids)
|
870 |
+
if "llama" in lang_encoder_name:
|
871 |
+
attention_mask[input_ids == 0] = 0
|
872 |
+
outputs = get_outputs(
|
873 |
+
model=model,
|
874 |
+
batch_images=batch_images,
|
875 |
+
attention_mask=attention_mask,
|
876 |
+
max_generation_length=10,
|
877 |
+
min_generation_length=1,
|
878 |
+
num_beams=5,
|
879 |
+
length_penalty=0,
|
880 |
+
input_ids=input_ids,
|
881 |
+
image_start_index_list=image_start_index_list,
|
882 |
+
image_nums=image_nums,
|
883 |
+
)
|
884 |
+
# postprocess begin
|
885 |
+
new_predictions = [
|
886 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
887 |
+
]
|
888 |
+
if vqa_dataset == "ok_vqa":
|
889 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
890 |
+
if model.special:
|
891 |
+
for i in range(len(new_predictions)):
|
892 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
893 |
+
if answer in new_predictions[i]:
|
894 |
+
new_predictions[i] = answer
|
895 |
+
break
|
896 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
897 |
+
new_predictions[i] = answer
|
898 |
+
break
|
899 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
900 |
+
new_predictions[i] = answer
|
901 |
+
break
|
902 |
+
|
903 |
+
# if rank == 0:
|
904 |
+
# tqdm.write(f"{image_nums} {image_start_index_list}")
|
905 |
+
# for i in range(1):
|
906 |
+
# tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
|
907 |
+
# tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
|
908 |
+
# tqdm.write("model output: " + new_predictions[i])
|
909 |
+
|
910 |
+
predictions.extend(
|
911 |
+
[
|
912 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
913 |
+
for p, sample in zip(new_predictions, batch)
|
914 |
+
]
|
915 |
+
)
|
916 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
917 |
+
f.write(json.dumps(predictions))
|
918 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
919 |
+
|
920 |
+
time.sleep(10)
|
921 |
+
if world_size > 1:
|
922 |
+
torch.distributed.barrier()
|
923 |
+
if rank == 0:
|
924 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
925 |
+
predictions = []
|
926 |
+
for rank_i in range(world_size):
|
927 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
928 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
929 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
930 |
+
print("num:", len(predictions))
|
931 |
+
# save the predictions to a temporary file
|
932 |
+
random_uuid = str(uuid.uuid4())
|
933 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
934 |
+
f.write(json.dumps(predictions, indent=4))
|
935 |
+
|
936 |
+
if vqa_dataset == "gqa":
|
937 |
+
acc = compute_gqa_accuracy(predictions)
|
938 |
+
else:
|
939 |
+
acc = compute_vqa_accuracy(
|
940 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
941 |
+
questions_json_path,
|
942 |
+
annotations_json_path,
|
943 |
+
vqa_dataset=vqa_dataset,
|
944 |
+
)
|
945 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
946 |
+
os.makedirs("eval_results", exist_ok=True)
|
947 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
948 |
+
f.write(json.dumps(predictions, indent=2))
|
949 |
+
|
950 |
+
# delete the temporary file
|
951 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
952 |
+
else:
|
953 |
+
time.sleep(5)
|
954 |
+
acc = 0.0
|
955 |
+
if world_size > 1:
|
956 |
+
torch.distributed.barrier()
|
957 |
+
return acc
|
958 |
+
|
959 |
+
|
960 |
+
def evaluate_refcoco(
|
961 |
+
model,
|
962 |
+
tokenizer,
|
963 |
+
image_processor,
|
964 |
+
batch_size,
|
965 |
+
tsvfile,
|
966 |
+
max_generation_length=20,
|
967 |
+
num_beams=3,
|
968 |
+
length_penalty=-2.0,
|
969 |
+
device=-1,
|
970 |
+
vis_embed_size=None,
|
971 |
+
rank=0,
|
972 |
+
world_size=1,
|
973 |
+
id=0,
|
974 |
+
):
|
975 |
+
model.eval().cuda()
|
976 |
+
loc_token_ids = []
|
977 |
+
for i in range(1000):
|
978 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
979 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
980 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
981 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
982 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
983 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
984 |
+
all_ids = set(range(model.lang_encoder.lm_head.out_features))
|
985 |
+
bad_words_ids = list(all_ids - set(loc_token_ids))
|
986 |
+
bad_words_ids = [[b] for b in bad_words_ids]
|
987 |
+
min_loc_token_id = min(loc_token_ids)
|
988 |
+
max_loc_token_id = max(loc_token_ids)
|
989 |
+
total = 0
|
990 |
+
correct = 0
|
991 |
+
ious = []
|
992 |
+
if "refcocog" in tsvfile:
|
993 |
+
dataset_name = "refcocog"
|
994 |
+
elif "refcocoplus" in tsvfile:
|
995 |
+
dataset_name = "refcocoplus"
|
996 |
+
else:
|
997 |
+
dataset_name = "refcoco"
|
998 |
+
with open(tsvfile, "r") as f:
|
999 |
+
lines = f.readlines()
|
1000 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
1001 |
+
for ii, line in enumerate(pbar):
|
1002 |
+
if ii % world_size != rank:
|
1003 |
+
continue
|
1004 |
+
total += 1
|
1005 |
+
line = line.rstrip()
|
1006 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
1007 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
1008 |
+
gt_box = np.array(list(map(float, region_coord.split(","))))
|
1009 |
+
width = image.width
|
1010 |
+
height = image.height
|
1011 |
+
image = image.resize((224, 224))
|
1012 |
+
gt_box /= np.array([width, height, width, height])
|
1013 |
+
# extra = abs(height - width) // 2
|
1014 |
+
# if width > height:
|
1015 |
+
# gt_box += np.array([0, extra, 0, extra])
|
1016 |
+
# else:
|
1017 |
+
# gt_box += np.array([extra, 0, extra, 0])
|
1018 |
+
# size = max(width, height)
|
1019 |
+
# gt_box /= np.array([size, size, size, size])
|
1020 |
+
# image = expand2square(image, (255, 255, 255))
|
1021 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1022 |
+
if model.use_format_v2:
|
1023 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#loc#|>"]
|
1024 |
+
else:
|
1025 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}"]
|
1026 |
+
encodings = tokenizer(
|
1027 |
+
prompt,
|
1028 |
+
padding="longest",
|
1029 |
+
truncation=True,
|
1030 |
+
return_tensors="pt",
|
1031 |
+
max_length=2000,
|
1032 |
+
)
|
1033 |
+
input_ids = encodings["input_ids"]
|
1034 |
+
attention_mask = encodings["attention_mask"]
|
1035 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1036 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1037 |
+
image_nums = [1] * len(input_ids)
|
1038 |
+
outputs = get_outputs(
|
1039 |
+
model=model,
|
1040 |
+
batch_images=batch_images.cuda(),
|
1041 |
+
attention_mask=attention_mask.cuda(),
|
1042 |
+
max_generation_length=5,
|
1043 |
+
min_generation_length=4,
|
1044 |
+
num_beams=1,
|
1045 |
+
length_penalty=1.0,
|
1046 |
+
input_ids=input_ids.cuda(),
|
1047 |
+
bad_words_ids=bad_words_ids,
|
1048 |
+
image_start_index_list=image_start_index_list,
|
1049 |
+
image_nums=image_nums,
|
1050 |
+
)
|
1051 |
+
box = []
|
1052 |
+
for o in outputs[0]:
|
1053 |
+
if o >= min_loc_token_id and o <= max_loc_token_id:
|
1054 |
+
box.append(o.item() - min_loc_token_id)
|
1055 |
+
if len(box) == 4:
|
1056 |
+
break
|
1057 |
+
iou = 0
|
1058 |
+
if len(box) == 4:
|
1059 |
+
box = np.array(box).astype(float) / 1000
|
1060 |
+
iou = get_iou(box, gt_box)
|
1061 |
+
ious.append(iou)
|
1062 |
+
else:
|
1063 |
+
tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
|
1064 |
+
tqdm.write(f"prompt: {prompt}")
|
1065 |
+
if iou >= 0.5:
|
1066 |
+
correct += 1
|
1067 |
+
pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
|
1068 |
+
|
1069 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1070 |
+
f.write(json.dumps([total, correct]))
|
1071 |
+
if world_size > 1:
|
1072 |
+
torch.distributed.barrier()
|
1073 |
+
if rank == 0:
|
1074 |
+
total = 0
|
1075 |
+
correct = 0
|
1076 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1077 |
+
for rank_i in range(world_size):
|
1078 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1079 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1080 |
+
total += total_part
|
1081 |
+
correct += correct_part
|
1082 |
+
score = correct / total
|
1083 |
+
print("score:", score)
|
1084 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1085 |
+
pass
|
1086 |
+
else:
|
1087 |
+
score = 0.0
|
1088 |
+
if world_size > 1:
|
1089 |
+
torch.distributed.barrier()
|
1090 |
+
return score
|
1091 |
+
|
1092 |
+
|
1093 |
+
if __name__ == "__main__":
|
1094 |
+
main()
|
open_flamingo/open_flamingo/eval/evaluate2.py
ADDED
@@ -0,0 +1,1113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
|
11 |
+
import more_itertools
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
15 |
+
from eval_datasets import VQADataset, GQADataset
|
16 |
+
from tqdm import tqdm
|
17 |
+
from collections import Counter
|
18 |
+
|
19 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
20 |
+
from open_flamingo.eval.classification import (
|
21 |
+
compute_per_sample_probs,
|
22 |
+
compute_per_sample_loss,
|
23 |
+
)
|
24 |
+
from open_flamingo.eval.imagenet_utils import (
|
25 |
+
openai_imagenet_classnames,
|
26 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
27 |
+
)
|
28 |
+
|
29 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
30 |
+
from PIL import Image
|
31 |
+
from io import BytesIO
|
32 |
+
import base64
|
33 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
34 |
+
import string
|
35 |
+
from lavis.datasets.builders import load_dataset
|
36 |
+
|
37 |
+
|
38 |
+
def get_iou(box1, box2):
|
39 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
40 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
41 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
42 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
43 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
44 |
+
union = area_box1 + area_box2 - intersection
|
45 |
+
iou = intersection / union if union > 0 else 0
|
46 |
+
return iou
|
47 |
+
|
48 |
+
def expand2square(pil_img, background_color):
|
49 |
+
width, height = pil_img.size
|
50 |
+
if width == height:
|
51 |
+
return pil_img
|
52 |
+
elif width > height:
|
53 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
54 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
55 |
+
return result
|
56 |
+
else:
|
57 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
58 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
59 |
+
return result
|
60 |
+
|
61 |
+
parser = argparse.ArgumentParser()
|
62 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
63 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
64 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
65 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
66 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
67 |
+
parser.add_argument(
|
68 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
69 |
+
)
|
70 |
+
|
71 |
+
# Trial arguments
|
72 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
73 |
+
parser.add_argument(
|
74 |
+
"--num_trials",
|
75 |
+
type=int,
|
76 |
+
default=1,
|
77 |
+
help="Number of trials to run for each shot using different demonstrations",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--trial_seeds",
|
81 |
+
nargs="+",
|
82 |
+
default=[0],
|
83 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
87 |
+
)
|
88 |
+
|
89 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
90 |
+
|
91 |
+
# Per-dataset evaluation flags
|
92 |
+
parser.add_argument(
|
93 |
+
"--eval_coco",
|
94 |
+
action="store_true",
|
95 |
+
default=False,
|
96 |
+
help="Whether to evaluate on COCO.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--eval_vqav2",
|
100 |
+
action="store_true",
|
101 |
+
default=False,
|
102 |
+
help="Whether to evaluate on VQAV2.",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--eval_ok_vqa",
|
106 |
+
action="store_true",
|
107 |
+
default=False,
|
108 |
+
help="Whether to evaluate on OK-VQA.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--eval_imagenet",
|
112 |
+
action="store_true",
|
113 |
+
default=False,
|
114 |
+
help="Whether to evaluate on ImageNet.",
|
115 |
+
)
|
116 |
+
|
117 |
+
parser.add_argument(
|
118 |
+
"--eval_flickr30",
|
119 |
+
action="store_true",
|
120 |
+
default=False,
|
121 |
+
help="Whether to evaluate on Flickr30.",
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--eval_refcoco",
|
126 |
+
action="store_true",
|
127 |
+
default=False,
|
128 |
+
help="Whether to evaluate on RefCOCO.",
|
129 |
+
)
|
130 |
+
|
131 |
+
# Dataset arguments
|
132 |
+
|
133 |
+
## Flickr30 Dataset
|
134 |
+
parser.add_argument(
|
135 |
+
"--flickr_image_dir_path",
|
136 |
+
type=str,
|
137 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
138 |
+
default=None,
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--flickr_annotations_json_path",
|
142 |
+
type=str,
|
143 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
144 |
+
default=None,
|
145 |
+
)
|
146 |
+
|
147 |
+
## COCO Dataset
|
148 |
+
parser.add_argument(
|
149 |
+
"--coco_image_dir_path",
|
150 |
+
type=str,
|
151 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
152 |
+
default=None,
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--coco_annotations_json_path",
|
156 |
+
type=str,
|
157 |
+
default=None,
|
158 |
+
)
|
159 |
+
|
160 |
+
## VQAV2 Dataset
|
161 |
+
parser.add_argument(
|
162 |
+
"--vqav2_image_dir_path",
|
163 |
+
type=str,
|
164 |
+
default=None,
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--vqav2_questions_json_path",
|
168 |
+
type=str,
|
169 |
+
default=None,
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--vqav2_annotations_json_path",
|
173 |
+
type=str,
|
174 |
+
default=None,
|
175 |
+
)
|
176 |
+
|
177 |
+
## OK-VQA Dataset
|
178 |
+
parser.add_argument(
|
179 |
+
"--ok_vqa_image_dir_path",
|
180 |
+
type=str,
|
181 |
+
help="Path to the vqav2/train2014 directory.",
|
182 |
+
default=None,
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--ok_vqa_questions_json_path",
|
186 |
+
type=str,
|
187 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
188 |
+
default=None,
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--ok_vqa_annotations_json_path",
|
192 |
+
type=str,
|
193 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
194 |
+
default=None,
|
195 |
+
)
|
196 |
+
|
197 |
+
## Imagenet dataset
|
198 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
199 |
+
|
200 |
+
## RefCOCO dataset
|
201 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
202 |
+
|
203 |
+
parser.add_argument(
|
204 |
+
"--add_visual_grounding",
|
205 |
+
default=False,
|
206 |
+
action="store_true",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--location_token_num",
|
210 |
+
default=1000,
|
211 |
+
type=int,
|
212 |
+
)
|
213 |
+
# distributed training
|
214 |
+
parser.add_argument(
|
215 |
+
"--dist-url",
|
216 |
+
default="env://",
|
217 |
+
type=str,
|
218 |
+
help="url used to set up distributed training",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--horovod",
|
225 |
+
default=False,
|
226 |
+
action="store_true",
|
227 |
+
help="Use horovod for distributed training.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--no-set-device-rank",
|
231 |
+
default=False,
|
232 |
+
action="store_true",
|
233 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--dist",
|
237 |
+
default=False,
|
238 |
+
action="store_true",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--lora",
|
242 |
+
default=False,
|
243 |
+
action="store_true",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--lora_r",
|
247 |
+
default=16,
|
248 |
+
type=int,
|
249 |
+
required=False,
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--legacy",
|
253 |
+
default=False,
|
254 |
+
action="store_true",
|
255 |
+
)
|
256 |
+
parser.add_argument(
|
257 |
+
"--special",
|
258 |
+
default=False,
|
259 |
+
action="store_true",
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--id",
|
263 |
+
default=0,
|
264 |
+
type=int,
|
265 |
+
required=False,
|
266 |
+
)
|
267 |
+
|
268 |
+
parser.add_argument(
|
269 |
+
"--eval_gqa",
|
270 |
+
default=False,
|
271 |
+
action="store_true",
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--use_sam",
|
275 |
+
default=None,
|
276 |
+
type=str,
|
277 |
+
required=False,
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--add_visual_token",
|
281 |
+
default=False,
|
282 |
+
action="store_true",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--use_format_v2",
|
286 |
+
default=False,
|
287 |
+
action="store_true",
|
288 |
+
)
|
289 |
+
|
290 |
+
|
291 |
+
class OKVQAPostProcess():
|
292 |
+
def __init__(self):
|
293 |
+
self._lemmatizer = None
|
294 |
+
|
295 |
+
def _lemmatize(self, answers):
|
296 |
+
def apply(answer):
|
297 |
+
doc = self.lemmatizer(answer)
|
298 |
+
|
299 |
+
words = []
|
300 |
+
for token in doc:
|
301 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
302 |
+
words.append(token.lemma_)
|
303 |
+
else:
|
304 |
+
words.append(token.text)
|
305 |
+
answer = " ".join(words)
|
306 |
+
|
307 |
+
return answer
|
308 |
+
|
309 |
+
return [apply(answer) for answer in answers]
|
310 |
+
|
311 |
+
@property
|
312 |
+
def lemmatizer(self):
|
313 |
+
if self._lemmatizer is None:
|
314 |
+
try:
|
315 |
+
import spacy
|
316 |
+
|
317 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
318 |
+
except ImportError:
|
319 |
+
logging.error(
|
320 |
+
"""
|
321 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
322 |
+
python -m spacy download en_core_web_sm
|
323 |
+
OR
|
324 |
+
import spacy.cli
|
325 |
+
spacy.cli.download("en_core_web_sm")
|
326 |
+
"""
|
327 |
+
)
|
328 |
+
exit(1)
|
329 |
+
|
330 |
+
return self._lemmatizer
|
331 |
+
|
332 |
+
|
333 |
+
def main():
|
334 |
+
args = parser.parse_args()
|
335 |
+
if args.dist:
|
336 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
337 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
338 |
+
device_id = init_distributed_device(args)
|
339 |
+
else:
|
340 |
+
args.rank = 0
|
341 |
+
args.world_size = 1
|
342 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
343 |
+
|
344 |
+
if "sam" in args.checkpoint_path:
|
345 |
+
args.use_sam = "vit_l"
|
346 |
+
if (
|
347 |
+
"ground" in args.checkpoint_path or
|
348 |
+
"all" in args.checkpoint_path or
|
349 |
+
"sam" in args.checkpoint_path
|
350 |
+
):
|
351 |
+
args.add_visual_grounding = True
|
352 |
+
|
353 |
+
if "visual" in args.checkpoint_path:
|
354 |
+
args.add_visual_token = True
|
355 |
+
if "formatV2" in args.checkpoint_path:
|
356 |
+
args.use_format_v2 = True
|
357 |
+
|
358 |
+
# load model
|
359 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
360 |
+
args.vision_encoder_path,
|
361 |
+
args.vision_encoder_pretrained,
|
362 |
+
args.lm_path,
|
363 |
+
args.lm_tokenizer_path,
|
364 |
+
add_visual_grounding=args.add_visual_grounding,
|
365 |
+
location_token_num=args.location_token_num,
|
366 |
+
lora=args.lora,
|
367 |
+
lora_r=args.lora_r,
|
368 |
+
use_sam=args.use_sam,
|
369 |
+
add_visual_token=args.add_visual_token,
|
370 |
+
use_format_v2=args.use_format_v2,
|
371 |
+
)
|
372 |
+
flamingo.use_format_v2 = args.use_format_v2
|
373 |
+
if args.special:
|
374 |
+
flamingo.special = True
|
375 |
+
else:
|
376 |
+
flamingo.special = False
|
377 |
+
if args.legacy:
|
378 |
+
flamingo.legacy = True
|
379 |
+
print("use legacy evaluation")
|
380 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
381 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
382 |
+
if args.rank == 0:
|
383 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
384 |
+
print("step:", flamingo.step_num)
|
385 |
+
print("expr:", flamingo.expr_name)
|
386 |
+
print("use format v2:", flamingo.use_format_v2)
|
387 |
+
print(args)
|
388 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
389 |
+
model_state_dict = {}
|
390 |
+
for key in checkpoint["model_state_dict"].keys():
|
391 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
392 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
393 |
+
# previous checkpoint has some unnecessary weights
|
394 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
395 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
396 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
397 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
398 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
399 |
+
results = defaultdict(list)
|
400 |
+
if args.eval_coco:
|
401 |
+
print("Evaluating on COCO...")
|
402 |
+
for shot in args.shots:
|
403 |
+
scores = []
|
404 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
405 |
+
cider_score = evaluate_coco_flickr(
|
406 |
+
model=flamingo,
|
407 |
+
tokenizer=tokenizer,
|
408 |
+
image_processor=image_processor,
|
409 |
+
batch_size=args.batch_size,
|
410 |
+
image_dir_path=args.coco_image_dir_path,
|
411 |
+
annotations_json_path=args.coco_annotations_json_path,
|
412 |
+
device=args.device,
|
413 |
+
seed=seed,
|
414 |
+
vis_embed_size=vis_embed_size,
|
415 |
+
rank=args.rank,
|
416 |
+
world_size=args.world_size,
|
417 |
+
id=args.id,
|
418 |
+
)
|
419 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
420 |
+
scores.append(cider_score)
|
421 |
+
print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
|
422 |
+
results["coco"].append(
|
423 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
424 |
+
)
|
425 |
+
|
426 |
+
if args.eval_ok_vqa:
|
427 |
+
print("Evaluating on OK-VQA...")
|
428 |
+
for shot in args.shots:
|
429 |
+
scores = []
|
430 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
431 |
+
ok_vqa_score = evaluate_vqa(
|
432 |
+
model=flamingo,
|
433 |
+
tokenizer=tokenizer,
|
434 |
+
image_processor=image_processor,
|
435 |
+
batch_size=args.batch_size,
|
436 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
437 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
438 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
439 |
+
vqa_dataset="ok_vqa",
|
440 |
+
vis_embed_size=vis_embed_size,
|
441 |
+
rank=args.rank,
|
442 |
+
world_size=args.world_size,
|
443 |
+
id=args.id,
|
444 |
+
)
|
445 |
+
results["ok_vqa"].append(
|
446 |
+
{"shots": shot, "score": ok_vqa_score}
|
447 |
+
)
|
448 |
+
|
449 |
+
if args.eval_vqav2:
|
450 |
+
print("Evaluating on VQAv2...")
|
451 |
+
for shot in args.shots:
|
452 |
+
scores = []
|
453 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
454 |
+
vqa_score = evaluate_vqa(
|
455 |
+
model=flamingo,
|
456 |
+
tokenizer=tokenizer,
|
457 |
+
image_processor=image_processor,
|
458 |
+
batch_size=args.batch_size,
|
459 |
+
image_dir_path=args.vqav2_image_dir_path,
|
460 |
+
questions_json_path=args.vqav2_questions_json_path,
|
461 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
462 |
+
vqa_dataset="vqa",
|
463 |
+
vis_embed_size=vis_embed_size,
|
464 |
+
rank=args.rank,
|
465 |
+
world_size=args.world_size,
|
466 |
+
id=args.id,
|
467 |
+
)
|
468 |
+
results["vqav2"].append(
|
469 |
+
{"shots": shot, "score": vqa_score}
|
470 |
+
)
|
471 |
+
|
472 |
+
if args.eval_gqa:
|
473 |
+
print("Evaluating on GQA...")
|
474 |
+
for shot in args.shots:
|
475 |
+
scores = []
|
476 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
477 |
+
vqa_score = evaluate_vqa(
|
478 |
+
model=flamingo,
|
479 |
+
tokenizer=tokenizer,
|
480 |
+
image_processor=image_processor,
|
481 |
+
batch_size=args.batch_size,
|
482 |
+
vqa_dataset="gqa",
|
483 |
+
vis_embed_size=vis_embed_size,
|
484 |
+
rank=args.rank,
|
485 |
+
world_size=args.world_size,
|
486 |
+
id=args.id,
|
487 |
+
)
|
488 |
+
results["gqa"].append(
|
489 |
+
{"shots": shot, "score": vqa_score}
|
490 |
+
)
|
491 |
+
|
492 |
+
if args.eval_imagenet:
|
493 |
+
print("Evaluating on ImageNet...")
|
494 |
+
for shot in args.shots:
|
495 |
+
scores = []
|
496 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
497 |
+
imagenet_score = evaluate_imagenet(
|
498 |
+
model=flamingo,
|
499 |
+
tokenizer=tokenizer,
|
500 |
+
image_processor=image_processor,
|
501 |
+
batch_size=args.batch_size,
|
502 |
+
num_samples=args.num_samples,
|
503 |
+
num_shots=shot,
|
504 |
+
device=args.device,
|
505 |
+
seed=seed,
|
506 |
+
imagenet_root=args.imagenet_root,
|
507 |
+
)
|
508 |
+
print(
|
509 |
+
f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
|
510 |
+
)
|
511 |
+
scores.append(imagenet_score)
|
512 |
+
print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
|
513 |
+
results["imagenet"].append(
|
514 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
515 |
+
)
|
516 |
+
|
517 |
+
if args.eval_refcoco:
|
518 |
+
print("Evaluating on RefCOCO...")
|
519 |
+
refcoco_score = evaluate_refcoco(
|
520 |
+
model=flamingo,
|
521 |
+
tokenizer=tokenizer,
|
522 |
+
image_processor=image_processor,
|
523 |
+
batch_size=args.batch_size,
|
524 |
+
device=args.device,
|
525 |
+
tsvfile=args.refcoco_tsvfile,
|
526 |
+
vis_embed_size=vis_embed_size,
|
527 |
+
rank=args.rank,
|
528 |
+
world_size=args.world_size,
|
529 |
+
id=args.id,
|
530 |
+
)
|
531 |
+
results["refcoco"].append(
|
532 |
+
{"score": refcoco_score}
|
533 |
+
)
|
534 |
+
|
535 |
+
def prepare_batch_images(batch, image_processor):
|
536 |
+
batch_images = None
|
537 |
+
for b in batch:
|
538 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
539 |
+
if batch_images is None:
|
540 |
+
batch_images = b_image
|
541 |
+
else:
|
542 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
543 |
+
return batch_images
|
544 |
+
|
545 |
+
def get_outputs(
|
546 |
+
model,
|
547 |
+
batch_images,
|
548 |
+
attention_mask,
|
549 |
+
max_generation_length,
|
550 |
+
min_generation_length,
|
551 |
+
num_beams,
|
552 |
+
length_penalty,
|
553 |
+
input_ids,
|
554 |
+
image_start_index_list=None,
|
555 |
+
image_nums=None,
|
556 |
+
bad_words_ids=None,
|
557 |
+
add_grounding_to_prompt=False,
|
558 |
+
):
|
559 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
560 |
+
outputs = model.generate(
|
561 |
+
batch_images,
|
562 |
+
input_ids,
|
563 |
+
attention_mask=attention_mask,
|
564 |
+
max_new_tokens=max_generation_length,
|
565 |
+
min_length=min_generation_length,
|
566 |
+
num_beams=num_beams,
|
567 |
+
length_penalty=length_penalty,
|
568 |
+
image_start_index_list=image_start_index_list,
|
569 |
+
image_nums=image_nums,
|
570 |
+
bad_words_ids=bad_words_ids,
|
571 |
+
add_grounding_to_prompt=add_grounding_to_prompt,
|
572 |
+
)
|
573 |
+
|
574 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
575 |
+
return outputs
|
576 |
+
|
577 |
+
|
578 |
+
def evaluate_coco_flickr(
|
579 |
+
model,
|
580 |
+
tokenizer,
|
581 |
+
image_processor,
|
582 |
+
batch_size,
|
583 |
+
image_dir_path,
|
584 |
+
annotations_json_path,
|
585 |
+
seed=42,
|
586 |
+
max_generation_length=20,
|
587 |
+
num_beams=1,
|
588 |
+
length_penalty=-2.0,
|
589 |
+
device=-1,
|
590 |
+
is_flickr=False,
|
591 |
+
vis_embed_size=None,
|
592 |
+
rank=0,
|
593 |
+
world_size=1,
|
594 |
+
id=0,
|
595 |
+
):
|
596 |
+
"""Evaluate a model on COCO dataset.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
model (nn.Module): model to evaluate
|
600 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
601 |
+
image_processor : image processor for the model
|
602 |
+
batch_size (int): batch size
|
603 |
+
image_dir_path (str, optional): path to the directory containing the images.
|
604 |
+
annotations_json_path (str, optional): path to the json file containing the annotations.
|
605 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
606 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
|
607 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
608 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
609 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
|
610 |
+
query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
|
611 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
612 |
+
device (int, optional): device to use. Defaults to -1.
|
613 |
+
num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
|
614 |
+
is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
|
615 |
+
|
616 |
+
Returns:
|
617 |
+
float: CIDEr score
|
618 |
+
|
619 |
+
"""
|
620 |
+
# eval_dataset = COCOFlickrDataset(
|
621 |
+
# image_dir_path=image_dir_path,
|
622 |
+
# annotations_path=annotations_json_path,
|
623 |
+
# is_flickr=is_flickr,
|
624 |
+
# )
|
625 |
+
coco_dataset = load_dataset("coco_caption")
|
626 |
+
eval_dataset = coco_dataset["test"]
|
627 |
+
|
628 |
+
|
629 |
+
model.eval().cuda()
|
630 |
+
predictions = defaultdict()
|
631 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
632 |
+
# if "peft" in lang_encoder_name:
|
633 |
+
# lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
634 |
+
try:
|
635 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
636 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
637 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
638 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
639 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
640 |
+
except:
|
641 |
+
pass
|
642 |
+
|
643 |
+
def get_prompt(sample):
|
644 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
645 |
+
|
646 |
+
tokenizer.padding_side = "left"
|
647 |
+
cnt = 0
|
648 |
+
if world_size > 1:
|
649 |
+
torch.distributed.barrier()
|
650 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
651 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
652 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
653 |
+
)):
|
654 |
+
if ii % world_size != rank:
|
655 |
+
continue
|
656 |
+
cnt += len(batch)
|
657 |
+
batch_images = prepare_batch_images(
|
658 |
+
batch=batch,
|
659 |
+
image_processor=image_processor,
|
660 |
+
).cuda()
|
661 |
+
batch_text = [get_prompt(s) for s in batch]
|
662 |
+
encodings = tokenizer(
|
663 |
+
batch_text,
|
664 |
+
padding="longest",
|
665 |
+
truncation=True,
|
666 |
+
return_tensors="pt",
|
667 |
+
max_length=2000,
|
668 |
+
)
|
669 |
+
input_ids = encodings["input_ids"].cuda()
|
670 |
+
attention_mask = encodings["attention_mask"].cuda()
|
671 |
+
skip_special_tokens = False
|
672 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
673 |
+
if rank == 0:
|
674 |
+
tqdm.write("use legacy model")
|
675 |
+
skip_special_tokens = True
|
676 |
+
for i in range(len(input_ids)):
|
677 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
678 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
679 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
680 |
+
input_ids[i, media_token_index] = pad_token_id
|
681 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
682 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
683 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
684 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
685 |
+
image_nums = [1] * len(input_ids)
|
686 |
+
if "llama" in lang_encoder_name:
|
687 |
+
attention_mask[input_ids == 0] = 0
|
688 |
+
outputs = get_outputs(
|
689 |
+
model=model,
|
690 |
+
batch_images=batch_images,
|
691 |
+
attention_mask=attention_mask,
|
692 |
+
max_generation_length=30,
|
693 |
+
min_generation_length=8,
|
694 |
+
num_beams=5,
|
695 |
+
length_penalty=0,
|
696 |
+
input_ids=input_ids,
|
697 |
+
image_start_index_list=image_start_index_list,
|
698 |
+
image_nums=image_nums,
|
699 |
+
)
|
700 |
+
new_predictions = [
|
701 |
+
postprocess_captioning_generation(out).replace('"', "")
|
702 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
703 |
+
]
|
704 |
+
# if rank == 0:
|
705 |
+
# tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
|
706 |
+
|
707 |
+
for i, sample in enumerate(batch):
|
708 |
+
predictions[int(sample["image_id"])] = {
|
709 |
+
"caption": new_predictions[i],
|
710 |
+
}
|
711 |
+
results_path = (
|
712 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
713 |
+
if is_flickr
|
714 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
715 |
+
)
|
716 |
+
with open(results_path, "w") as f:
|
717 |
+
f.write(
|
718 |
+
json.dumps(
|
719 |
+
[
|
720 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
721 |
+
for k in predictions
|
722 |
+
],
|
723 |
+
indent=2,
|
724 |
+
)
|
725 |
+
)
|
726 |
+
print("save to", results_path)
|
727 |
+
del predictions
|
728 |
+
time.sleep(10)
|
729 |
+
if world_size > 1:
|
730 |
+
torch.distributed.barrier()
|
731 |
+
if rank == 0:
|
732 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
733 |
+
predictions = []
|
734 |
+
for rank_i in range(world_size):
|
735 |
+
part_results_path = (
|
736 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
737 |
+
if is_flickr
|
738 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
739 |
+
)
|
740 |
+
print("load", part_results_path)
|
741 |
+
predictions.extend(json.load(open(part_results_path)))
|
742 |
+
os.remove(part_results_path)
|
743 |
+
print("num:", len(predictions))
|
744 |
+
results_path = (
|
745 |
+
f"flickrresults_{lang_encoder_name}.json"
|
746 |
+
if is_flickr
|
747 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
748 |
+
)
|
749 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
750 |
+
|
751 |
+
metrics = compute_cider(
|
752 |
+
result_path=results_path,
|
753 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
754 |
+
)
|
755 |
+
os.makedirs("eval_results", exist_ok=True)
|
756 |
+
acc = metrics["CIDEr"]
|
757 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
758 |
+
f.write(json.dumps(predictions, indent=2))
|
759 |
+
|
760 |
+
# delete the temporary file
|
761 |
+
os.remove(results_path)
|
762 |
+
else:
|
763 |
+
metrics = {}
|
764 |
+
metrics["CIDEr"] = 0.0
|
765 |
+
|
766 |
+
return metrics["CIDEr"]
|
767 |
+
|
768 |
+
|
769 |
+
def evaluate_vqa(
|
770 |
+
model,
|
771 |
+
tokenizer,
|
772 |
+
image_processor,
|
773 |
+
batch_size,
|
774 |
+
image_dir_path=None,
|
775 |
+
questions_json_path=None,
|
776 |
+
annotations_json_path=None,
|
777 |
+
vqa_dataset="vqa",
|
778 |
+
vis_embed_size=None,
|
779 |
+
rank=0,
|
780 |
+
world_size=1,
|
781 |
+
id=0,
|
782 |
+
):
|
783 |
+
"""
|
784 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
model (nn.Module): model to evaluate
|
788 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
789 |
+
image_processor : image processor for the model
|
790 |
+
batch_size (int): batch size
|
791 |
+
image_dir_path (str): path to image directory
|
792 |
+
questions_json_path (str): path to questions json file
|
793 |
+
annotations_json_path (str): path to annotations json file
|
794 |
+
seed (int, optional): random seed. Defaults to 42.
|
795 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
796 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
797 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
798 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
799 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
800 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
801 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
802 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
803 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
804 |
+
Returns:
|
805 |
+
float: accuracy score
|
806 |
+
"""
|
807 |
+
if world_size > 1:
|
808 |
+
torch.distributed.barrier()
|
809 |
+
if vqa_dataset == "gqa":
|
810 |
+
eval_dataset = GQADataset()
|
811 |
+
else:
|
812 |
+
eval_dataset = VQADataset(
|
813 |
+
image_dir_path=image_dir_path,
|
814 |
+
question_path=questions_json_path,
|
815 |
+
annotations_path=annotations_json_path,
|
816 |
+
vqa_dataset=vqa_dataset,
|
817 |
+
)
|
818 |
+
postprocessor = OKVQAPostProcess()
|
819 |
+
try:
|
820 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
821 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
822 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
823 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
824 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
825 |
+
except:
|
826 |
+
pass
|
827 |
+
def get_prompt(sample):
|
828 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
829 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
830 |
+
|
831 |
+
model.eval().cuda()
|
832 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
833 |
+
if "peft" in lang_encoder_name:
|
834 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
835 |
+
predictions = []
|
836 |
+
tokenizer.padding_side = "left"
|
837 |
+
if world_size > 1:
|
838 |
+
torch.distributed.barrier()
|
839 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
840 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
841 |
+
)):
|
842 |
+
|
843 |
+
# 393372004
|
844 |
+
# 262227001
|
845 |
+
# 262161007
|
846 |
+
if int(batch[0]['question_id']) != 393372004:
|
847 |
+
continue
|
848 |
+
|
849 |
+
print(batch[0]["question"])
|
850 |
+
batch[0]["question"] = "Is the girl on the left smiling?"
|
851 |
+
|
852 |
+
if ii % world_size != rank:
|
853 |
+
continue
|
854 |
+
batch_images = prepare_batch_images(
|
855 |
+
batch=batch,
|
856 |
+
image_processor=image_processor,
|
857 |
+
).cuda()
|
858 |
+
batch_text = [get_prompt(s) for s in batch]
|
859 |
+
encodings = tokenizer(
|
860 |
+
batch_text,
|
861 |
+
return_tensors="pt",
|
862 |
+
padding="longest",
|
863 |
+
truncation=True,
|
864 |
+
max_length=2000,
|
865 |
+
)
|
866 |
+
input_ids = encodings["input_ids"].cuda()
|
867 |
+
attention_mask = encodings["attention_mask"].cuda()
|
868 |
+
skip_special_tokens = True
|
869 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
870 |
+
if rank == 0:
|
871 |
+
tqdm.write("use legacy model")
|
872 |
+
for i in range(len(input_ids)):
|
873 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
874 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
875 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
876 |
+
input_ids[i, media_token_index] = pad_token_id
|
877 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
878 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
879 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
880 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
881 |
+
image_nums = [1] * len(input_ids)
|
882 |
+
if "llama" in lang_encoder_name:
|
883 |
+
attention_mask[input_ids == 0] = 0
|
884 |
+
|
885 |
+
for i in range(262, input_ids.shape[-1]-2):
|
886 |
+
print("-"*80)
|
887 |
+
print(tokenizer.batch_decode(input_ids[..., 258:i]))
|
888 |
+
outputs = get_outputs(
|
889 |
+
model=model,
|
890 |
+
batch_images=batch_images,
|
891 |
+
attention_mask=attention_mask[..., :i],
|
892 |
+
max_generation_length=1,
|
893 |
+
min_generation_length=1,
|
894 |
+
num_beams=1,
|
895 |
+
length_penalty=0,
|
896 |
+
input_ids=input_ids[..., :i],
|
897 |
+
image_start_index_list=image_start_index_list,
|
898 |
+
image_nums=image_nums,
|
899 |
+
add_grounding_to_prompt=True,
|
900 |
+
)
|
901 |
+
batch[0]["image"].save(f"{batch[0]['question_id']}.jpg")
|
902 |
+
exit()
|
903 |
+
# postprocess begin
|
904 |
+
new_predictions = [
|
905 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
906 |
+
]
|
907 |
+
if vqa_dataset == "ok_vqa":
|
908 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
909 |
+
if model.special:
|
910 |
+
for i in range(len(new_predictions)):
|
911 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
912 |
+
if answer in new_predictions[i]:
|
913 |
+
new_predictions[i] = answer
|
914 |
+
break
|
915 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
916 |
+
new_predictions[i] = answer
|
917 |
+
break
|
918 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
919 |
+
new_predictions[i] = answer
|
920 |
+
break
|
921 |
+
|
922 |
+
# if rank == 0:
|
923 |
+
# tqdm.write(f"{image_nums} {image_start_index_list}")
|
924 |
+
# for i in range(1):
|
925 |
+
# tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
|
926 |
+
# tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
|
927 |
+
# tqdm.write("model output: " + new_predictions[i])
|
928 |
+
|
929 |
+
predictions.extend(
|
930 |
+
[
|
931 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
932 |
+
for p, sample in zip(new_predictions, batch)
|
933 |
+
]
|
934 |
+
)
|
935 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
936 |
+
f.write(json.dumps(predictions))
|
937 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
938 |
+
|
939 |
+
time.sleep(10)
|
940 |
+
if world_size > 1:
|
941 |
+
torch.distributed.barrier()
|
942 |
+
if rank == 0:
|
943 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
944 |
+
predictions = []
|
945 |
+
for rank_i in range(world_size):
|
946 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
947 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
948 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
949 |
+
print("num:", len(predictions))
|
950 |
+
# save the predictions to a temporary file
|
951 |
+
random_uuid = str(uuid.uuid4())
|
952 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
953 |
+
f.write(json.dumps(predictions, indent=4))
|
954 |
+
|
955 |
+
if vqa_dataset == "gqa":
|
956 |
+
acc = compute_gqa_accuracy(predictions)
|
957 |
+
else:
|
958 |
+
acc = compute_vqa_accuracy(
|
959 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
960 |
+
questions_json_path,
|
961 |
+
annotations_json_path,
|
962 |
+
vqa_dataset=vqa_dataset,
|
963 |
+
)
|
964 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
965 |
+
os.makedirs("eval_results", exist_ok=True)
|
966 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
967 |
+
f.write(json.dumps(predictions, indent=2))
|
968 |
+
|
969 |
+
# delete the temporary file
|
970 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
971 |
+
else:
|
972 |
+
time.sleep(5)
|
973 |
+
acc = 0.0
|
974 |
+
if world_size > 1:
|
975 |
+
torch.distributed.barrier()
|
976 |
+
return acc
|
977 |
+
|
978 |
+
|
979 |
+
def evaluate_refcoco(
|
980 |
+
model,
|
981 |
+
tokenizer,
|
982 |
+
image_processor,
|
983 |
+
batch_size,
|
984 |
+
tsvfile,
|
985 |
+
max_generation_length=20,
|
986 |
+
num_beams=3,
|
987 |
+
length_penalty=-2.0,
|
988 |
+
device=-1,
|
989 |
+
vis_embed_size=None,
|
990 |
+
rank=0,
|
991 |
+
world_size=1,
|
992 |
+
id=0,
|
993 |
+
):
|
994 |
+
model.eval().cuda()
|
995 |
+
loc_token_ids = []
|
996 |
+
for i in range(1000):
|
997 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
998 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
999 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
1000 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1001 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
1002 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
1003 |
+
all_ids = set(range(model.lang_encoder.lm_head.out_features))
|
1004 |
+
bad_words_ids = list(all_ids - set(loc_token_ids))
|
1005 |
+
bad_words_ids = [[b] for b in bad_words_ids]
|
1006 |
+
min_loc_token_id = min(loc_token_ids)
|
1007 |
+
max_loc_token_id = max(loc_token_ids)
|
1008 |
+
total = 0
|
1009 |
+
correct = 0
|
1010 |
+
ious = []
|
1011 |
+
if "refcocog" in tsvfile:
|
1012 |
+
dataset_name = "refcocog"
|
1013 |
+
elif "refcocoplus" in tsvfile:
|
1014 |
+
dataset_name = "refcocoplus"
|
1015 |
+
else:
|
1016 |
+
dataset_name = "refcoco"
|
1017 |
+
with open(tsvfile, "r") as f:
|
1018 |
+
lines = f.readlines()
|
1019 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
1020 |
+
for ii, line in enumerate(pbar):
|
1021 |
+
if ii % world_size != rank:
|
1022 |
+
continue
|
1023 |
+
total += 1
|
1024 |
+
line = line.rstrip()
|
1025 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
1026 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
1027 |
+
gt_box = np.array(list(map(float, region_coord.split(","))))
|
1028 |
+
width = image.width
|
1029 |
+
height = image.height
|
1030 |
+
image = image.resize((224, 224))
|
1031 |
+
gt_box /= np.array([width, height, width, height])
|
1032 |
+
# extra = abs(height - width) // 2
|
1033 |
+
# if width > height:
|
1034 |
+
# gt_box += np.array([0, extra, 0, extra])
|
1035 |
+
# else:
|
1036 |
+
# gt_box += np.array([extra, 0, extra, 0])
|
1037 |
+
# size = max(width, height)
|
1038 |
+
# gt_box /= np.array([size, size, size, size])
|
1039 |
+
# image = expand2square(image, (255, 255, 255))
|
1040 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1041 |
+
if model.use_format_v2:
|
1042 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#loc#|>"]
|
1043 |
+
else:
|
1044 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}"]
|
1045 |
+
encodings = tokenizer(
|
1046 |
+
prompt,
|
1047 |
+
padding="longest",
|
1048 |
+
truncation=True,
|
1049 |
+
return_tensors="pt",
|
1050 |
+
max_length=2000,
|
1051 |
+
)
|
1052 |
+
input_ids = encodings["input_ids"]
|
1053 |
+
attention_mask = encodings["attention_mask"]
|
1054 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1055 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1056 |
+
image_nums = [1] * len(input_ids)
|
1057 |
+
outputs = get_outputs(
|
1058 |
+
model=model,
|
1059 |
+
batch_images=batch_images.cuda(),
|
1060 |
+
attention_mask=attention_mask.cuda(),
|
1061 |
+
max_generation_length=5,
|
1062 |
+
min_generation_length=4,
|
1063 |
+
num_beams=1,
|
1064 |
+
length_penalty=1.0,
|
1065 |
+
input_ids=input_ids.cuda(),
|
1066 |
+
bad_words_ids=bad_words_ids,
|
1067 |
+
image_start_index_list=image_start_index_list,
|
1068 |
+
image_nums=image_nums,
|
1069 |
+
)
|
1070 |
+
box = []
|
1071 |
+
for o in outputs[0]:
|
1072 |
+
if o >= min_loc_token_id and o <= max_loc_token_id:
|
1073 |
+
box.append(o.item() - min_loc_token_id)
|
1074 |
+
if len(box) == 4:
|
1075 |
+
break
|
1076 |
+
iou = 0
|
1077 |
+
if len(box) == 4:
|
1078 |
+
box = np.array(box).astype(float) / 1000
|
1079 |
+
iou = get_iou(box, gt_box)
|
1080 |
+
ious.append(iou)
|
1081 |
+
else:
|
1082 |
+
tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
|
1083 |
+
tqdm.write(f"prompt: {prompt}")
|
1084 |
+
if iou >= 0.5:
|
1085 |
+
correct += 1
|
1086 |
+
pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
|
1087 |
+
|
1088 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1089 |
+
f.write(json.dumps([total, correct]))
|
1090 |
+
if world_size > 1:
|
1091 |
+
torch.distributed.barrier()
|
1092 |
+
if rank == 0:
|
1093 |
+
total = 0
|
1094 |
+
correct = 0
|
1095 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1096 |
+
for rank_i in range(world_size):
|
1097 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1098 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1099 |
+
total += total_part
|
1100 |
+
correct += correct_part
|
1101 |
+
score = correct / total
|
1102 |
+
print("score:", score)
|
1103 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1104 |
+
pass
|
1105 |
+
else:
|
1106 |
+
score = 0.0
|
1107 |
+
if world_size > 1:
|
1108 |
+
torch.distributed.barrier()
|
1109 |
+
return score
|
1110 |
+
|
1111 |
+
|
1112 |
+
if __name__ == "__main__":
|
1113 |
+
main()
|
open_flamingo/open_flamingo/eval/imagenet_utils.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
|
2 |
+
openai_imagenet_classnames = [
|
3 |
+
"tench",
|
4 |
+
"goldfish",
|
5 |
+
"great white shark",
|
6 |
+
"tiger shark",
|
7 |
+
"hammerhead shark",
|
8 |
+
"electric ray",
|
9 |
+
"stingray",
|
10 |
+
"rooster",
|
11 |
+
"hen",
|
12 |
+
"ostrich",
|
13 |
+
"brambling",
|
14 |
+
"goldfinch",
|
15 |
+
"house finch",
|
16 |
+
"junco",
|
17 |
+
"indigo bunting",
|
18 |
+
"American robin",
|
19 |
+
"bulbul",
|
20 |
+
"jay",
|
21 |
+
"magpie",
|
22 |
+
"chickadee",
|
23 |
+
"American dipper",
|
24 |
+
"kite (bird of prey)",
|
25 |
+
"bald eagle",
|
26 |
+
"vulture",
|
27 |
+
"great grey owl",
|
28 |
+
"fire salamander",
|
29 |
+
"smooth newt",
|
30 |
+
"newt",
|
31 |
+
"spotted salamander",
|
32 |
+
"axolotl",
|
33 |
+
"American bullfrog",
|
34 |
+
"tree frog",
|
35 |
+
"tailed frog",
|
36 |
+
"loggerhead sea turtle",
|
37 |
+
"leatherback sea turtle",
|
38 |
+
"mud turtle",
|
39 |
+
"terrapin",
|
40 |
+
"box turtle",
|
41 |
+
"banded gecko",
|
42 |
+
"green iguana",
|
43 |
+
"Carolina anole",
|
44 |
+
"desert grassland whiptail lizard",
|
45 |
+
"agama",
|
46 |
+
"frilled-necked lizard",
|
47 |
+
"alligator lizard",
|
48 |
+
"Gila monster",
|
49 |
+
"European green lizard",
|
50 |
+
"chameleon",
|
51 |
+
"Komodo dragon",
|
52 |
+
"Nile crocodile",
|
53 |
+
"American alligator",
|
54 |
+
"triceratops",
|
55 |
+
"worm snake",
|
56 |
+
"ring-necked snake",
|
57 |
+
"eastern hog-nosed snake",
|
58 |
+
"smooth green snake",
|
59 |
+
"kingsnake",
|
60 |
+
"garter snake",
|
61 |
+
"water snake",
|
62 |
+
"vine snake",
|
63 |
+
"night snake",
|
64 |
+
"boa constrictor",
|
65 |
+
"African rock python",
|
66 |
+
"Indian cobra",
|
67 |
+
"green mamba",
|
68 |
+
"sea snake",
|
69 |
+
"Saharan horned viper",
|
70 |
+
"eastern diamondback rattlesnake",
|
71 |
+
"sidewinder rattlesnake",
|
72 |
+
"trilobite",
|
73 |
+
"harvestman",
|
74 |
+
"scorpion",
|
75 |
+
"yellow garden spider",
|
76 |
+
"barn spider",
|
77 |
+
"European garden spider",
|
78 |
+
"southern black widow",
|
79 |
+
"tarantula",
|
80 |
+
"wolf spider",
|
81 |
+
"tick",
|
82 |
+
"centipede",
|
83 |
+
"black grouse",
|
84 |
+
"ptarmigan",
|
85 |
+
"ruffed grouse",
|
86 |
+
"prairie grouse",
|
87 |
+
"peafowl",
|
88 |
+
"quail",
|
89 |
+
"partridge",
|
90 |
+
"african grey parrot",
|
91 |
+
"macaw",
|
92 |
+
"sulphur-crested cockatoo",
|
93 |
+
"lorikeet",
|
94 |
+
"coucal",
|
95 |
+
"bee eater",
|
96 |
+
"hornbill",
|
97 |
+
"hummingbird",
|
98 |
+
"jacamar",
|
99 |
+
"toucan",
|
100 |
+
"duck",
|
101 |
+
"red-breasted merganser",
|
102 |
+
"goose",
|
103 |
+
"black swan",
|
104 |
+
"tusker",
|
105 |
+
"echidna",
|
106 |
+
"platypus",
|
107 |
+
"wallaby",
|
108 |
+
"koala",
|
109 |
+
"wombat",
|
110 |
+
"jellyfish",
|
111 |
+
"sea anemone",
|
112 |
+
"brain coral",
|
113 |
+
"flatworm",
|
114 |
+
"nematode",
|
115 |
+
"conch",
|
116 |
+
"snail",
|
117 |
+
"slug",
|
118 |
+
"sea slug",
|
119 |
+
"chiton",
|
120 |
+
"chambered nautilus",
|
121 |
+
"Dungeness crab",
|
122 |
+
"rock crab",
|
123 |
+
"fiddler crab",
|
124 |
+
"red king crab",
|
125 |
+
"American lobster",
|
126 |
+
"spiny lobster",
|
127 |
+
"crayfish",
|
128 |
+
"hermit crab",
|
129 |
+
"isopod",
|
130 |
+
"white stork",
|
131 |
+
"black stork",
|
132 |
+
"spoonbill",
|
133 |
+
"flamingo",
|
134 |
+
"little blue heron",
|
135 |
+
"great egret",
|
136 |
+
"bittern bird",
|
137 |
+
"crane bird",
|
138 |
+
"limpkin",
|
139 |
+
"common gallinule",
|
140 |
+
"American coot",
|
141 |
+
"bustard",
|
142 |
+
"ruddy turnstone",
|
143 |
+
"dunlin",
|
144 |
+
"common redshank",
|
145 |
+
"dowitcher",
|
146 |
+
"oystercatcher",
|
147 |
+
"pelican",
|
148 |
+
"king penguin",
|
149 |
+
"albatross",
|
150 |
+
"grey whale",
|
151 |
+
"killer whale",
|
152 |
+
"dugong",
|
153 |
+
"sea lion",
|
154 |
+
"Chihuahua",
|
155 |
+
"Japanese Chin",
|
156 |
+
"Maltese",
|
157 |
+
"Pekingese",
|
158 |
+
"Shih Tzu",
|
159 |
+
"King Charles Spaniel",
|
160 |
+
"Papillon",
|
161 |
+
"toy terrier",
|
162 |
+
"Rhodesian Ridgeback",
|
163 |
+
"Afghan Hound",
|
164 |
+
"Basset Hound",
|
165 |
+
"Beagle",
|
166 |
+
"Bloodhound",
|
167 |
+
"Bluetick Coonhound",
|
168 |
+
"Black and Tan Coonhound",
|
169 |
+
"Treeing Walker Coonhound",
|
170 |
+
"English foxhound",
|
171 |
+
"Redbone Coonhound",
|
172 |
+
"borzoi",
|
173 |
+
"Irish Wolfhound",
|
174 |
+
"Italian Greyhound",
|
175 |
+
"Whippet",
|
176 |
+
"Ibizan Hound",
|
177 |
+
"Norwegian Elkhound",
|
178 |
+
"Otterhound",
|
179 |
+
"Saluki",
|
180 |
+
"Scottish Deerhound",
|
181 |
+
"Weimaraner",
|
182 |
+
"Staffordshire Bull Terrier",
|
183 |
+
"American Staffordshire Terrier",
|
184 |
+
"Bedlington Terrier",
|
185 |
+
"Border Terrier",
|
186 |
+
"Kerry Blue Terrier",
|
187 |
+
"Irish Terrier",
|
188 |
+
"Norfolk Terrier",
|
189 |
+
"Norwich Terrier",
|
190 |
+
"Yorkshire Terrier",
|
191 |
+
"Wire Fox Terrier",
|
192 |
+
"Lakeland Terrier",
|
193 |
+
"Sealyham Terrier",
|
194 |
+
"Airedale Terrier",
|
195 |
+
"Cairn Terrier",
|
196 |
+
"Australian Terrier",
|
197 |
+
"Dandie Dinmont Terrier",
|
198 |
+
"Boston Terrier",
|
199 |
+
"Miniature Schnauzer",
|
200 |
+
"Giant Schnauzer",
|
201 |
+
"Standard Schnauzer",
|
202 |
+
"Scottish Terrier",
|
203 |
+
"Tibetan Terrier",
|
204 |
+
"Australian Silky Terrier",
|
205 |
+
"Soft-coated Wheaten Terrier",
|
206 |
+
"West Highland White Terrier",
|
207 |
+
"Lhasa Apso",
|
208 |
+
"Flat-Coated Retriever",
|
209 |
+
"Curly-coated Retriever",
|
210 |
+
"Golden Retriever",
|
211 |
+
"Labrador Retriever",
|
212 |
+
"Chesapeake Bay Retriever",
|
213 |
+
"German Shorthaired Pointer",
|
214 |
+
"Vizsla",
|
215 |
+
"English Setter",
|
216 |
+
"Irish Setter",
|
217 |
+
"Gordon Setter",
|
218 |
+
"Brittany dog",
|
219 |
+
"Clumber Spaniel",
|
220 |
+
"English Springer Spaniel",
|
221 |
+
"Welsh Springer Spaniel",
|
222 |
+
"Cocker Spaniel",
|
223 |
+
"Sussex Spaniel",
|
224 |
+
"Irish Water Spaniel",
|
225 |
+
"Kuvasz",
|
226 |
+
"Schipperke",
|
227 |
+
"Groenendael dog",
|
228 |
+
"Malinois",
|
229 |
+
"Briard",
|
230 |
+
"Australian Kelpie",
|
231 |
+
"Komondor",
|
232 |
+
"Old English Sheepdog",
|
233 |
+
"Shetland Sheepdog",
|
234 |
+
"collie",
|
235 |
+
"Border Collie",
|
236 |
+
"Bouvier des Flandres dog",
|
237 |
+
"Rottweiler",
|
238 |
+
"German Shepherd Dog",
|
239 |
+
"Dobermann",
|
240 |
+
"Miniature Pinscher",
|
241 |
+
"Greater Swiss Mountain Dog",
|
242 |
+
"Bernese Mountain Dog",
|
243 |
+
"Appenzeller Sennenhund",
|
244 |
+
"Entlebucher Sennenhund",
|
245 |
+
"Boxer",
|
246 |
+
"Bullmastiff",
|
247 |
+
"Tibetan Mastiff",
|
248 |
+
"French Bulldog",
|
249 |
+
"Great Dane",
|
250 |
+
"St. Bernard",
|
251 |
+
"husky",
|
252 |
+
"Alaskan Malamute",
|
253 |
+
"Siberian Husky",
|
254 |
+
"Dalmatian",
|
255 |
+
"Affenpinscher",
|
256 |
+
"Basenji",
|
257 |
+
"pug",
|
258 |
+
"Leonberger",
|
259 |
+
"Newfoundland dog",
|
260 |
+
"Great Pyrenees dog",
|
261 |
+
"Samoyed",
|
262 |
+
"Pomeranian",
|
263 |
+
"Chow Chow",
|
264 |
+
"Keeshond",
|
265 |
+
"brussels griffon",
|
266 |
+
"Pembroke Welsh Corgi",
|
267 |
+
"Cardigan Welsh Corgi",
|
268 |
+
"Toy Poodle",
|
269 |
+
"Miniature Poodle",
|
270 |
+
"Standard Poodle",
|
271 |
+
"Mexican hairless dog (xoloitzcuintli)",
|
272 |
+
"grey wolf",
|
273 |
+
"Alaskan tundra wolf",
|
274 |
+
"red wolf or maned wolf",
|
275 |
+
"coyote",
|
276 |
+
"dingo",
|
277 |
+
"dhole",
|
278 |
+
"African wild dog",
|
279 |
+
"hyena",
|
280 |
+
"red fox",
|
281 |
+
"kit fox",
|
282 |
+
"Arctic fox",
|
283 |
+
"grey fox",
|
284 |
+
"tabby cat",
|
285 |
+
"tiger cat",
|
286 |
+
"Persian cat",
|
287 |
+
"Siamese cat",
|
288 |
+
"Egyptian Mau",
|
289 |
+
"cougar",
|
290 |
+
"lynx",
|
291 |
+
"leopard",
|
292 |
+
"snow leopard",
|
293 |
+
"jaguar",
|
294 |
+
"lion",
|
295 |
+
"tiger",
|
296 |
+
"cheetah",
|
297 |
+
"brown bear",
|
298 |
+
"American black bear",
|
299 |
+
"polar bear",
|
300 |
+
"sloth bear",
|
301 |
+
"mongoose",
|
302 |
+
"meerkat",
|
303 |
+
"tiger beetle",
|
304 |
+
"ladybug",
|
305 |
+
"ground beetle",
|
306 |
+
"longhorn beetle",
|
307 |
+
"leaf beetle",
|
308 |
+
"dung beetle",
|
309 |
+
"rhinoceros beetle",
|
310 |
+
"weevil",
|
311 |
+
"fly",
|
312 |
+
"bee",
|
313 |
+
"ant",
|
314 |
+
"grasshopper",
|
315 |
+
"cricket insect",
|
316 |
+
"stick insect",
|
317 |
+
"cockroach",
|
318 |
+
"praying mantis",
|
319 |
+
"cicada",
|
320 |
+
"leafhopper",
|
321 |
+
"lacewing",
|
322 |
+
"dragonfly",
|
323 |
+
"damselfly",
|
324 |
+
"red admiral butterfly",
|
325 |
+
"ringlet butterfly",
|
326 |
+
"monarch butterfly",
|
327 |
+
"small white butterfly",
|
328 |
+
"sulphur butterfly",
|
329 |
+
"gossamer-winged butterfly",
|
330 |
+
"starfish",
|
331 |
+
"sea urchin",
|
332 |
+
"sea cucumber",
|
333 |
+
"cottontail rabbit",
|
334 |
+
"hare",
|
335 |
+
"Angora rabbit",
|
336 |
+
"hamster",
|
337 |
+
"porcupine",
|
338 |
+
"fox squirrel",
|
339 |
+
"marmot",
|
340 |
+
"beaver",
|
341 |
+
"guinea pig",
|
342 |
+
"common sorrel horse",
|
343 |
+
"zebra",
|
344 |
+
"pig",
|
345 |
+
"wild boar",
|
346 |
+
"warthog",
|
347 |
+
"hippopotamus",
|
348 |
+
"ox",
|
349 |
+
"water buffalo",
|
350 |
+
"bison",
|
351 |
+
"ram (adult male sheep)",
|
352 |
+
"bighorn sheep",
|
353 |
+
"Alpine ibex",
|
354 |
+
"hartebeest",
|
355 |
+
"impala (antelope)",
|
356 |
+
"gazelle",
|
357 |
+
"arabian camel",
|
358 |
+
"llama",
|
359 |
+
"weasel",
|
360 |
+
"mink",
|
361 |
+
"European polecat",
|
362 |
+
"black-footed ferret",
|
363 |
+
"otter",
|
364 |
+
"skunk",
|
365 |
+
"badger",
|
366 |
+
"armadillo",
|
367 |
+
"three-toed sloth",
|
368 |
+
"orangutan",
|
369 |
+
"gorilla",
|
370 |
+
"chimpanzee",
|
371 |
+
"gibbon",
|
372 |
+
"siamang",
|
373 |
+
"guenon",
|
374 |
+
"patas monkey",
|
375 |
+
"baboon",
|
376 |
+
"macaque",
|
377 |
+
"langur",
|
378 |
+
"black-and-white colobus",
|
379 |
+
"proboscis monkey",
|
380 |
+
"marmoset",
|
381 |
+
"white-headed capuchin",
|
382 |
+
"howler monkey",
|
383 |
+
"titi monkey",
|
384 |
+
"Geoffroy's spider monkey",
|
385 |
+
"common squirrel monkey",
|
386 |
+
"ring-tailed lemur",
|
387 |
+
"indri",
|
388 |
+
"Asian elephant",
|
389 |
+
"African bush elephant",
|
390 |
+
"red panda",
|
391 |
+
"giant panda",
|
392 |
+
"snoek fish",
|
393 |
+
"eel",
|
394 |
+
"silver salmon",
|
395 |
+
"rock beauty fish",
|
396 |
+
"clownfish",
|
397 |
+
"sturgeon",
|
398 |
+
"gar fish",
|
399 |
+
"lionfish",
|
400 |
+
"pufferfish",
|
401 |
+
"abacus",
|
402 |
+
"abaya",
|
403 |
+
"academic gown",
|
404 |
+
"accordion",
|
405 |
+
"acoustic guitar",
|
406 |
+
"aircraft carrier",
|
407 |
+
"airliner",
|
408 |
+
"airship",
|
409 |
+
"altar",
|
410 |
+
"ambulance",
|
411 |
+
"amphibious vehicle",
|
412 |
+
"analog clock",
|
413 |
+
"apiary",
|
414 |
+
"apron",
|
415 |
+
"trash can",
|
416 |
+
"assault rifle",
|
417 |
+
"backpack",
|
418 |
+
"bakery",
|
419 |
+
"balance beam",
|
420 |
+
"balloon",
|
421 |
+
"ballpoint pen",
|
422 |
+
"Band-Aid",
|
423 |
+
"banjo",
|
424 |
+
"baluster / handrail",
|
425 |
+
"barbell",
|
426 |
+
"barber chair",
|
427 |
+
"barbershop",
|
428 |
+
"barn",
|
429 |
+
"barometer",
|
430 |
+
"barrel",
|
431 |
+
"wheelbarrow",
|
432 |
+
"baseball",
|
433 |
+
"basketball",
|
434 |
+
"bassinet",
|
435 |
+
"bassoon",
|
436 |
+
"swimming cap",
|
437 |
+
"bath towel",
|
438 |
+
"bathtub",
|
439 |
+
"station wagon",
|
440 |
+
"lighthouse",
|
441 |
+
"beaker",
|
442 |
+
"military hat (bearskin or shako)",
|
443 |
+
"beer bottle",
|
444 |
+
"beer glass",
|
445 |
+
"bell tower",
|
446 |
+
"baby bib",
|
447 |
+
"tandem bicycle",
|
448 |
+
"bikini",
|
449 |
+
"ring binder",
|
450 |
+
"binoculars",
|
451 |
+
"birdhouse",
|
452 |
+
"boathouse",
|
453 |
+
"bobsleigh",
|
454 |
+
"bolo tie",
|
455 |
+
"poke bonnet",
|
456 |
+
"bookcase",
|
457 |
+
"bookstore",
|
458 |
+
"bottle cap",
|
459 |
+
"hunting bow",
|
460 |
+
"bow tie",
|
461 |
+
"brass memorial plaque",
|
462 |
+
"bra",
|
463 |
+
"breakwater",
|
464 |
+
"breastplate",
|
465 |
+
"broom",
|
466 |
+
"bucket",
|
467 |
+
"buckle",
|
468 |
+
"bulletproof vest",
|
469 |
+
"high-speed train",
|
470 |
+
"butcher shop",
|
471 |
+
"taxicab",
|
472 |
+
"cauldron",
|
473 |
+
"candle",
|
474 |
+
"cannon",
|
475 |
+
"canoe",
|
476 |
+
"can opener",
|
477 |
+
"cardigan",
|
478 |
+
"car mirror",
|
479 |
+
"carousel",
|
480 |
+
"tool kit",
|
481 |
+
"cardboard box / carton",
|
482 |
+
"car wheel",
|
483 |
+
"automated teller machine",
|
484 |
+
"cassette",
|
485 |
+
"cassette player",
|
486 |
+
"castle",
|
487 |
+
"catamaran",
|
488 |
+
"CD player",
|
489 |
+
"cello",
|
490 |
+
"mobile phone",
|
491 |
+
"chain",
|
492 |
+
"chain-link fence",
|
493 |
+
"chain mail",
|
494 |
+
"chainsaw",
|
495 |
+
"storage chest",
|
496 |
+
"chiffonier",
|
497 |
+
"bell or wind chime",
|
498 |
+
"china cabinet",
|
499 |
+
"Christmas stocking",
|
500 |
+
"church",
|
501 |
+
"movie theater",
|
502 |
+
"cleaver",
|
503 |
+
"cliff dwelling",
|
504 |
+
"cloak",
|
505 |
+
"clogs",
|
506 |
+
"cocktail shaker",
|
507 |
+
"coffee mug",
|
508 |
+
"coffeemaker",
|
509 |
+
"spiral or coil",
|
510 |
+
"combination lock",
|
511 |
+
"computer keyboard",
|
512 |
+
"candy store",
|
513 |
+
"container ship",
|
514 |
+
"convertible",
|
515 |
+
"corkscrew",
|
516 |
+
"cornet",
|
517 |
+
"cowboy boot",
|
518 |
+
"cowboy hat",
|
519 |
+
"cradle",
|
520 |
+
"construction crane",
|
521 |
+
"crash helmet",
|
522 |
+
"crate",
|
523 |
+
"infant bed",
|
524 |
+
"Crock Pot",
|
525 |
+
"croquet ball",
|
526 |
+
"crutch",
|
527 |
+
"cuirass",
|
528 |
+
"dam",
|
529 |
+
"desk",
|
530 |
+
"desktop computer",
|
531 |
+
"rotary dial telephone",
|
532 |
+
"diaper",
|
533 |
+
"digital clock",
|
534 |
+
"digital watch",
|
535 |
+
"dining table",
|
536 |
+
"dishcloth",
|
537 |
+
"dishwasher",
|
538 |
+
"disc brake",
|
539 |
+
"dock",
|
540 |
+
"dog sled",
|
541 |
+
"dome",
|
542 |
+
"doormat",
|
543 |
+
"drilling rig",
|
544 |
+
"drum",
|
545 |
+
"drumstick",
|
546 |
+
"dumbbell",
|
547 |
+
"Dutch oven",
|
548 |
+
"electric fan",
|
549 |
+
"electric guitar",
|
550 |
+
"electric locomotive",
|
551 |
+
"entertainment center",
|
552 |
+
"envelope",
|
553 |
+
"espresso machine",
|
554 |
+
"face powder",
|
555 |
+
"feather boa",
|
556 |
+
"filing cabinet",
|
557 |
+
"fireboat",
|
558 |
+
"fire truck",
|
559 |
+
"fire screen",
|
560 |
+
"flagpole",
|
561 |
+
"flute",
|
562 |
+
"folding chair",
|
563 |
+
"football helmet",
|
564 |
+
"forklift",
|
565 |
+
"fountain",
|
566 |
+
"fountain pen",
|
567 |
+
"four-poster bed",
|
568 |
+
"freight car",
|
569 |
+
"French horn",
|
570 |
+
"frying pan",
|
571 |
+
"fur coat",
|
572 |
+
"garbage truck",
|
573 |
+
"gas mask or respirator",
|
574 |
+
"gas pump",
|
575 |
+
"goblet",
|
576 |
+
"go-kart",
|
577 |
+
"golf ball",
|
578 |
+
"golf cart",
|
579 |
+
"gondola",
|
580 |
+
"gong",
|
581 |
+
"gown",
|
582 |
+
"grand piano",
|
583 |
+
"greenhouse",
|
584 |
+
"radiator grille",
|
585 |
+
"grocery store",
|
586 |
+
"guillotine",
|
587 |
+
"hair clip",
|
588 |
+
"hair spray",
|
589 |
+
"half-track",
|
590 |
+
"hammer",
|
591 |
+
"hamper",
|
592 |
+
"hair dryer",
|
593 |
+
"hand-held computer",
|
594 |
+
"handkerchief",
|
595 |
+
"hard disk drive",
|
596 |
+
"harmonica",
|
597 |
+
"harp",
|
598 |
+
"combine harvester",
|
599 |
+
"hatchet",
|
600 |
+
"holster",
|
601 |
+
"home theater",
|
602 |
+
"honeycomb",
|
603 |
+
"hook",
|
604 |
+
"hoop skirt",
|
605 |
+
"gymnastic horizontal bar",
|
606 |
+
"horse-drawn vehicle",
|
607 |
+
"hourglass",
|
608 |
+
"iPod",
|
609 |
+
"clothes iron",
|
610 |
+
"carved pumpkin",
|
611 |
+
"jeans",
|
612 |
+
"jeep",
|
613 |
+
"T-shirt",
|
614 |
+
"jigsaw puzzle",
|
615 |
+
"rickshaw",
|
616 |
+
"joystick",
|
617 |
+
"kimono",
|
618 |
+
"knee pad",
|
619 |
+
"knot",
|
620 |
+
"lab coat",
|
621 |
+
"ladle",
|
622 |
+
"lampshade",
|
623 |
+
"laptop computer",
|
624 |
+
"lawn mower",
|
625 |
+
"lens cap",
|
626 |
+
"letter opener",
|
627 |
+
"library",
|
628 |
+
"lifeboat",
|
629 |
+
"lighter",
|
630 |
+
"limousine",
|
631 |
+
"ocean liner",
|
632 |
+
"lipstick",
|
633 |
+
"slip-on shoe",
|
634 |
+
"lotion",
|
635 |
+
"music speaker",
|
636 |
+
"loupe magnifying glass",
|
637 |
+
"sawmill",
|
638 |
+
"magnetic compass",
|
639 |
+
"messenger bag",
|
640 |
+
"mailbox",
|
641 |
+
"tights",
|
642 |
+
"one-piece bathing suit",
|
643 |
+
"manhole cover",
|
644 |
+
"maraca",
|
645 |
+
"marimba",
|
646 |
+
"mask",
|
647 |
+
"matchstick",
|
648 |
+
"maypole",
|
649 |
+
"maze",
|
650 |
+
"measuring cup",
|
651 |
+
"medicine cabinet",
|
652 |
+
"megalith",
|
653 |
+
"microphone",
|
654 |
+
"microwave oven",
|
655 |
+
"military uniform",
|
656 |
+
"milk can",
|
657 |
+
"minibus",
|
658 |
+
"miniskirt",
|
659 |
+
"minivan",
|
660 |
+
"missile",
|
661 |
+
"mitten",
|
662 |
+
"mixing bowl",
|
663 |
+
"mobile home",
|
664 |
+
"ford model t",
|
665 |
+
"modem",
|
666 |
+
"monastery",
|
667 |
+
"monitor",
|
668 |
+
"moped",
|
669 |
+
"mortar and pestle",
|
670 |
+
"graduation cap",
|
671 |
+
"mosque",
|
672 |
+
"mosquito net",
|
673 |
+
"vespa",
|
674 |
+
"mountain bike",
|
675 |
+
"tent",
|
676 |
+
"computer mouse",
|
677 |
+
"mousetrap",
|
678 |
+
"moving van",
|
679 |
+
"muzzle",
|
680 |
+
"metal nail",
|
681 |
+
"neck brace",
|
682 |
+
"necklace",
|
683 |
+
"baby pacifier",
|
684 |
+
"notebook computer",
|
685 |
+
"obelisk",
|
686 |
+
"oboe",
|
687 |
+
"ocarina",
|
688 |
+
"odometer",
|
689 |
+
"oil filter",
|
690 |
+
"pipe organ",
|
691 |
+
"oscilloscope",
|
692 |
+
"overskirt",
|
693 |
+
"bullock cart",
|
694 |
+
"oxygen mask",
|
695 |
+
"product packet / packaging",
|
696 |
+
"paddle",
|
697 |
+
"paddle wheel",
|
698 |
+
"padlock",
|
699 |
+
"paintbrush",
|
700 |
+
"pajamas",
|
701 |
+
"palace",
|
702 |
+
"pan flute",
|
703 |
+
"paper towel",
|
704 |
+
"parachute",
|
705 |
+
"parallel bars",
|
706 |
+
"park bench",
|
707 |
+
"parking meter",
|
708 |
+
"railroad car",
|
709 |
+
"patio",
|
710 |
+
"payphone",
|
711 |
+
"pedestal",
|
712 |
+
"pencil case",
|
713 |
+
"pencil sharpener",
|
714 |
+
"perfume",
|
715 |
+
"Petri dish",
|
716 |
+
"photocopier",
|
717 |
+
"plectrum",
|
718 |
+
"Pickelhaube",
|
719 |
+
"picket fence",
|
720 |
+
"pickup truck",
|
721 |
+
"pier",
|
722 |
+
"piggy bank",
|
723 |
+
"pill bottle",
|
724 |
+
"pillow",
|
725 |
+
"ping-pong ball",
|
726 |
+
"pinwheel",
|
727 |
+
"pirate ship",
|
728 |
+
"drink pitcher",
|
729 |
+
"block plane",
|
730 |
+
"planetarium",
|
731 |
+
"plastic bag",
|
732 |
+
"plate rack",
|
733 |
+
"farm plow",
|
734 |
+
"plunger",
|
735 |
+
"Polaroid camera",
|
736 |
+
"pole",
|
737 |
+
"police van",
|
738 |
+
"poncho",
|
739 |
+
"pool table",
|
740 |
+
"soda bottle",
|
741 |
+
"plant pot",
|
742 |
+
"potter's wheel",
|
743 |
+
"power drill",
|
744 |
+
"prayer rug",
|
745 |
+
"printer",
|
746 |
+
"prison",
|
747 |
+
"missile",
|
748 |
+
"projector",
|
749 |
+
"hockey puck",
|
750 |
+
"punching bag",
|
751 |
+
"purse",
|
752 |
+
"quill",
|
753 |
+
"quilt",
|
754 |
+
"race car",
|
755 |
+
"racket",
|
756 |
+
"radiator",
|
757 |
+
"radio",
|
758 |
+
"radio telescope",
|
759 |
+
"rain barrel",
|
760 |
+
"recreational vehicle",
|
761 |
+
"fishing casting reel",
|
762 |
+
"reflex camera",
|
763 |
+
"refrigerator",
|
764 |
+
"remote control",
|
765 |
+
"restaurant",
|
766 |
+
"revolver",
|
767 |
+
"rifle",
|
768 |
+
"rocking chair",
|
769 |
+
"rotisserie",
|
770 |
+
"eraser",
|
771 |
+
"rugby ball",
|
772 |
+
"ruler measuring stick",
|
773 |
+
"sneaker",
|
774 |
+
"safe",
|
775 |
+
"safety pin",
|
776 |
+
"salt shaker",
|
777 |
+
"sandal",
|
778 |
+
"sarong",
|
779 |
+
"saxophone",
|
780 |
+
"scabbard",
|
781 |
+
"weighing scale",
|
782 |
+
"school bus",
|
783 |
+
"schooner",
|
784 |
+
"scoreboard",
|
785 |
+
"CRT monitor",
|
786 |
+
"screw",
|
787 |
+
"screwdriver",
|
788 |
+
"seat belt",
|
789 |
+
"sewing machine",
|
790 |
+
"shield",
|
791 |
+
"shoe store",
|
792 |
+
"shoji screen / room divider",
|
793 |
+
"shopping basket",
|
794 |
+
"shopping cart",
|
795 |
+
"shovel",
|
796 |
+
"shower cap",
|
797 |
+
"shower curtain",
|
798 |
+
"ski",
|
799 |
+
"balaclava ski mask",
|
800 |
+
"sleeping bag",
|
801 |
+
"slide rule",
|
802 |
+
"sliding door",
|
803 |
+
"slot machine",
|
804 |
+
"snorkel",
|
805 |
+
"snowmobile",
|
806 |
+
"snowplow",
|
807 |
+
"soap dispenser",
|
808 |
+
"soccer ball",
|
809 |
+
"sock",
|
810 |
+
"solar thermal collector",
|
811 |
+
"sombrero",
|
812 |
+
"soup bowl",
|
813 |
+
"keyboard space bar",
|
814 |
+
"space heater",
|
815 |
+
"space shuttle",
|
816 |
+
"spatula",
|
817 |
+
"motorboat",
|
818 |
+
"spider web",
|
819 |
+
"spindle",
|
820 |
+
"sports car",
|
821 |
+
"spotlight",
|
822 |
+
"stage",
|
823 |
+
"steam locomotive",
|
824 |
+
"through arch bridge",
|
825 |
+
"steel drum",
|
826 |
+
"stethoscope",
|
827 |
+
"scarf",
|
828 |
+
"stone wall",
|
829 |
+
"stopwatch",
|
830 |
+
"stove",
|
831 |
+
"strainer",
|
832 |
+
"tram",
|
833 |
+
"stretcher",
|
834 |
+
"couch",
|
835 |
+
"stupa",
|
836 |
+
"submarine",
|
837 |
+
"suit",
|
838 |
+
"sundial",
|
839 |
+
"sunglasses",
|
840 |
+
"sunglasses",
|
841 |
+
"sunscreen",
|
842 |
+
"suspension bridge",
|
843 |
+
"mop",
|
844 |
+
"sweatshirt",
|
845 |
+
"swim trunks / shorts",
|
846 |
+
"swing",
|
847 |
+
"electrical switch",
|
848 |
+
"syringe",
|
849 |
+
"table lamp",
|
850 |
+
"tank",
|
851 |
+
"tape player",
|
852 |
+
"teapot",
|
853 |
+
"teddy bear",
|
854 |
+
"television",
|
855 |
+
"tennis ball",
|
856 |
+
"thatched roof",
|
857 |
+
"front curtain",
|
858 |
+
"thimble",
|
859 |
+
"threshing machine",
|
860 |
+
"throne",
|
861 |
+
"tile roof",
|
862 |
+
"toaster",
|
863 |
+
"tobacco shop",
|
864 |
+
"toilet seat",
|
865 |
+
"torch",
|
866 |
+
"totem pole",
|
867 |
+
"tow truck",
|
868 |
+
"toy store",
|
869 |
+
"tractor",
|
870 |
+
"semi-trailer truck",
|
871 |
+
"tray",
|
872 |
+
"trench coat",
|
873 |
+
"tricycle",
|
874 |
+
"trimaran",
|
875 |
+
"tripod",
|
876 |
+
"triumphal arch",
|
877 |
+
"trolleybus",
|
878 |
+
"trombone",
|
879 |
+
"hot tub",
|
880 |
+
"turnstile",
|
881 |
+
"typewriter keyboard",
|
882 |
+
"umbrella",
|
883 |
+
"unicycle",
|
884 |
+
"upright piano",
|
885 |
+
"vacuum cleaner",
|
886 |
+
"vase",
|
887 |
+
"vaulted or arched ceiling",
|
888 |
+
"velvet fabric",
|
889 |
+
"vending machine",
|
890 |
+
"vestment",
|
891 |
+
"viaduct",
|
892 |
+
"violin",
|
893 |
+
"volleyball",
|
894 |
+
"waffle iron",
|
895 |
+
"wall clock",
|
896 |
+
"wallet",
|
897 |
+
"wardrobe",
|
898 |
+
"military aircraft",
|
899 |
+
"sink",
|
900 |
+
"washing machine",
|
901 |
+
"water bottle",
|
902 |
+
"water jug",
|
903 |
+
"water tower",
|
904 |
+
"whiskey jug",
|
905 |
+
"whistle",
|
906 |
+
"hair wig",
|
907 |
+
"window screen",
|
908 |
+
"window shade",
|
909 |
+
"Windsor tie",
|
910 |
+
"wine bottle",
|
911 |
+
"airplane wing",
|
912 |
+
"wok",
|
913 |
+
"wooden spoon",
|
914 |
+
"wool",
|
915 |
+
"split-rail fence",
|
916 |
+
"shipwreck",
|
917 |
+
"sailboat",
|
918 |
+
"yurt",
|
919 |
+
"website",
|
920 |
+
"comic book",
|
921 |
+
"crossword",
|
922 |
+
"traffic or street sign",
|
923 |
+
"traffic light",
|
924 |
+
"dust jacket",
|
925 |
+
"menu",
|
926 |
+
"plate",
|
927 |
+
"guacamole",
|
928 |
+
"consomme",
|
929 |
+
"hot pot",
|
930 |
+
"trifle",
|
931 |
+
"ice cream",
|
932 |
+
"popsicle",
|
933 |
+
"baguette",
|
934 |
+
"bagel",
|
935 |
+
"pretzel",
|
936 |
+
"cheeseburger",
|
937 |
+
"hot dog",
|
938 |
+
"mashed potatoes",
|
939 |
+
"cabbage",
|
940 |
+
"broccoli",
|
941 |
+
"cauliflower",
|
942 |
+
"zucchini",
|
943 |
+
"spaghetti squash",
|
944 |
+
"acorn squash",
|
945 |
+
"butternut squash",
|
946 |
+
"cucumber",
|
947 |
+
"artichoke",
|
948 |
+
"bell pepper",
|
949 |
+
"cardoon",
|
950 |
+
"mushroom",
|
951 |
+
"Granny Smith apple",
|
952 |
+
"strawberry",
|
953 |
+
"orange",
|
954 |
+
"lemon",
|
955 |
+
"fig",
|
956 |
+
"pineapple",
|
957 |
+
"banana",
|
958 |
+
"jackfruit",
|
959 |
+
"cherimoya (custard apple)",
|
960 |
+
"pomegranate",
|
961 |
+
"hay",
|
962 |
+
"carbonara",
|
963 |
+
"chocolate syrup",
|
964 |
+
"dough",
|
965 |
+
"meatloaf",
|
966 |
+
"pizza",
|
967 |
+
"pot pie",
|
968 |
+
"burrito",
|
969 |
+
"red wine",
|
970 |
+
"espresso",
|
971 |
+
"tea cup",
|
972 |
+
"eggnog",
|
973 |
+
"mountain",
|
974 |
+
"bubble",
|
975 |
+
"cliff",
|
976 |
+
"coral reef",
|
977 |
+
"geyser",
|
978 |
+
"lakeshore",
|
979 |
+
"promontory",
|
980 |
+
"sandbar",
|
981 |
+
"beach",
|
982 |
+
"valley",
|
983 |
+
"volcano",
|
984 |
+
"baseball player",
|
985 |
+
"bridegroom",
|
986 |
+
"scuba diver",
|
987 |
+
"rapeseed",
|
988 |
+
"daisy",
|
989 |
+
"yellow lady's slipper",
|
990 |
+
"corn",
|
991 |
+
"acorn",
|
992 |
+
"rose hip",
|
993 |
+
"horse chestnut seed",
|
994 |
+
"coral fungus",
|
995 |
+
"agaric",
|
996 |
+
"gyromitra",
|
997 |
+
"stinkhorn mushroom",
|
998 |
+
"earth star fungus",
|
999 |
+
"hen of the woods mushroom",
|
1000 |
+
"bolete",
|
1001 |
+
"corn cob",
|
1002 |
+
"toilet paper",
|
1003 |
+
]
|
1004 |
+
# Maps numeric class ids to labels
|
1005 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL = dict(
|
1006 |
+
zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames)
|
1007 |
+
)
|
open_flamingo/open_flamingo/eval/ok_vqa_utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Those are manual mapping that are not caught by our stemming rules or would
|
2 |
+
# would be done incorrectly by our automatic stemming rule. In details,
|
3 |
+
# the keys of the _MANUAL_MATCHES dict contains the original word and the value
|
4 |
+
# contains the transformation of the word expected by the OKVQA stemming rule.
|
5 |
+
# These manual rules were found by checking the `raw_answers` and the `answers`
|
6 |
+
# fields of the released OKVQA dataset and checking all things that were not
|
7 |
+
# properly mapped by our automatic rules. In particular some of the mapping
|
8 |
+
# are sometimes constant, e.g. christmas -> christmas which was incorrectly
|
9 |
+
# singularized by our inflection.singularize.
|
10 |
+
import re
|
11 |
+
import nltk
|
12 |
+
from nltk.corpus.reader import VERB
|
13 |
+
import inflection
|
14 |
+
|
15 |
+
_MANUAL_MATCHES = {
|
16 |
+
"police": "police",
|
17 |
+
"las": "las",
|
18 |
+
"vegas": "vegas",
|
19 |
+
"yes": "yes",
|
20 |
+
"jeans": "jean",
|
21 |
+
"hell's": "hell",
|
22 |
+
"domino's": "domino",
|
23 |
+
"morning": "morn",
|
24 |
+
"clothes": "cloth",
|
25 |
+
"are": "are",
|
26 |
+
"riding": "ride",
|
27 |
+
"leaves": "leaf",
|
28 |
+
"dangerous": "danger",
|
29 |
+
"clothing": "cloth",
|
30 |
+
"texting": "text",
|
31 |
+
"kiting": "kite",
|
32 |
+
"firefighters": "firefight",
|
33 |
+
"ties": "tie",
|
34 |
+
"married": "married",
|
35 |
+
"teething": "teeth",
|
36 |
+
"gloves": "glove",
|
37 |
+
"tennis": "tennis",
|
38 |
+
"dining": "dine",
|
39 |
+
"directions": "direct",
|
40 |
+
"waves": "wave",
|
41 |
+
"christmas": "christmas",
|
42 |
+
"drives": "drive",
|
43 |
+
"pudding": "pud",
|
44 |
+
"coding": "code",
|
45 |
+
"plating": "plate",
|
46 |
+
"quantas": "quanta",
|
47 |
+
"hornes": "horn",
|
48 |
+
"graves": "grave",
|
49 |
+
"mating": "mate",
|
50 |
+
"paned": "pane",
|
51 |
+
"alertness": "alert",
|
52 |
+
"sunbathing": "sunbath",
|
53 |
+
"tenning": "ten",
|
54 |
+
"wetness": "wet",
|
55 |
+
"urinating": "urine",
|
56 |
+
"sickness": "sick",
|
57 |
+
"braves": "brave",
|
58 |
+
"firefighting": "firefight",
|
59 |
+
"lenses": "lens",
|
60 |
+
"reflections": "reflect",
|
61 |
+
"backpackers": "backpack",
|
62 |
+
"eatting": "eat",
|
63 |
+
"designers": "design",
|
64 |
+
"curiousity": "curious",
|
65 |
+
"playfulness": "play",
|
66 |
+
"blindness": "blind",
|
67 |
+
"hawke": "hawk",
|
68 |
+
"tomatoe": "tomato",
|
69 |
+
"rodeoing": "rodeo",
|
70 |
+
"brightness": "bright",
|
71 |
+
"circuses": "circus",
|
72 |
+
"skateboarders": "skateboard",
|
73 |
+
"staring": "stare",
|
74 |
+
"electronics": "electron",
|
75 |
+
"electicity": "elect",
|
76 |
+
"mountainous": "mountain",
|
77 |
+
"socializing": "social",
|
78 |
+
"hamburgers": "hamburg",
|
79 |
+
"caves": "cave",
|
80 |
+
"transitions": "transit",
|
81 |
+
"wading": "wade",
|
82 |
+
"creame": "cream",
|
83 |
+
"toileting": "toilet",
|
84 |
+
"sautee": "saute",
|
85 |
+
"buildings": "build",
|
86 |
+
"belongings": "belong",
|
87 |
+
"stockings": "stock",
|
88 |
+
"walle": "wall",
|
89 |
+
"cumulis": "cumuli",
|
90 |
+
"travelers": "travel",
|
91 |
+
"conducter": "conduct",
|
92 |
+
"browsing": "brows",
|
93 |
+
"pooping": "poop",
|
94 |
+
"haircutting": "haircut",
|
95 |
+
"toppings": "top",
|
96 |
+
"hearding": "heard",
|
97 |
+
"sunblocker": "sunblock",
|
98 |
+
"bases": "base",
|
99 |
+
"markings": "mark",
|
100 |
+
"mopeds": "mope",
|
101 |
+
"kindergartener": "kindergarten",
|
102 |
+
"pies": "pie",
|
103 |
+
"scrapbooking": "scrapbook",
|
104 |
+
"couponing": "coupon",
|
105 |
+
"meetings": "meet",
|
106 |
+
"elevators": "elev",
|
107 |
+
"lowes": "low",
|
108 |
+
"men's": "men",
|
109 |
+
"childrens": "children",
|
110 |
+
"shelves": "shelve",
|
111 |
+
"paintings": "paint",
|
112 |
+
"raines": "rain",
|
113 |
+
"paring": "pare",
|
114 |
+
"expressions": "express",
|
115 |
+
"routes": "rout",
|
116 |
+
"pease": "peas",
|
117 |
+
"vastness": "vast",
|
118 |
+
"awning": "awn",
|
119 |
+
"boy's": "boy",
|
120 |
+
"drunkenness": "drunken",
|
121 |
+
"teasing": "teas",
|
122 |
+
"conferences": "confer",
|
123 |
+
"ripeness": "ripe",
|
124 |
+
"suspenders": "suspend",
|
125 |
+
"earnings": "earn",
|
126 |
+
"reporters": "report",
|
127 |
+
"kid's": "kid",
|
128 |
+
"containers": "contain",
|
129 |
+
"corgie": "corgi",
|
130 |
+
"porche": "porch",
|
131 |
+
"microwaves": "microwave",
|
132 |
+
"batter's": "batter",
|
133 |
+
"sadness": "sad",
|
134 |
+
"apartments": "apart",
|
135 |
+
"oxygenize": "oxygen",
|
136 |
+
"striping": "stripe",
|
137 |
+
"purring": "pure",
|
138 |
+
"professionals": "profession",
|
139 |
+
"piping": "pipe",
|
140 |
+
"farmer's": "farmer",
|
141 |
+
"potatoe": "potato",
|
142 |
+
"emirates": "emir",
|
143 |
+
"womens": "women",
|
144 |
+
"veteran's": "veteran",
|
145 |
+
"wilderness": "wilder",
|
146 |
+
"propellers": "propel",
|
147 |
+
"alpes": "alp",
|
148 |
+
"charioteering": "chariot",
|
149 |
+
"swining": "swine",
|
150 |
+
"illness": "ill",
|
151 |
+
"crepte": "crept",
|
152 |
+
"adhesives": "adhesive",
|
153 |
+
"regent's": "regent",
|
154 |
+
"decorations": "decor",
|
155 |
+
"rabbies": "rabbi",
|
156 |
+
"overseas": "oversea",
|
157 |
+
"travellers": "travel",
|
158 |
+
"casings": "case",
|
159 |
+
"smugness": "smug",
|
160 |
+
"doves": "dove",
|
161 |
+
"nationals": "nation",
|
162 |
+
"mustange": "mustang",
|
163 |
+
"ringe": "ring",
|
164 |
+
"gondoliere": "gondolier",
|
165 |
+
"vacationing": "vacate",
|
166 |
+
"reminders": "remind",
|
167 |
+
"baldness": "bald",
|
168 |
+
"settings": "set",
|
169 |
+
"glaced": "glace",
|
170 |
+
"coniferous": "conifer",
|
171 |
+
"revelations": "revel",
|
172 |
+
"personals": "person",
|
173 |
+
"daughter's": "daughter",
|
174 |
+
"badness": "bad",
|
175 |
+
"projections": "project",
|
176 |
+
"polarizing": "polar",
|
177 |
+
"vandalizers": "vandal",
|
178 |
+
"minerals": "miner",
|
179 |
+
"protesters": "protest",
|
180 |
+
"controllers": "control",
|
181 |
+
"weddings": "wed",
|
182 |
+
"sometimes": "sometime",
|
183 |
+
"earing": "ear",
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
class OKVQAStemmer:
|
188 |
+
"""Stemmer to match OKVQA v1.1 procedure."""
|
189 |
+
|
190 |
+
def __init__(self):
|
191 |
+
self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
|
192 |
+
|
193 |
+
def stem(self, input_string):
|
194 |
+
"""Apply stemming."""
|
195 |
+
word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
|
196 |
+
stemmed_words = []
|
197 |
+
for w, p in word_and_pos:
|
198 |
+
if w in _MANUAL_MATCHES:
|
199 |
+
w = _MANUAL_MATCHES[w]
|
200 |
+
elif w.endswith("ing"):
|
201 |
+
w = self._wordnet_lemmatizer.lemmatize(w, VERB)
|
202 |
+
elif p.startswith("NNS") or p.startswith("NNPS"):
|
203 |
+
w = inflection.singularize(w)
|
204 |
+
stemmed_words.append(w)
|
205 |
+
return " ".join(stemmed_words)
|
206 |
+
|
207 |
+
|
208 |
+
stemmer = OKVQAStemmer()
|
209 |
+
|
210 |
+
|
211 |
+
def postprocess_ok_vqa_generation(prediction) -> str:
|
212 |
+
prediction_stem = stemmer.stem(prediction)
|
213 |
+
return prediction_stem
|
open_flamingo/open_flamingo/eval/vqa_metric.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Interface for accessing the VQA dataset.
|
10 |
+
|
11 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
12 |
+
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
13 |
+
|
14 |
+
# The following functions are defined:
|
15 |
+
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
16 |
+
# getQuesIds - Get question ids that satisfy given filter conditions.
|
17 |
+
# getImgIds - Get image ids that satisfy given filter conditions.
|
18 |
+
# loadQA - Load questions and answers with the specified question ids.
|
19 |
+
# showQA - Display the specified questions and answers.
|
20 |
+
# loadRes - Load result file and create result object.
|
21 |
+
|
22 |
+
# Help on each function can be accessed by: "help(COCO.function)"
|
23 |
+
|
24 |
+
|
25 |
+
class VQA:
|
26 |
+
def __init__(self, annotation_file=None, question_file=None):
|
27 |
+
"""
|
28 |
+
Constructor of VQA helper class for reading and visualizing questions and answers.
|
29 |
+
:param annotation_file (str): location of VQA annotation file
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
# load dataset
|
33 |
+
self.dataset = {}
|
34 |
+
self.questions = {}
|
35 |
+
self.qa = {}
|
36 |
+
self.qqa = {}
|
37 |
+
self.imgToQA = {}
|
38 |
+
if not annotation_file == None and not question_file == None:
|
39 |
+
print("loading VQA annotations and questions into memory...")
|
40 |
+
time_t = datetime.datetime.utcnow()
|
41 |
+
dataset = json.load(open(annotation_file, "r"))
|
42 |
+
questions = json.load(open(question_file, "r"))
|
43 |
+
print(datetime.datetime.utcnow() - time_t)
|
44 |
+
self.dataset = dataset
|
45 |
+
self.questions = questions
|
46 |
+
self.createIndex()
|
47 |
+
|
48 |
+
def createIndex(self):
|
49 |
+
# create index
|
50 |
+
print("creating index...")
|
51 |
+
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
|
52 |
+
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
53 |
+
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
54 |
+
for ann in self.dataset["annotations"]:
|
55 |
+
imgToQA[ann["image_id"]] += [ann]
|
56 |
+
qa[ann["question_id"]] = ann
|
57 |
+
for ques in self.questions["questions"]:
|
58 |
+
qqa[ques["question_id"]] = ques
|
59 |
+
print("index created!")
|
60 |
+
|
61 |
+
# create class members
|
62 |
+
self.qa = qa
|
63 |
+
self.qqa = qqa
|
64 |
+
self.imgToQA = imgToQA
|
65 |
+
|
66 |
+
def info(self):
|
67 |
+
"""
|
68 |
+
Print information about the VQA annotation file.
|
69 |
+
:return:
|
70 |
+
"""
|
71 |
+
for key, value in self.dataset["info"].items():
|
72 |
+
print("%s: %s" % (key, value))
|
73 |
+
|
74 |
+
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
75 |
+
"""
|
76 |
+
Get question ids that satisfy given filter conditions. default skips that filter
|
77 |
+
:param imgIds (int array) : get question ids for given imgs
|
78 |
+
quesTypes (str array) : get question ids for given question types
|
79 |
+
ansTypes (str array) : get question ids for given answer types
|
80 |
+
:return: ids (int array) : integer array of question ids
|
81 |
+
"""
|
82 |
+
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
83 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
84 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
85 |
+
|
86 |
+
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
87 |
+
anns = self.dataset["annotations"]
|
88 |
+
else:
|
89 |
+
if not len(imgIds) == 0:
|
90 |
+
anns = sum(
|
91 |
+
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
|
92 |
+
[],
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
anns = self.dataset["annotations"]
|
96 |
+
anns = (
|
97 |
+
anns
|
98 |
+
if len(quesTypes) == 0
|
99 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
100 |
+
)
|
101 |
+
anns = (
|
102 |
+
anns
|
103 |
+
if len(ansTypes) == 0
|
104 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
105 |
+
)
|
106 |
+
ids = [ann["question_id"] for ann in anns]
|
107 |
+
return ids
|
108 |
+
|
109 |
+
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
110 |
+
"""
|
111 |
+
Get image ids that satisfy given filter conditions. default skips that filter
|
112 |
+
:param quesIds (int array) : get image ids for given question ids
|
113 |
+
quesTypes (str array) : get image ids for given question types
|
114 |
+
ansTypes (str array) : get image ids for given answer types
|
115 |
+
:return: ids (int array) : integer array of image ids
|
116 |
+
"""
|
117 |
+
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
118 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
119 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
120 |
+
|
121 |
+
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
122 |
+
anns = self.dataset["annotations"]
|
123 |
+
else:
|
124 |
+
if not len(quesIds) == 0:
|
125 |
+
anns = sum(
|
126 |
+
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
anns = self.dataset["annotations"]
|
130 |
+
anns = (
|
131 |
+
anns
|
132 |
+
if len(quesTypes) == 0
|
133 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
134 |
+
)
|
135 |
+
anns = (
|
136 |
+
anns
|
137 |
+
if len(ansTypes) == 0
|
138 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
139 |
+
)
|
140 |
+
ids = [ann["image_id"] for ann in anns]
|
141 |
+
return ids
|
142 |
+
|
143 |
+
def loadQA(self, ids=[]):
|
144 |
+
"""
|
145 |
+
Load questions and answers with the specified question ids.
|
146 |
+
:param ids (int array) : integer ids specifying question ids
|
147 |
+
:return: qa (object array) : loaded qa objects
|
148 |
+
"""
|
149 |
+
if type(ids) == list:
|
150 |
+
return [self.qa[id] for id in ids]
|
151 |
+
elif type(ids) == int:
|
152 |
+
return [self.qa[ids]]
|
153 |
+
|
154 |
+
def showQA(self, anns):
|
155 |
+
"""
|
156 |
+
Display the specified annotations.
|
157 |
+
:param anns (array of object): annotations to display
|
158 |
+
:return: None
|
159 |
+
"""
|
160 |
+
if len(anns) == 0:
|
161 |
+
return 0
|
162 |
+
for ann in anns:
|
163 |
+
quesId = ann["question_id"]
|
164 |
+
print("Question: %s" % (self.qqa[quesId]["question"]))
|
165 |
+
for ans in ann["answers"]:
|
166 |
+
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
|
167 |
+
|
168 |
+
def loadRes(self, resFile, quesFile):
|
169 |
+
"""
|
170 |
+
Load result file and return a result object.
|
171 |
+
:param resFile (str) : file name of result file
|
172 |
+
:return: res (obj) : result api object
|
173 |
+
"""
|
174 |
+
res = VQA()
|
175 |
+
res.questions = json.load(open(quesFile))
|
176 |
+
res.dataset["info"] = copy.deepcopy(self.questions["info"])
|
177 |
+
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
|
178 |
+
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
|
179 |
+
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
|
180 |
+
res.dataset["license"] = copy.deepcopy(self.questions["license"])
|
181 |
+
|
182 |
+
print("Loading and preparing results... ")
|
183 |
+
time_t = datetime.datetime.utcnow()
|
184 |
+
anns = json.load(open(resFile))
|
185 |
+
assert type(anns) == list, "results is not an array of objects"
|
186 |
+
annsQuesIds = [ann["question_id"] for ann in anns]
|
187 |
+
# print set of question ids that do not have corresponding annotations
|
188 |
+
|
189 |
+
# assert set(annsQuesIds) == set(self.getQuesIds()), \
|
190 |
+
# 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
191 |
+
for ann in anns:
|
192 |
+
quesId = ann["question_id"]
|
193 |
+
if res.dataset["task_type"] == "Multiple Choice":
|
194 |
+
assert (
|
195 |
+
ann["answer"] in self.qqa[quesId]["multiple_choices"]
|
196 |
+
), "predicted answer is not one of the multiple choices"
|
197 |
+
qaAnn = self.qa[quesId]
|
198 |
+
ann["image_id"] = qaAnn["image_id"]
|
199 |
+
ann["question_type"] = qaAnn["question_type"]
|
200 |
+
ann["answer_type"] = qaAnn["answer_type"]
|
201 |
+
print(
|
202 |
+
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
|
203 |
+
)
|
204 |
+
|
205 |
+
res.dataset["annotations"] = anns
|
206 |
+
res.createIndex()
|
207 |
+
return res
|
208 |
+
|
209 |
+
|
210 |
+
class VQAEval:
|
211 |
+
def __init__(self, vqa=None, vqaRes=None, n=2):
|
212 |
+
self.n = n
|
213 |
+
self.accuracy = {}
|
214 |
+
self.evalQA = {}
|
215 |
+
self.evalQuesType = {}
|
216 |
+
self.evalAnsType = {}
|
217 |
+
self.vqa = vqa
|
218 |
+
self.vqaRes = vqaRes
|
219 |
+
if vqaRes is not None:
|
220 |
+
self.params = {"question_id": vqaRes.getQuesIds()}
|
221 |
+
self.contractions = {
|
222 |
+
"aint": "ain't",
|
223 |
+
"arent": "aren't",
|
224 |
+
"cant": "can't",
|
225 |
+
"couldve": "could've",
|
226 |
+
"couldnt": "couldn't",
|
227 |
+
"couldn'tve": "couldn't've",
|
228 |
+
"couldnt've": "couldn't've",
|
229 |
+
"didnt": "didn't",
|
230 |
+
"doesnt": "doesn't",
|
231 |
+
"dont": "don't",
|
232 |
+
"hadnt": "hadn't",
|
233 |
+
"hadnt've": "hadn't've",
|
234 |
+
"hadn'tve": "hadn't've",
|
235 |
+
"hasnt": "hasn't",
|
236 |
+
"havent": "haven't",
|
237 |
+
"hed": "he'd",
|
238 |
+
"hed've": "he'd've",
|
239 |
+
"he'dve": "he'd've",
|
240 |
+
"hes": "he's",
|
241 |
+
"howd": "how'd",
|
242 |
+
"howll": "how'll",
|
243 |
+
"hows": "how's",
|
244 |
+
"Id've": "I'd've",
|
245 |
+
"I'dve": "I'd've",
|
246 |
+
"Im": "I'm",
|
247 |
+
"Ive": "I've",
|
248 |
+
"isnt": "isn't",
|
249 |
+
"itd": "it'd",
|
250 |
+
"itd've": "it'd've",
|
251 |
+
"it'dve": "it'd've",
|
252 |
+
"itll": "it'll",
|
253 |
+
"let's": "let's",
|
254 |
+
"maam": "ma'am",
|
255 |
+
"mightnt": "mightn't",
|
256 |
+
"mightnt've": "mightn't've",
|
257 |
+
"mightn'tve": "mightn't've",
|
258 |
+
"mightve": "might've",
|
259 |
+
"mustnt": "mustn't",
|
260 |
+
"mustve": "must've",
|
261 |
+
"neednt": "needn't",
|
262 |
+
"notve": "not've",
|
263 |
+
"oclock": "o'clock",
|
264 |
+
"oughtnt": "oughtn't",
|
265 |
+
"ow's'at": "'ow's'at",
|
266 |
+
"'ows'at": "'ow's'at",
|
267 |
+
"'ow'sat": "'ow's'at",
|
268 |
+
"shant": "shan't",
|
269 |
+
"shed've": "she'd've",
|
270 |
+
"she'dve": "she'd've",
|
271 |
+
"she's": "she's",
|
272 |
+
"shouldve": "should've",
|
273 |
+
"shouldnt": "shouldn't",
|
274 |
+
"shouldnt've": "shouldn't've",
|
275 |
+
"shouldn'tve": "shouldn't've",
|
276 |
+
"somebody'd": "somebodyd",
|
277 |
+
"somebodyd've": "somebody'd've",
|
278 |
+
"somebody'dve": "somebody'd've",
|
279 |
+
"somebodyll": "somebody'll",
|
280 |
+
"somebodys": "somebody's",
|
281 |
+
"someoned": "someone'd",
|
282 |
+
"someoned've": "someone'd've",
|
283 |
+
"someone'dve": "someone'd've",
|
284 |
+
"someonell": "someone'll",
|
285 |
+
"someones": "someone's",
|
286 |
+
"somethingd": "something'd",
|
287 |
+
"somethingd've": "something'd've",
|
288 |
+
"something'dve": "something'd've",
|
289 |
+
"somethingll": "something'll",
|
290 |
+
"thats": "that's",
|
291 |
+
"thered": "there'd",
|
292 |
+
"thered've": "there'd've",
|
293 |
+
"there'dve": "there'd've",
|
294 |
+
"therere": "there're",
|
295 |
+
"theres": "there's",
|
296 |
+
"theyd": "they'd",
|
297 |
+
"theyd've": "they'd've",
|
298 |
+
"they'dve": "they'd've",
|
299 |
+
"theyll": "they'll",
|
300 |
+
"theyre": "they're",
|
301 |
+
"theyve": "they've",
|
302 |
+
"twas": "'twas",
|
303 |
+
"wasnt": "wasn't",
|
304 |
+
"wed've": "we'd've",
|
305 |
+
"we'dve": "we'd've",
|
306 |
+
"weve": "we've",
|
307 |
+
"werent": "weren't",
|
308 |
+
"whatll": "what'll",
|
309 |
+
"whatre": "what're",
|
310 |
+
"whats": "what's",
|
311 |
+
"whatve": "what've",
|
312 |
+
"whens": "when's",
|
313 |
+
"whered": "where'd",
|
314 |
+
"wheres": "where's",
|
315 |
+
"whereve": "where've",
|
316 |
+
"whod": "who'd",
|
317 |
+
"whod've": "who'd've",
|
318 |
+
"who'dve": "who'd've",
|
319 |
+
"wholl": "who'll",
|
320 |
+
"whos": "who's",
|
321 |
+
"whove": "who've",
|
322 |
+
"whyll": "why'll",
|
323 |
+
"whyre": "why're",
|
324 |
+
"whys": "why's",
|
325 |
+
"wont": "won't",
|
326 |
+
"wouldve": "would've",
|
327 |
+
"wouldnt": "wouldn't",
|
328 |
+
"wouldnt've": "wouldn't've",
|
329 |
+
"wouldn'tve": "wouldn't've",
|
330 |
+
"yall": "y'all",
|
331 |
+
"yall'll": "y'all'll",
|
332 |
+
"y'allll": "y'all'll",
|
333 |
+
"yall'd've": "y'all'd've",
|
334 |
+
"y'alld've": "y'all'd've",
|
335 |
+
"y'all'dve": "y'all'd've",
|
336 |
+
"youd": "you'd",
|
337 |
+
"youd've": "you'd've",
|
338 |
+
"you'dve": "you'd've",
|
339 |
+
"youll": "you'll",
|
340 |
+
"youre": "you're",
|
341 |
+
"youve": "you've",
|
342 |
+
}
|
343 |
+
self.manualMap = {
|
344 |
+
"none": "0",
|
345 |
+
"zero": "0",
|
346 |
+
"one": "1",
|
347 |
+
"two": "2",
|
348 |
+
"three": "3",
|
349 |
+
"four": "4",
|
350 |
+
"five": "5",
|
351 |
+
"six": "6",
|
352 |
+
"seven": "7",
|
353 |
+
"eight": "8",
|
354 |
+
"nine": "9",
|
355 |
+
"ten": "10",
|
356 |
+
}
|
357 |
+
self.articles = ["a", "an", "the"]
|
358 |
+
|
359 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
360 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
361 |
+
self.punct = [
|
362 |
+
";",
|
363 |
+
r"/",
|
364 |
+
"[",
|
365 |
+
"]",
|
366 |
+
'"',
|
367 |
+
"{",
|
368 |
+
"}",
|
369 |
+
"(",
|
370 |
+
")",
|
371 |
+
"=",
|
372 |
+
"+",
|
373 |
+
"\\",
|
374 |
+
"_",
|
375 |
+
"-",
|
376 |
+
">",
|
377 |
+
"<",
|
378 |
+
"@",
|
379 |
+
"`",
|
380 |
+
",",
|
381 |
+
"?",
|
382 |
+
"!",
|
383 |
+
]
|
384 |
+
|
385 |
+
def evaluate(self, quesIds=None):
|
386 |
+
if quesIds == None:
|
387 |
+
quesIds = [quesId for quesId in self.params["question_id"]]
|
388 |
+
gts = {}
|
389 |
+
res = {}
|
390 |
+
for quesId in quesIds:
|
391 |
+
gts[quesId] = self.vqa.qa[quesId]
|
392 |
+
res[quesId] = self.vqaRes.qa[quesId]
|
393 |
+
|
394 |
+
# =================================================
|
395 |
+
# Compute accuracy
|
396 |
+
# =================================================
|
397 |
+
accQA = []
|
398 |
+
accQuesType = {}
|
399 |
+
accAnsType = {}
|
400 |
+
print("computing accuracy")
|
401 |
+
step = 0
|
402 |
+
for quesId in quesIds:
|
403 |
+
for ansDic in gts[quesId]["answers"]:
|
404 |
+
ansDic["answer"] = ansDic["answer"].replace("\n", " ")
|
405 |
+
ansDic["answer"] = ansDic["answer"].replace("\t", " ")
|
406 |
+
ansDic["answer"] = ansDic["answer"].strip()
|
407 |
+
resAns = res[quesId]["answer"]
|
408 |
+
resAns = resAns.replace("\n", " ")
|
409 |
+
resAns = resAns.replace("\t", " ")
|
410 |
+
resAns = resAns.strip()
|
411 |
+
gtAcc = []
|
412 |
+
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
413 |
+
|
414 |
+
if len(set(gtAnswers)) > 1:
|
415 |
+
for ansDic in gts[quesId]["answers"]:
|
416 |
+
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
417 |
+
ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
|
418 |
+
resAns = self.processPunctuation(resAns)
|
419 |
+
resAns = self.processDigitArticle(resAns)
|
420 |
+
|
421 |
+
for gtAnsDatum in gts[quesId]["answers"]:
|
422 |
+
otherGTAns = [
|
423 |
+
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
424 |
+
]
|
425 |
+
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
426 |
+
acc = min(1, float(len(matchingAns)) / 3)
|
427 |
+
gtAcc.append(acc)
|
428 |
+
quesType = gts[quesId]["question_type"]
|
429 |
+
ansType = gts[quesId]["answer_type"]
|
430 |
+
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
431 |
+
accQA.append(avgGTAcc)
|
432 |
+
if quesType not in accQuesType:
|
433 |
+
accQuesType[quesType] = []
|
434 |
+
accQuesType[quesType].append(avgGTAcc)
|
435 |
+
if ansType not in accAnsType:
|
436 |
+
accAnsType[ansType] = []
|
437 |
+
accAnsType[ansType].append(avgGTAcc)
|
438 |
+
self.setEvalQA(quesId, avgGTAcc)
|
439 |
+
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
440 |
+
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
441 |
+
if step % 100 == 0:
|
442 |
+
self.updateProgress(step / float(len(quesIds)))
|
443 |
+
step = step + 1
|
444 |
+
|
445 |
+
self.setAccuracy(accQA, accQuesType, accAnsType)
|
446 |
+
print("Done computing accuracy")
|
447 |
+
|
448 |
+
def processPunctuation(self, inText):
|
449 |
+
outText = inText
|
450 |
+
for p in self.punct:
|
451 |
+
if (p + " " in inText or " " + p in inText) or (
|
452 |
+
re.search(self.commaStrip, inText) != None
|
453 |
+
):
|
454 |
+
outText = outText.replace(p, "")
|
455 |
+
else:
|
456 |
+
outText = outText.replace(p, " ")
|
457 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
458 |
+
return outText
|
459 |
+
|
460 |
+
def processDigitArticle(self, inText):
|
461 |
+
outText = []
|
462 |
+
tempText = inText.lower().split()
|
463 |
+
for word in tempText:
|
464 |
+
word = self.manualMap.setdefault(word, word)
|
465 |
+
if word not in self.articles:
|
466 |
+
outText.append(word)
|
467 |
+
else:
|
468 |
+
pass
|
469 |
+
for wordId, word in enumerate(outText):
|
470 |
+
if word in self.contractions:
|
471 |
+
outText[wordId] = self.contractions[word]
|
472 |
+
outText = " ".join(outText)
|
473 |
+
return outText
|
474 |
+
|
475 |
+
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
476 |
+
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
477 |
+
self.accuracy["perQuestionType"] = {
|
478 |
+
quesType: round(
|
479 |
+
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
480 |
+
self.n,
|
481 |
+
)
|
482 |
+
for quesType in accQuesType
|
483 |
+
}
|
484 |
+
self.accuracy["perAnswerType"] = {
|
485 |
+
ansType: round(
|
486 |
+
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
487 |
+
)
|
488 |
+
for ansType in accAnsType
|
489 |
+
}
|
490 |
+
|
491 |
+
def setEvalQA(self, quesId, acc):
|
492 |
+
self.evalQA[quesId] = round(100 * acc, self.n)
|
493 |
+
|
494 |
+
def setEvalQuesType(self, quesId, quesType, acc):
|
495 |
+
if quesType not in self.evalQuesType:
|
496 |
+
self.evalQuesType[quesType] = {}
|
497 |
+
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
498 |
+
|
499 |
+
def setEvalAnsType(self, quesId, ansType, acc):
|
500 |
+
if ansType not in self.evalAnsType:
|
501 |
+
self.evalAnsType[ansType] = {}
|
502 |
+
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
503 |
+
|
504 |
+
def updateProgress(self, progress):
|
505 |
+
barLength = 20
|
506 |
+
status = ""
|
507 |
+
if isinstance(progress, int):
|
508 |
+
progress = float(progress)
|
509 |
+
if not isinstance(progress, float):
|
510 |
+
progress = 0
|
511 |
+
status = "error: progress var must be float\r\n"
|
512 |
+
if progress < 0:
|
513 |
+
progress = 0
|
514 |
+
status = "Halt...\r\n"
|
515 |
+
if progress >= 1:
|
516 |
+
progress = 1
|
517 |
+
status = "Done...\r\n"
|
518 |
+
block = int(round(barLength * progress))
|
519 |
+
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
520 |
+
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
521 |
+
)
|
522 |
+
sys.stdout.write(text)
|
523 |
+
sys.stdout.flush()
|
524 |
+
|
525 |
+
|
526 |
+
def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path, vqa_dataset):
|
527 |
+
"""Compute the VQA accuracy metric.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
predictions (List): list of predictions
|
531 |
+
ground_truth (List[List]): list of all possible ground truth answers
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
float: VQA accuracy
|
535 |
+
"""
|
536 |
+
# coding: utf-8
|
537 |
+
# dataDir = data_dir
|
538 |
+
|
539 |
+
# set up file names and paths
|
540 |
+
# versionType = 'v2_' # this should be '' when using VQA v2.0 dataset
|
541 |
+
# 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
542 |
+
# taskType = 'OpenEnded'
|
543 |
+
# 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
544 |
+
# dataType = 'mscoco'
|
545 |
+
# dataSubType = 'train2014'
|
546 |
+
# annFile = '%s/%s%s_%s_annotations.json' % (
|
547 |
+
# dataDir, versionType, dataType, dataSubType)
|
548 |
+
# quesFile = '%s/%s%s_%s_%s_questions.json' % (
|
549 |
+
# dataDir, versionType, taskType, dataType, dataSubType)
|
550 |
+
# imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType)
|
551 |
+
# resultType = res_file_name
|
552 |
+
# fileTypes = ['results', 'accuracy',
|
553 |
+
# 'evalQA', 'evalQuesType', 'evalAnsType']
|
554 |
+
|
555 |
+
# An example result json file has been provided in './Results' folder.
|
556 |
+
|
557 |
+
# [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType,
|
558 |
+
# resultType, fileType) for fileType in fileTypes]
|
559 |
+
|
560 |
+
# create vqa object and vqaRes object
|
561 |
+
vqa = VQA(annotation_json_path, question_json_path)
|
562 |
+
vqaRes = vqa.loadRes(result_json_path, question_json_path)
|
563 |
+
|
564 |
+
# create vqaEval object by taking vqa and vqaRes
|
565 |
+
# n is precision of accuracy (number of places after decimal), default is 2
|
566 |
+
vqaEval = VQAEval(vqa, vqaRes, n=2)
|
567 |
+
|
568 |
+
# evaluate results
|
569 |
+
"""
|
570 |
+
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
571 |
+
By default it uses all the question ids in annotation file
|
572 |
+
"""
|
573 |
+
vqaEval.evaluate()
|
574 |
+
|
575 |
+
return vqaEval.accuracy["overall"]
|
576 |
+
|
577 |
+
|
578 |
+
def postprocess_vqa_generation(predictions):
|
579 |
+
return re.split("Question|Answer", predictions, 1)[0]
|
580 |
+
|
581 |
+
|
582 |
+
def compute_gqa_accuracy(results):
|
583 |
+
acc = []
|
584 |
+
vqa_tool = VQAEval()
|
585 |
+
|
586 |
+
for res in results:
|
587 |
+
gt_ans = res["answers"]
|
588 |
+
pred = res["answer"]
|
589 |
+
pred = vqa_tool.processPunctuation(pred)
|
590 |
+
pred = vqa_tool.processDigitArticle(pred)
|
591 |
+
vqa_acc = 1 if pred == gt_ans else 0
|
592 |
+
acc.append(vqa_acc)
|
593 |
+
accuracy = sum(acc) / len(acc)
|
594 |
+
return accuracy
|
open_flamingo/open_flamingo/src/__init__.py
ADDED
File without changes
|
open_flamingo/open_flamingo/src/factory.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .flamingo import Flamingo
|
6 |
+
from .flamingo_lm import FlamingoLMMixin
|
7 |
+
from .utils import extend_instance
|
8 |
+
import logging
|
9 |
+
import random
|
10 |
+
import time
|
11 |
+
|
12 |
+
def create_model_and_transforms(
|
13 |
+
clip_vision_encoder_path: str,
|
14 |
+
clip_vision_encoder_pretrained: str,
|
15 |
+
lang_encoder_path: str,
|
16 |
+
tokenizer_path: str,
|
17 |
+
use_local_files: bool = False,
|
18 |
+
decoder_layers_attr_name: str = None,
|
19 |
+
location_token_num: int = 1000,
|
20 |
+
checkpoint_activations: bool = False,
|
21 |
+
freeze_vision_encoder: bool = False,
|
22 |
+
add_visual_grounding: bool = False,
|
23 |
+
lora: bool = False,
|
24 |
+
lora_r: int = 16,
|
25 |
+
fix_ffn: bool = False,
|
26 |
+
add_visual_token: bool = False,
|
27 |
+
use_format_v2: bool = False,
|
28 |
+
use_sam: str = None,
|
29 |
+
**flamingo_kwargs,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
|
33 |
+
Appends special tokens to the tokenizer and freezes backbones.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
|
37 |
+
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
|
38 |
+
lang_encoder_path (str): path to pretrained language encoder
|
39 |
+
tokenizer_path (str): path to pretrained tokenizer
|
40 |
+
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
|
41 |
+
use_local_files (bool, optional): whether to use local files. Defaults to False.
|
42 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
43 |
+
Returns:
|
44 |
+
Flamingo: Flamingo model from pretrained vision and language encoders
|
45 |
+
Image processor: Pipeline to preprocess input images
|
46 |
+
Tokenizer: A tokenizer for the language model
|
47 |
+
"""
|
48 |
+
if use_sam is None:
|
49 |
+
no_success = True
|
50 |
+
while no_success:
|
51 |
+
try:
|
52 |
+
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
|
53 |
+
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
|
54 |
+
)
|
55 |
+
no_success = False
|
56 |
+
except:
|
57 |
+
logging.info("retry creating vision_encoder")
|
58 |
+
time.sleep(random.random() * 5)
|
59 |
+
|
60 |
+
# set the vision encoder to output the visual features
|
61 |
+
vision_encoder.visual.output_tokens = True
|
62 |
+
# delete text encoder part
|
63 |
+
del vision_encoder.transformer
|
64 |
+
del vision_encoder.text_projection
|
65 |
+
del vision_encoder.token_embedding
|
66 |
+
del vision_encoder.ln_final
|
67 |
+
del vision_encoder.positional_embedding
|
68 |
+
del vision_encoder.logit_scale
|
69 |
+
vision_encoder.visual.proj = None
|
70 |
+
vision_encoder.visual.ln_post = torch.nn.Identity()
|
71 |
+
else:
|
72 |
+
from segment_anything import SamPredictor, sam_model_registry
|
73 |
+
assert use_sam == "vit_l"
|
74 |
+
sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth")
|
75 |
+
del sam.prompt_encoder
|
76 |
+
del sam.mask_decoder
|
77 |
+
sam.image_encoder.neck = torch.nn.Identity()
|
78 |
+
vision_encoder = sam.image_encoder
|
79 |
+
from open_clip.transform import image_transform
|
80 |
+
image_processor = image_transform(
|
81 |
+
256,
|
82 |
+
is_train=False,
|
83 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
84 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
85 |
+
)
|
86 |
+
|
87 |
+
if "llama" in tokenizer_path:
|
88 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
89 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(
|
90 |
+
tokenizer_path, local_files_only=use_local_files
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
94 |
+
tokenizer_path, local_files_only=use_local_files
|
95 |
+
)
|
96 |
+
# add location special tokens to the tokenizer
|
97 |
+
location_token = ["<|#obj#|>", "<|#endofobj#|>"]
|
98 |
+
if use_format_v2:
|
99 |
+
location_token.append("<|#loc#|>")
|
100 |
+
location_token.append("<|#endofloc#|>")
|
101 |
+
for i in range(location_token_num):
|
102 |
+
location_token.append(f"<loc_{i}>")
|
103 |
+
# add Flamingo special tokens to the tokenizer
|
104 |
+
additional_special_tokens = ["<|endofchunk|>", "<|#image#|>", "<|#endofimage#|>"]
|
105 |
+
if add_visual_grounding:
|
106 |
+
additional_special_tokens += location_token
|
107 |
+
if add_visual_token:
|
108 |
+
additional_special_tokens += ["<|#visual#|>"]
|
109 |
+
text_tokenizer.add_special_tokens(
|
110 |
+
{"additional_special_tokens": additional_special_tokens}
|
111 |
+
)
|
112 |
+
if text_tokenizer.pad_token is None:
|
113 |
+
# Issue: GPT models don't have a pad token, which we use to
|
114 |
+
# modify labels for the loss.
|
115 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
116 |
+
if "llama" in lang_encoder_path:
|
117 |
+
lang_encoder = LlamaForCausalLM.from_pretrained(
|
118 |
+
lang_encoder_path, local_files_only=use_local_files
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
# lang_encoder = AutoModelForCausalLM.from_pretrained(
|
122 |
+
# lang_encoder_path, local_files_only=use_local_files
|
123 |
+
# )
|
124 |
+
config_lang = AutoConfig.from_pretrained(lang_encoder_path)
|
125 |
+
lang_encoder = AutoModelForCausalLM.from_config(config_lang)
|
126 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
127 |
+
|
128 |
+
if decoder_layers_attr_name is None:
|
129 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
130 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
131 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
132 |
+
lang_encoder_name = lang_encoder.__class__.__name__.lower()
|
133 |
+
if checkpoint_activations:
|
134 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
135 |
+
if use_sam is None:
|
136 |
+
for i in range(len(vision_encoder.visual.transformer.resblocks)):
|
137 |
+
vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper(
|
138 |
+
vision_encoder.visual.transformer.resblocks[i],
|
139 |
+
offload_to_cpu=False,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
for i in range(len(vision_encoder.blocks)):
|
143 |
+
vision_encoder.blocks[i] = checkpoint_wrapper(
|
144 |
+
vision_encoder.blocks[i],
|
145 |
+
offload_to_cpu=False,
|
146 |
+
)
|
147 |
+
if "opt" in lang_encoder_name:
|
148 |
+
for i in range(len(lang_encoder.model.decoder.layers)):
|
149 |
+
lang_encoder.model.decoder.layers[i] = checkpoint_wrapper(
|
150 |
+
lang_encoder.model.decoder.layers[i],
|
151 |
+
offload_to_cpu=False,
|
152 |
+
)
|
153 |
+
elif "codegen" in lang_encoder_name:
|
154 |
+
for i in range(len(lang_encoder.transformer.h)):
|
155 |
+
lang_encoder.transformer.h[i] = checkpoint_wrapper(
|
156 |
+
lang_encoder.transformer.h[i],
|
157 |
+
offload_to_cpu=False,
|
158 |
+
)
|
159 |
+
elif "llama" in lang_encoder_name:
|
160 |
+
for i in range(len(lang_encoder.model.layers)):
|
161 |
+
lang_encoder.model.layers[i] = checkpoint_wrapper(
|
162 |
+
lang_encoder.model.layers[i],
|
163 |
+
offload_to_cpu=False,
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
raise ValueError(f"unknown model {lang_encoder_name}")
|
167 |
+
if use_sam is None:
|
168 |
+
vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"]
|
169 |
+
image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"]
|
170 |
+
patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"]
|
171 |
+
else:
|
172 |
+
vis_dim = 1024
|
173 |
+
image_size = 256
|
174 |
+
patch_size = 16
|
175 |
+
assert image_size % patch_size == 0
|
176 |
+
vis_embed_size = (image_size // patch_size) ** 2
|
177 |
+
|
178 |
+
if lora:
|
179 |
+
raise NotImplementedError
|
180 |
+
from peft import LoraConfig, TaskType
|
181 |
+
from peft import get_peft_model
|
182 |
+
if "codegen" in lang_encoder_name:
|
183 |
+
lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
|
184 |
+
elif "opt" in lang_encoder_name:
|
185 |
+
lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj", "fc1", "fc2"]
|
186 |
+
elif "llama" in lang_encoder_name:
|
187 |
+
lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
|
188 |
+
else:
|
189 |
+
raise NotImplementedError
|
190 |
+
lang_peft_config = LoraConfig(
|
191 |
+
task_type="CAUSAL_LM",
|
192 |
+
r=lora_r, lora_alpha=lora_r,
|
193 |
+
target_modules=lang_target_modules,
|
194 |
+
lora_dropout=0.05, bias="none",
|
195 |
+
)
|
196 |
+
lang_encoder = get_peft_model(lang_encoder, lang_peft_config)
|
197 |
+
if "codegen" in lang_encoder_name:
|
198 |
+
raise NotImplementedError
|
199 |
+
elif "opt" in lang_encoder_name:
|
200 |
+
# lang_encoder.base_model.model.lm_head.requires_grad_(True)
|
201 |
+
# lang_encoder.base_model.model.model.decoder.embed_tokens.requires_grad_(True)
|
202 |
+
# lang_encoder.base_model.model.model.decoder.embed_positions.requires_grad_(True)
|
203 |
+
def activate_grad(m):
|
204 |
+
if not hasattr(m, "lora_A") and hasattr(m, "weight") and not m.weight.requires_grad:
|
205 |
+
print(m, "reactivate grad")
|
206 |
+
m.requires_grad_(True)
|
207 |
+
lang_encoder.base_model.model.apply(activate_grad)
|
208 |
+
elif "llama" in lang_encoder_name:
|
209 |
+
def activate_grad(m):
|
210 |
+
if not hasattr(m, "lora_A") and hasattr(m, "weight") and not m.weight.requires_grad:
|
211 |
+
print(m, "reactivate grad")
|
212 |
+
m.requires_grad_(True)
|
213 |
+
lang_encoder.base_model.model.apply(activate_grad)
|
214 |
+
else:
|
215 |
+
raise NotImplementedError
|
216 |
+
lang_encoder.print_trainable_parameters()
|
217 |
+
|
218 |
+
if fix_ffn:
|
219 |
+
if "opt" in lang_encoder_name:
|
220 |
+
for i in range(len(lang_encoder.model.decoder.layers)):
|
221 |
+
lang_encoder.model.decoder.layers[i].requires_grad_(False)
|
222 |
+
lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True)
|
223 |
+
else:
|
224 |
+
raise NotImplementedError
|
225 |
+
|
226 |
+
loc_token_ids = []
|
227 |
+
for i in range(location_token_num):
|
228 |
+
loc_token_ids.append(int(text_tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
229 |
+
min_loc_token_id = min(loc_token_ids)
|
230 |
+
max_loc_token_id = max(loc_token_ids)
|
231 |
+
lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size)
|
232 |
+
model = Flamingo(
|
233 |
+
vision_encoder=vision_encoder,
|
234 |
+
lang_encoder=lang_encoder,
|
235 |
+
eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1],
|
236 |
+
media_token_id=text_tokenizer.encode("<|#image#|>")[-1],
|
237 |
+
image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1],
|
238 |
+
visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None,
|
239 |
+
min_loc_token_id=min_loc_token_id,
|
240 |
+
vis_dim=vis_dim,
|
241 |
+
vis_embed_size=vis_embed_size,
|
242 |
+
lang_dim=lang_dim,
|
243 |
+
add_visual_token=add_visual_token,
|
244 |
+
**flamingo_kwargs,
|
245 |
+
)
|
246 |
+
|
247 |
+
if freeze_vision_encoder:
|
248 |
+
print("freeze vision encoder")
|
249 |
+
model.vision_encoder.requires_grad_(False)
|
250 |
+
|
251 |
+
print(
|
252 |
+
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
|
253 |
+
)
|
254 |
+
|
255 |
+
return model, image_processor, text_tokenizer, vis_embed_size
|
256 |
+
|
257 |
+
|
258 |
+
def _infer_decoder_layers_attr_name(model):
|
259 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
260 |
+
if k.lower() in model.__class__.__name__.lower():
|
261 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
262 |
+
|
263 |
+
raise ValueError(
|
264 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
269 |
+
"opt": "model.decoder.layers",
|
270 |
+
"gptneo": "transformer.h",
|
271 |
+
"gptj": "transformer.h",
|
272 |
+
"gpt-j": "transformer.h",
|
273 |
+
"pythia": "gpt_neox.layers",
|
274 |
+
"llama": "model.layers",
|
275 |
+
"llamaforcausallm": "model.layers",
|
276 |
+
"gpt2": "transformer.h",
|
277 |
+
"codegen": "transformer.h",
|
278 |
+
}
|
open_flamingo/open_flamingo/src/flamingo.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class Flamingo(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
vision_encoder: nn.Module,
|
9 |
+
lang_encoder: nn.Module,
|
10 |
+
eoc_token_id: int,
|
11 |
+
media_token_id: int,
|
12 |
+
image_end_token_id: int,
|
13 |
+
visual_token_id: int,
|
14 |
+
min_loc_token_id: int,
|
15 |
+
vis_dim: int,
|
16 |
+
vis_embed_size: int,
|
17 |
+
lang_dim: int,
|
18 |
+
use_media_placement_augmentation: bool = False,
|
19 |
+
add_visual_token: bool = False,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
vision_encoder (nn.Module): HF CLIPModel
|
24 |
+
lang_encoder (nn.Module): HF causal language model
|
25 |
+
eoc_token_id (int): Token id for <|endofchunk|>
|
26 |
+
media_token_id (int): Token id for <|#image#|>
|
27 |
+
vis_dim (int): Dimension of the visual features.
|
28 |
+
Visual features are projected to match this shape along the last dimension.
|
29 |
+
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
|
30 |
+
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.image_end_token_id = image_end_token_id
|
34 |
+
self.eoc_token_id = eoc_token_id
|
35 |
+
self.media_token_id = media_token_id
|
36 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
37 |
+
self.vis_dim = vis_dim
|
38 |
+
self.lang_dim = lang_dim
|
39 |
+
self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim)
|
40 |
+
self.vision_encoder = vision_encoder
|
41 |
+
self.num_positions = vis_embed_size
|
42 |
+
self.lang_encoder = lang_encoder
|
43 |
+
self.lang_encoder.init_flamingo(
|
44 |
+
media_token_id=media_token_id,
|
45 |
+
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
46 |
+
)
|
47 |
+
first_layer = self.lang_encoder._get_decoder_layers()[0]
|
48 |
+
first_layer.add_visual_token = add_visual_token
|
49 |
+
first_layer.visual_token_id = visual_token_id
|
50 |
+
first_layer.media_token_id = media_token_id
|
51 |
+
first_layer.min_loc_token_id = min_loc_token_id
|
52 |
+
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
vision_x: torch.Tensor,
|
57 |
+
lang_x: torch.Tensor,
|
58 |
+
attention_mask: torch.Tensor = None,
|
59 |
+
labels: torch.Tensor = None,
|
60 |
+
use_cached_vision_x: bool = False,
|
61 |
+
clear_conditioned_layers: bool = True,
|
62 |
+
past_key_values=None,
|
63 |
+
use_cache: bool = False,
|
64 |
+
image_nums=None,
|
65 |
+
image_start_index_list=None,
|
66 |
+
added_bbox_list=None,
|
67 |
+
added_visual_token_idx_list=None,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Forward pass of Flamingo.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
vision_x (torch.Tensor): Vision input
|
74 |
+
shape (B, T_img, F, C, H, W) with F=1
|
75 |
+
lang_x (torch.Tensor): Language input ids
|
76 |
+
shape (B, T_txt)
|
77 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
78 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
79 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
80 |
+
once the foward pass is completed. Set this to false if the
|
81 |
+
same set of images will be reused in another subsequent
|
82 |
+
forward pass.
|
83 |
+
past_key_values: pre-computed values to pass to language model.
|
84 |
+
See past_key_values documentation in Hugging Face
|
85 |
+
CausalLM models.
|
86 |
+
use_cache: whether to use cached key values. See use_cache
|
87 |
+
documentation in Hugging Face CausalLM models.
|
88 |
+
"""
|
89 |
+
|
90 |
+
if use_cached_vision_x:
|
91 |
+
# Case: use cached; vision_x should be cached and other
|
92 |
+
# vision-related inputs should not be provided.
|
93 |
+
assert (
|
94 |
+
vision_x is None
|
95 |
+
), "Expect vision_x to be None when use_cached_vision_x is True."
|
96 |
+
assert self.lang_encoder.is_conditioned()
|
97 |
+
|
98 |
+
else:
|
99 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
100 |
+
self._encode_vision_x(
|
101 |
+
vision_x=vision_x,
|
102 |
+
image_nums=image_nums,
|
103 |
+
image_start_index_list=image_start_index_list,
|
104 |
+
added_bbox_list=added_bbox_list,
|
105 |
+
added_visual_token_idx_list=added_visual_token_idx_list,
|
106 |
+
)
|
107 |
+
output = self.lang_encoder(
|
108 |
+
input_ids=lang_x,
|
109 |
+
attention_mask=attention_mask,
|
110 |
+
labels=labels,
|
111 |
+
past_key_values=past_key_values,
|
112 |
+
use_cache=use_cache,
|
113 |
+
)
|
114 |
+
if vision_x is None:
|
115 |
+
output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean()
|
116 |
+
|
117 |
+
if clear_conditioned_layers:
|
118 |
+
self.lang_encoder.clear_conditioned_layers()
|
119 |
+
|
120 |
+
return output
|
121 |
+
|
122 |
+
def generate(
|
123 |
+
self,
|
124 |
+
vision_x: torch.Tensor,
|
125 |
+
lang_x: torch.Tensor,
|
126 |
+
attention_mask: torch.Tensor = None,
|
127 |
+
num_beams=1,
|
128 |
+
max_new_tokens=None,
|
129 |
+
temperature=1.0,
|
130 |
+
top_k=0,
|
131 |
+
top_p=1.0,
|
132 |
+
no_repeat_ngram_size=0,
|
133 |
+
prefix_allowed_tokens_fn=None,
|
134 |
+
length_penalty=1.0,
|
135 |
+
num_return_sequences=1,
|
136 |
+
do_sample=False,
|
137 |
+
early_stopping=False,
|
138 |
+
bad_words_ids=None,
|
139 |
+
force_words_ids=None,
|
140 |
+
image_start_index_list=None,
|
141 |
+
image_nums=None,
|
142 |
+
min_length=None,
|
143 |
+
output_scores=False,
|
144 |
+
add_grounding_to_prompt=False,
|
145 |
+
):
|
146 |
+
"""
|
147 |
+
Generate text conditioned on vision and language inputs.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
vision_x (torch.Tensor): Vision input
|
151 |
+
shape (B, T_img, F, C, H, W)
|
152 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
153 |
+
currently only F=1 is supported (single-frame videos)
|
154 |
+
lang_x (torch.Tensor): Language input
|
155 |
+
shape (B, T_txt)
|
156 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
157 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
158 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
159 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
160 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
161 |
+
top_k (int, optional): Top k. Defaults to 0.
|
162 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
163 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
164 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
165 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
166 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
167 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
168 |
+
Returns:
|
169 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
170 |
+
"""
|
171 |
+
if num_beams > 1:
|
172 |
+
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
173 |
+
image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist()
|
174 |
+
image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist()
|
175 |
+
|
176 |
+
self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams)
|
177 |
+
|
178 |
+
if add_grounding_to_prompt:
|
179 |
+
# raise NotImplementedError
|
180 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
181 |
+
class GroundingLogits(LogitsProcessor):
|
182 |
+
def __call__(self, input_ids, scores):
|
183 |
+
print("<|#loc#|> token (id: 50270) score:", scores[0, 50270].item(), round(scores[0, 50270].item()/scores.max().item(), 2))
|
184 |
+
print("max min mean:", scores.max().item(), scores.min().item(), scores.mean().item())
|
185 |
+
print("max prob token id:", scores[0].argmax().item())
|
186 |
+
# print(input_ids.shape, scores[0, 50270].item(), scores.max(), scores.min(), scores.mean(), scores[0].argmax())
|
187 |
+
return scores
|
188 |
+
logits_processor = LogitsProcessorList([GroundingLogits()])
|
189 |
+
else:
|
190 |
+
logits_processor = None
|
191 |
+
|
192 |
+
output = self.lang_encoder.generate(
|
193 |
+
input_ids=lang_x,
|
194 |
+
attention_mask=attention_mask,
|
195 |
+
eos_token_id=self.eoc_token_id,
|
196 |
+
num_beams=num_beams,
|
197 |
+
max_new_tokens=max_new_tokens,
|
198 |
+
min_length=min_length,
|
199 |
+
length_penalty=length_penalty,
|
200 |
+
return_dict_in_generate=output_scores,
|
201 |
+
output_scores=output_scores,
|
202 |
+
bad_words_ids=bad_words_ids,
|
203 |
+
logits_processor=logits_processor,
|
204 |
+
)
|
205 |
+
self.lang_encoder.clear_conditioned_layers()
|
206 |
+
return output
|
207 |
+
|
208 |
+
def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, added_visual_token_idx_list=None, num_beams=None):
|
209 |
+
"""
|
210 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
211 |
+
Args:
|
212 |
+
vision_x (torch.Tensor): Vision input
|
213 |
+
shape (B, T_img, F, C, H, W)
|
214 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
215 |
+
Currently only F=1 is supported (single-frame videos)
|
216 |
+
|
217 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
218 |
+
"""
|
219 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
220 |
+
b, T, F = vision_x.shape[:3]
|
221 |
+
assert F == 1, "Only single frame supported"
|
222 |
+
|
223 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
224 |
+
if hasattr(self.vision_encoder, "visual"):
|
225 |
+
vision_x = self.vision_encoder.visual(vision_x)[1]
|
226 |
+
else:
|
227 |
+
vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1)
|
228 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
229 |
+
|
230 |
+
vision_x = vision_x.mean(2)
|
231 |
+
# vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
232 |
+
# vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0)
|
233 |
+
vision_x = self.vis_proj(vision_x).squeeze(1)
|
234 |
+
|
235 |
+
first_layer = self.lang_encoder._get_decoder_layers()[0]
|
236 |
+
first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, added_bbox_list, added_visual_token_idx_list, num_beams=num_beams)
|
open_flamingo/open_flamingo/src/flamingo_lm.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .helpers import GatedCrossAttentionBlock
|
7 |
+
from .utils import getattr_recursive, setattr_recursive
|
8 |
+
|
9 |
+
|
10 |
+
class FlamingoLayer(nn.Module):
|
11 |
+
def __init__(self, decoder_layer):
|
12 |
+
super().__init__()
|
13 |
+
self.decoder_layer = decoder_layer
|
14 |
+
self.vis_x = None
|
15 |
+
self.image_nums = None
|
16 |
+
self.image_start_index_list = None
|
17 |
+
self.media_locations = None
|
18 |
+
self.add_visual_token = False
|
19 |
+
self.input_ids = None
|
20 |
+
|
21 |
+
def is_conditioned(self) -> bool:
|
22 |
+
"""Check whether the layer is conditioned."""
|
23 |
+
return self.vis_x is not None
|
24 |
+
|
25 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
26 |
+
def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, added_bbox_list=None, added_visual_token_idx_list=None, num_beams=None):
|
27 |
+
self.vis_x = vis_x
|
28 |
+
self.image_nums = image_nums
|
29 |
+
self.image_start_index_list = image_start_index_list
|
30 |
+
self.added_bbox_list = added_bbox_list
|
31 |
+
self.added_visual_token_idx_list = added_visual_token_idx_list
|
32 |
+
self.num_beams = num_beams
|
33 |
+
self.input_ids = None
|
34 |
+
|
35 |
+
def condition_media_locations(self, media_locations):
|
36 |
+
self.media_locations = media_locations
|
37 |
+
|
38 |
+
def condition_attend_previous(self, attend_previous):
|
39 |
+
self.attend_previous = attend_previous
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
hidden_states, # alignment with hugging face name
|
44 |
+
attention_mask=None,
|
45 |
+
**decoder_layer_kwargs,
|
46 |
+
):
|
47 |
+
if self.media_locations is None:
|
48 |
+
raise ValueError("media_locations must be conditioned before forward pass")
|
49 |
+
|
50 |
+
if self.vis_x is not None:
|
51 |
+
if self.training:
|
52 |
+
single_length = self.vis_x.shape[-2]
|
53 |
+
image_nums = self.image_nums
|
54 |
+
image_start_index_list = self.image_start_index_list
|
55 |
+
image_nums = [0] + np.cumsum(image_nums).tolist()
|
56 |
+
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
|
57 |
+
for index in start_indices:
|
58 |
+
if image_num_begin < image_num_end:
|
59 |
+
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
|
60 |
+
image_num_begin += 1
|
61 |
+
if self.add_visual_token:
|
62 |
+
visual_token_position = (decoder_layer_kwargs["input_ids"] == self.visual_token_id).nonzero()
|
63 |
+
input_ids = decoder_layer_kwargs["input_ids"]
|
64 |
+
prev_batch_idx = -1
|
65 |
+
media_idx = []
|
66 |
+
cnt = 0
|
67 |
+
for batch_idx, idx in visual_token_position:
|
68 |
+
batch_idx = batch_idx.item()
|
69 |
+
idx = idx.item()
|
70 |
+
if batch_idx != prev_batch_idx:
|
71 |
+
prev_batch_idx = batch_idx
|
72 |
+
this_input_ids = input_ids[batch_idx]
|
73 |
+
cnt += len(media_idx)
|
74 |
+
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
|
75 |
+
for i in range(len(media_idx)):
|
76 |
+
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
|
77 |
+
break
|
78 |
+
image_index = cnt + i
|
79 |
+
size = int(self.vis_x[image_index].shape[0] ** 0.5)
|
80 |
+
image_feature = self.vis_x[image_index].reshape(size, size, -1)
|
81 |
+
bbox = this_input_ids[idx-4:idx] - self.min_loc_token_id
|
82 |
+
region_xyxy = (bbox / 1000 * size)
|
83 |
+
try:
|
84 |
+
x1, y1, x2, y2 = region_xyxy.long().tolist()
|
85 |
+
except:
|
86 |
+
print(region_xyxy.long().tolist(), media_idx, prev_batch_idx, batch_idx, idx, cnt, image_index)
|
87 |
+
raise ValueError("something wrong")
|
88 |
+
# print("region_xyxy", region_xyxy.long().tolist(), "media_idx", media_idx, "batch_idx", batch_idx, "idx", idx, "cnt", cnt, "image_index", image_index)
|
89 |
+
visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
|
90 |
+
hidden_states[batch_idx, idx] = visual_token
|
91 |
+
|
92 |
+
elif not self.training:
|
93 |
+
if self.add_visual_token:
|
94 |
+
if self.input_ids is None:
|
95 |
+
self.input_ids = decoder_layer_kwargs["input_ids"]
|
96 |
+
else:
|
97 |
+
self.input_ids = torch.cat([self.input_ids, decoder_layer_kwargs["input_ids"]], dim=-1)
|
98 |
+
visual_token_position = (self.input_ids[..., -1] == self.visual_token_id).nonzero().reshape(-1)
|
99 |
+
# print("input_ids:", self.input_ids.shape)
|
100 |
+
if len(visual_token_position) != 0:
|
101 |
+
for batch_idx in visual_token_position:
|
102 |
+
batch_idx = batch_idx.item()
|
103 |
+
image_index = batch_idx
|
104 |
+
this_input_ids = self.input_ids[batch_idx]
|
105 |
+
size = int(self.vis_x[image_index].shape[0] ** 0.5)
|
106 |
+
image_feature = self.vis_x[image_index].reshape(size, size, -1)
|
107 |
+
bbox = this_input_ids[-5:-1] - self.min_loc_token_id
|
108 |
+
region_xyxy = (bbox / 1000 * size)
|
109 |
+
x1, y1, x2, y2 = region_xyxy.long().tolist()
|
110 |
+
# print("region_xyxy", region_xyxy.long().tolist(), "batch_idx", batch_idx, "image_index", image_index)
|
111 |
+
visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
|
112 |
+
hidden_states[batch_idx, 0] = visual_token
|
113 |
+
|
114 |
+
|
115 |
+
if not decoder_layer_kwargs["use_cache"]:
|
116 |
+
raise NotImplementedError
|
117 |
+
if (
|
118 |
+
("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
|
119 |
+
("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
|
120 |
+
):
|
121 |
+
single_length = self.vis_x.shape[-2]
|
122 |
+
image_nums = self.image_nums
|
123 |
+
image_start_index_list = self.image_start_index_list
|
124 |
+
image_nums = [0] + np.cumsum(image_nums).tolist()
|
125 |
+
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
|
126 |
+
for index in start_indices:
|
127 |
+
if image_num_begin < image_num_end:
|
128 |
+
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
|
129 |
+
image_num_begin += 1
|
130 |
+
hidden_states = self.decoder_layer(
|
131 |
+
hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
|
132 |
+
)
|
133 |
+
return hidden_states
|
134 |
+
|
135 |
+
|
136 |
+
class FlamingoLMMixin(nn.Module):
|
137 |
+
"""
|
138 |
+
Mixin to add cross-attention layers to a language model.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
142 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
143 |
+
|
144 |
+
def _get_decoder_layers(self):
|
145 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
146 |
+
|
147 |
+
def _set_decoder_layers(self, value):
|
148 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
149 |
+
|
150 |
+
def init_flamingo(
|
151 |
+
self,
|
152 |
+
media_token_id,
|
153 |
+
use_media_placement_augmentation,
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
157 |
+
"""
|
158 |
+
self._set_decoder_layers(
|
159 |
+
nn.ModuleList(
|
160 |
+
[FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
|
161 |
+
)
|
162 |
+
)
|
163 |
+
self.media_token_id = media_token_id
|
164 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
165 |
+
self.initialized_flamingo = True
|
166 |
+
|
167 |
+
def forward(self, *input, **kwargs):
|
168 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
169 |
+
if not self.initialized_flamingo:
|
170 |
+
raise ValueError(
|
171 |
+
"Flamingo layers are not initialized. Please call `init_flamingo` first."
|
172 |
+
)
|
173 |
+
|
174 |
+
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
|
175 |
+
media_locations = input_ids == self.media_token_id
|
176 |
+
attend_previous = (
|
177 |
+
(random.random() < 0.5) if self.use_media_placement_augmentation else True
|
178 |
+
)
|
179 |
+
|
180 |
+
if (
|
181 |
+
"gpt2" in self.__class__.__name__.lower()
|
182 |
+
or "codegen" in self.__class__.__name__.lower()
|
183 |
+
):
|
184 |
+
for layer in self.transformer.h:
|
185 |
+
layer.condition_media_locations(media_locations)
|
186 |
+
layer.condition_attend_previous(attend_previous)
|
187 |
+
else:
|
188 |
+
for layer in self.get_decoder().layers:
|
189 |
+
layer.condition_media_locations(media_locations)
|
190 |
+
layer.condition_attend_previous(attend_previous)
|
191 |
+
return super().forward(
|
192 |
+
*input, **kwargs
|
193 |
+
) # Call the other parent's forward method
|
194 |
+
|
195 |
+
def is_conditioned(self) -> bool:
|
196 |
+
"""Check whether all decoder layers are already conditioned."""
|
197 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
198 |
+
|
199 |
+
def clear_conditioned_layers(self):
|
200 |
+
for layer in self._get_decoder_layers():
|
201 |
+
layer.condition_vis_x(None)
|
202 |
+
layer.condition_media_locations(None)
|
203 |
+
layer.condition_attend_previous(None)
|
open_flamingo/open_flamingo/src/helpers.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from https://github.com/lucidrains/flamingo-pytorch
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from einops_exts import rearrange_many
|
8 |
+
from torch import einsum, nn
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def FeedForward(dim, mult=4):
|
16 |
+
inner_dim = int(dim * mult)
|
17 |
+
return nn.Sequential(
|
18 |
+
nn.LayerNorm(dim),
|
19 |
+
nn.Linear(dim, inner_dim, bias=False),
|
20 |
+
nn.GELU(),
|
21 |
+
nn.Linear(inner_dim, dim, bias=False),
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class PerceiverAttention(nn.Module):
|
26 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
27 |
+
super().__init__()
|
28 |
+
self.scale = dim_head**-0.5
|
29 |
+
self.heads = heads
|
30 |
+
inner_dim = dim_head * heads
|
31 |
+
|
32 |
+
self.norm_media = nn.LayerNorm(dim)
|
33 |
+
self.norm_latents = nn.LayerNorm(dim)
|
34 |
+
|
35 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
36 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
37 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
38 |
+
|
39 |
+
def forward(self, x, latents):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
x (torch.Tensor): image features
|
43 |
+
shape (b, T, n1, D)
|
44 |
+
latent (torch.Tensor): latent features
|
45 |
+
shape (b, T, n2, D)
|
46 |
+
"""
|
47 |
+
x = self.norm_media(x)
|
48 |
+
latents = self.norm_latents(latents)
|
49 |
+
|
50 |
+
h = self.heads
|
51 |
+
|
52 |
+
q = self.to_q(latents)
|
53 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
54 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
55 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
56 |
+
q = q * self.scale
|
57 |
+
|
58 |
+
# attention
|
59 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
60 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
61 |
+
attn = sim.softmax(dim=-1)
|
62 |
+
|
63 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
64 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
65 |
+
return self.to_out(out)
|
66 |
+
|
67 |
+
|
68 |
+
class PerceiverResampler(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
*,
|
72 |
+
dim,
|
73 |
+
depth=6,
|
74 |
+
dim_head=64,
|
75 |
+
heads=8,
|
76 |
+
num_latents=64,
|
77 |
+
max_num_media=None,
|
78 |
+
max_num_frames=None,
|
79 |
+
ff_mult=4,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
assert False, "Do not use PerceiverResampler"
|
83 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
84 |
+
self.frame_embs = (
|
85 |
+
nn.Parameter(torch.randn(max_num_frames, dim))
|
86 |
+
if exists(max_num_frames)
|
87 |
+
else None
|
88 |
+
)
|
89 |
+
self.media_time_embs = (
|
90 |
+
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
91 |
+
if exists(max_num_media)
|
92 |
+
else None
|
93 |
+
)
|
94 |
+
|
95 |
+
self.layers = nn.ModuleList([])
|
96 |
+
for _ in range(depth):
|
97 |
+
self.layers.append(
|
98 |
+
nn.ModuleList(
|
99 |
+
[
|
100 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
101 |
+
FeedForward(dim=dim, mult=ff_mult),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.norm = nn.LayerNorm(dim)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
x (torch.Tensor): image features
|
112 |
+
shape (b, T, F, v, D)
|
113 |
+
Returns:
|
114 |
+
shape (b, T, n, D) where n is self.num_latents
|
115 |
+
"""
|
116 |
+
b, T, F, v = x.shape[:4]
|
117 |
+
|
118 |
+
# frame and media time embeddings
|
119 |
+
if exists(self.frame_embs):
|
120 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
121 |
+
x = x + frame_embs
|
122 |
+
x = rearrange(
|
123 |
+
x, "b T F v d -> b T (F v) d"
|
124 |
+
) # flatten the frame and spatial dimensions
|
125 |
+
if exists(self.media_time_embs):
|
126 |
+
x = x + self.media_time_embs[:T]
|
127 |
+
|
128 |
+
# blocks
|
129 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
130 |
+
for attn, ff in self.layers:
|
131 |
+
latents = attn(x, latents) + latents
|
132 |
+
latents = ff(latents) + latents
|
133 |
+
return self.norm(latents)
|
134 |
+
|
135 |
+
|
136 |
+
# gated cross attention
|
137 |
+
|
138 |
+
|
139 |
+
class MaskedCrossAttention(nn.Module):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
*,
|
143 |
+
dim,
|
144 |
+
dim_visual,
|
145 |
+
dim_head=64,
|
146 |
+
heads=8,
|
147 |
+
only_attend_immediate_media=True,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.scale = dim_head**-0.5
|
151 |
+
self.heads = heads
|
152 |
+
inner_dim = dim_head * heads
|
153 |
+
|
154 |
+
self.norm = nn.LayerNorm(dim)
|
155 |
+
|
156 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
157 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
158 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
159 |
+
|
160 |
+
# whether for text to only attend to immediate preceding image, or all previous images
|
161 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
162 |
+
|
163 |
+
def forward(self, x, media, media_locations=None, attend_previous=True):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
x (torch.Tensor): text features
|
167 |
+
shape (B, T_txt, D_txt)
|
168 |
+
media (torch.Tensor): image features
|
169 |
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
170 |
+
media_locations: boolean mask identifying the media tokens in x
|
171 |
+
shape (B, T_txt)
|
172 |
+
attend_previous: bool
|
173 |
+
If false, ignores immediately preceding image and starts attending when following image
|
174 |
+
"""
|
175 |
+
assert attend_previous, "text must attend to the image that before it"
|
176 |
+
|
177 |
+
_, T_img, n = media.shape[:3]
|
178 |
+
h = self.heads
|
179 |
+
|
180 |
+
x = self.norm(x)
|
181 |
+
|
182 |
+
q = self.to_q(x)
|
183 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
184 |
+
|
185 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
186 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
187 |
+
|
188 |
+
q = q * self.scale
|
189 |
+
|
190 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
191 |
+
|
192 |
+
if exists(media_locations):
|
193 |
+
# at each boolean of True, increment the time counter (relative to media time)
|
194 |
+
text_time = media_locations.cumsum(dim=-1)
|
195 |
+
media_time = torch.arange(T_img, device=x.device) + 1
|
196 |
+
|
197 |
+
if not attend_previous:
|
198 |
+
text_time[~media_locations] += 1
|
199 |
+
# make sure max is still the number of images in the sequence
|
200 |
+
text_time[
|
201 |
+
text_time
|
202 |
+
> repeat(
|
203 |
+
torch.count_nonzero(media_locations, dim=1),
|
204 |
+
"b -> b i",
|
205 |
+
i=text_time.shape[1],
|
206 |
+
)
|
207 |
+
] = 0
|
208 |
+
|
209 |
+
# text time must equal media time if only attending to most immediate image
|
210 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
211 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
212 |
+
|
213 |
+
text_to_media_mask = mask_op(
|
214 |
+
rearrange(text_time, "b i -> b 1 i 1"),
|
215 |
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
216 |
+
)
|
217 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
218 |
+
|
219 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
220 |
+
attn = sim.softmax(dim=-1)
|
221 |
+
|
222 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
223 |
+
# any text without a preceding media needs to have attention zeroed out
|
224 |
+
text_without_media_mask = text_time == 0
|
225 |
+
text_without_media_mask = rearrange(
|
226 |
+
text_without_media_mask, "b i -> b 1 i 1"
|
227 |
+
)
|
228 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
229 |
+
|
230 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
231 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
232 |
+
return self.to_out(out)
|
233 |
+
|
234 |
+
|
235 |
+
class GatedCrossAttentionBlock(nn.Module):
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
*,
|
239 |
+
dim,
|
240 |
+
dim_visual,
|
241 |
+
dim_head=64,
|
242 |
+
heads=8,
|
243 |
+
ff_mult=4,
|
244 |
+
only_attend_immediate_media=True,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
self.attn = MaskedCrossAttention(
|
248 |
+
dim=dim,
|
249 |
+
dim_visual=dim_visual,
|
250 |
+
dim_head=dim_head,
|
251 |
+
heads=heads,
|
252 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
253 |
+
)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
x,
|
258 |
+
media,
|
259 |
+
media_locations=None,
|
260 |
+
attend_previous=True,
|
261 |
+
):
|
262 |
+
x = self.attn(x, media, media_locations=media_locations, attend_previous=attend_previous) + x
|
263 |
+
return x
|
open_flamingo/open_flamingo/src/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extend_instance(obj, mixin):
|
2 |
+
"""Apply mixins to a class instance after creation"""
|
3 |
+
base_cls = obj.__class__
|
4 |
+
base_cls_name = obj.__class__.__name__
|
5 |
+
obj.__class__ = type(
|
6 |
+
base_cls_name, (mixin, base_cls), {}
|
7 |
+
) # mixin needs to go first for our forward() logic to work
|
8 |
+
|
9 |
+
|
10 |
+
def getattr_recursive(obj, att):
|
11 |
+
"""
|
12 |
+
Return nested attribute of obj
|
13 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
14 |
+
"""
|
15 |
+
if att == "":
|
16 |
+
return obj
|
17 |
+
i = att.find(".")
|
18 |
+
if i < 0:
|
19 |
+
return getattr(obj, att)
|
20 |
+
else:
|
21 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
22 |
+
|
23 |
+
|
24 |
+
def setattr_recursive(obj, att, val):
|
25 |
+
"""
|
26 |
+
Set nested attribute of obj
|
27 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
28 |
+
"""
|
29 |
+
if "." in att:
|
30 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
31 |
+
setattr(obj, att.split(".")[-1], val)
|
open_flamingo/open_flamingo/train/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
open_flamingo/open_flamingo/train/data.deprecated.py
ADDED
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import functools
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
import tarfile
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from multiprocessing import Value
|
13 |
+
|
14 |
+
import braceexpand
|
15 |
+
import torch
|
16 |
+
import torchvision
|
17 |
+
import webdataset as wds
|
18 |
+
from PIL import Image
|
19 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
20 |
+
from torch.utils.data.distributed import DistributedSampler
|
21 |
+
from webdataset.filters import _shuffle
|
22 |
+
from webdataset.tariterators import (
|
23 |
+
base_plus_ext,
|
24 |
+
tar_file_expander,
|
25 |
+
url_opener,
|
26 |
+
valid_sample,
|
27 |
+
)
|
28 |
+
try:
|
29 |
+
from groundingdino.demo.caption_grounder import caption_grounder
|
30 |
+
from groundingdino.demo.inference_on_laion import add_loc_to_text
|
31 |
+
except:
|
32 |
+
pass
|
33 |
+
|
34 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
35 |
+
MAX_NUM_TOKENS = 256
|
36 |
+
MAX_NUM_IMAGES = 5
|
37 |
+
TINY_IMAGE_SIZE_THRESHOLD = 1
|
38 |
+
N_CHANNELS = 3
|
39 |
+
INTERLEAVED_IMAGE_SIZE = 224
|
40 |
+
|
41 |
+
try:
|
42 |
+
import horovod.torch as hvd
|
43 |
+
except ImportError:
|
44 |
+
hvd = None
|
45 |
+
|
46 |
+
|
47 |
+
class SharedEpoch:
|
48 |
+
def __init__(self, epoch: int = 0):
|
49 |
+
self.shared_epoch = Value("i", epoch)
|
50 |
+
|
51 |
+
def set_value(self, epoch):
|
52 |
+
self.shared_epoch.value = epoch
|
53 |
+
|
54 |
+
def get_value(self):
|
55 |
+
return self.shared_epoch.value
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class DataInfo:
|
60 |
+
dataloader: DataLoader
|
61 |
+
sampler: DistributedSampler = None
|
62 |
+
shared_epoch: SharedEpoch = None
|
63 |
+
|
64 |
+
def set_epoch(self, epoch):
|
65 |
+
if self.shared_epoch is not None:
|
66 |
+
self.shared_epoch.set_value(epoch)
|
67 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
68 |
+
self.sampler.set_epoch(epoch)
|
69 |
+
|
70 |
+
|
71 |
+
def get_dataset_size(shards):
|
72 |
+
shards_list = list(braceexpand.braceexpand(shards))
|
73 |
+
# shards_list = shards # ?
|
74 |
+
dir_path = os.path.dirname(shards[0])
|
75 |
+
sizes_filename = os.path.join(dir_path, "sizes.json")
|
76 |
+
len_filename = os.path.join(dir_path, "__len__")
|
77 |
+
if os.path.exists(sizes_filename):
|
78 |
+
sizes = json.load(open(sizes_filename, "r"))
|
79 |
+
total_size = sum(
|
80 |
+
[
|
81 |
+
int(sizes[os.path.basename(shard)])
|
82 |
+
if os.path.basename(shard) in sizes
|
83 |
+
else 0
|
84 |
+
for shard in shards_list
|
85 |
+
]
|
86 |
+
)
|
87 |
+
elif os.path.exists(len_filename):
|
88 |
+
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
|
89 |
+
total_size = ast.literal_eval(open(len_filename, "r").read())
|
90 |
+
else:
|
91 |
+
total_size = None # num samples undefined
|
92 |
+
# some common dataset sizes (at time of authors last download)
|
93 |
+
# CC3M (train): 2905954
|
94 |
+
# CC12M: 10968539
|
95 |
+
# LAION-400M: 407332084
|
96 |
+
# LAION-2B (english): 2170337258
|
97 |
+
num_shards = len(shards_list)
|
98 |
+
return total_size, num_shards
|
99 |
+
|
100 |
+
|
101 |
+
def count_samples(dataloader):
|
102 |
+
os.environ["WDS_EPOCH"] = "0"
|
103 |
+
n_elements, n_batches = 0, 0
|
104 |
+
for images, texts in dataloader:
|
105 |
+
n_batches += 1
|
106 |
+
n_elements += len(images)
|
107 |
+
assert len(images) == len(texts)
|
108 |
+
return n_elements, n_batches
|
109 |
+
|
110 |
+
|
111 |
+
def filter_no_caption_or_no_image(sample):
|
112 |
+
return ("txt" in sample) and (
|
113 |
+
"png" in sample or "jpg" in sample or "jpeg" in sample
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def log_and_continue(exn):
|
118 |
+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
|
119 |
+
if "No images in sample" in str(exn) or "Only one image in sample" in str(
|
120 |
+
exn
|
121 |
+
): # Avoid spamming logs with these
|
122 |
+
return True
|
123 |
+
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
124 |
+
return True
|
125 |
+
|
126 |
+
|
127 |
+
def group_by_keys_nothrow(
|
128 |
+
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
|
129 |
+
):
|
130 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
131 |
+
|
132 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
133 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
134 |
+
"""
|
135 |
+
current_sample = None
|
136 |
+
for filesample in data:
|
137 |
+
assert isinstance(filesample, dict)
|
138 |
+
fname, value = filesample["fname"], filesample["data"]
|
139 |
+
prefix, suffix = keys(fname)
|
140 |
+
if prefix is None:
|
141 |
+
continue
|
142 |
+
if lcase:
|
143 |
+
suffix = suffix.lower()
|
144 |
+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
145 |
+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
146 |
+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
147 |
+
if (
|
148 |
+
current_sample is None
|
149 |
+
or prefix != current_sample["__key__"]
|
150 |
+
or suffix in current_sample
|
151 |
+
):
|
152 |
+
if valid_sample(current_sample):
|
153 |
+
yield current_sample
|
154 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
155 |
+
if suffixes is None or suffix in suffixes:
|
156 |
+
current_sample[suffix] = value
|
157 |
+
if valid_sample(current_sample):
|
158 |
+
yield current_sample
|
159 |
+
|
160 |
+
|
161 |
+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
|
162 |
+
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
163 |
+
streams = url_opener(src, handler=handler)
|
164 |
+
files = tar_file_expander(streams, handler=handler)
|
165 |
+
samples = group_by_keys_nothrow(files, handler=handler)
|
166 |
+
return samples
|
167 |
+
|
168 |
+
|
169 |
+
def pytorch_worker_seed(increment=0):
|
170 |
+
"""get dataloader worker seed from pytorch"""
|
171 |
+
worker_info = get_worker_info()
|
172 |
+
if worker_info is not None:
|
173 |
+
# favour using the seed already created for pytorch dataloader workers if it exists
|
174 |
+
seed = worker_info.seed
|
175 |
+
if increment:
|
176 |
+
# space out seed increments so they can't overlap across workers in different iterations
|
177 |
+
seed += increment * max(1, worker_info.num_workers)
|
178 |
+
return seed
|
179 |
+
# fallback to wds rank based seed
|
180 |
+
return wds.utils.pytorch_worker_seed()
|
181 |
+
|
182 |
+
|
183 |
+
_SHARD_SHUFFLE_SIZE = 2000
|
184 |
+
_SHARD_SHUFFLE_INITIAL = 500
|
185 |
+
_SAMPLE_SHUFFLE_SIZE = 5000
|
186 |
+
_SAMPLE_SHUFFLE_INITIAL = 1000
|
187 |
+
|
188 |
+
|
189 |
+
class detshuffle2(wds.PipelineStage):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
bufsize=1000,
|
193 |
+
initial=100,
|
194 |
+
seed=0,
|
195 |
+
epoch=-1,
|
196 |
+
):
|
197 |
+
self.bufsize = bufsize
|
198 |
+
self.initial = initial
|
199 |
+
self.seed = seed
|
200 |
+
self.epoch = epoch
|
201 |
+
|
202 |
+
def run(self, src):
|
203 |
+
if isinstance(self.epoch, SharedEpoch):
|
204 |
+
epoch = self.epoch.get_value()
|
205 |
+
else:
|
206 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
207 |
+
# situation as different workers may wrap at different times (or not at all).
|
208 |
+
self.epoch += 1
|
209 |
+
epoch = self.epoch
|
210 |
+
rng = random.Random()
|
211 |
+
if self.seed < 0:
|
212 |
+
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
|
213 |
+
seed = pytorch_worker_seed(epoch)
|
214 |
+
else:
|
215 |
+
# This seed to be deterministic AND the same across all nodes/workers in each epoch
|
216 |
+
seed = self.seed + epoch
|
217 |
+
rng.seed(seed)
|
218 |
+
return _shuffle(src, self.bufsize, self.initial, rng)
|
219 |
+
|
220 |
+
|
221 |
+
class ResampledShards2(IterableDataset):
|
222 |
+
"""An iterable dataset yielding a list of urls."""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
urls,
|
227 |
+
nshards=sys.maxsize,
|
228 |
+
worker_seed=None,
|
229 |
+
deterministic=False,
|
230 |
+
epoch=-1,
|
231 |
+
):
|
232 |
+
"""Sample shards from the shard list with replacement.
|
233 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
234 |
+
"""
|
235 |
+
super().__init__()
|
236 |
+
urls = wds.shardlists.expand_urls(urls)
|
237 |
+
self.urls = urls
|
238 |
+
assert isinstance(self.urls[0], str)
|
239 |
+
self.nshards = nshards
|
240 |
+
self.rng = random.Random()
|
241 |
+
self.worker_seed = worker_seed
|
242 |
+
self.deterministic = deterministic
|
243 |
+
self.epoch = epoch
|
244 |
+
|
245 |
+
def __iter__(self):
|
246 |
+
"""Return an iterator over the shards."""
|
247 |
+
if isinstance(self.epoch, SharedEpoch):
|
248 |
+
epoch = self.epoch.get_value()
|
249 |
+
else:
|
250 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
251 |
+
# situation as different workers may wrap at different times (or not at all).
|
252 |
+
self.epoch += 1
|
253 |
+
epoch = self.epoch
|
254 |
+
|
255 |
+
if self.deterministic:
|
256 |
+
# reset seed w/ epoch if deterministic
|
257 |
+
if self.worker_seed is None:
|
258 |
+
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
|
259 |
+
seed = pytorch_worker_seed(epoch)
|
260 |
+
else:
|
261 |
+
seed = self.worker_seed() + epoch
|
262 |
+
self.rng.seed(seed)
|
263 |
+
for _ in range(self.nshards):
|
264 |
+
yield dict(url=self.rng.choice(self.urls))
|
265 |
+
|
266 |
+
|
267 |
+
def preprocess_image(sample, image_processor):
|
268 |
+
image = [image_processor(s).unsqueeze(0) for s in sample]
|
269 |
+
image = torch.cat(image, dim=0)
|
270 |
+
# apply random horizontal flip and color jitter
|
271 |
+
image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image)
|
272 |
+
image = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.3)(image)
|
273 |
+
return image
|
274 |
+
|
275 |
+
|
276 |
+
def preprocess_text(sample, tokenizer, max_length=512):
|
277 |
+
tokenizer.padding_side = "right"
|
278 |
+
sample = [
|
279 |
+
(f"<|#image#|><|#endofimage#|>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
|
280 |
+
]
|
281 |
+
text = tokenizer(
|
282 |
+
sample,
|
283 |
+
max_length=max_length,
|
284 |
+
padding="longest",
|
285 |
+
truncation="only_first",
|
286 |
+
return_tensors="pt",
|
287 |
+
)
|
288 |
+
return text["input_ids"], text["attention_mask"]
|
289 |
+
|
290 |
+
|
291 |
+
def preprocess_encoded_text(sample, tokenizer):
|
292 |
+
sample = [s.decode("utf-8") for s in sample]
|
293 |
+
return preprocess_text(sample, tokenizer, max_length=768)
|
294 |
+
|
295 |
+
|
296 |
+
def preprocess_ground_caption(sample, image_processor, tokenizer, generator, image_size):
|
297 |
+
texts = []
|
298 |
+
images, captions, logits_filts, boxes_filts = sample
|
299 |
+
for cap, logits_filt, boxes_filt in zip(captions, logits_filts, boxes_filts):
|
300 |
+
boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, cap, generator.text_threshold, generator.box_threshold, with_logits=True)
|
301 |
+
caption_with_loc = add_loc_to_text(boxes_filt, pred_phrases, cap)
|
302 |
+
texts.append(caption_with_loc)
|
303 |
+
del boxes_filt
|
304 |
+
del pred_phrases
|
305 |
+
image = preprocess_image(images, image_processor=image_processor)
|
306 |
+
sample = texts
|
307 |
+
input_ids, attention_mask = preprocess_text(sample, tokenizer)
|
308 |
+
return image, (input_ids, attention_mask)
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
MIN_KB = 10
|
313 |
+
MAX_NUM_IMAGES = 5
|
314 |
+
|
315 |
+
|
316 |
+
def preprocess_interleaved(sample, tokenizer, clip_processor, sim_threshold):
|
317 |
+
info = json.loads(sample[0])
|
318 |
+
tar_file_obj = io.BytesIO(sample[1])
|
319 |
+
image_tar = tarfile.open(fileobj=tar_file_obj)
|
320 |
+
sentences = info["text_list"]
|
321 |
+
|
322 |
+
images, image_idxs = [], []
|
323 |
+
for image_path, sim in zip(info["image_info"], info["similarity_matrix"]):
|
324 |
+
# pick one image per sentence
|
325 |
+
if info["image_info"][image_path]["matched_text_index"] in image_idxs:
|
326 |
+
continue
|
327 |
+
rawbytes = image_tar.extractfile(
|
328 |
+
os.path.join(image_tar.getnames()[0], image_path)
|
329 |
+
).read()
|
330 |
+
|
331 |
+
# filter to images >= 10KB
|
332 |
+
if len(rawbytes) // 1000 <= MIN_KB:
|
333 |
+
continue
|
334 |
+
if sim[info["image_info"][image_path]["matched_text_index"]] < sim_threshold:
|
335 |
+
continue
|
336 |
+
image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
|
337 |
+
|
338 |
+
images.append(image)
|
339 |
+
image_idxs.append(info["image_info"][image_path]["matched_text_index"])
|
340 |
+
|
341 |
+
if len(images) == 0:
|
342 |
+
raise ValueError("No images in sample")
|
343 |
+
|
344 |
+
# filter out images that are exact duplicates
|
345 |
+
images_tensors = preprocess_image(images, clip_processor)
|
346 |
+
keep_ixs = range(min(len(images_tensors), MAX_NUM_IMAGES))
|
347 |
+
images_tensors = images_tensors[keep_ixs]
|
348 |
+
image_idxs = [image_idxs[ix] for ix in keep_ixs]
|
349 |
+
|
350 |
+
# pad to 5 images
|
351 |
+
if len(images_tensors) < MAX_NUM_IMAGES:
|
352 |
+
zero_padding = torch.zeros(
|
353 |
+
(MAX_NUM_IMAGES - len(images_tensors), 3, 224, 224), dtype=torch.float
|
354 |
+
)
|
355 |
+
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
|
356 |
+
|
357 |
+
# add in <|#image#|> and <eoc> tokens
|
358 |
+
# eoc after sentence = "sentence loss"
|
359 |
+
for ix in image_idxs:
|
360 |
+
sentences[ix] = f"<|endofchunk|><|#image#|>{sentences[ix]}"
|
361 |
+
|
362 |
+
text = " ".join(sentences)
|
363 |
+
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
|
364 |
+
# whitespace cleanup
|
365 |
+
text = (
|
366 |
+
text.replace(" <|endofchunk|>", "<|endofchunk|>")
|
367 |
+
.replace("<|#image#|> ", "<|#image#|>")
|
368 |
+
.replace(" <|#image#|>", "<|#image#|>")
|
369 |
+
)
|
370 |
+
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
|
371 |
+
tokenizer.padding_side = "right"
|
372 |
+
text_tensor = tokenizer(
|
373 |
+
text, max_length=256, truncation=True, padding="max_length", return_tensors="pt"
|
374 |
+
)
|
375 |
+
|
376 |
+
# reject sequences with too few images (after truncation)
|
377 |
+
num_images = torch.count_nonzero(
|
378 |
+
text_tensor["input_ids"]
|
379 |
+
== tokenizer.additional_special_tokens_ids[
|
380 |
+
tokenizer.additional_special_tokens.index("<|#image#|>")
|
381 |
+
]
|
382 |
+
)
|
383 |
+
|
384 |
+
if num_images == 0:
|
385 |
+
raise ValueError("No images in sample")
|
386 |
+
elif (
|
387 |
+
num_images == 1 and random.random() <= 0.5
|
388 |
+
): # 50% chance of keeping single image samples
|
389 |
+
raise ValueError("Only one image in sample")
|
390 |
+
|
391 |
+
return (
|
392 |
+
images_tensors,
|
393 |
+
(text_tensor["input_ids"], text_tensor["attention_mask"]),
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
398 |
+
input_shards = args.mmc4_shards
|
399 |
+
assert input_shards is not None
|
400 |
+
resampled = getattr(args, "dataset_resampled", False)
|
401 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
402 |
+
|
403 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
404 |
+
num_samples = None
|
405 |
+
if not num_samples:
|
406 |
+
num_samples = args.train_num_samples_mmc4
|
407 |
+
if not num_samples:
|
408 |
+
raise RuntimeError(
|
409 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
410 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
411 |
+
)
|
412 |
+
|
413 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
414 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
415 |
+
if resampled:
|
416 |
+
pipeline = [
|
417 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
418 |
+
]
|
419 |
+
else:
|
420 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
421 |
+
|
422 |
+
preprocess_fn = functools.partial(
|
423 |
+
preprocess_interleaved,
|
424 |
+
clip_processor=image_processor,
|
425 |
+
tokenizer=tokenizer,
|
426 |
+
sim_threshold=args.mmc4_textsim_threshold,
|
427 |
+
)
|
428 |
+
|
429 |
+
# at this point we have an iterator over all the shards
|
430 |
+
if not resampled:
|
431 |
+
pipeline.extend(
|
432 |
+
[
|
433 |
+
detshuffle2(
|
434 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
435 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
436 |
+
seed=args.seed,
|
437 |
+
epoch=shared_epoch,
|
438 |
+
),
|
439 |
+
wds.split_by_node,
|
440 |
+
wds.split_by_worker,
|
441 |
+
]
|
442 |
+
)
|
443 |
+
pipeline.extend(
|
444 |
+
[
|
445 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
446 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
447 |
+
tarfile_to_samples_nothrow,
|
448 |
+
wds.shuffle(
|
449 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
450 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
451 |
+
),
|
452 |
+
]
|
453 |
+
)
|
454 |
+
|
455 |
+
pipeline.extend(
|
456 |
+
[
|
457 |
+
wds.to_tuple("json", "tar", handler=log_and_continue),
|
458 |
+
wds.map(preprocess_fn, handler=log_and_continue),
|
459 |
+
wds.batched(args.batch_size_mmc4, partial=False),
|
460 |
+
]
|
461 |
+
)
|
462 |
+
|
463 |
+
dataset = wds.DataPipeline(*pipeline)
|
464 |
+
if not resampled:
|
465 |
+
assert (
|
466 |
+
num_shards >= args.workers * args.world_size
|
467 |
+
), "number of shards must be >= total workers"
|
468 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
469 |
+
round_fn = math.floor if floor else math.ceil
|
470 |
+
global_batch_size = args.batch_size_mmc4 * args.world_size
|
471 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
472 |
+
num_workers = max(1, args.workers)
|
473 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
474 |
+
num_batches = num_worker_batches * num_workers
|
475 |
+
num_samples = num_batches * global_batch_size
|
476 |
+
# each worker is iterating over this
|
477 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
478 |
+
|
479 |
+
dataloader = wds.WebLoader(
|
480 |
+
dataset,
|
481 |
+
batch_size=None,
|
482 |
+
shuffle=False,
|
483 |
+
num_workers=args.workers,
|
484 |
+
persistent_workers=True,
|
485 |
+
)
|
486 |
+
|
487 |
+
# add meta-data to dataloader instance for convenience
|
488 |
+
dataloader.num_batches = num_batches
|
489 |
+
dataloader.num_samples = num_samples
|
490 |
+
|
491 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
492 |
+
|
493 |
+
|
494 |
+
def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
495 |
+
input_shards = args.laion_shards
|
496 |
+
assert input_shards is not None
|
497 |
+
resampled = getattr(args, "dataset_resampled", False)
|
498 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
499 |
+
|
500 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
501 |
+
num_samples = None
|
502 |
+
if not num_samples:
|
503 |
+
num_samples = args.train_num_samples_laion
|
504 |
+
if not num_samples:
|
505 |
+
raise RuntimeError(
|
506 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
507 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
508 |
+
)
|
509 |
+
|
510 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
511 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
512 |
+
if resampled:
|
513 |
+
pipeline = [
|
514 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
515 |
+
]
|
516 |
+
else:
|
517 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
518 |
+
|
519 |
+
# create two preprocess functions that take in the passed in image_processor and tokenizer
|
520 |
+
preprocess_image_fn = functools.partial(
|
521 |
+
preprocess_image, image_processor=image_processor
|
522 |
+
)
|
523 |
+
preprocess_text_fn = functools.partial(preprocess_text, tokenizer=tokenizer)
|
524 |
+
|
525 |
+
# at this point we have an iterator over all the shards
|
526 |
+
if not resampled:
|
527 |
+
pipeline.extend(
|
528 |
+
[
|
529 |
+
detshuffle2(
|
530 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
531 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
532 |
+
seed=args.seed,
|
533 |
+
epoch=shared_epoch,
|
534 |
+
),
|
535 |
+
wds.split_by_node,
|
536 |
+
wds.split_by_worker,
|
537 |
+
]
|
538 |
+
)
|
539 |
+
pipeline.extend(
|
540 |
+
[
|
541 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
542 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
543 |
+
tarfile_to_samples_nothrow,
|
544 |
+
wds.shuffle(
|
545 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
546 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
547 |
+
),
|
548 |
+
]
|
549 |
+
)
|
550 |
+
|
551 |
+
pipeline.extend(
|
552 |
+
[
|
553 |
+
wds.select(filter_no_caption_or_no_image),
|
554 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
555 |
+
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
|
556 |
+
wds.batched(args.batch_size_laion, partial=False),
|
557 |
+
wds.map_tuple(
|
558 |
+
preprocess_image_fn, preprocess_text_fn, handler=log_and_continue
|
559 |
+
),
|
560 |
+
]
|
561 |
+
)
|
562 |
+
|
563 |
+
dataset = wds.DataPipeline(*pipeline)
|
564 |
+
if not resampled:
|
565 |
+
assert (
|
566 |
+
num_shards >= args.workers * args.world_size
|
567 |
+
), "number of shards must be >= total workers"
|
568 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
569 |
+
round_fn = math.floor if floor else math.ceil
|
570 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
571 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
572 |
+
num_workers = max(1, args.workers)
|
573 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
574 |
+
num_batches = num_worker_batches * num_workers
|
575 |
+
num_samples = num_batches * global_batch_size
|
576 |
+
# each worker is iterating over this
|
577 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
578 |
+
|
579 |
+
dataloader = wds.WebLoader(
|
580 |
+
dataset,
|
581 |
+
batch_size=None,
|
582 |
+
shuffle=False,
|
583 |
+
num_workers=args.workers,
|
584 |
+
persistent_workers=True,
|
585 |
+
)
|
586 |
+
|
587 |
+
# add meta-data to dataloader instance for convenience
|
588 |
+
dataloader.num_batches = num_batches
|
589 |
+
dataloader.num_samples = num_samples
|
590 |
+
|
591 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
592 |
+
|
593 |
+
|
594 |
+
def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
595 |
+
input_shards = args.mmc4_shards
|
596 |
+
assert input_shards is not None
|
597 |
+
resampled = getattr(args, "dataset_resampled", False)
|
598 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
599 |
+
|
600 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
601 |
+
num_samples = None
|
602 |
+
if not num_samples:
|
603 |
+
num_samples = args.train_num_samples_mmc4
|
604 |
+
if not num_samples:
|
605 |
+
raise RuntimeError(
|
606 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
607 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
608 |
+
)
|
609 |
+
|
610 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
611 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
612 |
+
if resampled:
|
613 |
+
pipeline = [
|
614 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
615 |
+
]
|
616 |
+
else:
|
617 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
618 |
+
|
619 |
+
preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer)
|
620 |
+
|
621 |
+
# at this point we have an iterator over all the shards
|
622 |
+
if not resampled:
|
623 |
+
pipeline.extend(
|
624 |
+
[
|
625 |
+
detshuffle2(
|
626 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
627 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
628 |
+
seed=args.seed,
|
629 |
+
epoch=shared_epoch,
|
630 |
+
),
|
631 |
+
wds.split_by_node,
|
632 |
+
wds.split_by_worker,
|
633 |
+
]
|
634 |
+
)
|
635 |
+
pipeline.extend(
|
636 |
+
[
|
637 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
638 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
639 |
+
tarfile_to_samples_nothrow,
|
640 |
+
wds.shuffle(
|
641 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
642 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
643 |
+
),
|
644 |
+
]
|
645 |
+
)
|
646 |
+
|
647 |
+
pipeline.extend(
|
648 |
+
[
|
649 |
+
wds.to_tuple("txt", handler=log_and_continue),
|
650 |
+
wds.batched(args.batch_size_mmc4, partial=False),
|
651 |
+
wds.map_tuple(
|
652 |
+
preprocess_text_fn, handler=log_and_continue
|
653 |
+
),
|
654 |
+
]
|
655 |
+
)
|
656 |
+
|
657 |
+
dataset = wds.DataPipeline(*pipeline)
|
658 |
+
if not resampled:
|
659 |
+
assert (
|
660 |
+
num_shards >= args.workers * args.world_size
|
661 |
+
), "number of shards must be >= total workers"
|
662 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
663 |
+
round_fn = math.floor if floor else math.ceil
|
664 |
+
global_batch_size = args.batch_size_mmc4 * args.world_size
|
665 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
666 |
+
num_workers = max(1, args.workers)
|
667 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
668 |
+
num_batches = num_worker_batches * num_workers
|
669 |
+
num_samples = num_batches * global_batch_size
|
670 |
+
# each worker is iterating over this
|
671 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
672 |
+
|
673 |
+
dataloader = wds.WebLoader(
|
674 |
+
dataset,
|
675 |
+
batch_size=None,
|
676 |
+
shuffle=False,
|
677 |
+
num_workers=args.workers,
|
678 |
+
persistent_workers=True,
|
679 |
+
)
|
680 |
+
|
681 |
+
# add meta-data to dataloader instance for convenience
|
682 |
+
dataloader.num_batches = num_batches
|
683 |
+
dataloader.num_samples = num_samples
|
684 |
+
|
685 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
686 |
+
|
687 |
+
|
688 |
+
# FIXME:
|
689 |
+
# modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433
|
690 |
+
# combine_tensors=True to combine_tensors=False
|
691 |
+
def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
692 |
+
input_shards = args.laion_shards
|
693 |
+
assert input_shards is not None
|
694 |
+
resampled = getattr(args, "dataset_resampled", False)
|
695 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
696 |
+
|
697 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
698 |
+
num_samples = None
|
699 |
+
if not num_samples:
|
700 |
+
num_samples = args.train_num_samples_laion
|
701 |
+
if not num_samples:
|
702 |
+
raise RuntimeError(
|
703 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
704 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
705 |
+
)
|
706 |
+
|
707 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
708 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
709 |
+
if resampled:
|
710 |
+
pipeline = [
|
711 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
712 |
+
]
|
713 |
+
else:
|
714 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
715 |
+
|
716 |
+
# create the preprocess function that take in the passed in image_processor and tokenizer
|
717 |
+
generator = caption_grounder(
|
718 |
+
config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
719 |
+
checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
720 |
+
cpu_only=True,
|
721 |
+
)
|
722 |
+
image_size = image_processor.transforms[0].size
|
723 |
+
preprocess_ground_caption_fn = functools.partial(
|
724 |
+
preprocess_ground_caption, image_processor=image_processor,
|
725 |
+
tokenizer=tokenizer, generator=generator, image_size=image_size,
|
726 |
+
)
|
727 |
+
|
728 |
+
# at this point we have an iterator over all the shards
|
729 |
+
if not resampled:
|
730 |
+
pipeline.extend(
|
731 |
+
[
|
732 |
+
detshuffle2(
|
733 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
734 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
735 |
+
seed=args.seed,
|
736 |
+
epoch=shared_epoch,
|
737 |
+
),
|
738 |
+
wds.split_by_node,
|
739 |
+
wds.split_by_worker,
|
740 |
+
]
|
741 |
+
)
|
742 |
+
pipeline.extend(
|
743 |
+
[
|
744 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
745 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
746 |
+
tarfile_to_samples_nothrow,
|
747 |
+
wds.shuffle(
|
748 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
749 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
750 |
+
),
|
751 |
+
]
|
752 |
+
)
|
753 |
+
|
754 |
+
pipeline.extend(
|
755 |
+
[
|
756 |
+
wds.select(filter_no_caption_or_no_image),
|
757 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
758 |
+
wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", handler=log_and_continue),
|
759 |
+
wds.batched(args.batch_size_laion, partial=False),
|
760 |
+
wds.map(
|
761 |
+
preprocess_ground_caption_fn, handler=log_and_continue
|
762 |
+
),
|
763 |
+
]
|
764 |
+
)
|
765 |
+
|
766 |
+
dataset = wds.DataPipeline(*pipeline)
|
767 |
+
if not resampled:
|
768 |
+
assert (
|
769 |
+
num_shards >= args.workers * args.world_size
|
770 |
+
), "number of shards must be >= total workers"
|
771 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
772 |
+
round_fn = math.floor if floor else math.ceil
|
773 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
774 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
775 |
+
num_workers = max(1, args.workers) # FIXME: single worker
|
776 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
777 |
+
num_batches = num_worker_batches * num_workers
|
778 |
+
num_samples = num_batches * global_batch_size
|
779 |
+
# each worker is iterating over this
|
780 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
781 |
+
|
782 |
+
dataloader = wds.WebLoader(
|
783 |
+
dataset,
|
784 |
+
batch_size=None,
|
785 |
+
shuffle=False,
|
786 |
+
num_workers=0,
|
787 |
+
)
|
788 |
+
|
789 |
+
# add meta-data to dataloader instance for convenience
|
790 |
+
dataloader.num_batches = num_batches
|
791 |
+
dataloader.num_samples = num_samples
|
792 |
+
|
793 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
794 |
+
|
795 |
+
|
796 |
+
def get_dataset_fn(dataset_type):
|
797 |
+
if dataset_type == "image_text":
|
798 |
+
return get_laion_dataset
|
799 |
+
elif dataset_type == "mmc4":
|
800 |
+
return get_mmc4_dataset
|
801 |
+
elif dataset_type == "pile":
|
802 |
+
return get_pile_dataset
|
803 |
+
elif dataset_type == "ground_image_text":
|
804 |
+
return get_ground_laion_dataset
|
805 |
+
else:
|
806 |
+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
807 |
+
|
808 |
+
|
809 |
+
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
|
810 |
+
return get_dataset_fn(dataset_type)(
|
811 |
+
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
|
812 |
+
)
|
open_flamingo/open_flamingo/train/data2.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import functools
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
import tarfile
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from multiprocessing import Value
|
13 |
+
import time
|
14 |
+
|
15 |
+
import braceexpand
|
16 |
+
import torch
|
17 |
+
import torchvision
|
18 |
+
import webdataset as wds
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
21 |
+
from torch.utils.data.distributed import DistributedSampler
|
22 |
+
from webdataset.filters import _shuffle
|
23 |
+
from webdataset.tariterators import (
|
24 |
+
base_plus_ext,
|
25 |
+
tar_file_expander,
|
26 |
+
url_opener,
|
27 |
+
valid_sample,
|
28 |
+
)
|
29 |
+
try:
|
30 |
+
from groundingdino.demo.caption_grounder import caption_grounder
|
31 |
+
from groundingdino.demo.inference_on_laion import add_loc_to_text
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
|
35 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
36 |
+
LAION2B_NUM_SAMPLE = 1500000000
|
37 |
+
VQAV2_TRAIN_NUM_SAMPLE = 1828467
|
38 |
+
|
39 |
+
try:
|
40 |
+
import horovod.torch as hvd
|
41 |
+
except ImportError:
|
42 |
+
hvd = None
|
43 |
+
|
44 |
+
|
45 |
+
class ConcatDataset(IterableDataset):
|
46 |
+
def __init__(
|
47 |
+
self, dataset, max_length,
|
48 |
+
delimiter_id, pad_id=None, media_id=None, endofmedia_id=None,
|
49 |
+
image_embedding_size=-2, single=False,
|
50 |
+
):
|
51 |
+
self.dataset = dataset
|
52 |
+
self.max_length = max_length
|
53 |
+
self.delimiter_id = torch.ones(1,1).long() * delimiter_id
|
54 |
+
if pad_id is not None:
|
55 |
+
self.pad_id = int(pad_id)
|
56 |
+
if media_id is not None:
|
57 |
+
self.media_id = torch.ones(1,1).long() * int(media_id)
|
58 |
+
if endofmedia_id is not None:
|
59 |
+
self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id)
|
60 |
+
if image_embedding_size > 0:
|
61 |
+
logging.info(f"image_embedding_size: {image_embedding_size}")
|
62 |
+
self.image_embedding_size = image_embedding_size + 2
|
63 |
+
self.single = single
|
64 |
+
|
65 |
+
def __iter__(self):
|
66 |
+
while True:
|
67 |
+
input_ids_list = []
|
68 |
+
attention_mask_list = []
|
69 |
+
image_list = []
|
70 |
+
image_start_index_list = []
|
71 |
+
added_bbox_list = []
|
72 |
+
cnt = 0
|
73 |
+
while cnt < self.max_length:
|
74 |
+
sample = next(self.dataset)
|
75 |
+
if len(sample) == 4:
|
76 |
+
image = sample[0].unsqueeze(0)
|
77 |
+
input_ids = sample[1]
|
78 |
+
attention_mask = sample[2]
|
79 |
+
added_bbox = sample[3]
|
80 |
+
image_list.append(image)
|
81 |
+
added_bbox_list.append(added_bbox)
|
82 |
+
else:
|
83 |
+
sample = sample[0]
|
84 |
+
input_ids = sample[0]
|
85 |
+
attention_mask = sample[1]
|
86 |
+
input_ids_list.append(input_ids)
|
87 |
+
attention_mask_list.append(attention_mask)
|
88 |
+
cnt += input_ids.shape[-1]
|
89 |
+
if self.single:
|
90 |
+
break
|
91 |
+
input_ids = torch.cat(input_ids_list, dim=-1)[0]
|
92 |
+
attention_mask = torch.cat(attention_mask_list, dim=-1)[0]
|
93 |
+
if not self.single:
|
94 |
+
input_ids = input_ids[:self.max_length]
|
95 |
+
attention_mask = attention_mask[:self.max_length]
|
96 |
+
if len(image_list) != 0:
|
97 |
+
images = torch.cat(image_list, dim=0)
|
98 |
+
image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1)
|
99 |
+
image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1)
|
100 |
+
images = images[:len(image_end)]
|
101 |
+
added_bbox_list = added_bbox_list[:len(image_end)]
|
102 |
+
if len(image_begin) != len(image_end):
|
103 |
+
assert len(image_begin) == len(image_end) + 1
|
104 |
+
input_ids[image_begin[-1]:] = self.pad_id
|
105 |
+
attention_mask[image_begin[-1]:] = 0
|
106 |
+
image_begin = image_begin[:-1]
|
107 |
+
image_start_index_list = (image_begin + 1).tolist()
|
108 |
+
yield images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox_list
|
109 |
+
else:
|
110 |
+
yield input_ids, attention_mask
|
111 |
+
|
112 |
+
|
113 |
+
class SharedEpoch:
|
114 |
+
def __init__(self, epoch: int = 0):
|
115 |
+
self.shared_epoch = Value("i", epoch)
|
116 |
+
|
117 |
+
def set_value(self, epoch):
|
118 |
+
self.shared_epoch.value = epoch
|
119 |
+
|
120 |
+
def get_value(self):
|
121 |
+
return self.shared_epoch.value
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class DataInfo:
|
126 |
+
dataloader: DataLoader
|
127 |
+
sampler: DistributedSampler = None
|
128 |
+
shared_epoch: SharedEpoch = None
|
129 |
+
|
130 |
+
def set_epoch(self, epoch):
|
131 |
+
if self.shared_epoch is not None:
|
132 |
+
self.shared_epoch.set_value(epoch)
|
133 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
134 |
+
self.sampler.set_epoch(epoch)
|
135 |
+
|
136 |
+
|
137 |
+
def filter_no_caption_or_no_image(sample):
|
138 |
+
return ("txt" in sample) and (
|
139 |
+
"png" in sample or "jpg" in sample or "jpeg" in sample
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def log_and_continue(exn):
|
144 |
+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
|
145 |
+
if "No images in sample" in str(exn) or "Only one image in sample" in str(
|
146 |
+
exn
|
147 |
+
): # Avoid spamming logs with these
|
148 |
+
return True
|
149 |
+
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
150 |
+
return True
|
151 |
+
|
152 |
+
|
153 |
+
def group_by_keys_nothrow(
|
154 |
+
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
|
155 |
+
):
|
156 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
157 |
+
|
158 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
159 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
160 |
+
"""
|
161 |
+
current_sample = None
|
162 |
+
for filesample in data:
|
163 |
+
assert isinstance(filesample, dict)
|
164 |
+
fname, value = filesample["fname"], filesample["data"]
|
165 |
+
prefix, suffix = keys(fname)
|
166 |
+
if prefix is None:
|
167 |
+
continue
|
168 |
+
if lcase:
|
169 |
+
suffix = suffix.lower()
|
170 |
+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
171 |
+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
172 |
+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
173 |
+
if (
|
174 |
+
current_sample is None
|
175 |
+
or prefix != current_sample["__key__"]
|
176 |
+
or suffix in current_sample
|
177 |
+
):
|
178 |
+
if valid_sample(current_sample):
|
179 |
+
yield current_sample
|
180 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
181 |
+
if suffixes is None or suffix in suffixes:
|
182 |
+
current_sample[suffix] = value
|
183 |
+
if valid_sample(current_sample):
|
184 |
+
yield current_sample
|
185 |
+
|
186 |
+
|
187 |
+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
|
188 |
+
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
189 |
+
streams = url_opener(src, handler=handler)
|
190 |
+
files = tar_file_expander(streams, handler=handler)
|
191 |
+
samples = group_by_keys_nothrow(files, handler=handler)
|
192 |
+
return samples
|
193 |
+
|
194 |
+
|
195 |
+
def pytorch_worker_seed(increment=0):
|
196 |
+
"""get dataloader worker seed from pytorch"""
|
197 |
+
worker_info = get_worker_info()
|
198 |
+
if worker_info is not None:
|
199 |
+
# favour using the seed already created for pytorch dataloader workers if it exists
|
200 |
+
seed = worker_info.seed
|
201 |
+
if increment:
|
202 |
+
# space out seed increments so they can't overlap across workers in different iterations
|
203 |
+
seed += increment * max(1, worker_info.num_workers)
|
204 |
+
return seed
|
205 |
+
# fallback to wds rank based seed
|
206 |
+
return wds.utils.pytorch_worker_seed()
|
207 |
+
|
208 |
+
|
209 |
+
_SHARD_SHUFFLE_SIZE = 2000
|
210 |
+
_SHARD_SHUFFLE_INITIAL = 500
|
211 |
+
_SAMPLE_SHUFFLE_SIZE = 5000
|
212 |
+
_SAMPLE_SHUFFLE_INITIAL = 1000
|
213 |
+
|
214 |
+
|
215 |
+
class ResampledShards2(IterableDataset):
|
216 |
+
"""An iterable dataset yielding a list of urls."""
|
217 |
+
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
urls,
|
221 |
+
nshards=sys.maxsize,
|
222 |
+
worker_seed=None,
|
223 |
+
deterministic=False,
|
224 |
+
epoch=-1,
|
225 |
+
):
|
226 |
+
"""Sample shards from the shard list with replacement.
|
227 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
228 |
+
"""
|
229 |
+
super().__init__()
|
230 |
+
urls = wds.shardlists.expand_urls(urls)
|
231 |
+
self.urls = urls
|
232 |
+
assert isinstance(self.urls[0], str)
|
233 |
+
self.nshards = nshards
|
234 |
+
self.rng = random.Random()
|
235 |
+
self.worker_seed = worker_seed
|
236 |
+
self.deterministic = deterministic
|
237 |
+
self.epoch = epoch
|
238 |
+
|
239 |
+
def __iter__(self):
|
240 |
+
"""Return an iterator over the shards."""
|
241 |
+
if isinstance(self.epoch, SharedEpoch):
|
242 |
+
epoch = self.epoch.get_value()
|
243 |
+
else:
|
244 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
245 |
+
# situation as different workers may wrap at different times (or not at all).
|
246 |
+
self.epoch += 1
|
247 |
+
epoch = self.epoch
|
248 |
+
|
249 |
+
if self.deterministic:
|
250 |
+
# reset seed w/ epoch if deterministic
|
251 |
+
if self.worker_seed is None:
|
252 |
+
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
|
253 |
+
seed = pytorch_worker_seed(epoch)
|
254 |
+
else:
|
255 |
+
seed = self.worker_seed() + epoch
|
256 |
+
seed = seed + int(time.time())
|
257 |
+
self.rng.seed(seed)
|
258 |
+
# logging.info(f"epoch: {epoch} seed: {seed}")
|
259 |
+
self.rng.shuffle(self.urls)
|
260 |
+
# logging.info(f"{len(self.urls)} | {self.urls[:2]}")
|
261 |
+
for url in self.urls:
|
262 |
+
# logging.info(f"{seed}: {url}")
|
263 |
+
yield dict(url=url)
|
264 |
+
|
265 |
+
|
266 |
+
def preprocess_image(sample, image_processor):
|
267 |
+
image = image_processor(sample)
|
268 |
+
return image
|
269 |
+
|
270 |
+
|
271 |
+
def preprocess_text(sample, tokenizer, max_length, single=False):
|
272 |
+
if not single:
|
273 |
+
text = tokenizer(sample.strip(), return_tensors="pt", max_length=max_length, truncation=True)
|
274 |
+
else:
|
275 |
+
text = tokenizer(sample.strip(), return_tensors="pt", padding="max_length", truncation=True, max_length=320)
|
276 |
+
return text["input_ids"], text["attention_mask"]
|
277 |
+
|
278 |
+
|
279 |
+
def preprocess_encoded_text(sample, tokenizer, max_length):
|
280 |
+
sample = sample.decode("utf-8")
|
281 |
+
return preprocess_text(sample, tokenizer, max_length=max_length)
|
282 |
+
|
283 |
+
|
284 |
+
def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
|
285 |
+
assert max_length is not None
|
286 |
+
assert not single, "single is not supported for preprocess_ground_caption"
|
287 |
+
image, caption, logits_filt, boxes_filt = sample
|
288 |
+
image = preprocess_image(image, image_processor=image_processor)
|
289 |
+
added_bbox = []
|
290 |
+
if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0:
|
291 |
+
boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
|
292 |
+
caption, added_bbox = add_loc_to_text(boxes_filt, pred_phrases, caption, use_format_v2, add_visual_token)
|
293 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
|
294 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
|
295 |
+
return image, input_ids, attention_mask, added_bbox
|
296 |
+
|
297 |
+
|
298 |
+
def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False):
|
299 |
+
image, caption = sample
|
300 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
|
301 |
+
image = preprocess_image(image, image_processor=image_processor)
|
302 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single)
|
303 |
+
return image, input_ids, attention_mask
|
304 |
+
|
305 |
+
|
306 |
+
def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
307 |
+
input_shards = args.pile_shards
|
308 |
+
assert input_shards is not None
|
309 |
+
resampled = getattr(args, "dataset_resampled", False)
|
310 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
311 |
+
|
312 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
313 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
314 |
+
preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length)
|
315 |
+
pipeline = [
|
316 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
317 |
+
tarfile_to_samples_nothrow,
|
318 |
+
wds.shuffle(
|
319 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
320 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
321 |
+
),
|
322 |
+
wds.to_tuple("txt", handler=log_and_continue),
|
323 |
+
wds.map_tuple(
|
324 |
+
preprocess_text_fn, handler=log_and_continue
|
325 |
+
),
|
326 |
+
]
|
327 |
+
# with_epoch(sys.maxsize) will give us an infinite sample stream
|
328 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
329 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
330 |
+
dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id)
|
331 |
+
|
332 |
+
|
333 |
+
def text_collate_fn(items):
|
334 |
+
try:
|
335 |
+
input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0)
|
336 |
+
attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0)
|
337 |
+
return input_ids, attention_mask
|
338 |
+
except:
|
339 |
+
return None, None
|
340 |
+
|
341 |
+
dataloader = wds.WebLoader(
|
342 |
+
dataset,
|
343 |
+
batch_size=args.batch_size_pile,
|
344 |
+
shuffle=False,
|
345 |
+
num_workers=args.workers,
|
346 |
+
persistent_workers=False,
|
347 |
+
collate_fn=text_collate_fn,
|
348 |
+
)
|
349 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
350 |
+
|
351 |
+
|
352 |
+
# FIXME:
|
353 |
+
# modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433
|
354 |
+
# combine_tensors=True to combine_tensors=False
|
355 |
+
def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
356 |
+
input_shards = args.laion_shards
|
357 |
+
assert input_shards is not None
|
358 |
+
resampled = getattr(args, "dataset_resampled", False)
|
359 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
360 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
361 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
362 |
+
generator = caption_grounder(
|
363 |
+
config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
364 |
+
checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
365 |
+
cpu_only=True,
|
366 |
+
)
|
367 |
+
preprocess_ground_caption_fn = functools.partial(
|
368 |
+
preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer,
|
369 |
+
image_embedding_size=args.vis_embed_size, single=args.single, generator=generator,
|
370 |
+
prob_ground=args.prob_ground, use_format_v2=args.use_format_v2,
|
371 |
+
add_visual_token=args.add_visual_token, max_length=args.max_length,
|
372 |
+
)
|
373 |
+
pipeline = [
|
374 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
375 |
+
tarfile_to_samples_nothrow,
|
376 |
+
wds.shuffle(
|
377 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
378 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
379 |
+
),
|
380 |
+
wds.select(filter_no_caption_or_no_image),
|
381 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
382 |
+
wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", handler=log_and_continue),
|
383 |
+
wds.map(
|
384 |
+
preprocess_ground_caption_fn, handler=log_and_continue
|
385 |
+
),
|
386 |
+
]
|
387 |
+
|
388 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
389 |
+
|
390 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
391 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
392 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
393 |
+
dataset = ConcatDataset(
|
394 |
+
iter(dataset), max_length=args.max_length,
|
395 |
+
delimiter_id=delimiter_id,
|
396 |
+
pad_id=tokenizer.pad_token_id,
|
397 |
+
media_id=media_token_id,
|
398 |
+
endofmedia_id=endofmedia_token_id,
|
399 |
+
image_embedding_size=args.vis_embed_size,
|
400 |
+
single=args.single,
|
401 |
+
)
|
402 |
+
|
403 |
+
def image_collate_fn(items):
|
404 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
405 |
+
image_nums = [x[1] for x in items]
|
406 |
+
image_start_index_list = [x[2] for x in items]
|
407 |
+
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
|
408 |
+
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
|
409 |
+
added_bbox_list = [x[5] for x in items]
|
410 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask, added_bbox_list
|
411 |
+
|
412 |
+
dataloader = wds.WebLoader(
|
413 |
+
dataset,
|
414 |
+
batch_size=args.batch_size_laion,
|
415 |
+
shuffle=False,
|
416 |
+
num_workers=args.workers,
|
417 |
+
persistent_workers=False,
|
418 |
+
collate_fn=image_collate_fn,
|
419 |
+
)
|
420 |
+
round_fn = math.floor if floor else math.ceil
|
421 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
422 |
+
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
|
423 |
+
dataloader.num_batches = num_batches
|
424 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
425 |
+
|
426 |
+
|
427 |
+
def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
428 |
+
input_shards = args.laion_shards
|
429 |
+
assert input_shards is not None
|
430 |
+
resampled = getattr(args, "dataset_resampled", False)
|
431 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
432 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
433 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
434 |
+
preprocess_caption_fn = functools.partial(
|
435 |
+
preprocess_caption, image_processor=image_processor, tokenizer=tokenizer,
|
436 |
+
image_embedding_size=args.vis_embed_size, single=args.single,
|
437 |
+
max_length=args.max_length,
|
438 |
+
)
|
439 |
+
pipeline = [
|
440 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
441 |
+
tarfile_to_samples_nothrow,
|
442 |
+
wds.shuffle(
|
443 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
444 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
445 |
+
),
|
446 |
+
wds.select(filter_no_caption_or_no_image),
|
447 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
448 |
+
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
|
449 |
+
wds.map(
|
450 |
+
preprocess_caption_fn, handler=log_and_continue
|
451 |
+
),
|
452 |
+
]
|
453 |
+
|
454 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
455 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
456 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
457 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
458 |
+
dataset = ConcatDataset(
|
459 |
+
iter(dataset), max_length=args.max_length,
|
460 |
+
delimiter_id=delimiter_id,
|
461 |
+
pad_id=tokenizer.pad_token_id,
|
462 |
+
media_id=media_token_id,
|
463 |
+
endofmedia_id=endofmedia_token_id,
|
464 |
+
image_embedding_size=args.vis_embed_size,
|
465 |
+
single=args.single,
|
466 |
+
)
|
467 |
+
|
468 |
+
def image_collate_fn(items):
|
469 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
470 |
+
image_nums = [x[1] for x in items]
|
471 |
+
image_start_index_list = [x[2] for x in items]
|
472 |
+
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
|
473 |
+
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
|
474 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask
|
475 |
+
|
476 |
+
dataloader = wds.WebLoader(
|
477 |
+
dataset,
|
478 |
+
batch_size=args.batch_size_laion,
|
479 |
+
shuffle=False,
|
480 |
+
num_workers=args.workers,
|
481 |
+
persistent_workers=False,
|
482 |
+
collate_fn=image_collate_fn,
|
483 |
+
)
|
484 |
+
round_fn = math.floor if floor else math.ceil
|
485 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
486 |
+
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
|
487 |
+
dataloader.num_batches = num_batches
|
488 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
489 |
+
|
490 |
+
|
491 |
+
def get_vqav2_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
492 |
+
input_shards = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_train_wds/{000000..000182}.tar"
|
493 |
+
assert input_shards is not None
|
494 |
+
resampled = getattr(args, "dataset_resampled", False)
|
495 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
496 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
497 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
498 |
+
preprocess_caption_fn = functools.partial(
|
499 |
+
preprocess_caption, image_processor=image_processor, tokenizer=tokenizer,
|
500 |
+
image_embedding_size=args.vis_embed_size, single=True,
|
501 |
+
max_length=args.max_length,
|
502 |
+
)
|
503 |
+
pipeline = [
|
504 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
505 |
+
tarfile_to_samples_nothrow,
|
506 |
+
wds.shuffle(
|
507 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
508 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
509 |
+
),
|
510 |
+
wds.select(filter_no_caption_or_no_image),
|
511 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
512 |
+
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
|
513 |
+
wds.map(
|
514 |
+
preprocess_caption_fn, handler=log_and_continue
|
515 |
+
),
|
516 |
+
]
|
517 |
+
|
518 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
519 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
520 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
521 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
522 |
+
dataset = ConcatDataset(
|
523 |
+
iter(dataset), max_length=args.max_length,
|
524 |
+
delimiter_id=delimiter_id,
|
525 |
+
pad_id=tokenizer.pad_token_id,
|
526 |
+
media_id=media_token_id,
|
527 |
+
endofmedia_id=endofmedia_token_id,
|
528 |
+
image_embedding_size=args.vis_embed_size,
|
529 |
+
single=True,
|
530 |
+
)
|
531 |
+
|
532 |
+
def image_collate_fn(items):
|
533 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
534 |
+
image_nums = [x[1] for x in items]
|
535 |
+
image_start_index_list = [x[2] for x in items]
|
536 |
+
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
|
537 |
+
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
|
538 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask
|
539 |
+
|
540 |
+
dataloader = wds.WebLoader(
|
541 |
+
dataset,
|
542 |
+
batch_size=args.batch_size_laion,
|
543 |
+
shuffle=False,
|
544 |
+
num_workers=args.workers,
|
545 |
+
persistent_workers=False,
|
546 |
+
collate_fn=image_collate_fn,
|
547 |
+
)
|
548 |
+
round_fn = math.floor if floor else math.ceil
|
549 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
550 |
+
num_batches = round_fn(VQAV2_TRAIN_NUM_SAMPLE / global_batch_size)
|
551 |
+
dataloader.num_batches = num_batches
|
552 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
553 |
+
|
554 |
+
|
555 |
+
def get_dataset_fn(dataset_type):
|
556 |
+
if dataset_type == "mmc4":
|
557 |
+
raise NotImplementedError
|
558 |
+
elif dataset_type == "pile":
|
559 |
+
return get_pile_dataset
|
560 |
+
elif dataset_type == "ground_image_text":
|
561 |
+
return get_ground_laion_dataset
|
562 |
+
elif dataset_type == "image_text":
|
563 |
+
return get_image_text_pair_dataset
|
564 |
+
elif dataset_type == "vqav2":
|
565 |
+
return get_vqav2_dataset
|
566 |
+
else:
|
567 |
+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
568 |
+
|
569 |
+
|
570 |
+
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
|
571 |
+
return get_dataset_fn(dataset_type)(
|
572 |
+
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
|
573 |
+
)
|
open_flamingo/open_flamingo/train/distributed.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
import horovod.torch as hvd
|
7 |
+
except ImportError:
|
8 |
+
hvd = None
|
9 |
+
|
10 |
+
|
11 |
+
def is_global_master(args):
|
12 |
+
return args.rank == 0
|
13 |
+
|
14 |
+
|
15 |
+
def is_local_master(args):
|
16 |
+
return args.local_rank == 0
|
17 |
+
|
18 |
+
|
19 |
+
def is_master(args, local=False):
|
20 |
+
return is_local_master(args) if local else is_global_master(args)
|
21 |
+
|
22 |
+
|
23 |
+
def is_using_horovod():
|
24 |
+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
|
25 |
+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
|
26 |
+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
27 |
+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
28 |
+
if all([var in os.environ for var in ompi_vars]) or all(
|
29 |
+
[var in os.environ for var in pmi_vars]
|
30 |
+
):
|
31 |
+
return True
|
32 |
+
else:
|
33 |
+
return False
|
34 |
+
|
35 |
+
|
36 |
+
def is_using_distributed():
|
37 |
+
if "WORLD_SIZE" in os.environ:
|
38 |
+
return int(os.environ["WORLD_SIZE"]) > 1
|
39 |
+
if "SLURM_NTASKS" in os.environ:
|
40 |
+
return int(os.environ["SLURM_NTASKS"]) > 1
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
def world_info_from_env():
|
45 |
+
local_rank = 0
|
46 |
+
for v in (
|
47 |
+
"LOCAL_RANK",
|
48 |
+
"MPI_LOCALRANKID",
|
49 |
+
"SLURM_LOCALID",
|
50 |
+
"OMPI_COMM_WORLD_LOCAL_RANK",
|
51 |
+
):
|
52 |
+
if v in os.environ:
|
53 |
+
local_rank = int(os.environ[v])
|
54 |
+
break
|
55 |
+
global_rank = 0
|
56 |
+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
|
57 |
+
if v in os.environ:
|
58 |
+
global_rank = int(os.environ[v])
|
59 |
+
break
|
60 |
+
world_size = 1
|
61 |
+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
|
62 |
+
if v in os.environ:
|
63 |
+
world_size = int(os.environ[v])
|
64 |
+
break
|
65 |
+
|
66 |
+
return local_rank, global_rank, world_size
|
67 |
+
|
68 |
+
|
69 |
+
def init_distributed_device(args):
|
70 |
+
# Distributed training = training on more than one GPU.
|
71 |
+
# Works in both single and multi-node scenarios.
|
72 |
+
args.distributed = False
|
73 |
+
args.world_size = 1
|
74 |
+
args.rank = 0 # global rank
|
75 |
+
args.local_rank = 0
|
76 |
+
if args.horovod:
|
77 |
+
assert hvd is not None, "Horovod is not installed"
|
78 |
+
hvd.init()
|
79 |
+
args.local_rank = int(hvd.local_rank())
|
80 |
+
args.rank = hvd.rank()
|
81 |
+
args.world_size = hvd.size()
|
82 |
+
args.distributed = True
|
83 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
84 |
+
os.environ["RANK"] = str(args.rank)
|
85 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
86 |
+
elif is_using_distributed():
|
87 |
+
if "SLURM_PROCID" in os.environ:
|
88 |
+
# DDP via SLURM
|
89 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
90 |
+
# SLURM var -> torch.distributed vars in case needed
|
91 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
92 |
+
os.environ["RANK"] = str(args.rank)
|
93 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
94 |
+
torch.distributed.init_process_group(
|
95 |
+
backend=args.dist_backend,
|
96 |
+
init_method=args.dist_url,
|
97 |
+
world_size=args.world_size,
|
98 |
+
rank=args.rank,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
# DDP via torchrun, torch.distributed.launch
|
102 |
+
args.local_rank, _, _ = world_info_from_env()
|
103 |
+
torch.distributed.init_process_group(
|
104 |
+
backend=args.dist_backend, init_method=args.dist_url
|
105 |
+
)
|
106 |
+
args.world_size = torch.distributed.get_world_size()
|
107 |
+
args.rank = torch.distributed.get_rank()
|
108 |
+
args.distributed = True
|
109 |
+
else:
|
110 |
+
# needed to run on single gpu
|
111 |
+
torch.distributed.init_process_group(
|
112 |
+
backend=args.dist_backend,
|
113 |
+
init_method=args.dist_url,
|
114 |
+
world_size=1,
|
115 |
+
rank=0,
|
116 |
+
)
|
117 |
+
|
118 |
+
if torch.cuda.is_available():
|
119 |
+
if args.distributed and not args.no_set_device_rank:
|
120 |
+
device = "cuda:%d" % args.local_rank
|
121 |
+
else:
|
122 |
+
device = "cuda:0"
|
123 |
+
torch.cuda.set_device(device)
|
124 |
+
else:
|
125 |
+
device = "cpu"
|
126 |
+
args.device = device
|
127 |
+
device = torch.device(device)
|
128 |
+
return device
|
open_flamingo/open_flamingo/train/train.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Main training script """
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import copy
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import functools
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import wandb
|
13 |
+
from data2 import get_data
|
14 |
+
from distributed import init_distributed_device, world_info_from_env
|
15 |
+
from torch.distributed.fsdp import (
|
16 |
+
FullyShardedDataParallel as FSDP,
|
17 |
+
MixedPrecision,
|
18 |
+
BackwardPrefetch,
|
19 |
+
ShardingStrategy,
|
20 |
+
FullStateDictConfig,
|
21 |
+
CPUOffload,
|
22 |
+
StateDictType,
|
23 |
+
)
|
24 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
25 |
+
from torch.distributed.fsdp.wrap import (
|
26 |
+
transformer_auto_wrap_policy,
|
27 |
+
enable_wrap,
|
28 |
+
wrap,
|
29 |
+
)
|
30 |
+
|
31 |
+
from train_utils import train_one_epoch
|
32 |
+
from transformers import (
|
33 |
+
get_constant_schedule_with_warmup,
|
34 |
+
get_cosine_schedule_with_warmup,
|
35 |
+
get_linear_schedule_with_warmup,
|
36 |
+
)
|
37 |
+
|
38 |
+
from open_flamingo import create_model_and_transforms
|
39 |
+
from torch.utils.tensorboard import SummaryWriter
|
40 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
41 |
+
from torch.cuda.amp import GradScaler
|
42 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
43 |
+
import warnings
|
44 |
+
warnings.filterwarnings("ignore")
|
45 |
+
|
46 |
+
class FakeDataloader:
|
47 |
+
def __iter__(self):
|
48 |
+
return self
|
49 |
+
|
50 |
+
def __next__(self):
|
51 |
+
return None
|
52 |
+
|
53 |
+
def random_seed(seed=42, rank=0):
|
54 |
+
torch.manual_seed(seed + rank)
|
55 |
+
np.random.seed(seed + rank)
|
56 |
+
random.seed(seed + rank)
|
57 |
+
|
58 |
+
|
59 |
+
def get_grouped_params(model, args):
|
60 |
+
params_with_wd, params_without_wd = [], []
|
61 |
+
|
62 |
+
def apply_decay(x):
|
63 |
+
x = x.lower()
|
64 |
+
return "norm" not in x and "bias" not in x and "embed" not in x and "wte" not in x
|
65 |
+
|
66 |
+
for n, p in model.named_parameters():
|
67 |
+
# if p.requires_grad:
|
68 |
+
if apply_decay(n):
|
69 |
+
# print("with wd", n)
|
70 |
+
params_with_wd.append(p)
|
71 |
+
else:
|
72 |
+
# print("without wd", n)
|
73 |
+
params_without_wd.append(p)
|
74 |
+
return [
|
75 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
76 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
77 |
+
]
|
78 |
+
|
79 |
+
|
80 |
+
def lambda_policy_fn(module):
|
81 |
+
if (
|
82 |
+
len(list(module.named_children())) == 0
|
83 |
+
and getattr(module, "weight", None) is not None
|
84 |
+
and module.weight.requires_grad
|
85 |
+
):
|
86 |
+
return True
|
87 |
+
return False
|
88 |
+
|
89 |
+
|
90 |
+
def lambda_auto_wrap_policy(
|
91 |
+
module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
|
92 |
+
) -> bool:
|
93 |
+
"""
|
94 |
+
A convenient auto wrap policy to wrap submodules based on an arbitrary user
|
95 |
+
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
|
96 |
+
a `wrapper_cls` unit.
|
97 |
+
|
98 |
+
Return if a module should be wrapped during auto wrapping.
|
99 |
+
|
100 |
+
The first three parameters are required by :func:`_recursive_wrap`.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
module (nn.Module): Current module being considered.
|
104 |
+
recurse (bool): If ``False``, then this function must decide whether
|
105 |
+
``module`` should be wrapped as an FSDP instance or not. If
|
106 |
+
``True``, then the function is still recursing down the module
|
107 |
+
tree as a part of the DFS.
|
108 |
+
nonwrapped_numel (int): Parameter numel not yet wrapped.
|
109 |
+
|
110 |
+
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
|
111 |
+
this module will be wrapped.
|
112 |
+
"""
|
113 |
+
if recurse:
|
114 |
+
return True # always recurse
|
115 |
+
return lambda_fn(module)
|
116 |
+
|
117 |
+
|
118 |
+
def main():
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
|
121 |
+
parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
|
122 |
+
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
|
123 |
+
parser.add_argument(
|
124 |
+
"--tokenizer_path",
|
125 |
+
default="facebook/opt-1.3b",
|
126 |
+
type=str,
|
127 |
+
help="path to tokenizer",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--run_name",
|
131 |
+
type=str,
|
132 |
+
default="openflamingo3B",
|
133 |
+
help="used to name saving directory and wandb run",
|
134 |
+
)
|
135 |
+
parser.add_argument("--use_media_placement_augmentation", action="store_true")
|
136 |
+
parser.add_argument("--offline", action="store_true")
|
137 |
+
parser.add_argument("--num_steps", type=int, default=300000)
|
138 |
+
parser.add_argument(
|
139 |
+
"--logging_steps", type=int, default=10, help="log loss every n steps"
|
140 |
+
)
|
141 |
+
# Sum of gradient optimization batch size
|
142 |
+
parser.add_argument("--batch_size_mmc4", type=int, default=128)
|
143 |
+
parser.add_argument("--batch_size_laion", type=int, default=128)
|
144 |
+
parser.add_argument("--batch_size_pile", type=int, default=128)
|
145 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
146 |
+
parser.add_argument(
|
147 |
+
"--resume_from_checkpoint",
|
148 |
+
type=str,
|
149 |
+
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
|
150 |
+
default=None,
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--delete_previous_checkpoint",
|
154 |
+
action="store_true",
|
155 |
+
help="delete previous checkpoint when saving new checkpoint",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--laion_shards",
|
159 |
+
type=str,
|
160 |
+
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--mmc4_shards",
|
164 |
+
type=str,
|
165 |
+
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--pile_shards",
|
169 |
+
type=str,
|
170 |
+
default=None,
|
171 |
+
help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
172 |
+
)
|
173 |
+
parser.add_argument("--seed", type=int, default=42)
|
174 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
175 |
+
parser.add_argument(
|
176 |
+
"--lr_scheduler",
|
177 |
+
default="constant",
|
178 |
+
type=str,
|
179 |
+
help="constant, linear, or cosine",
|
180 |
+
)
|
181 |
+
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
|
182 |
+
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
|
183 |
+
parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
|
184 |
+
parser.add_argument("--warmup_steps", default=5000, type=int)
|
185 |
+
parser.add_argument("--weight_decay", default=0.1, type=float)
|
186 |
+
parser.add_argument(
|
187 |
+
"--precision",
|
188 |
+
choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
|
189 |
+
default="fp32",
|
190 |
+
help="Floating point precision.",
|
191 |
+
)
|
192 |
+
# data args
|
193 |
+
parser.add_argument("--workers", type=int, default=1)
|
194 |
+
parser.add_argument("--dataset_resampled", action="store_true")
|
195 |
+
# distributed training args
|
196 |
+
parser.add_argument(
|
197 |
+
"--dist-url",
|
198 |
+
default="env://",
|
199 |
+
type=str,
|
200 |
+
help="url used to set up distributed training",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--horovod",
|
207 |
+
default=False,
|
208 |
+
action="store_true",
|
209 |
+
help="Use horovod for distributed training.",
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
"--no-set-device-rank",
|
213 |
+
default=False,
|
214 |
+
action="store_true",
|
215 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
216 |
+
)
|
217 |
+
# wandb args
|
218 |
+
parser.add_argument("--report_to_wandb", default=False, action="store_true")
|
219 |
+
parser.add_argument(
|
220 |
+
"--wandb_project",
|
221 |
+
type=str,
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--wandb_entity",
|
225 |
+
type=str,
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--save_checkpoints_to_wandb",
|
229 |
+
default=False,
|
230 |
+
action="store_true",
|
231 |
+
help="save checkpoints to wandb",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--checkpoint_activations",
|
235 |
+
default=False,
|
236 |
+
action="store_true",
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--freeze_vision_encoder",
|
240 |
+
default=False,
|
241 |
+
action="store_true",
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--mmc4_textsim_threshold",
|
245 |
+
default=30,
|
246 |
+
type=float,
|
247 |
+
help="threshold for filtering images in mmc4 based on image-text similarity",
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--add_visual_grounding",
|
251 |
+
default=False,
|
252 |
+
action="store_true",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--location_token_num",
|
256 |
+
default=1000,
|
257 |
+
type=int,
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--vis_embed_size",
|
261 |
+
type=int,
|
262 |
+
required=False,
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--save_interval",
|
266 |
+
default=1000,
|
267 |
+
type=int,
|
268 |
+
required=False,
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--skip_delete_pattern",
|
272 |
+
default=1500,
|
273 |
+
type=int,
|
274 |
+
required=False,
|
275 |
+
)
|
276 |
+
parser.add_argument(
|
277 |
+
"--ddp",
|
278 |
+
default=False,
|
279 |
+
action="store_true",
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--pile_freq",
|
283 |
+
default=1,
|
284 |
+
type=int,
|
285 |
+
required=False,
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--restart",
|
289 |
+
default=False,
|
290 |
+
action="store_true",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--lora",
|
294 |
+
default=False,
|
295 |
+
action="store_true",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--lora_r",
|
299 |
+
default=16,
|
300 |
+
type=int,
|
301 |
+
required=False,
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--single",
|
305 |
+
default=False,
|
306 |
+
action="store_true",
|
307 |
+
)
|
308 |
+
|
309 |
+
# Finetune
|
310 |
+
parser.add_argument(
|
311 |
+
"--vqav2",
|
312 |
+
default=False,
|
313 |
+
action="store_true",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--fix-ffn",
|
317 |
+
default=False,
|
318 |
+
action="store_true",
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--prob_ground",
|
322 |
+
default=1.0,
|
323 |
+
type=float,
|
324 |
+
required=False,
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--optimizer",
|
328 |
+
default="adamw",
|
329 |
+
type=str,
|
330 |
+
required=False,
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--add_visual_token",
|
334 |
+
default=False,
|
335 |
+
action="store_true",
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--use_format_v2",
|
339 |
+
default=False,
|
340 |
+
action="store_true",
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--use_sam",
|
344 |
+
default=None,
|
345 |
+
type=str,
|
346 |
+
required=False,
|
347 |
+
)
|
348 |
+
parser.add_argument(
|
349 |
+
"--max-length",
|
350 |
+
default=608,
|
351 |
+
type=int,
|
352 |
+
required=False,
|
353 |
+
)
|
354 |
+
parser.add_argument(
|
355 |
+
"--image-size",
|
356 |
+
default=256,
|
357 |
+
type=int,
|
358 |
+
required=False,
|
359 |
+
)
|
360 |
+
|
361 |
+
|
362 |
+
args = parser.parse_args()
|
363 |
+
assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
|
364 |
+
|
365 |
+
if args.offline:
|
366 |
+
os.environ["WANDB_MODE"] = "offline"
|
367 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
368 |
+
|
369 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
370 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
371 |
+
device_id = init_distributed_device(args)
|
372 |
+
|
373 |
+
random_seed(args.seed)
|
374 |
+
|
375 |
+
model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
|
376 |
+
args.vision_encoder_path,
|
377 |
+
args.vision_encoder_pretrained,
|
378 |
+
args.lm_path,
|
379 |
+
args.tokenizer_path if args.tokenizer_path else args.lm_path,
|
380 |
+
use_local_files=args.offline,
|
381 |
+
use_media_placement_augmentation=args.use_media_placement_augmentation,
|
382 |
+
checkpoint_activations=args.checkpoint_activations,
|
383 |
+
freeze_vision_encoder=args.freeze_vision_encoder,
|
384 |
+
add_visual_grounding=args.add_visual_grounding,
|
385 |
+
location_token_num=args.location_token_num,
|
386 |
+
lora=args.lora,
|
387 |
+
lora_r=args.lora_r,
|
388 |
+
fix_ffn=args.fix_ffn,
|
389 |
+
add_visual_token=args.add_visual_token,
|
390 |
+
use_format_v2=args.use_format_v2,
|
391 |
+
use_sam=args.use_sam,
|
392 |
+
)
|
393 |
+
if args.rank == 0:
|
394 |
+
print(args)
|
395 |
+
print(image_processor)
|
396 |
+
|
397 |
+
random_seed(args.seed, args.rank)
|
398 |
+
|
399 |
+
if args.rank == 0 and args.report_to_wandb:
|
400 |
+
wandb.init(
|
401 |
+
project=args.wandb_project,
|
402 |
+
entity=args.wandb_entity,
|
403 |
+
name=args.run_name,
|
404 |
+
config=vars(args),
|
405 |
+
)
|
406 |
+
|
407 |
+
device_id = args.rank % torch.cuda.device_count()
|
408 |
+
if args.ddp:
|
409 |
+
print("use ddp mode")
|
410 |
+
model = model.to(device_id)
|
411 |
+
model = DDP(model)
|
412 |
+
else:
|
413 |
+
fpSixteen = MixedPrecision(
|
414 |
+
param_dtype=torch.float16,
|
415 |
+
# Gradient communication precision.
|
416 |
+
reduce_dtype=torch.float16,
|
417 |
+
# Buffer precision.
|
418 |
+
buffer_dtype=torch.float16,
|
419 |
+
)
|
420 |
+
# from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
421 |
+
from open_clip.transformer import ResidualAttentionBlock
|
422 |
+
from open_flamingo.src.flamingo_lm import FlamingoLayer
|
423 |
+
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
|
424 |
+
from segment_anything.modeling.image_encoder import Block
|
425 |
+
transformer_layer_cls=[
|
426 |
+
FlamingoLayer,
|
427 |
+
ResidualAttentionBlock,
|
428 |
+
Block,
|
429 |
+
]
|
430 |
+
if args.fix_ffn:
|
431 |
+
transformer_layer_cls.append(OPTAttention)
|
432 |
+
auto_wrap_policy = functools.partial(
|
433 |
+
transformer_auto_wrap_policy,
|
434 |
+
transformer_layer_cls=transformer_layer_cls,
|
435 |
+
)
|
436 |
+
if args.lora:
|
437 |
+
from torch.distributed.fsdp.wrap import _or_policy
|
438 |
+
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
439 |
+
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
|
440 |
+
ignored_modules = [model.vision_encoder]
|
441 |
+
# ignored_modules = None
|
442 |
+
else:
|
443 |
+
ignored_modules = None
|
444 |
+
model = FSDP(
|
445 |
+
model,
|
446 |
+
auto_wrap_policy=auto_wrap_policy,
|
447 |
+
mixed_precision=fpSixteen,
|
448 |
+
device_id=torch.cuda.current_device(),
|
449 |
+
ignored_modules=ignored_modules,
|
450 |
+
)
|
451 |
+
model = model.to(device_id)
|
452 |
+
|
453 |
+
|
454 |
+
pile_dataset = None
|
455 |
+
if args.vqav2:
|
456 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "vqav2")
|
457 |
+
else:
|
458 |
+
if args.pile_shards is not None:
|
459 |
+
pile_dataset = get_data(args, image_processor, tokenizer, "pile")
|
460 |
+
if args.add_visual_grounding:
|
461 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
|
462 |
+
else:
|
463 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "image_text")
|
464 |
+
|
465 |
+
|
466 |
+
optim_groups = get_grouped_params(model, args)
|
467 |
+
# optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
|
468 |
+
if args.ddp:
|
469 |
+
raise NotImplementedError
|
470 |
+
optimizer = ZeroRedundancyOptimizer(
|
471 |
+
optim_groups,
|
472 |
+
optimizer_class=torch.optim.AdamW,
|
473 |
+
lr=args.learning_rate,
|
474 |
+
parameters_as_bucket_view=True,
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
if args.optimizer == "adamw":
|
478 |
+
print("use adamw")
|
479 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
|
480 |
+
elif args.optimizer == "sgd":
|
481 |
+
print("use sgd...")
|
482 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
|
483 |
+
else:
|
484 |
+
raise NotImplementedError
|
485 |
+
|
486 |
+
total_training_steps = args.num_steps
|
487 |
+
|
488 |
+
if args.rank == 0:
|
489 |
+
print(f"Total training steps: {total_training_steps}")
|
490 |
+
|
491 |
+
if args.lr_scheduler == "linear":
|
492 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
493 |
+
optimizer,
|
494 |
+
num_warmup_steps=args.warmup_steps,
|
495 |
+
num_training_steps=total_training_steps,
|
496 |
+
)
|
497 |
+
elif args.lr_scheduler == "cosine":
|
498 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
499 |
+
optimizer,
|
500 |
+
num_warmup_steps=args.warmup_steps,
|
501 |
+
num_training_steps=total_training_steps,
|
502 |
+
)
|
503 |
+
else:
|
504 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
505 |
+
optimizer, num_warmup_steps=args.warmup_steps
|
506 |
+
)
|
507 |
+
if args.ddp:
|
508 |
+
scaler = GradScaler()
|
509 |
+
else:
|
510 |
+
scaler = ShardedGradScaler()
|
511 |
+
total_laion_token = 0
|
512 |
+
total_pile_token = 0
|
513 |
+
total_laion_sample = 0
|
514 |
+
total_step = 0
|
515 |
+
|
516 |
+
# check if a checkpoint exists for this run
|
517 |
+
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
|
518 |
+
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
|
519 |
+
if len(checkpoint_list) == 0:
|
520 |
+
print(f"Found no checkpoints for run {args.run_name}.")
|
521 |
+
else:
|
522 |
+
args.resume_from_checkpoint = sorted(
|
523 |
+
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
524 |
+
)[-1]
|
525 |
+
print(
|
526 |
+
f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}."
|
527 |
+
)
|
528 |
+
if args.resume_from_checkpoint is not None:
|
529 |
+
if args.rank == 0:
|
530 |
+
print(f"Loading checkpoint from {args.resume_from_checkpoint}")
|
531 |
+
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
|
532 |
+
if args.ddp:
|
533 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
534 |
+
sharded_osd = checkpoint['optimizer_state_dict']
|
535 |
+
else:
|
536 |
+
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
537 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
538 |
+
sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
|
539 |
+
if not args.restart and not args.vqav2:
|
540 |
+
# optimizer.load_state_dict(sharded_osd)
|
541 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
|
542 |
+
scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
543 |
+
total_laion_token = checkpoint.get("total_laion_token", 0)
|
544 |
+
total_pile_token = checkpoint.get("total_pile_token", 0)
|
545 |
+
total_laion_sample = checkpoint.get("total_laion_sample", 0)
|
546 |
+
total_step = checkpoint.get("total_step", 0)
|
547 |
+
else:
|
548 |
+
print("restart training / finetuning. only load model weight...")
|
549 |
+
del checkpoint
|
550 |
+
torch.cuda.empty_cache()
|
551 |
+
torch.distributed.barrier()
|
552 |
+
|
553 |
+
model.train()
|
554 |
+
if args.rank == 0:
|
555 |
+
if not os.path.exists(args.run_name):
|
556 |
+
os.makedirs(args.run_name)
|
557 |
+
writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
|
558 |
+
else:
|
559 |
+
writer = None
|
560 |
+
|
561 |
+
laion_dataset.set_epoch(total_step)
|
562 |
+
laion_loader = laion_dataset.dataloader
|
563 |
+
if pile_dataset is not None:
|
564 |
+
pile_dataset.set_epoch(total_step)
|
565 |
+
pile_loader = pile_dataset.dataloader
|
566 |
+
else:
|
567 |
+
pile_loader = FakeDataloader()
|
568 |
+
train_one_epoch(
|
569 |
+
args=args,
|
570 |
+
model=model,
|
571 |
+
tokenizer=tokenizer,
|
572 |
+
optimizer=optimizer,
|
573 |
+
lr_scheduler=lr_scheduler,
|
574 |
+
laion_loader=laion_loader,
|
575 |
+
pile_loader=pile_loader,
|
576 |
+
device_id=device_id,
|
577 |
+
writer=writer,
|
578 |
+
scaler=scaler,
|
579 |
+
optim_groups=optim_groups,
|
580 |
+
total_laion_token=total_laion_token,
|
581 |
+
total_pile_token=total_pile_token,
|
582 |
+
total_laion_sample=total_laion_sample,
|
583 |
+
total_step=total_step,
|
584 |
+
)
|
585 |
+
|
586 |
+
if __name__ == "__main__":
|
587 |
+
main()
|
open_flamingo/open_flamingo/train/train_utils.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from contextlib import suppress
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
import datetime
|
8 |
+
import os
|
9 |
+
import gc
|
10 |
+
from torch.distributed.fsdp import (
|
11 |
+
FullyShardedDataParallel as FSDP,
|
12 |
+
MixedPrecision,
|
13 |
+
BackwardPrefetch,
|
14 |
+
ShardingStrategy,
|
15 |
+
FullStateDictConfig,
|
16 |
+
StateDictType,
|
17 |
+
)
|
18 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
19 |
+
from torch.distributed.fsdp.wrap import (
|
20 |
+
transformer_auto_wrap_policy,
|
21 |
+
enable_wrap,
|
22 |
+
wrap,
|
23 |
+
)
|
24 |
+
|
25 |
+
from torch.utils.tensorboard import SummaryWriter
|
26 |
+
import logging
|
27 |
+
logging.basicConfig(
|
28 |
+
level=logging.INFO,
|
29 |
+
format='%(asctime)s %(message)s',
|
30 |
+
datefmt='%m/%d %I:%M:%S',
|
31 |
+
)
|
32 |
+
|
33 |
+
def get_cast_dtype(precision: str):
|
34 |
+
cast_dtype = None
|
35 |
+
if precision == "bf16":
|
36 |
+
cast_dtype = torch.bfloat16
|
37 |
+
elif precision == "fp16":
|
38 |
+
cast_dtype = torch.float16
|
39 |
+
return cast_dtype
|
40 |
+
|
41 |
+
|
42 |
+
def get_autocast(precision):
|
43 |
+
if precision == "amp_fp16":
|
44 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
|
45 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
46 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
47 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
48 |
+
else:
|
49 |
+
return suppress
|
50 |
+
|
51 |
+
|
52 |
+
def get_sync(model, flag):
|
53 |
+
if flag:
|
54 |
+
return suppress
|
55 |
+
else:
|
56 |
+
return lambda: model.no_sync()
|
57 |
+
|
58 |
+
def get_iou(box1, box2):
|
59 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
60 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
61 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
62 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
63 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
64 |
+
union = area_box1 + area_box2 - intersection
|
65 |
+
iou = intersection / union if union > 0 else 0
|
66 |
+
return iou
|
67 |
+
|
68 |
+
|
69 |
+
def train_one_epoch(
|
70 |
+
args,
|
71 |
+
model,
|
72 |
+
laion_loader,
|
73 |
+
pile_loader,
|
74 |
+
tokenizer,
|
75 |
+
optimizer,
|
76 |
+
lr_scheduler,
|
77 |
+
device_id,
|
78 |
+
writer: SummaryWriter,
|
79 |
+
optim_groups,
|
80 |
+
scaler,
|
81 |
+
total_laion_token: int,
|
82 |
+
total_pile_token: int,
|
83 |
+
total_laion_sample: int,
|
84 |
+
total_step: int,
|
85 |
+
):
|
86 |
+
world_size = torch.distributed.get_world_size()
|
87 |
+
autocast = get_autocast(args.precision)
|
88 |
+
cast_dtype = get_cast_dtype(args.precision)
|
89 |
+
|
90 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
91 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[
|
92 |
+
"input_ids"
|
93 |
+
][-1]
|
94 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
95 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
96 |
+
if args.add_visual_grounding:
|
97 |
+
obj_token_id = tokenizer("<|#obj#|>", add_special_tokens=False)["input_ids"][-1]
|
98 |
+
endofobj_token_id = tokenizer("<|#endofobj#|>", add_special_tokens=False)["input_ids"][-1]
|
99 |
+
loc_token_ids = []
|
100 |
+
for i in range(args.location_token_num):
|
101 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
102 |
+
min_loc_token_id = min(loc_token_ids)
|
103 |
+
max_loc_token_id = max(loc_token_ids)
|
104 |
+
|
105 |
+
# if args.vqav2:
|
106 |
+
# split_token_id = tokenizer(" Answer", add_special_tokens=False)["input_ids"][-1]
|
107 |
+
if args.rank == 0:
|
108 |
+
logging.info(f"train from: {total_step} step")
|
109 |
+
model.train()
|
110 |
+
# loop through dataloader
|
111 |
+
last_logging_step = total_step
|
112 |
+
last_save_step = total_step
|
113 |
+
for num_steps, (batch_laion, batch_pile) in tqdm(
|
114 |
+
enumerate(zip(laion_loader, pile_loader)),
|
115 |
+
disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
|
116 |
+
total=args.num_steps * args.gradient_accumulation_steps,
|
117 |
+
initial=total_step * args.gradient_accumulation_steps,
|
118 |
+
):
|
119 |
+
#### LAION FORWARD PASS ####
|
120 |
+
images = (
|
121 |
+
batch_laion[0]
|
122 |
+
.to(device_id, dtype=cast_dtype, non_blocking=True)
|
123 |
+
.unsqueeze(1)
|
124 |
+
.unsqueeze(1)
|
125 |
+
)
|
126 |
+
image_nums = batch_laion[1]
|
127 |
+
image_start_index_list = batch_laion[2]
|
128 |
+
|
129 |
+
# TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
|
130 |
+
input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
|
131 |
+
attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
|
132 |
+
# added_bbox_list = batch_laion[5] # list object
|
133 |
+
total_laion_token += int(attention_mask.sum().long()) * world_size
|
134 |
+
total_laion_sample += sum(image_nums) * world_size
|
135 |
+
|
136 |
+
added_visual_token_idx_list = []
|
137 |
+
added_bbox_list = []
|
138 |
+
|
139 |
+
labels = input_ids.clone()
|
140 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
141 |
+
labels[:, 0] = -100
|
142 |
+
labels.to(device_id)
|
143 |
+
current_laion_num = input_ids.shape[0]
|
144 |
+
|
145 |
+
#### PILE FORWARD PASS ####
|
146 |
+
if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
|
147 |
+
input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
|
148 |
+
attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
|
149 |
+
input_length = input_ids.shape[-1]
|
150 |
+
|
151 |
+
input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
|
152 |
+
attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)
|
153 |
+
|
154 |
+
labels2 = input_ids2.clone()
|
155 |
+
labels2[labels2 == tokenizer.pad_token_id] = -100
|
156 |
+
labels2[:, 0] = -100
|
157 |
+
labels2.to(device_id)
|
158 |
+
|
159 |
+
if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
|
160 |
+
image_nums = image_nums + [0] * len(input_ids2)
|
161 |
+
image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
|
162 |
+
input_ids = torch.cat([input_ids, input_ids2], dim=0)
|
163 |
+
attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
|
164 |
+
added_bbox_list = added_bbox_list + [[]] * len(input_ids2)
|
165 |
+
added_visual_token_idx_list = added_visual_token_idx_list + [[]] * len(input_ids2)
|
166 |
+
labels = torch.cat([labels, labels2], dim=0)
|
167 |
+
total_pile_token += int(attention_mask2.sum().long()) * world_size
|
168 |
+
else:
|
169 |
+
del input_ids2
|
170 |
+
del attention_mask2
|
171 |
+
del labels2
|
172 |
+
# if args.vqav2:
|
173 |
+
# for i in range(len(input_ids)):
|
174 |
+
# answer_start_idx = (input_ids[i] == 31652).nonzero()[0,0] + 2
|
175 |
+
# labels[i, :answer_start_idx] = -100
|
176 |
+
|
177 |
+
update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
|
178 |
+
# do_sync = get_sync(model, update_flag)
|
179 |
+
with autocast():
|
180 |
+
# modify:
|
181 |
+
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
|
182 |
+
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
|
183 |
+
# CrossEntropyLoss(reduction="none")
|
184 |
+
outputs = model(
|
185 |
+
vision_x=images,
|
186 |
+
lang_x=input_ids,
|
187 |
+
attention_mask=attention_mask,
|
188 |
+
labels=labels,
|
189 |
+
image_nums=image_nums,
|
190 |
+
image_start_index_list=image_start_index_list,
|
191 |
+
added_bbox_list=added_bbox_list,
|
192 |
+
added_visual_token_idx_list=added_visual_token_idx_list,
|
193 |
+
)
|
194 |
+
loss_total = outputs.loss.reshape(labels.shape[0], -1)
|
195 |
+
iou = 0.0
|
196 |
+
loss_loc = 0.0
|
197 |
+
if args.add_visual_grounding:
|
198 |
+
with torch.no_grad():
|
199 |
+
shift_logits = outputs.logits.detach().cpu()[..., :-1, :].contiguous()
|
200 |
+
shift_labels = labels.detach().cpu()[..., 1:].contiguous()
|
201 |
+
loc_mask = (shift_labels >= min_loc_token_id) * (shift_labels <= max_loc_token_id)
|
202 |
+
if loc_mask.any():
|
203 |
+
loss_loc = loss_total[loc_mask].detach().mean().item()
|
204 |
+
try:
|
205 |
+
loc_labels = (shift_labels[loc_mask] - min_loc_token_id)
|
206 |
+
loc_logits = shift_logits[loc_mask]
|
207 |
+
loc_logits = loc_logits[:, min_loc_token_id:max_loc_token_id+1]
|
208 |
+
loc_preds = loc_logits.argmax(-1)
|
209 |
+
loc_labels = loc_labels.reshape(-1, 4).numpy()
|
210 |
+
loc_preds = loc_preds.reshape(-1, 4).numpy()
|
211 |
+
ious = []
|
212 |
+
for box1, box2 in zip(loc_preds, loc_labels):
|
213 |
+
ious.append(get_iou(box1, box2))
|
214 |
+
iou = np.mean(ious)
|
215 |
+
except:
|
216 |
+
pass
|
217 |
+
loss_caption = loss_total[:current_laion_num][~loc_mask[:current_laion_num]].detach()
|
218 |
+
loss_caption = loss_caption.sum() / (loss_caption != 0).sum()
|
219 |
+
loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
|
220 |
+
loss_laion = loss_sample[:current_laion_num].mean()
|
221 |
+
if not args.add_visual_grounding:
|
222 |
+
loss_caption = loss_laion
|
223 |
+
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
|
224 |
+
if current_laion_num != loss_sample.shape[0]:
|
225 |
+
loss_pile = loss_sample[current_laion_num:].mean()
|
226 |
+
else:
|
227 |
+
loss_pile = torch.tensor(0.0).cuda()
|
228 |
+
divided_loss_pile = loss_pile / args.gradient_accumulation_steps
|
229 |
+
loss = (
|
230 |
+
divided_loss_laion * args.loss_multiplier_laion +
|
231 |
+
divided_loss_pile * args.loss_multiplier_pile
|
232 |
+
)
|
233 |
+
|
234 |
+
scaler.scale(loss).backward()
|
235 |
+
|
236 |
+
# for logging only
|
237 |
+
loss = (
|
238 |
+
loss_laion * args.loss_multiplier_laion
|
239 |
+
+ loss_pile * args.loss_multiplier_pile
|
240 |
+
).detach()
|
241 |
+
|
242 |
+
if torch.isnan(loss_laion) or torch.isnan(loss_pile):
|
243 |
+
logging.error(f"NaN!!! laion: {loss_laion.item()} pile: {loss_pile.item()}")
|
244 |
+
torch.distributed.barrier()
|
245 |
+
exit(0)
|
246 |
+
|
247 |
+
# step optimizer and log
|
248 |
+
if update_flag:
|
249 |
+
#### MASK GRADIENTS FOR EMBEDDINGS ####
|
250 |
+
# Note (anas): Do not apply weight decay to embeddings as it will break this function.
|
251 |
+
# ! not an important point
|
252 |
+
if args.ddp:
|
253 |
+
def mask_embedding(m):
|
254 |
+
if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
|
255 |
+
zero_mask = torch.zeros_like(m.weight.grad)
|
256 |
+
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
|
257 |
+
zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
|
258 |
+
zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
|
259 |
+
if args.add_visual_grounding:
|
260 |
+
zero_mask[loc_token_ids] = torch.ones_like(zero_mask[loc_token_ids])
|
261 |
+
zero_mask[obj_token_id] = torch.ones_like(zero_mask[obj_token_id])
|
262 |
+
zero_mask[endofobj_token_id] = torch.ones_like(zero_mask[endofobj_token_id])
|
263 |
+
m.weight.grad = m.weight.grad * zero_mask
|
264 |
+
model.apply(mask_embedding)
|
265 |
+
total_step += 1
|
266 |
+
scaler.unscale_(optimizer)
|
267 |
+
if args.ddp:
|
268 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
269 |
+
else:
|
270 |
+
model.clip_grad_norm_(1.0)
|
271 |
+
scaler.step(optimizer)
|
272 |
+
scaler.update()
|
273 |
+
lr_scheduler.step()
|
274 |
+
optimizer.zero_grad()
|
275 |
+
# https://github.com/facebookresearch/fairscale/issues/627
|
276 |
+
model.zero_grad(set_to_none=True)
|
277 |
+
|
278 |
+
if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
|
279 |
+
last_logging_step = total_step
|
280 |
+
global_step = total_step
|
281 |
+
lr = optimizer.param_groups[0]["lr"]
|
282 |
+
writer.add_scalar("lr", lr, global_step)
|
283 |
+
writer.add_scalar("scale", scaler.get_scale(), global_step)
|
284 |
+
writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
|
285 |
+
writer.add_scalar("loss_laion", loss_caption.item(), global_step)
|
286 |
+
writer.add_scalar("loss_pile", loss_pile.item(), global_step)
|
287 |
+
writer.add_scalar("loss_loc", loss_loc, global_step)
|
288 |
+
writer.add_scalar("loss", loss.item(), global_step)
|
289 |
+
writer.add_scalar("iou", iou, global_step)
|
290 |
+
|
291 |
+
global_sample_num = total_laion_sample
|
292 |
+
writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
|
293 |
+
writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
|
294 |
+
writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
|
295 |
+
writer.add_scalar("loss_loc_vs_sample_num", loss_loc, global_sample_num)
|
296 |
+
writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
|
297 |
+
writer.add_scalar("iou_vs_sample_num", iou, global_sample_num)
|
298 |
+
writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)
|
299 |
+
|
300 |
+
writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
|
301 |
+
writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
|
302 |
+
writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
|
303 |
+
writer.add_scalar("loss_loc_vs_token", loss_loc, total_laion_token)
|
304 |
+
writer.add_scalar("iou_vs_token", iou, total_laion_token)
|
305 |
+
|
306 |
+
total_token = total_laion_token + total_pile_token
|
307 |
+
writer.add_scalar("sample_num", global_sample_num, global_step)
|
308 |
+
writer.add_scalar("total_laion_token", total_laion_token, global_step)
|
309 |
+
writer.add_scalar("total_pile_token", total_pile_token, global_step)
|
310 |
+
writer.add_scalar("total_token", total_token, global_step)
|
311 |
+
logging.info(
|
312 |
+
f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // glaion: {loss_laion.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // loc: {loss_loc:.3f} // iou: {iou:.4f} lr: {lr:.2e} // scale: {scaler.get_scale()}"
|
313 |
+
)
|
314 |
+
|
315 |
+
if total_step % args.save_interval == 0 and total_step != last_save_step:
|
316 |
+
last_save_step = total_step
|
317 |
+
torch.distributed.barrier()
|
318 |
+
if args.ddp:
|
319 |
+
cpu_state = model.state_dict()
|
320 |
+
optimizer.consolidate_state_dict(0)
|
321 |
+
if args.rank == 0:
|
322 |
+
optimizer_state = optimizer.state_dict()
|
323 |
+
else:
|
324 |
+
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
325 |
+
with FSDP.state_dict_type(
|
326 |
+
model, StateDictType.FULL_STATE_DICT, save_policy
|
327 |
+
):
|
328 |
+
cpu_state = model.state_dict()
|
329 |
+
torch.distributed.barrier()
|
330 |
+
# https://pytorch.org/docs/1.12/fsdp.html
|
331 |
+
# need to pass optim_groups as optim_input
|
332 |
+
optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
|
333 |
+
if args.rank == 0:
|
334 |
+
checkpoint_dict = {
|
335 |
+
"model_state_dict": cpu_state,
|
336 |
+
"optimizer_state_dict": optimizer_state,
|
337 |
+
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
|
338 |
+
"scaler_state_dict": scaler.state_dict(),
|
339 |
+
"total_pile_token": total_pile_token,
|
340 |
+
"total_laion_token": total_laion_token,
|
341 |
+
"total_laion_sample": total_laion_sample,
|
342 |
+
"total_step": total_step,
|
343 |
+
}
|
344 |
+
logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
|
345 |
+
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
|
346 |
+
del checkpoint_dict
|
347 |
+
if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
|
348 |
+
try:
|
349 |
+
os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
|
350 |
+
except:
|
351 |
+
pass
|
352 |
+
torch.distributed.barrier()
|
353 |
+
|
354 |
+
|
355 |
+
class AverageMeter(object):
|
356 |
+
"""Computes and stores the average and current value"""
|
357 |
+
|
358 |
+
def __init__(self):
|
359 |
+
self.reset()
|
360 |
+
|
361 |
+
def reset(self):
|
362 |
+
self.val = 0
|
363 |
+
self.avg = 0
|
364 |
+
self.sum = 0
|
365 |
+
self.count = 0
|
366 |
+
|
367 |
+
def update(self, val, n=1):
|
368 |
+
self.val = val
|
369 |
+
self.sum += val * n
|
370 |
+
self.count += n
|
371 |
+
self.avg = self.sum / self.count
|
open_flamingo/requirements-dev.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
black
|
2 |
+
mypy
|
3 |
+
pylint
|
4 |
+
pytest
|
5 |
+
requests
|
open_flamingo/requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
einops-exts
|
3 |
+
transformers
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
pillow
|
7 |
+
more-itertools
|
8 |
+
datasets
|
9 |
+
braceexpand
|
10 |
+
webdataset
|
11 |
+
wandb
|
12 |
+
nltk
|
13 |
+
scipy
|
14 |
+
inflection
|
15 |
+
sentencepiece
|
16 |
+
open_clip_torch
|
open_flamingo/setup.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from setuptools import find_packages, setup
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
|
7 |
+
long_description = file.read()
|
8 |
+
|
9 |
+
# TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this.
|
10 |
+
# def _read_reqs(relpath):
|
11 |
+
# fullpath = os.path.join(Path(__file__).parent, relpath)
|
12 |
+
# with open(fullpath) as f:
|
13 |
+
# return [
|
14 |
+
# s.strip()
|
15 |
+
# for s in f.readlines()
|
16 |
+
# if (s.strip() and not s.startswith("#"))
|
17 |
+
# ]
|
18 |
+
|
19 |
+
REQUIREMENTS = [
|
20 |
+
"einops",
|
21 |
+
"einops-exts",
|
22 |
+
"transformers",
|
23 |
+
"torch",
|
24 |
+
"torchvision",
|
25 |
+
"pillow",
|
26 |
+
"more-itertools",
|
27 |
+
"datasets",
|
28 |
+
"braceexpand",
|
29 |
+
"webdataset",
|
30 |
+
"wandb",
|
31 |
+
"nltk",
|
32 |
+
"scipy",
|
33 |
+
"inflection",
|
34 |
+
"sentencepiece",
|
35 |
+
"open_clip_torch",
|
36 |
+
"opencv-python"
|
37 |
+
]
|
38 |
+
|
39 |
+
setup(
|
40 |
+
name="open_flamingo",
|
41 |
+
packages=find_packages(),
|
42 |
+
include_package_data=True,
|
43 |
+
version="0.0.2",
|
44 |
+
license="MIT",
|
45 |
+
description="An open-source framework for training large multimodal models",
|
46 |
+
long_description=long_description,
|
47 |
+
long_description_content_type="text/markdown",
|
48 |
+
data_files=[(".", ["README.md"])],
|
49 |
+
keywords=["machine learning"],
|
50 |
+
install_requires=REQUIREMENTS,
|
51 |
+
classifiers=[
|
52 |
+
"Development Status :: 4 - Beta",
|
53 |
+
"Intended Audience :: Developers",
|
54 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
55 |
+
"License :: OSI Approved :: MIT License",
|
56 |
+
"Programming Language :: Python :: 3.9",
|
57 |
+
],
|
58 |
+
)
|
open_flamingo/tests/test_flamingo_model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import unittest
|
2 |
+
|
3 |
+
# import requests
|
4 |
+
# from PIL import Image
|
5 |
+
|
6 |
+
# from open_flamingo import create_model_and_transforms
|
7 |
+
|
8 |
+
|
9 |
+
# class TestFlamingoModel(unittest.TestCase):
|
10 |
+
# def test_forward_pass(self):
|
11 |
+
# model, image_processor, tokenizer = create_model_and_transforms(
|
12 |
+
# clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
|
13 |
+
# clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
|
14 |
+
# lang_encoder_path="hf-internal-testing/tiny-random-OPTModel",
|
15 |
+
# tokenizer_path="hf-internal-testing/tiny-random-OPTModel",
|
16 |
+
# )
|
17 |
+
|
18 |
+
# image = Image.open(
|
19 |
+
# requests.get(
|
20 |
+
# "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
|
21 |
+
# ).raw
|
22 |
+
# )
|
23 |
+
# vis_x = image_processor(images=[image, image], return_tensors="pt")[
|
24 |
+
# "pixel_values"
|
25 |
+
# ]
|
26 |
+
# vis_x = vis_x.unsqueeze(1).unsqueeze(1)
|
27 |
+
# lang_x = tokenizer(
|
28 |
+
# ["<|#image#|> A dog", "<|#image#|> A cat"],
|
29 |
+
# max_length=10,
|
30 |
+
# padding=True,
|
31 |
+
# truncation=True,
|
32 |
+
# return_tensors="pt",
|
33 |
+
# )
|
34 |
+
|
35 |
+
# # try batched forward pass
|
36 |
+
# model(vis_x, lang_x["input_ids"], attention_mask=lang_x["attention_mask"])
|
37 |
+
|
38 |
+
# def test_generate(self):
|
39 |
+
# model, image_processor, tokenizer = create_model_and_transforms(
|
40 |
+
# clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
|
41 |
+
# clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
|
42 |
+
# lang_encoder_path="hf-internal-testing/tiny-random-OPTModel",
|
43 |
+
# tokenizer_path="hf-internal-testing/tiny-random-OPTModel",
|
44 |
+
# )
|
45 |
+
|
46 |
+
# tokenizer.padding_side = (
|
47 |
+
# "left" # we want to pad on the left side for generation
|
48 |
+
# )
|
49 |
+
|
50 |
+
# image = Image.open(
|
51 |
+
# requests.get(
|
52 |
+
# "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
|
53 |
+
# ).raw
|
54 |
+
# )
|
55 |
+
# vis_x = image_processor(images=[image, image], return_tensors="pt")[
|
56 |
+
# "pixel_values"
|
57 |
+
# ]
|
58 |
+
# vis_x = vis_x.unsqueeze(1).unsqueeze(1)
|
59 |
+
# lang_x = tokenizer(
|
60 |
+
# ["<|#image#|> A dog", "<|#image#|> A cat <|endofchunk|>"],
|
61 |
+
# max_length=10,
|
62 |
+
# padding=True,
|
63 |
+
# truncation=True,
|
64 |
+
# return_tensors="pt",
|
65 |
+
# )
|
66 |
+
|
67 |
+
# # try batched generation
|
68 |
+
# model.generate(
|
69 |
+
# vis_x,
|
70 |
+
# lang_x["input_ids"],
|
71 |
+
# attention_mask=lang_x["attention_mask"],
|
72 |
+
# max_new_tokens=20,
|
73 |
+
# )
|
74 |
+
|
75 |
+
|
76 |
+
# if __name__ == "__main__":
|
77 |
+
# unittest.main()
|
open_flamingo/tools/check_refcoco.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
captions = []
|
8 |
+
with open(sys.argv[1]) as f:
|
9 |
+
for line in tqdm(f):
|
10 |
+
line = line.rstrip().split("\t")
|
11 |
+
caption = line[2]
|
12 |
+
captions.append(caption)
|
13 |
+
lengths = [len(c.split(" ")) for c in captions]
|
14 |
+
print(np.mean(lengths))
|
open_flamingo/tools/convert_mmc4_to_wds.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import base64
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import tarfile
|
6 |
+
import uuid
|
7 |
+
import zipfile
|
8 |
+
import time
|
9 |
+
|
10 |
+
import braceexpand
|
11 |
+
import webdataset as wds
|
12 |
+
from tqdm import tqdm
|
13 |
+
from tqdm.contrib.concurrent import process_map
|
14 |
+
|
15 |
+
arg_parser = argparse.ArgumentParser()
|
16 |
+
arg_parser.add_argument("--output_dir", type=str)
|
17 |
+
arg_parser.add_argument(
|
18 |
+
"--image_shards",
|
19 |
+
type=str,
|
20 |
+
help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar",
|
21 |
+
)
|
22 |
+
arg_parser.add_argument(
|
23 |
+
"--doc_shards",
|
24 |
+
type=str,
|
25 |
+
help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip",
|
26 |
+
)
|
27 |
+
arg_parser.add_argument(
|
28 |
+
"--thread",
|
29 |
+
type=int,
|
30 |
+
default=128,
|
31 |
+
)
|
32 |
+
args = arg_parser.parse_args()
|
33 |
+
|
34 |
+
def get_txt_to_filename_dict(image_shards, disable_tqdm=False):
|
35 |
+
txt_to_filename_dict = {}
|
36 |
+
dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json")
|
37 |
+
for data in tqdm(dataset, disable=disable_tqdm):
|
38 |
+
txt = data[0].split(".")[0]
|
39 |
+
txt_to_filename_dict[txt] = data[1]['key']
|
40 |
+
return txt_to_filename_dict
|
41 |
+
|
42 |
+
|
43 |
+
def single_thread(args):
|
44 |
+
i = args["i"]
|
45 |
+
output_dir = args["output_dir"]
|
46 |
+
doc_shards = args["doc_shards"]
|
47 |
+
image_shards = args["image_shards"]
|
48 |
+
if i == 0:
|
49 |
+
tqdm.write(f"output_dir: {output_dir}")
|
50 |
+
tqdm.write(f"doc_shards: {doc_shards[:5]}")
|
51 |
+
tqdm.write(f"image_shards: {image_shards[:5]}")
|
52 |
+
with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink:
|
53 |
+
sink.verbose = False
|
54 |
+
for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)):
|
55 |
+
# txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0))
|
56 |
+
# image_tar = tarfile.open(image_shard)
|
57 |
+
# Open the ZIP archive and extract the JSON file
|
58 |
+
with zipfile.ZipFile(doc_shard, "r") as zip_file:
|
59 |
+
# Assumes the JSON file is the first file in the archive
|
60 |
+
json_filename = zip_file.namelist()[0]
|
61 |
+
with zip_file.open(json_filename, "r") as json_file:
|
62 |
+
pbar = tqdm(json_file, disable=True)
|
63 |
+
total_num = 0
|
64 |
+
exist_num = 0
|
65 |
+
for sample_data in pbar:
|
66 |
+
# get image names from json
|
67 |
+
sample_data = json.loads(sample_data)
|
68 |
+
image_info = sample_data["image_info"]
|
69 |
+
image_names = [image["image_name"] for image in image_info]
|
70 |
+
|
71 |
+
# Add each image to the tar file
|
72 |
+
for img_idx, image_name in enumerate(image_names):
|
73 |
+
total_num += 1
|
74 |
+
try:
|
75 |
+
image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg")
|
76 |
+
# convert to base64
|
77 |
+
image_bytes = image.read()
|
78 |
+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
79 |
+
exist_num += 1
|
80 |
+
except:
|
81 |
+
tqdm.write(f"{image_name.split('.')[0]}")
|
82 |
+
image_base64 = "null"
|
83 |
+
sample_data["image_info"][img_idx][
|
84 |
+
"image_base64"
|
85 |
+
] = image_base64
|
86 |
+
|
87 |
+
key_str = uuid.uuid4().hex
|
88 |
+
sink.write({"__key__": key_str, "json": sample_data})
|
89 |
+
pbar.set_description(f"{exist_num/total_num:.2f}")
|
90 |
+
# image_tar.close()
|
91 |
+
|
92 |
+
|
93 |
+
def main():
|
94 |
+
timestamp = int(time.time())
|
95 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
96 |
+
os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True)
|
97 |
+
tasks = []
|
98 |
+
for i in range(args.thread):
|
99 |
+
thread_dir = os.path.join(args.output_dir, str(timestamp), str(i))
|
100 |
+
os.makedirs(thread_dir, exist_ok=True)
|
101 |
+
tasks.append({
|
102 |
+
"i": i,
|
103 |
+
"output_dir": thread_dir,
|
104 |
+
"doc_shards": [],
|
105 |
+
"image_shards": [],
|
106 |
+
})
|
107 |
+
|
108 |
+
doc_shards = list(braceexpand.braceexpand(args.doc_shards))
|
109 |
+
image_shards = list(braceexpand.braceexpand(args.image_shards))
|
110 |
+
|
111 |
+
assert len(doc_shards) == len(
|
112 |
+
image_shards
|
113 |
+
), "Each doc shards must have a corresponding image shard"
|
114 |
+
|
115 |
+
for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)):
|
116 |
+
tasks[i % args.thread]["doc_shards"].append(doc_shard)
|
117 |
+
tasks[i % args.thread]["image_shards"].append(image_shard)
|
118 |
+
|
119 |
+
# assert len(tasks) == args.thread
|
120 |
+
# process_map(single_thread, tasks, max_workers=args.thread, disable=True)
|
121 |
+
single_thread(tasks[0])
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
open_flamingo/tools/make_gqa_val.py
ADDED
File without changes
|
open_flamingo/tools/make_mmc4_global_table.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
from tqdm.contrib.concurrent import process_map
|
6 |
+
import pickle as pkl
|
7 |
+
|
8 |
+
|
9 |
+
def single_thread(filename):
|
10 |
+
id_table = {}
|
11 |
+
dataset = wds.WebDataset(filename).decode().to_tuple("json")
|
12 |
+
for data in dataset:
|
13 |
+
data = data[0]
|
14 |
+
image_id = data["caption"].split(".")[0]
|
15 |
+
image_key = data["key"]
|
16 |
+
tarfile = os.path.basename(filename)
|
17 |
+
if image_id not in id_table:
|
18 |
+
id_table[image_id] = [tarfile, image_key]
|
19 |
+
return id_table
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
filenames = sorted(glob.glob("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/mmc4/images/*.tar"))[:16000]
|
23 |
+
print("start from", filenames[0])
|
24 |
+
print("to", filenames[-1])
|
25 |
+
id_tables = process_map(single_thread, filenames, max_workers=64)
|
26 |
+
id_table = {}
|
27 |
+
for table in tqdm(id_tables):
|
28 |
+
id_table.update(table)
|
29 |
+
print("total unique image:", len(id_table))
|
30 |
+
pkl.dump(id_table, open("mmc4_id_table.pkl", "wb"))
|
31 |
+
print("DONE")
|
open_flamingo/tools/make_soft_link.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import glob
|
4 |
+
import random
|
5 |
+
|
6 |
+
DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
|
7 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_mini_dataset_full_karpathy"
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
12 |
+
cc3m_tars = glob.glob(os.path.join(DIR, "cc3m", "cc3m_*", "*.tar"))
|
13 |
+
cc12m_tars = glob.glob(os.path.join(DIR, "cc12m", "tars", "*.tar"))
|
14 |
+
coco_tars = glob.glob(os.path.join(DIR, "karpathy_coco_wds_full", "*.tar"))
|
15 |
+
vg_tars = glob.glob(os.path.join(DIR, "vg_wds_full", "*.tar"))
|
16 |
+
tars = []
|
17 |
+
tars.extend(cc3m_tars)
|
18 |
+
tars.extend(cc12m_tars)
|
19 |
+
tars.extend(coco_tars)
|
20 |
+
tars.extend(vg_tars)
|
21 |
+
random.shuffle(tars)
|
22 |
+
for i, tar in enumerate(tars):
|
23 |
+
dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
|
24 |
+
print(tar, dst)
|
25 |
+
os.symlink(tar, dst)
|
26 |
+
|
open_flamingo/tools/make_soft_link_blip2_data.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import glob
|
4 |
+
import random
|
5 |
+
from pprint import pprint
|
6 |
+
|
7 |
+
DIR_COCO_VG = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
|
8 |
+
DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining"
|
9 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_ground"
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
14 |
+
ccs_tars = glob.glob(os.path.join(DIR, "ccs_synthetic_filtered_large_ground", "*.tar"))
|
15 |
+
coco_tars = glob.glob(os.path.join(DIR_COCO_VG, "karpathy_coco_wds_full_ground", "*.tar"))
|
16 |
+
vg_tars = glob.glob(os.path.join(DIR_COCO_VG, "vg_wds_full_ground", "*.tar"))
|
17 |
+
laion_part_tars = glob.glob(os.path.join(DIR, "laion_synthetic_filtered_large", "all_ground", "*.tar"))
|
18 |
+
tars = []
|
19 |
+
tars.extend(ccs_tars)
|
20 |
+
for _ in range(5):
|
21 |
+
tars.extend(coco_tars)
|
22 |
+
tars.extend(vg_tars)
|
23 |
+
tars.extend(laion_part_tars)
|
24 |
+
random.shuffle(tars)
|
25 |
+
print(len(tars))
|
26 |
+
pprint(tars[:20])
|
27 |
+
for i, tar in enumerate(tars):
|
28 |
+
dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
|
29 |
+
# print(tar, dst)
|
30 |
+
os.symlink(tar, dst)
|
open_flamingo/tools/make_soft_link_laion.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import glob
|
4 |
+
import random
|
5 |
+
from pprint import pprint
|
6 |
+
|
7 |
+
DIR_COCO_VG = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
|
8 |
+
DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/"
|
9 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all"
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
14 |
+
tars = []
|
15 |
+
for i in range(10):
|
16 |
+
laion_part_tars = glob.glob(os.path.join(DIR, "laion_synthetic_filtered_large", f"part{i}", "*.tar"))
|
17 |
+
tars.extend(laion_part_tars)
|
18 |
+
print(len(tars))
|
19 |
+
pprint(tars[:20])
|
20 |
+
for i, tar in enumerate(tars):
|
21 |
+
dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
|
22 |
+
# print(tar, dst)
|
23 |
+
os.symlink(tar, dst)
|
open_flamingo/tools/make_vqav2_ft_dataset.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import base64
|
7 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_train_wds"
|
8 |
+
TOTAL = 1828467
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
with wds.ShardWriter(os.path.join(OUT_DIR, "%06d.tar"), maxcount=10000) as sink:
|
12 |
+
sink.verbose = False
|
13 |
+
f = open("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_ofa/vqa_train.tsv")
|
14 |
+
for data in tqdm(f, total=TOTAL):
|
15 |
+
data = data.rstrip().split("\t")
|
16 |
+
id1 = data[0]
|
17 |
+
id2 = data[1]
|
18 |
+
question = data[2]
|
19 |
+
answer = data[3].split("|!+")[-1]
|
20 |
+
image = data[5]
|
21 |
+
id3 = data[6]
|
22 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
23 |
+
caption = f"Question: {question.strip()} Answer: {answer.strip()}"
|
24 |
+
sink.write({"__key__": f"vqav2_{id1}_{id2}_{id3}", "jpg": image, "txt": caption})
|
open_flamingo/tools/prepare_mini_blip2_dataset.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
import orjson as json
|
6 |
+
import itertools
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
class Generator():
|
12 |
+
def __init__(self, dataset_name):
|
13 |
+
self.dataset_name = dataset_name
|
14 |
+
self.is_end = False
|
15 |
+
|
16 |
+
class CC3MGenerator(Generator):
|
17 |
+
def __init__(self, root: str, dataset_name="cc3m"):
|
18 |
+
super().__init__(dataset_name=dataset_name)
|
19 |
+
self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar"))
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return 3000000
|
23 |
+
|
24 |
+
def __iter__(self):
|
25 |
+
for tar in self.tars:
|
26 |
+
dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt")
|
27 |
+
for data in dataset:
|
28 |
+
yield [self.dataset_name] + list(data)
|
29 |
+
self.is_end = True
|
30 |
+
|
31 |
+
class CC12MGenerator(CC3MGenerator):
|
32 |
+
def __init__(self, root: str):
|
33 |
+
super().__init__(root, "cc12m")
|
34 |
+
self.tars = glob.glob(os.path.join(root, "*.tar"))
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return 12000000
|
38 |
+
|
39 |
+
class COCOGenerator(Generator):
|
40 |
+
def __init__(self, anno: str, image_dir):
|
41 |
+
super().__init__(dataset_name="coco")
|
42 |
+
data = json.loads(open(anno).read())
|
43 |
+
self.annotations = data["annotations"]
|
44 |
+
self.image_id_to_filename = {}
|
45 |
+
for image in data["images"]:
|
46 |
+
file_name = image["file_name"]
|
47 |
+
image_id = image["id"]
|
48 |
+
self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name)
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.annotations)
|
52 |
+
|
53 |
+
def __iter__(self):
|
54 |
+
for anno in self.annotations:
|
55 |
+
image_id = anno["image_id"]
|
56 |
+
caption = anno["caption"]
|
57 |
+
try:
|
58 |
+
image = Image.open(self.image_id_to_filename[image_id])
|
59 |
+
except:
|
60 |
+
continue
|
61 |
+
yield [self.dataset_name, image, caption]
|
62 |
+
self.is_end = True
|
63 |
+
|
64 |
+
|
65 |
+
class KarpathyCOCOGenerator(Generator):
|
66 |
+
def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"):
|
67 |
+
super().__init__(dataset_name="coco")
|
68 |
+
data = json.loads(open(anno).read())
|
69 |
+
self.annotations = data
|
70 |
+
self.image_id_to_filename = {}
|
71 |
+
for d in data:
|
72 |
+
self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"])
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.annotations)
|
76 |
+
|
77 |
+
def __iter__(self):
|
78 |
+
for anno in self.annotations:
|
79 |
+
image_id = anno["image_id"]
|
80 |
+
caption = anno["caption"]
|
81 |
+
try:
|
82 |
+
image = Image.open(self.image_id_to_filename[image_id])
|
83 |
+
except:
|
84 |
+
print(self.image_id_to_filename[image_id])
|
85 |
+
yield [self.dataset_name, image, caption]
|
86 |
+
self.is_end = True
|
87 |
+
|
88 |
+
|
89 |
+
class VisualGenomeGenerator(Generator):
|
90 |
+
def __init__(self, root: str):
|
91 |
+
super().__init__(dataset_name="vg")
|
92 |
+
data = json.loads(open(os.path.join(root, "region_descriptions.json")).read())
|
93 |
+
image_data = json.loads(open(os.path.join(root, "image_data.json")).read())
|
94 |
+
self.image_id_to_filename = {}
|
95 |
+
self.image_id_to_wh = {}
|
96 |
+
for image in image_data:
|
97 |
+
image_id = image["image_id"]
|
98 |
+
subfolder, filename = image['url'].split("/")[-2:]
|
99 |
+
self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
|
100 |
+
self.image_id_to_wh[image_id] = (image["width"], image["height"])
|
101 |
+
self.regions = []
|
102 |
+
total = 0
|
103 |
+
total_image = 0
|
104 |
+
used_image = 0
|
105 |
+
for xx in data:
|
106 |
+
total_image += 1
|
107 |
+
flag = False
|
108 |
+
for region in xx['regions']:
|
109 |
+
total += 1
|
110 |
+
region_w = int(region["width"])
|
111 |
+
region_h = int(region["height"])
|
112 |
+
image_w = self.image_id_to_wh[region["image_id"]][0]
|
113 |
+
image_h = self.image_id_to_wh[region["image_id"]][1]
|
114 |
+
if region_w * region_h < (image_w * image_h) * 0.2:
|
115 |
+
continue
|
116 |
+
self.regions.append(region)
|
117 |
+
flag = True
|
118 |
+
if flag:
|
119 |
+
used_image += 1
|
120 |
+
print("valid region", len(self.regions), total, len(self.regions) / total)
|
121 |
+
print("valid image", used_image, total_image, used_image / total_image)
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.regions)
|
125 |
+
|
126 |
+
def __iter__(self):
|
127 |
+
for region in self.regions:
|
128 |
+
image_id = region["image_id"]
|
129 |
+
phrase = region["phrase"]
|
130 |
+
try:
|
131 |
+
image = Image.open(self.image_id_to_filename[image_id])
|
132 |
+
except:
|
133 |
+
continue
|
134 |
+
yield [self.dataset_name, image, phrase]
|
135 |
+
self.is_end = True
|
136 |
+
|
137 |
+
class ShuffleGenerator():
|
138 |
+
def __init__(self, generators: List[Generator], p: List[int]):
|
139 |
+
self.generators = generators
|
140 |
+
self.p = list(np.array(p) / sum(p))
|
141 |
+
self.ids = list(range(len(self.generators)))
|
142 |
+
print("rebalance", self.ids, self.p)
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return sum([len(g) for g in self.generators])
|
146 |
+
|
147 |
+
def __iter__(self):
|
148 |
+
while True:
|
149 |
+
if len(self.ids) == 0:
|
150 |
+
break
|
151 |
+
id = np.random.choice(self.ids, p=self.p)
|
152 |
+
gen = self.generators[id]
|
153 |
+
if gen.is_end:
|
154 |
+
print(gen.dataset_name, "is all done")
|
155 |
+
del self.ids[id]
|
156 |
+
del self.p[id]
|
157 |
+
self.p = list(np.array(self.p) / sum(p))
|
158 |
+
print("rebalance", self.ids, self.p)
|
159 |
+
else:
|
160 |
+
return iter(gen)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full"
|
165 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
166 |
+
# cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m")
|
167 |
+
# cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars")
|
168 |
+
coco_generator = KarpathyCOCOGenerator()
|
169 |
+
# visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg")
|
170 |
+
# generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator]
|
171 |
+
# p = [len(generator) for generator in generators]
|
172 |
+
# dataset = ShuffleGenerator(generators, p)
|
173 |
+
|
174 |
+
with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink:
|
175 |
+
sink.verbose = False
|
176 |
+
for i, data in enumerate(tqdm(coco_generator)):
|
177 |
+
dataset_name, image, caption = data
|
178 |
+
sink.write({"__key__": f"{dataset_name}_{i}", "jpg": image, "txt": caption})
|
open_flamingo/tools/prepare_pile.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
import webdataset as wds
|
5 |
+
import json
|
6 |
+
|
7 |
+
DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train"
|
8 |
+
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile"
|
9 |
+
SAMPLE_PER_SHARD = 100000
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
os.makedirs(OUT_DIR)
|
13 |
+
print("load dataset...")
|
14 |
+
pile = datasets.load_from_disk(DATASET_ROOT)
|
15 |
+
total_num = pile.num_rows
|
16 |
+
print("total num:", total_num)
|
17 |
+
num = 0
|
18 |
+
pbar = tqdm(total=total_num)
|
19 |
+
with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink:
|
20 |
+
for sample in pile.iter(4096):
|
21 |
+
for text, meta in zip(sample["text"], sample["meta"]):
|
22 |
+
pbar.update(1)
|
23 |
+
if meta.get("pile_set_name", None) == "Github":
|
24 |
+
continue
|
25 |
+
num += 1
|
26 |
+
sink.write({
|
27 |
+
'__key__': str(num),
|
28 |
+
'txt': text.encode("utf-8"),
|
29 |
+
'json': json.dumps(meta, indent=4).encode("utf-8"),
|
30 |
+
})
|
31 |
+
print(f"{num} out of {total_num} is written")
|