Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
7962ed0
1
Parent(s):
fb92e97
implement gpt-k demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .vscode/settings.json +6 -0
- README.md +1 -1
- app.py +387 -0
- conversation.py +364 -0
- examples/diamond_head.jpg +3 -0
- examples/horseshoe_bend.jpg +3 -0
- examples/mona_lisa.jpg +3 -0
- examples/mona_lisa_dog.jpg +3 -0
- examples/titanic.jpg +3 -0
- knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
- knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
- knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
- knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
- knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
- knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
- knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
- knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
- knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
- knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
- knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
- knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
- knowledge/__init__.py +2 -0
- knowledge/__pycache__/__init__.cpython-37.pyc +0 -0
- knowledge/__pycache__/__init__.cpython-38.pyc +0 -0
- knowledge/__pycache__/cluster.cpython-38.pyc +0 -0
- knowledge/__pycache__/dbscan.cpython-37.pyc +0 -0
- knowledge/__pycache__/dbscan.cpython-38.pyc +0 -0
- knowledge/__pycache__/image_crops_idx.cpython-38.pyc +0 -0
- knowledge/__pycache__/image_tokens_idx.cpython-38.pyc +0 -0
- knowledge/__pycache__/revive.cpython-38.pyc +0 -0
- knowledge/__pycache__/sentence_db.cpython-37.pyc +0 -0
- knowledge/__pycache__/sentence_db.cpython-38.pyc +0 -0
- knowledge/__pycache__/sentence_idx.cpython-37.pyc +0 -0
- knowledge/__pycache__/sentence_idx.cpython-38.pyc +0 -0
- knowledge/__pycache__/text_db.cpython-38.pyc +0 -0
- knowledge/__pycache__/utils.cpython-37.pyc +0 -0
- knowledge/__pycache__/utils.cpython-38.pyc +0 -0
- knowledge/__pycache__/vis_vocab.cpython-37.pyc +0 -0
- knowledge/__pycache__/wordnet.cpython-37.pyc +0 -0
- knowledge/cluster.py +178 -0
- knowledge/retrieve.py +327 -0
- knowledge/text_db.py +197 -0
- knowledge/transforms.py +52 -0
- knowledge/utils.py +127 -0
- model/.gitattributes +2 -0
- model/__init__.py +1 -0
- model/ckpt/mp_rank_00_model_states.pt +3 -0
- model/eva_vit.py +434 -0
- model/gptk-7b.yaml +25 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
.bin filter=lfs diff=lfs merge=lfs -text
|
37 |
+
.pt filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.index filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.vscode/settings.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[python]": {
|
3 |
+
"editor.defaultFormatter": "ms-python.autopep8"
|
4 |
+
},
|
5 |
+
"python.formatting.provider": "none"
|
6 |
+
}
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: K
|
3 |
emoji: 🚀
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: GPT-K
|
3 |
emoji: 🚀
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
app.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import open_clip
|
12 |
+
import faiss
|
13 |
+
from transformers import TextIteratorStreamer
|
14 |
+
from threading import Thread
|
15 |
+
|
16 |
+
from conversation import default_conversation, conv_templates, Conversation
|
17 |
+
from knowledge import TextDB
|
18 |
+
from knowledge.transforms import five_crop, nine_crop
|
19 |
+
from knowledge.utils import refine_cosine
|
20 |
+
from model import get_gptk_model, get_gptk_image_transform
|
21 |
+
|
22 |
+
|
23 |
+
no_change_btn = gr.Button.update()
|
24 |
+
enable_btn = gr.Button.update(interactive=True)
|
25 |
+
disable_btn = gr.Button.update(interactive=False)
|
26 |
+
knwl_none = (None, ) * 30
|
27 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
28 |
+
|
29 |
+
|
30 |
+
def violates_moderation(text):
|
31 |
+
"""
|
32 |
+
Check whether the text violates OpenAI moderation API.
|
33 |
+
"""
|
34 |
+
url = "https://api.openai.com/v1/moderations"
|
35 |
+
headers = {
|
36 |
+
"Content-Type": "application/json",
|
37 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]
|
38 |
+
}
|
39 |
+
text = text.replace("\n", "")
|
40 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
41 |
+
data = data.encode("utf-8")
|
42 |
+
try:
|
43 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
44 |
+
flagged = ret.json()["results"][0]["flagged"]
|
45 |
+
except requests.exceptions.RequestException as e:
|
46 |
+
flagged = False
|
47 |
+
except KeyError as e:
|
48 |
+
flagged = False
|
49 |
+
|
50 |
+
return flagged
|
51 |
+
|
52 |
+
|
53 |
+
def load_demo():
|
54 |
+
state = default_conversation.copy()
|
55 |
+
return (state, )
|
56 |
+
|
57 |
+
|
58 |
+
def regenerate(state: Conversation):
|
59 |
+
state.messages[-1][-1] = None
|
60 |
+
prev_human_msg = state.messages[-2]
|
61 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
62 |
+
prev_human_msg[1] = prev_human_msg[1][:2]
|
63 |
+
state.skip_next = False
|
64 |
+
|
65 |
+
return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn)
|
66 |
+
|
67 |
+
|
68 |
+
def clear_history():
|
69 |
+
state = default_conversation.copy()
|
70 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2 + knwl_none
|
71 |
+
|
72 |
+
|
73 |
+
def add_text(state: Conversation, text, image):
|
74 |
+
if len(text) <= 0 and image is None:
|
75 |
+
state.skip_next = True
|
76 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 2
|
77 |
+
|
78 |
+
if violates_moderation(text):
|
79 |
+
state.skip_next = True
|
80 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 2
|
81 |
+
|
82 |
+
text = (text, image)
|
83 |
+
if len(state.get_images(return_pil=True)) > 0:
|
84 |
+
state = default_conversation.copy()
|
85 |
+
state.append_message(state.roles[0], text)
|
86 |
+
state.append_message(state.roles[1], None)
|
87 |
+
state.skip_next = False
|
88 |
+
|
89 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
|
90 |
+
|
91 |
+
|
92 |
+
def search(image, pos, topk, knwl_db, knwl_idx):
|
93 |
+
with torch.cuda.amp.autocast():
|
94 |
+
image = query_trans(image).unsqueeze(0).to(device)
|
95 |
+
query = query_enc.encode_image(image, normalize=True)
|
96 |
+
query = query.cpu().numpy()
|
97 |
+
|
98 |
+
_, I = knwl_idx.search(query, 4*topk)
|
99 |
+
score, I = refine_cosine(knwl_db.feature, query, I, device, topk)
|
100 |
+
score, I = score.flatten(), I.flatten()
|
101 |
+
embd, text = knwl_db[I]
|
102 |
+
pos = np.full((topk, ), fill_value=pos)
|
103 |
+
|
104 |
+
query = torch.FloatTensor(query).unsqueeze(0).to(device)
|
105 |
+
embd = torch.FloatTensor(embd).unsqueeze(0).to(device)
|
106 |
+
pos = torch.LongTensor(pos).unsqueeze(0).to(device)
|
107 |
+
score = torch.FloatTensor(score).unsqueeze(0).to(device)
|
108 |
+
|
109 |
+
return query, embd, pos, score, text
|
110 |
+
|
111 |
+
|
112 |
+
def retrieve_knowledge(image):
|
113 |
+
knwl_embd = {}
|
114 |
+
knwl_text = {}
|
115 |
+
for query_type, topk_q in topk.items():
|
116 |
+
if topk_q == 0: continue
|
117 |
+
|
118 |
+
if query_type == "whole":
|
119 |
+
images = [image, ]
|
120 |
+
knwl_text[query_type] = {i: {} for i in range(1)}
|
121 |
+
elif query_type == "five":
|
122 |
+
images = five_crop(image)
|
123 |
+
knwl_text[query_type] = {i: {} for i in range(5)}
|
124 |
+
elif query_type == "nine":
|
125 |
+
images = nine_crop(image)
|
126 |
+
knwl_text[query_type] = {i: {} for i in range(9)}
|
127 |
+
else:
|
128 |
+
raise ValueError
|
129 |
+
|
130 |
+
knwl_embd[query_type] = {}
|
131 |
+
for knwl_type, (knwl_db_t, knwl_idx_t) in knwl_db.items():
|
132 |
+
query, embed, pos, score = [], [], [], []
|
133 |
+
for i, img in enumerate(images):
|
134 |
+
query_i, embed_i, pos_i, score_i, text_i = search(
|
135 |
+
img, i, topk_q, knwl_db_t, knwl_idx_t
|
136 |
+
)
|
137 |
+
query.append(query_i)
|
138 |
+
embed.append(embed_i)
|
139 |
+
pos.append(pos_i)
|
140 |
+
score.append(score_i)
|
141 |
+
knwl_text[query_type][i][knwl_type] = text_i
|
142 |
+
|
143 |
+
query = torch.cat(query, dim=1)
|
144 |
+
embed = torch.cat(embed, dim=1)
|
145 |
+
pos = torch.cat(pos, dim=1)
|
146 |
+
score = torch.cat(score, dim=1)
|
147 |
+
|
148 |
+
knwl_embd[query_type][knwl_type] = {
|
149 |
+
"embed": embed, "query": query, "pos": pos, "score": score
|
150 |
+
}
|
151 |
+
|
152 |
+
return knwl_embd, knwl_text
|
153 |
+
|
154 |
+
|
155 |
+
def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
|
156 |
+
if state.skip_next: # This generate call is skipped due to invalid inputs
|
157 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2 + knwl_none
|
158 |
+
return
|
159 |
+
|
160 |
+
if len(state.messages) == state.offset + 2: # First round of conversation
|
161 |
+
new_state = conv_templates["gptk"].copy()
|
162 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
163 |
+
new_state.append_message(new_state.roles[1], None)
|
164 |
+
state = new_state
|
165 |
+
|
166 |
+
# retrieve and visualize knowledge
|
167 |
+
image = state.get_images(return_pil=True)[0]
|
168 |
+
if bool(add_knwl):
|
169 |
+
knwl_embd, knwl = retrieve_knowledge(image)
|
170 |
+
knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
|
171 |
+
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
172 |
+
if query_type == "whole":
|
173 |
+
images = [image, ]
|
174 |
+
elif query_type == "five":
|
175 |
+
images = five_crop(image)
|
176 |
+
elif query_type == "nine":
|
177 |
+
images = nine_crop(image)
|
178 |
+
|
179 |
+
for pos in range(knwl_pos):
|
180 |
+
try:
|
181 |
+
txt = ""
|
182 |
+
for k, v in knwl[query_type][str(pos)].items():
|
183 |
+
v = ", ".join([vi.replace("_", " ") for vi in v])
|
184 |
+
txt += f"**[{k.upper()}]:** {v}\n\n"
|
185 |
+
knwl_txt[idx] += txt
|
186 |
+
knwl_img[idx] = images[pos]
|
187 |
+
except KeyError:
|
188 |
+
pass
|
189 |
+
idx += 1
|
190 |
+
knwl_vis = tuple(knwl_img + knwl_txt)
|
191 |
+
else:
|
192 |
+
knwl_embd = None
|
193 |
+
knwl_vis = knwl_none
|
194 |
+
|
195 |
+
# generate output
|
196 |
+
prompt = state.get_prompt()
|
197 |
+
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
198 |
+
image_pt = image_trans(image).to(device).unsqueeze(0)
|
199 |
+
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
200 |
+
|
201 |
+
if bool(do_beam_search):
|
202 |
+
new_text = gptk_model.generate(
|
203 |
+
samples=samples,
|
204 |
+
use_nucleus_sampling=bool(do_sampling),
|
205 |
+
max_length=min(int(max_new_tokens), 1024),
|
206 |
+
top_p=float(top_p),
|
207 |
+
temperature=float(temperature),
|
208 |
+
auto_cast=True
|
209 |
+
)[0]
|
210 |
+
streamer = [new_text, ]
|
211 |
+
else:
|
212 |
+
streamer = TextIteratorStreamer(
|
213 |
+
gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
214 |
+
)
|
215 |
+
thread = Thread(
|
216 |
+
target=gptk_model.generate,
|
217 |
+
kwargs=dict(
|
218 |
+
samples=samples,
|
219 |
+
use_nucleus_sampling=bool(do_sampling),
|
220 |
+
max_length=min(int(max_new_tokens), 1024),
|
221 |
+
top_p=float(top_p),
|
222 |
+
temperature=float(temperature),
|
223 |
+
streamer=streamer,
|
224 |
+
num_beams=1,
|
225 |
+
auto_cast=True
|
226 |
+
)
|
227 |
+
)
|
228 |
+
thread.start()
|
229 |
+
|
230 |
+
generated_text = ""
|
231 |
+
for new_text in streamer:
|
232 |
+
generated_text += new_text
|
233 |
+
state.messages[-1][-1] = generated_text + "▌"
|
234 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 + knwl_vis
|
235 |
+
time.sleep(0.03)
|
236 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
237 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2 + knwl_vis
|
238 |
+
|
239 |
+
|
240 |
+
title_markdown = ("""
|
241 |
+
# GPT-K: Knowledge Augmented Vision-and-Language Assistant
|
242 |
+
""")
|
243 |
+
|
244 |
+
tos_markdown = ("""
|
245 |
+
### Terms of use
|
246 |
+
By using this service, users are required to agree to the following terms:
|
247 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
248 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
249 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
250 |
+
""")
|
251 |
+
|
252 |
+
learn_more_markdown = ("""
|
253 |
+
### License
|
254 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
255 |
+
""")
|
256 |
+
|
257 |
+
|
258 |
+
def build_demo():
|
259 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
260 |
+
imagebox = gr.Image(type="pil")
|
261 |
+
state = gr.State()
|
262 |
+
|
263 |
+
with gr.Blocks(title="GPT-K", theme=gr.themes.Base()) as demo:
|
264 |
+
gr.Markdown(title_markdown)
|
265 |
+
with gr.Row():
|
266 |
+
with gr.Column(scale=3):
|
267 |
+
gr.Examples(examples=[
|
268 |
+
["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
|
269 |
+
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
270 |
+
["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
|
271 |
+
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
272 |
+
["examples/titanic.jpg", "What happen in the scene in this movie?"],
|
273 |
+
], inputs=[imagebox, textbox])
|
274 |
+
|
275 |
+
imagebox.render()
|
276 |
+
textbox.render()
|
277 |
+
with gr.Column():
|
278 |
+
submit_btn = gr.Button(value="📝 Submit")
|
279 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
280 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
281 |
+
|
282 |
+
with gr.Accordion("Parameters", open=True):
|
283 |
+
with gr.Row():
|
284 |
+
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
285 |
+
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
286 |
+
do_beam_search = gr.Checkbox(value=False, interactive=True, label="Beam search")
|
287 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
288 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
289 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
290 |
+
|
291 |
+
with gr.Column(scale=6):
|
292 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
|
293 |
+
|
294 |
+
gr.Markdown("Retrieved Knowledge")
|
295 |
+
knwl_img, knwl_txt = [], []
|
296 |
+
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
297 |
+
with gr.Tab(query_type):
|
298 |
+
for p in range(knwl_pos):
|
299 |
+
with gr.Tab(str(p)):
|
300 |
+
with gr.Row():
|
301 |
+
with gr.Column(scale=1):
|
302 |
+
knwl_img.append(gr.Image(type="pil", show_label=False, interactive=False))
|
303 |
+
with gr.Column(scale=7):
|
304 |
+
knwl_txt.append(gr.Markdown())
|
305 |
+
knwl_vis = knwl_img + knwl_txt
|
306 |
+
|
307 |
+
gr.Markdown(tos_markdown)
|
308 |
+
gr.Markdown(learn_more_markdown)
|
309 |
+
|
310 |
+
# Register listeners
|
311 |
+
btn_list = [regenerate_btn, clear_btn]
|
312 |
+
regenerate_btn.click(
|
313 |
+
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
314 |
+
).then(
|
315 |
+
generate,
|
316 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
|
317 |
+
[state, chatbot] + btn_list + knwl_vis
|
318 |
+
)
|
319 |
+
|
320 |
+
clear_btn.click(
|
321 |
+
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + knwl_vis
|
322 |
+
)
|
323 |
+
|
324 |
+
textbox.submit(
|
325 |
+
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
326 |
+
).then(
|
327 |
+
generate,
|
328 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
|
329 |
+
[state, chatbot] + btn_list + knwl_vis
|
330 |
+
)
|
331 |
+
|
332 |
+
submit_btn.click(
|
333 |
+
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
334 |
+
).then(
|
335 |
+
generate,
|
336 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
|
337 |
+
[state, chatbot] + btn_list + knwl_vis
|
338 |
+
)
|
339 |
+
|
340 |
+
demo.load(load_demo, None, [state, ])
|
341 |
+
|
342 |
+
return demo
|
343 |
+
|
344 |
+
|
345 |
+
def build_model():
|
346 |
+
if torch.cuda.is_available():
|
347 |
+
device = torch.device("cuda")
|
348 |
+
else:
|
349 |
+
device = torch.device("cpu")
|
350 |
+
|
351 |
+
query_enc, _, query_trans = open_clip.create_model_and_transforms(
|
352 |
+
"ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
|
353 |
+
)
|
354 |
+
query_enc = query_enc.to(device).eval()
|
355 |
+
|
356 |
+
def get_knwl(knowledge_db):
|
357 |
+
knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
|
358 |
+
knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
|
359 |
+
knwl_idx.add(knwl_db.feature)
|
360 |
+
|
361 |
+
return knwl_db, knwl_idx
|
362 |
+
|
363 |
+
knwl_db = {
|
364 |
+
"obj": get_knwl('knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
365 |
+
"act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
366 |
+
"attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
367 |
+
}
|
368 |
+
d_knwl = knwl_db["obj"][0].feature.shape[-1]
|
369 |
+
|
370 |
+
_, image_trans = get_gptk_image_transform()
|
371 |
+
topk = {"whole": 60, "five": 24, "nine": 16}
|
372 |
+
gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
|
373 |
+
gptk_ckpt = "model/ckpt/mp_rank_00_model_states.pt"
|
374 |
+
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
375 |
+
gptk_ckpt = {
|
376 |
+
".".join(k.split(".")[2:]): v
|
377 |
+
for k, v in gptk_ckpt["module"].items()
|
378 |
+
}
|
379 |
+
gptk_model.load_state_dict(gptk_ckpt)
|
380 |
+
gptk_model = gptk_model.to(device).eval()
|
381 |
+
|
382 |
+
return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
|
383 |
+
|
384 |
+
|
385 |
+
knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device = build_model()
|
386 |
+
demo = build_demo()
|
387 |
+
demo.queue().launch()
|
conversation.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class Conversation:
|
17 |
+
"""A class that keeps all conversation history."""
|
18 |
+
system: str
|
19 |
+
roles: List[str]
|
20 |
+
messages: List[List[str]]
|
21 |
+
offset: int
|
22 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
23 |
+
sep: str = "###"
|
24 |
+
sep2: str = None
|
25 |
+
version: str = "Unknown"
|
26 |
+
|
27 |
+
skip_next: bool = False
|
28 |
+
|
29 |
+
def get_prompt(self):
|
30 |
+
messages = self.messages
|
31 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
32 |
+
messages = self.messages.copy()
|
33 |
+
init_role, init_msg = messages[0].copy()
|
34 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
35 |
+
if 'mmtag' in self.version:
|
36 |
+
messages[0] = (init_role, init_msg)
|
37 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
38 |
+
messages.insert(1, (self.roles[1], "Received."))
|
39 |
+
else:
|
40 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
41 |
+
|
42 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
43 |
+
ret = self.system + self.sep
|
44 |
+
for role, message in messages:
|
45 |
+
if message:
|
46 |
+
if type(message) is tuple:
|
47 |
+
message, _, _ = message
|
48 |
+
ret += role + ": " + message + self.sep
|
49 |
+
else:
|
50 |
+
ret += role + ":"
|
51 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
52 |
+
seps = [self.sep, self.sep2]
|
53 |
+
ret = self.system + seps[0]
|
54 |
+
for i, (role, message) in enumerate(messages):
|
55 |
+
if message:
|
56 |
+
if type(message) is tuple:
|
57 |
+
message, _, _ = message
|
58 |
+
ret += role + ": " + message + seps[i % 2]
|
59 |
+
else:
|
60 |
+
ret += role + ":"
|
61 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
62 |
+
ret = self.system + self.sep
|
63 |
+
for role, message in messages:
|
64 |
+
if message:
|
65 |
+
if type(message) is tuple:
|
66 |
+
message, _, _ = message
|
67 |
+
ret += role + message + self.sep
|
68 |
+
else:
|
69 |
+
ret += role
|
70 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
71 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
72 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
73 |
+
ret = ""
|
74 |
+
|
75 |
+
for i, (role, message) in enumerate(messages):
|
76 |
+
if i == 0:
|
77 |
+
assert message, "first message should not be none"
|
78 |
+
assert role == self.roles[0], "first message should come from user"
|
79 |
+
if message:
|
80 |
+
if type(message) is tuple:
|
81 |
+
message, _, _ = message
|
82 |
+
if i == 0: message = wrap_sys(self.system) + message
|
83 |
+
if i % 2 == 0:
|
84 |
+
message = wrap_inst(message)
|
85 |
+
ret += self.sep + message
|
86 |
+
else:
|
87 |
+
ret += " " + message + " " + self.sep2
|
88 |
+
else:
|
89 |
+
ret += ""
|
90 |
+
ret = ret.lstrip(self.sep)
|
91 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
92 |
+
seps = [self.sep, self.sep2]
|
93 |
+
ret = self.system
|
94 |
+
for i, (role, message) in enumerate(messages):
|
95 |
+
if message:
|
96 |
+
if type(message) is tuple:
|
97 |
+
message, _, _ = message
|
98 |
+
ret += message + seps[i % 2]
|
99 |
+
else:
|
100 |
+
ret += ""
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
103 |
+
|
104 |
+
return ret
|
105 |
+
|
106 |
+
def append_message(self, role, message):
|
107 |
+
self.messages.append([role, message])
|
108 |
+
|
109 |
+
def get_images(self, return_pil=False):
|
110 |
+
images = []
|
111 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
112 |
+
if i % 2 == 0:
|
113 |
+
if type(msg) is tuple:
|
114 |
+
image = msg[1].convert('RGB')
|
115 |
+
if return_pil:
|
116 |
+
images.append(image)
|
117 |
+
else:
|
118 |
+
import base64
|
119 |
+
from io import BytesIO
|
120 |
+
buffered = BytesIO()
|
121 |
+
image.save(buffered, format="PNG")
|
122 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
123 |
+
images.append(img_b64_str)
|
124 |
+
return images
|
125 |
+
|
126 |
+
def to_gradio_chatbot(self):
|
127 |
+
ret = []
|
128 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
129 |
+
if i % 2 == 0:
|
130 |
+
if type(msg) is tuple:
|
131 |
+
import base64
|
132 |
+
from io import BytesIO
|
133 |
+
msg, image = msg
|
134 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
135 |
+
aspect_ratio = max_hw / min_hw
|
136 |
+
max_len, min_len = 800, 400
|
137 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
138 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
139 |
+
W, H = image.size
|
140 |
+
if H > W:
|
141 |
+
H, W = longest_edge, shortest_edge
|
142 |
+
else:
|
143 |
+
H, W = shortest_edge, longest_edge
|
144 |
+
image = image.resize((W, H))
|
145 |
+
buffered = BytesIO()
|
146 |
+
image.save(buffered, format="JPEG")
|
147 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
148 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
149 |
+
ret.append([img_str, None])
|
150 |
+
msg = msg.replace('<image>', '').strip()
|
151 |
+
if len(msg) > 0:
|
152 |
+
ret.append([msg, None])
|
153 |
+
else:
|
154 |
+
ret.append([msg, None])
|
155 |
+
else:
|
156 |
+
ret[-1][-1] = msg
|
157 |
+
return ret
|
158 |
+
|
159 |
+
def copy(self):
|
160 |
+
return Conversation(
|
161 |
+
system=self.system,
|
162 |
+
roles=self.roles,
|
163 |
+
messages=[[x, y] for x, y in self.messages],
|
164 |
+
offset=self.offset,
|
165 |
+
sep_style=self.sep_style,
|
166 |
+
sep=self.sep,
|
167 |
+
sep2=self.sep2,
|
168 |
+
version=self.version)
|
169 |
+
|
170 |
+
def dict(self):
|
171 |
+
if len(self.get_images()) > 0:
|
172 |
+
return {
|
173 |
+
"system": self.system,
|
174 |
+
"roles": self.roles,
|
175 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
176 |
+
"offset": self.offset,
|
177 |
+
"sep": self.sep,
|
178 |
+
"sep2": self.sep2,
|
179 |
+
}
|
180 |
+
return {
|
181 |
+
"system": self.system,
|
182 |
+
"roles": self.roles,
|
183 |
+
"messages": self.messages,
|
184 |
+
"offset": self.offset,
|
185 |
+
"sep": self.sep,
|
186 |
+
"sep2": self.sep2,
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
conv_gptk = Conversation(
|
191 |
+
system="",
|
192 |
+
roles=("USER", "ASSISTANT"),
|
193 |
+
version="v1",
|
194 |
+
messages=(),
|
195 |
+
offset=0,
|
196 |
+
sep_style=SeparatorStyle.SINGLE,
|
197 |
+
sep=""
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
conv_vicuna_v0 = Conversation(
|
202 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
203 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
204 |
+
roles=("Human", "Assistant"),
|
205 |
+
messages=(
|
206 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
207 |
+
("Assistant",
|
208 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
209 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
210 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
211 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
212 |
+
"renewable and non-renewable energy sources:\n"
|
213 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
214 |
+
"energy sources are finite and will eventually run out.\n"
|
215 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
216 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
217 |
+
"and other negative effects.\n"
|
218 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
219 |
+
"have lower operational costs than non-renewable sources.\n"
|
220 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
221 |
+
"locations than non-renewable sources.\n"
|
222 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
223 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
224 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
225 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
226 |
+
),
|
227 |
+
offset=2,
|
228 |
+
sep_style=SeparatorStyle.SINGLE,
|
229 |
+
sep="###",
|
230 |
+
)
|
231 |
+
|
232 |
+
conv_vicuna_v1 = Conversation(
|
233 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
234 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
235 |
+
roles=("USER", "ASSISTANT"),
|
236 |
+
version="v1",
|
237 |
+
messages=(),
|
238 |
+
offset=0,
|
239 |
+
sep_style=SeparatorStyle.TWO,
|
240 |
+
sep=" ",
|
241 |
+
sep2="</s>",
|
242 |
+
)
|
243 |
+
|
244 |
+
conv_llama_2 = Conversation(
|
245 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
246 |
+
|
247 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
248 |
+
roles=("USER", "ASSISTANT"),
|
249 |
+
version="llama_v2",
|
250 |
+
messages=(),
|
251 |
+
offset=0,
|
252 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
253 |
+
sep="<s>",
|
254 |
+
sep2="</s>",
|
255 |
+
)
|
256 |
+
|
257 |
+
conv_llava_llama_2 = Conversation(
|
258 |
+
system="You are a helpful language and vision assistant. "
|
259 |
+
"You are able to understand the visual content that the user provides, "
|
260 |
+
"and assist the user with a variety of tasks using natural language.",
|
261 |
+
roles=("USER", "ASSISTANT"),
|
262 |
+
version="llama_v2",
|
263 |
+
messages=(),
|
264 |
+
offset=0,
|
265 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
266 |
+
sep="<s>",
|
267 |
+
sep2="</s>",
|
268 |
+
)
|
269 |
+
|
270 |
+
conv_mpt = Conversation(
|
271 |
+
system="""<|im_start|>system
|
272 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
273 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
274 |
+
version="mpt",
|
275 |
+
messages=(),
|
276 |
+
offset=0,
|
277 |
+
sep_style=SeparatorStyle.MPT,
|
278 |
+
sep="<|im_end|>",
|
279 |
+
)
|
280 |
+
|
281 |
+
conv_llava_plain = Conversation(
|
282 |
+
system="",
|
283 |
+
roles=("", ""),
|
284 |
+
messages=(
|
285 |
+
),
|
286 |
+
offset=0,
|
287 |
+
sep_style=SeparatorStyle.PLAIN,
|
288 |
+
sep="\n",
|
289 |
+
)
|
290 |
+
|
291 |
+
conv_llava_v0 = Conversation(
|
292 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
293 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
294 |
+
roles=("Human", "Assistant"),
|
295 |
+
messages=(
|
296 |
+
("Human", "Hi!"),
|
297 |
+
("Assistant", "Hi there! How can I help you today?")
|
298 |
+
),
|
299 |
+
offset=2,
|
300 |
+
sep_style=SeparatorStyle.SINGLE,
|
301 |
+
sep="###",
|
302 |
+
)
|
303 |
+
|
304 |
+
conv_llava_v0_mmtag = Conversation(
|
305 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
306 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
307 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
308 |
+
roles=("Human", "Assistant"),
|
309 |
+
messages=(
|
310 |
+
),
|
311 |
+
offset=0,
|
312 |
+
sep_style=SeparatorStyle.SINGLE,
|
313 |
+
sep="###",
|
314 |
+
version="v0_mmtag",
|
315 |
+
)
|
316 |
+
|
317 |
+
conv_llava_v1 = Conversation(
|
318 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
319 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
320 |
+
roles=("USER", "ASSISTANT"),
|
321 |
+
version="v1",
|
322 |
+
messages=(),
|
323 |
+
offset=0,
|
324 |
+
sep_style=SeparatorStyle.TWO,
|
325 |
+
sep=" ",
|
326 |
+
sep2="</s>",
|
327 |
+
)
|
328 |
+
|
329 |
+
conv_llava_v1_mmtag = Conversation(
|
330 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
331 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
332 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
333 |
+
roles=("USER", "ASSISTANT"),
|
334 |
+
messages=(),
|
335 |
+
offset=0,
|
336 |
+
sep_style=SeparatorStyle.TWO,
|
337 |
+
sep=" ",
|
338 |
+
sep2="</s>",
|
339 |
+
version="v1_mmtag",
|
340 |
+
)
|
341 |
+
|
342 |
+
default_conversation = conv_vicuna_v0
|
343 |
+
conv_templates = {
|
344 |
+
"default": conv_vicuna_v0,
|
345 |
+
"v0": conv_vicuna_v0,
|
346 |
+
"v1": conv_vicuna_v1,
|
347 |
+
"vicuna_v1": conv_vicuna_v1,
|
348 |
+
"llama_2": conv_llama_2,
|
349 |
+
"gptk": conv_gptk,
|
350 |
+
|
351 |
+
"plain": conv_llava_plain,
|
352 |
+
"v0_plain": conv_llava_plain,
|
353 |
+
"llava_v0": conv_llava_v0,
|
354 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
355 |
+
"llava_v1": conv_llava_v1,
|
356 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
357 |
+
"llava_llama_2": conv_llava_llama_2,
|
358 |
+
|
359 |
+
"mpt": conv_mpt,
|
360 |
+
}
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == "__main__":
|
364 |
+
print(default_conversation.get_prompt())
|
examples/diamond_head.jpg
ADDED
Git LFS Details
|
examples/horseshoe_bend.jpg
ADDED
Git LFS Details
|
examples/mona_lisa.jpg
ADDED
Git LFS Details
|
examples/mona_lisa_dog.jpg
ADDED
Git LFS Details
|
examples/titanic.jpg
ADDED
Git LFS Details
|
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb05eb3ab8b8e775c1e10ab21a4f8d409b77a47ffacbc606050c2055bd78549a
|
3 |
+
size 45
|
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
n_samples = 148,620; n_clusters = 43,296; noise_ratio = 0.000%
|
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6281557260322cacbfbe58d710e3dd537e823d6d6565da7c9fea27e30ced5e31
|
3 |
+
size 166074480
|
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:881bf21972ffb9a9155d185282530a75a4ca4ffdb75c8a05d38dda901c0f366c
|
3 |
+
size 1189088
|
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1efe5c6accd575c85403aaeccaf24c6fb1cfff05bd6a0f1ecdbdbc0ce0a5befa
|
3 |
+
size 9093259
|
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
n_samples = 191,836; n_clusters = 77,073; noise_ratio = 0.000%
|
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51449f86c49a1651debdaf7ec1b4c1020db911785bd5f51e0766a4bfefe1897f
|
3 |
+
size 295832959
|
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6cf9a479e50595e52e593f961d4f3dcc822d9c0caf097fed3498a64c175f7e2c
|
3 |
+
size 1534816
|
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ece63b94bf3672252b77fbbf47a3070a378280ef3eafb682f99340fc74e1d096
|
3 |
+
size 18702475
|
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
n_samples = 770,808; n_clusters = 325,813; noise_ratio = 0.000%
|
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9dcca3e4560c724f42128b8d476dd28ad0305ad66125213050c7fec7715d6a8b
|
3 |
+
size 1251033850
|
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c79a747b0551e46056391ad988317604dd29a8905acb3167127550dcc6b90890
|
3 |
+
size 6166592
|
knowledge/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .text_db import TextDB
|
2 |
+
from .retrieve import *
|
knowledge/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (254 Bytes). View file
|
|
knowledge/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (254 Bytes). View file
|
|
knowledge/__pycache__/cluster.cpython-38.pyc
ADDED
Binary file (5.12 kB). View file
|
|
knowledge/__pycache__/dbscan.cpython-37.pyc
ADDED
Binary file (2.29 kB). View file
|
|
knowledge/__pycache__/dbscan.cpython-38.pyc
ADDED
Binary file (2.32 kB). View file
|
|
knowledge/__pycache__/image_crops_idx.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
knowledge/__pycache__/image_tokens_idx.cpython-38.pyc
ADDED
Binary file (7.7 kB). View file
|
|
knowledge/__pycache__/revive.cpython-38.pyc
ADDED
Binary file (2.19 kB). View file
|
|
knowledge/__pycache__/sentence_db.cpython-37.pyc
ADDED
Binary file (6.01 kB). View file
|
|
knowledge/__pycache__/sentence_db.cpython-38.pyc
ADDED
Binary file (6.39 kB). View file
|
|
knowledge/__pycache__/sentence_idx.cpython-37.pyc
ADDED
Binary file (9.12 kB). View file
|
|
knowledge/__pycache__/sentence_idx.cpython-38.pyc
ADDED
Binary file (9.75 kB). View file
|
|
knowledge/__pycache__/text_db.cpython-38.pyc
ADDED
Binary file (7.22 kB). View file
|
|
knowledge/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (3.05 kB). View file
|
|
knowledge/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.1 kB). View file
|
|
knowledge/__pycache__/vis_vocab.cpython-37.pyc
ADDED
Binary file (8.46 kB). View file
|
|
knowledge/__pycache__/wordnet.cpython-37.pyc
ADDED
Binary file (2.3 kB). View file
|
|
knowledge/cluster.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import h5py
|
6 |
+
import time
|
7 |
+
|
8 |
+
import faiss
|
9 |
+
import torch
|
10 |
+
from pytorch_lightning import seed_everything
|
11 |
+
|
12 |
+
import sys
|
13 |
+
sys.path.append('.')
|
14 |
+
from knowledge.text_db import TextDB
|
15 |
+
from knowledge.utils import nn_search, build_faiss_index, refine_cosine
|
16 |
+
|
17 |
+
|
18 |
+
UNSEEN = -2
|
19 |
+
NOISE = -1
|
20 |
+
|
21 |
+
|
22 |
+
def dbscan(X, faiss_index, device, eps=0.1, min_points=1, k=2048, bs=512):
|
23 |
+
neighbors = []
|
24 |
+
N = (len(X) - 1) // bs + 1
|
25 |
+
for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
|
26 |
+
Xi = X[i*bs: (i+1)*bs]
|
27 |
+
_, I = faiss_index.search(Xi, k*2)
|
28 |
+
S, I = refine_cosine(X, Xi, I, device, k)
|
29 |
+
|
30 |
+
for sim, idx in zip(S, I):
|
31 |
+
dist = 1. - sim
|
32 |
+
neighbors.append(idx[dist < eps])
|
33 |
+
|
34 |
+
cluster_id = 0
|
35 |
+
n_points = len(X)
|
36 |
+
labels = np.array([
|
37 |
+
NOISE if len(neighbors[i]) < min_points else UNSEEN
|
38 |
+
for i in range(n_points)
|
39 |
+
])
|
40 |
+
|
41 |
+
with tqdm(total=n_points, dynamic_ncols=True, desc="DBSCAN clustering", mininterval=1.0) as pbar:
|
42 |
+
for i in range(n_points):
|
43 |
+
if labels[i] == UNSEEN:
|
44 |
+
seeds = np.array([i, ])
|
45 |
+
labels[seeds] = cluster_id
|
46 |
+
|
47 |
+
while len(seeds) > 0:
|
48 |
+
neighbor_seeds = set()
|
49 |
+
for s in seeds:
|
50 |
+
n = neighbors[s]
|
51 |
+
if len(n) > 0:
|
52 |
+
l = np.array(list(set(labels[n])))
|
53 |
+
l = l[np.logical_and(l >= 0, l != cluster_id)]
|
54 |
+
for li in l:
|
55 |
+
labels[labels == li] = cluster_id
|
56 |
+
|
57 |
+
n = n[labels[n] == UNSEEN]
|
58 |
+
neighbor_seeds.update(n)
|
59 |
+
|
60 |
+
seeds = np.array(list(neighbor_seeds))
|
61 |
+
if len(seeds) > 0:
|
62 |
+
assert np.all(labels[seeds] == UNSEEN)
|
63 |
+
labels[seeds] = cluster_id
|
64 |
+
|
65 |
+
cluster_id += 1
|
66 |
+
|
67 |
+
pbar.set_postfix(num_clusters=cluster_id)
|
68 |
+
pbar.update()
|
69 |
+
|
70 |
+
label_set = np.sort(list(set(labels)))
|
71 |
+
label_set = label_set[label_set >= 0]
|
72 |
+
labels_mapping = {l1: l2 for l2, l1 in enumerate(label_set)}
|
73 |
+
labels_mapping[-1] = -1
|
74 |
+
labels = np.array([labels_mapping[l] for l in labels])
|
75 |
+
|
76 |
+
return labels
|
77 |
+
|
78 |
+
|
79 |
+
def extract_clusters(feat, text, labels, faiss_index, device, k=128, bs=8192):
|
80 |
+
clusters = {}
|
81 |
+
for i, l in enumerate(tqdm(labels, dynamic_ncols=True, desc="Label each samples", mininterval=1.0)):
|
82 |
+
if l >= 0:
|
83 |
+
try:
|
84 |
+
clusters[l]["feat"] += feat[i].astype(np.float64)
|
85 |
+
clusters[l]["N"] += 1
|
86 |
+
except KeyError:
|
87 |
+
clusters[l] = {"feat": feat[i].astype(np.float64), "N": 1}
|
88 |
+
|
89 |
+
cc = []
|
90 |
+
for l in tqdm(list(clusters.keys()), dynamic_ncols=True, desc="Compute cluster centers", mininterval=1.0):
|
91 |
+
c = clusters[l]["feat"]/clusters[l]["N"]
|
92 |
+
cc.append(c.astype(np.float32))
|
93 |
+
cc = np.stack(cc)
|
94 |
+
cc /= np.linalg.norm(cc, keepdims=True, axis=-1)
|
95 |
+
|
96 |
+
idx = []
|
97 |
+
N = (len(cc) - 1) // bs + 1
|
98 |
+
for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
|
99 |
+
cc_i = cc[i*bs: (i+1)*bs]
|
100 |
+
_, I = faiss_index.search(cc_i, k)
|
101 |
+
_, I = refine_cosine(feat, cc_i, I, device, 1)
|
102 |
+
idx.append(I[:, 0])
|
103 |
+
idx = np.unique(np.concatenate(idx))
|
104 |
+
text = [text[i] for i in idx]
|
105 |
+
feat = np.stack([feat[i] for i in idx])
|
106 |
+
|
107 |
+
return feat, text
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
parser = argparse.ArgumentParser(description="Cluster knowledge database using DBSCAN")
|
112 |
+
parser.add_argument("--knowledge_db", type=str, required=True)
|
113 |
+
parser.add_argument("--seed", type=int, default=12345)
|
114 |
+
parser.add_argument("--eps", type=float, default=0.1)
|
115 |
+
parser.add_argument("--ms", type=int, default=1)
|
116 |
+
parser.add_argument("--ratio", type=float, default=None)
|
117 |
+
parser.add_argument("--device", type=int, default=None)
|
118 |
+
args = parser.parse_args()
|
119 |
+
|
120 |
+
# parse exp name
|
121 |
+
args.knowledge_db = Path(args.knowledge_db)
|
122 |
+
exp_name = args.knowledge_db.parent.name
|
123 |
+
exp_name += f"(dbscan)(eps-{args.eps})(ms-{args.ms})"
|
124 |
+
save_root = args.knowledge_db.parent.parent/exp_name
|
125 |
+
setattr(args, "save_root", save_root)
|
126 |
+
args.save_root.mkdir(parents=True, exist_ok=True)
|
127 |
+
|
128 |
+
args.device = torch.device("cuda", args.device) \
|
129 |
+
if args.device is not None else torch.device("cpu")
|
130 |
+
|
131 |
+
seed_everything(args.seed, workers=True)
|
132 |
+
print(args)
|
133 |
+
|
134 |
+
# load feature, text, and faiss index from knowledge db
|
135 |
+
knowledge_db = TextDB(args.knowledge_db)
|
136 |
+
feat = knowledge_db.feature.astype(np.float32)
|
137 |
+
text = knowledge_db.text
|
138 |
+
if args.ratio is not None:
|
139 |
+
N = int(len(feat) * args.ratio)
|
140 |
+
feat, text = feat[:N], text[:N]
|
141 |
+
faiss_index = faiss.read_index(str(args.knowledge_db.parent/"faiss.index"))
|
142 |
+
print("Add data to faiss index...", end="\r")
|
143 |
+
ts = time.time()
|
144 |
+
faiss_index.add(feat)
|
145 |
+
print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
|
146 |
+
|
147 |
+
# DBSCAN clustering
|
148 |
+
labels_file = args.save_root/"labels.npy"
|
149 |
+
if labels_file.exists():
|
150 |
+
labels = np.load(labels_file)
|
151 |
+
else:
|
152 |
+
labels = dbscan(feat, faiss_index, args.device, args.eps, args.ms)
|
153 |
+
with open(labels_file, 'wb') as f:
|
154 |
+
np.save(f, labels)
|
155 |
+
|
156 |
+
# extract clusters
|
157 |
+
feat, text = extract_clusters(feat, text, labels, faiss_index, args.device)
|
158 |
+
with h5py.File(args.save_root/f"knowledge_db.hdf5", "w") as f:
|
159 |
+
bs = 65536
|
160 |
+
N = (len(feat) - 1) // bs + 1
|
161 |
+
for i in tqdm(range(N), dynamic_ncols=True, desc="Saving clustered DB", mininterval=1.0):
|
162 |
+
g = f.create_group(str(i))
|
163 |
+
g.create_dataset("feature", data=feat[i*bs: (i+1)*bs], compression="gzip")
|
164 |
+
g.create_dataset("text", data=text[i*bs: (i+1)*bs], compression="gzip")
|
165 |
+
|
166 |
+
# build faiss index for the clustered DB
|
167 |
+
index = build_faiss_index(feat, gpus=[args.device.index, ])
|
168 |
+
faiss.write_index(index, str(args.save_root/"faiss.index"))
|
169 |
+
|
170 |
+
# some stats
|
171 |
+
noise_ratio = np.sum(labels == -1) / len(labels)
|
172 |
+
n_clusters, n_samples = len(text), len(labels)
|
173 |
+
msg = f"n_samples = {n_samples:,}; n_clusters = {n_clusters:,}; noise_ratio = {noise_ratio*100:.3f}%\n"
|
174 |
+
with open(save_root/"info.txt", "w") as f:
|
175 |
+
f.write(msg)
|
176 |
+
print(msg)
|
177 |
+
|
178 |
+
|
knowledge/retrieve.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
import h5py
|
4 |
+
import time
|
5 |
+
import shutil
|
6 |
+
import numpy as np
|
7 |
+
import subprocess
|
8 |
+
import time
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import faiss
|
12 |
+
import open_clip
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from pytorch_lightning import callbacks
|
17 |
+
from pytorch_lightning import Trainer, LightningModule, seed_everything
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('.')
|
21 |
+
from dataset import coco, cc, llava
|
22 |
+
from knowledge.utils import refine_cosine
|
23 |
+
from knowledge import text_db
|
24 |
+
from knowledge import TextDB
|
25 |
+
from train.utils import ExpName
|
26 |
+
|
27 |
+
|
28 |
+
class ImageCropsIdx:
|
29 |
+
def __init__(self, knowledge_idx, topk_w, topk_f, topk_n):
|
30 |
+
topk = {"whole": topk_w, "five": topk_f, "nine": topk_n}
|
31 |
+
self.topk = {k: v for k, v in topk.items() if v > 0}
|
32 |
+
|
33 |
+
self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk)
|
34 |
+
|
35 |
+
def load(self, knowledge_idx, topk):
|
36 |
+
with h5py.File(knowledge_idx, "r") as f:
|
37 |
+
fdim = f.attrs["fdim"]
|
38 |
+
file_hash = f.attrs["file_hash"]
|
39 |
+
|
40 |
+
knowledge_idx_ = {}
|
41 |
+
for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0):
|
42 |
+
knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]}
|
43 |
+
for k, v in topk.items():
|
44 |
+
knowledge_idx_[str(i)][k] = {
|
45 |
+
"index": f[f"{i}/{k}/index"][:, :, :v],
|
46 |
+
"score": f[f"{i}/{k}/score"][:, :, :v],
|
47 |
+
"query": f[f"{i}/{k}/query"][:]
|
48 |
+
}
|
49 |
+
|
50 |
+
knowledge_idx = {}
|
51 |
+
for i in knowledge_idx_.keys():
|
52 |
+
for j, id in enumerate(knowledge_idx_[i]["image_ids"]):
|
53 |
+
knowledge_idx[id] = {}
|
54 |
+
for k in topk.keys():
|
55 |
+
knowledge_idx[id][k] = {
|
56 |
+
"index": knowledge_idx_[i][k]["index"][j],
|
57 |
+
"score": knowledge_idx_[i][k]["score"][j],
|
58 |
+
"query": knowledge_idx_[i][k]["query"][j],
|
59 |
+
}
|
60 |
+
|
61 |
+
return knowledge_idx, fdim, file_hash
|
62 |
+
|
63 |
+
def __getitem__(self, image_id):
|
64 |
+
return self.knowledge_idx[image_id]
|
65 |
+
|
66 |
+
|
67 |
+
class KnowAugImageCrops:
|
68 |
+
def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False):
|
69 |
+
self.knowledge_db = knowledge_db
|
70 |
+
self.knowledge_idx = knowledge_idx
|
71 |
+
assert knowledge_db.file_hash == knowledge_idx.file_hash
|
72 |
+
|
73 |
+
self.ncrop = {"whole": 1, "five": 5, "nine": 9}
|
74 |
+
self.topk = knowledge_idx.topk
|
75 |
+
self.fdim = knowledge_idx.fdim
|
76 |
+
|
77 |
+
self.return_txt = return_txt
|
78 |
+
|
79 |
+
def __call__(self, image_id):
|
80 |
+
ret = {}
|
81 |
+
for k in self.topk.keys():
|
82 |
+
ki = self.knowledge_idx[image_id][k]["index"].flatten()
|
83 |
+
ke, kt = self.knowledge_db[ki]
|
84 |
+
kq = self.knowledge_idx[image_id][k]["query"]
|
85 |
+
kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten()
|
86 |
+
ks = self.knowledge_idx[image_id][k]["score"].flatten()
|
87 |
+
|
88 |
+
ke = torch.FloatTensor(ke)
|
89 |
+
kq = torch.FloatTensor(kq)
|
90 |
+
kp = torch.LongTensor(kp)
|
91 |
+
ks = torch.FloatTensor(ks)
|
92 |
+
|
93 |
+
ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks}
|
94 |
+
if self.return_txt:
|
95 |
+
ret[k]["text"] = kt
|
96 |
+
|
97 |
+
return ret
|
98 |
+
|
99 |
+
|
100 |
+
class KnowAugImageCropsCombined:
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
knwl_aug_obj: KnowAugImageCrops,
|
104 |
+
knwl_aug_attr: KnowAugImageCrops,
|
105 |
+
knwl_aug_act: KnowAugImageCrops
|
106 |
+
):
|
107 |
+
self.knwl_aug_obj = knwl_aug_obj
|
108 |
+
self.knwl_aug_act = knwl_aug_act
|
109 |
+
self.knwl_aug_attr = knwl_aug_attr
|
110 |
+
self.fdim = knwl_aug_obj.fdim
|
111 |
+
|
112 |
+
def __call__(self, image_id):
|
113 |
+
knwl_obj = self.knwl_aug_obj(image_id)
|
114 |
+
knwl_attr = self.knwl_aug_attr(image_id)
|
115 |
+
knwl_act = self.knwl_aug_act(image_id)
|
116 |
+
|
117 |
+
ret = {}
|
118 |
+
for k in knwl_obj.keys():
|
119 |
+
ret[k] = {
|
120 |
+
"obj": knwl_obj[k],
|
121 |
+
"attr": knwl_attr[k],
|
122 |
+
"act": knwl_act[k]
|
123 |
+
}
|
124 |
+
|
125 |
+
return ret
|
126 |
+
|
127 |
+
|
128 |
+
class ImageCropsIdxBuilder(LightningModule):
|
129 |
+
def __init__(self, args, model: open_clip.model.CLIP):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
self.args = args
|
133 |
+
self.save_root = args.save_root
|
134 |
+
self.k = args.k
|
135 |
+
self.model = model
|
136 |
+
|
137 |
+
def on_validation_epoch_start(self):
|
138 |
+
if self.global_rank == 0:
|
139 |
+
knowledge_db = TextDB(self.args.knowledge_db)
|
140 |
+
self.feature = knowledge_db.feature
|
141 |
+
self.text = knowledge_db.text
|
142 |
+
|
143 |
+
self.faiss_index = faiss.read_index(
|
144 |
+
str(Path(self.args.knowledge_db).parent/"faiss.index")
|
145 |
+
)
|
146 |
+
print("\nAdd data to faiss index...", end="\r")
|
147 |
+
ts = time.time()
|
148 |
+
self.faiss_index.add(self.feature)
|
149 |
+
print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
|
150 |
+
|
151 |
+
with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
|
152 |
+
f.attrs["fdim"] = self.feature.shape[-1]
|
153 |
+
f.attrs["file_hash"] = knowledge_db.file_hash
|
154 |
+
|
155 |
+
self.trainer.strategy.barrier()
|
156 |
+
|
157 |
+
def all_gather_object(self, data):
|
158 |
+
if self.trainer.world_size > 1:
|
159 |
+
gathered = [None for _ in range(self.trainer.world_size)]
|
160 |
+
dist.all_gather_object(gathered, data)
|
161 |
+
data = gathered
|
162 |
+
else:
|
163 |
+
data = [data, ]
|
164 |
+
|
165 |
+
return data
|
166 |
+
|
167 |
+
def broadcast_object(self, data, src_rank=0):
|
168 |
+
if self.trainer.world_size > 1:
|
169 |
+
if self.global_rank == src_rank:
|
170 |
+
data_list = [data, ] * self.trainer.world_size
|
171 |
+
else:
|
172 |
+
data_list = [None, ] * self.trainer.world_size
|
173 |
+
|
174 |
+
dist.broadcast_object_list(data_list, src=src_rank)
|
175 |
+
return data_list[0]
|
176 |
+
else:
|
177 |
+
return data
|
178 |
+
|
179 |
+
def search(self, images, topk):
|
180 |
+
query = self.model.encode_image(images, normalize=True)
|
181 |
+
query = query.cpu().numpy()
|
182 |
+
query = self.all_gather_object(query)
|
183 |
+
query = np.concatenate(query)
|
184 |
+
|
185 |
+
if self.global_rank == 0:
|
186 |
+
_, I = self.faiss_index.search(query, 4*topk)
|
187 |
+
S, I = refine_cosine(self.feature, query, I, self.device, topk)
|
188 |
+
else:
|
189 |
+
S = I = None
|
190 |
+
|
191 |
+
return S, I, query
|
192 |
+
|
193 |
+
def validation_step(self, batch, batch_idx):
|
194 |
+
orig_imgs, five_imgs, nine_imgs, ids = batch
|
195 |
+
|
196 |
+
ids = ids.cpu().numpy()
|
197 |
+
ids = np.concatenate(self.all_gather_object(ids))
|
198 |
+
|
199 |
+
S_w, I_w, Q_w = self.search(orig_imgs, topk=self.k)
|
200 |
+
|
201 |
+
S_f, I_f, Q_f = [], [], []
|
202 |
+
for i in range(five_imgs.shape[1]):
|
203 |
+
Si, Ii, Qi = self.search(five_imgs[:, i], topk=self.k)
|
204 |
+
S_f.append(Si)
|
205 |
+
I_f.append(Ii)
|
206 |
+
Q_f.append(Qi)
|
207 |
+
|
208 |
+
S_n, I_n, Q_n = [], [], []
|
209 |
+
for i in range(nine_imgs.shape[1]):
|
210 |
+
Si, Ii, Qi = self.search(nine_imgs[:, i], topk=self.k)
|
211 |
+
S_n.append(Si)
|
212 |
+
I_n.append(Ii)
|
213 |
+
Q_n.append(Qi)
|
214 |
+
|
215 |
+
if self.global_rank == 0:
|
216 |
+
S_w, I_w, Q_w = np.expand_dims(S_w, axis=1), np.expand_dims(I_w, axis=1), np.expand_dims(Q_w, axis=1)
|
217 |
+
S_f, I_f, Q_f = np.stack(S_f, axis=1), np.stack(I_f, axis=1), np.stack(Q_f, axis=1)
|
218 |
+
S_n, I_n, Q_n = np.stack(S_n, axis=1), np.stack(I_n, axis=1), np.stack(Q_n, axis=1)
|
219 |
+
|
220 |
+
with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
|
221 |
+
g = f.create_group(str(batch_idx))
|
222 |
+
|
223 |
+
g.create_dataset("image_ids", data=ids.astype(np.int32), compression="gzip")
|
224 |
+
|
225 |
+
gw = g.create_group("whole")
|
226 |
+
gw.create_dataset("index", data=I_w.astype(np.int32), compression="gzip")
|
227 |
+
gw.create_dataset("score", data=S_w.astype(np.float32), compression="gzip")
|
228 |
+
gw.create_dataset("query", data=Q_w.astype(np.float32), compression="gzip")
|
229 |
+
|
230 |
+
gf = g.create_group("five")
|
231 |
+
gf.create_dataset("index", data=I_f.astype(np.int32), compression="gzip")
|
232 |
+
gf.create_dataset("score", data=S_f.astype(np.float32), compression="gzip")
|
233 |
+
gf.create_dataset("query", data=Q_f.astype(np.float32), compression="gzip")
|
234 |
+
|
235 |
+
gn = g.create_group("nine")
|
236 |
+
gn.create_dataset("index", data=I_n.astype(np.int32), compression="gzip")
|
237 |
+
gn.create_dataset("score", data=S_n.astype(np.float32), compression="gzip")
|
238 |
+
gn.create_dataset("query", data=Q_n.astype(np.float32), compression="gzip")
|
239 |
+
|
240 |
+
def on_validation_epoch_end(self):
|
241 |
+
if self.args.azcopy and self.global_rank == 0:
|
242 |
+
with open("azcopy/sas_output", "r") as f:
|
243 |
+
sas = f.readline()
|
244 |
+
sas_base, sas_key = sas.split("?")
|
245 |
+
sas = f"{sas_base}/knowledge_idx?{sas_key}"
|
246 |
+
|
247 |
+
cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
|
248 |
+
print(f"start copying data with command {cmd}")
|
249 |
+
ts = time.time()
|
250 |
+
subprocess.run(cmd)
|
251 |
+
print(f"done copying data in {time.time() - ts:.2f} secs")
|
252 |
+
|
253 |
+
|
254 |
+
def main(args):
|
255 |
+
model, _, trans_img = open_clip.create_model_and_transforms(
|
256 |
+
args.clip_model, pretrained=text_db.CLIP_MODELS[args.clip_model]
|
257 |
+
)
|
258 |
+
|
259 |
+
print("load query dataset...")
|
260 |
+
if "coco" in args.query:
|
261 |
+
dset = coco.COCOImageCrops(Path(f"data/{args.query}"), trans=trans_img)
|
262 |
+
collate_crops = coco.collate_coco_crops
|
263 |
+
elif args.query == "cc3m":
|
264 |
+
dset = cc.CC3MImageCrops(Path("data/cc3m_instruct"), trans=trans_img)
|
265 |
+
collate_crops = cc.collate_cc_crops
|
266 |
+
elif args.query == "llava":
|
267 |
+
dset = llava.LLaVAImageCrops(Path("data/llava_bench"), trans=trans_img)
|
268 |
+
collate_crops = llava.collate_llava_crops
|
269 |
+
else:
|
270 |
+
raise ValueError
|
271 |
+
loader = DataLoader(
|
272 |
+
dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
|
273 |
+
drop_last=False, collate_fn=collate_crops
|
274 |
+
)
|
275 |
+
|
276 |
+
print("build model and trainer...")
|
277 |
+
pl_model = ImageCropsIdxBuilder(args, model)
|
278 |
+
model_summary = callbacks.RichModelSummary()
|
279 |
+
progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
|
280 |
+
trainer_callbacks = [model_summary, progress_bar]
|
281 |
+
trainer = Trainer(
|
282 |
+
sync_batchnorm=True,
|
283 |
+
precision=16,
|
284 |
+
accelerator='gpu',
|
285 |
+
devices=args.devices,
|
286 |
+
strategy="ddp",
|
287 |
+
default_root_dir=args.save_root,
|
288 |
+
callbacks=trainer_callbacks,
|
289 |
+
limit_val_batches=args.limit_val_batches
|
290 |
+
)
|
291 |
+
|
292 |
+
print("retrieve knowledge...")
|
293 |
+
trainer.validate(pl_model, dataloaders=loader)
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
parser = argparse.ArgumentParser(description='Knowledge retrieval using image crops')
|
298 |
+
parser = Trainer.add_argparse_args(parser)
|
299 |
+
parser.add_argument('--query', type=str, choices=["coco14", "coco17", "cc3m", "llava"], required=True)
|
300 |
+
parser.add_argument('--knowledge_db', type=str, required=True)
|
301 |
+
parser.add_argument('--k', type=int, default=128)
|
302 |
+
parser.add_argument("--bs", type=int, default=128)
|
303 |
+
parser.add_argument("--num_workers", type=int, default=7)
|
304 |
+
parser.add_argument("--seed", type=int, default=12345)
|
305 |
+
parser.add_argument("--refresh_rate", type=int, default=1)
|
306 |
+
parser.add_argument("--azcopy", action="store_true")
|
307 |
+
args = parser.parse_args()
|
308 |
+
|
309 |
+
# parse exp_name
|
310 |
+
exp_name = ExpName(f"(query-{args.query})")
|
311 |
+
exp_name += Path(args.knowledge_db).parent.name
|
312 |
+
if args.azcopy:
|
313 |
+
setattr(args, "save_root", Path("azcopy")/str(exp_name))
|
314 |
+
else:
|
315 |
+
setattr(args, "save_root", Path("output")/"knowledge_idx"/str(exp_name))
|
316 |
+
shutil.rmtree(args.save_root, ignore_errors=True)
|
317 |
+
args.save_root.mkdir(parents=True, exist_ok=True)
|
318 |
+
|
319 |
+
# parse model
|
320 |
+
model = exp_name.get("clip-model")[1:-1]
|
321 |
+
model = model[len("clip-model-"):]
|
322 |
+
assert model in text_db.CLIP_MODELS.keys()
|
323 |
+
setattr(args, "clip_model", model)
|
324 |
+
|
325 |
+
print(args)
|
326 |
+
seed_everything(args.seed, workers=True)
|
327 |
+
main(args)
|
knowledge/text_db.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import h5py
|
6 |
+
import time
|
7 |
+
import subprocess
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import codecs
|
11 |
+
|
12 |
+
import open_clip
|
13 |
+
import faiss
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from pytorch_lightning import callbacks
|
18 |
+
from pytorch_lightning import Trainer, LightningModule, seed_everything
|
19 |
+
|
20 |
+
import sys
|
21 |
+
sys.path.append("./")
|
22 |
+
from dataset import cc, words
|
23 |
+
from knowledge.utils import file_hash, build_faiss_index
|
24 |
+
|
25 |
+
|
26 |
+
class TextDB:
|
27 |
+
def __init__(self, text_db):
|
28 |
+
self.feature, self.text = self.load(text_db)
|
29 |
+
self.file_hash = file_hash(text_db)
|
30 |
+
|
31 |
+
def load(self, text_db):
|
32 |
+
with h5py.File(text_db, 'r') as f:
|
33 |
+
db_size = 0
|
34 |
+
for i in range(len(f)):
|
35 |
+
db_size += len(f[f"{i}/feature"])
|
36 |
+
_, d = f[f"0/feature"].shape
|
37 |
+
|
38 |
+
with h5py.File(text_db, 'r') as f:
|
39 |
+
feature = np.zeros((db_size, d), dtype=np.float32)
|
40 |
+
text = []
|
41 |
+
N = 0
|
42 |
+
for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
|
43 |
+
fi = f[f"{i}/feature"][:]
|
44 |
+
feature[N:N+len(fi)] = fi
|
45 |
+
N += len(fi)
|
46 |
+
|
47 |
+
text.extend(f[f"{i}/text"][:])
|
48 |
+
text = [codecs.decode(t) for t in text]
|
49 |
+
|
50 |
+
return feature, text
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
f = self.feature[idx]
|
54 |
+
|
55 |
+
try:
|
56 |
+
t = [self.text[i] for i in idx]
|
57 |
+
except TypeError:
|
58 |
+
t = self.text[idx]
|
59 |
+
|
60 |
+
return f, t
|
61 |
+
|
62 |
+
|
63 |
+
class TextDBBuilder(LightningModule):
|
64 |
+
def __init__(self, args, model: open_clip.model.CLIP):
|
65 |
+
super().__init__()
|
66 |
+
self.args = args
|
67 |
+
self.model = model
|
68 |
+
|
69 |
+
def validation_step(self, batch, batch_idx):
|
70 |
+
token, text = batch
|
71 |
+
feat = self.model.encode_text(token, normalize=True)
|
72 |
+
|
73 |
+
if self.trainer.world_size > 1:
|
74 |
+
text_gathered = [None for _ in range(self.trainer.world_size)]
|
75 |
+
dist.all_gather_object(text_gathered, text)
|
76 |
+
text = list(itertools.chain.from_iterable(text_gathered))
|
77 |
+
|
78 |
+
feat_gathered = [None for _ in range(self.trainer.world_size)]
|
79 |
+
dist.all_gather_object(feat_gathered, feat)
|
80 |
+
feat = torch.cat([x.to(self.device) for x in feat_gathered])
|
81 |
+
feat = feat.cpu().numpy()
|
82 |
+
|
83 |
+
if self.global_rank == 0:
|
84 |
+
with h5py.File(self.args.save_root/"knowledge_db.hdf5", "a") as f:
|
85 |
+
g = f.create_group(str(batch_idx))
|
86 |
+
g.create_dataset("feature", data=feat, compression="gzip")
|
87 |
+
g.create_dataset("text", data=text, compression="gzip")
|
88 |
+
|
89 |
+
def validation_epoch_end(self, outputs):
|
90 |
+
if self.global_rank == 0:
|
91 |
+
knowledge_db = TextDB(self.args.save_root/"knowledge_db.hdf5")
|
92 |
+
feat = knowledge_db.feature
|
93 |
+
|
94 |
+
if self.args.devices == "-1":
|
95 |
+
num_devices = torch.cuda.device_count()
|
96 |
+
devices = list(range(num_devices))
|
97 |
+
else:
|
98 |
+
devices = [int(x) for x in args.devices.split(",") if x]
|
99 |
+
print(f"CUDA devices: {devices}")
|
100 |
+
|
101 |
+
index = build_faiss_index(feat, gpus=devices)
|
102 |
+
faiss.write_index(index, str(self.args.save_root/"faiss.index"))
|
103 |
+
self.trainer.strategy.barrier()
|
104 |
+
|
105 |
+
if self.args.azcopy and self.global_rank == 0:
|
106 |
+
with open("azcopy/sas_output", "r") as f:
|
107 |
+
sas = f.readline()
|
108 |
+
sas_base, sas_key = sas.split("?")
|
109 |
+
sas = f"{sas_base}/knowledge_db?{sas_key}"
|
110 |
+
|
111 |
+
cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
|
112 |
+
print(f"start copying data with command {cmd}")
|
113 |
+
ts = time.time()
|
114 |
+
subprocess.run(cmd)
|
115 |
+
print(f"done copying data in {time.time() - ts:.2f} secs")
|
116 |
+
self.trainer.strategy.barrier()
|
117 |
+
|
118 |
+
|
119 |
+
DATASETS = {
|
120 |
+
"object": words.ObjsDataset,
|
121 |
+
"attribute": words.AttrsDataset,
|
122 |
+
"action": words.ActsDataset,
|
123 |
+
"cc3m": cc.CC3MTextDataset,
|
124 |
+
"cc12m": cc.CC12MTextDataset
|
125 |
+
}
|
126 |
+
|
127 |
+
|
128 |
+
def main(args):
|
129 |
+
model, _, _ = open_clip.create_model_and_transforms(
|
130 |
+
args.clip_model, pretrained=CLIP_MODELS[args.clip_model]
|
131 |
+
)
|
132 |
+
trans_txt = open_clip.get_tokenizer(args.clip_model)
|
133 |
+
|
134 |
+
print("load dataset...")
|
135 |
+
dset = DATASETS[args.dataset](Path(args.data_root), trans_txt)
|
136 |
+
loader = DataLoader(
|
137 |
+
dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
|
138 |
+
drop_last=False, collate_fn=cc.collate_cc_txt
|
139 |
+
)
|
140 |
+
|
141 |
+
print("build model and trainer...")
|
142 |
+
pl_model = TextDBBuilder(args, model)
|
143 |
+
model_summary = callbacks.RichModelSummary()
|
144 |
+
progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
|
145 |
+
trainer_callbacks = [model_summary, progress_bar]
|
146 |
+
trainer = Trainer(
|
147 |
+
sync_batchnorm=True,
|
148 |
+
precision=16,
|
149 |
+
accelerator='gpu',
|
150 |
+
devices=args.devices,
|
151 |
+
strategy="ddp",
|
152 |
+
default_root_dir=args.save_root,
|
153 |
+
callbacks=trainer_callbacks,
|
154 |
+
limit_val_batches=args.limit_val_batches
|
155 |
+
)
|
156 |
+
|
157 |
+
print("compute textual features...")
|
158 |
+
trainer.validate(pl_model, dataloaders=loader)
|
159 |
+
|
160 |
+
|
161 |
+
CLIP_MODELS = {
|
162 |
+
'ViT-B-32': 'openai',
|
163 |
+
'ViT-B-16': 'openai',
|
164 |
+
'ViT-L-14': 'openai',
|
165 |
+
'ViT-g-14': 'laion2b_s34b_b88k',
|
166 |
+
'ViT-bigG-14': 'laion2b_s39b_b160k',
|
167 |
+
'convnext_xxlarge': 'laion2b_s34b_b82k_augreg_soup',
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
parser = argparse.ArgumentParser(description="Build knowledge database of words")
|
173 |
+
parser = Trainer.add_argparse_args(parser)
|
174 |
+
parser.add_argument(
|
175 |
+
"--dataset", type=str, required=True, choices=["object", "attribute", "action", "cc3m", "cc12m"]
|
176 |
+
)
|
177 |
+
parser.add_argument("--data_root", type=str, default="data/conceptnet/conceptnet-assertions-5.7.0.csv")
|
178 |
+
parser.add_argument("--clip_model", type=str, default="ViT-g-14", choices=CLIP_MODELS.keys())
|
179 |
+
parser.add_argument("--bs", type=int, default=2**10)
|
180 |
+
parser.add_argument("--num_workers", type=int, default=7)
|
181 |
+
parser.add_argument("--seed", type=int, default=12345)
|
182 |
+
parser.add_argument("--refresh_rate", type=int, default=1)
|
183 |
+
parser.add_argument("--azcopy", action="store_true")
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
# feature dir
|
187 |
+
exp_name = f"(dataset-{args.dataset})(clip-model-{args.clip_model})"
|
188 |
+
if args.azcopy:
|
189 |
+
setattr(args, "save_root", Path("azcopy")/"knowledge_db"/exp_name)
|
190 |
+
else:
|
191 |
+
setattr(args, "save_root", Path("output")/"knowledge_db"/exp_name)
|
192 |
+
shutil.rmtree(args.save_root, ignore_errors=True)
|
193 |
+
args.save_root.mkdir(parents=True, exist_ok=True)
|
194 |
+
|
195 |
+
print(args)
|
196 |
+
seed_everything(args.seed, workers=True)
|
197 |
+
main(args)
|
knowledge/transforms.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from torchvision.transforms import functional as F
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
def five_crop(image, ratio=0.6):
|
7 |
+
w, h = image.size
|
8 |
+
hw = (h*ratio, w*ratio)
|
9 |
+
|
10 |
+
return F.five_crop(image, hw)
|
11 |
+
|
12 |
+
def nine_crop(image, ratio=0.4):
|
13 |
+
w, h = image.size
|
14 |
+
|
15 |
+
t = (0, int((0.5-ratio/2)*h), int((1.0 - ratio)*h))
|
16 |
+
b = (int(ratio*h), int((0.5+ratio/2)*h), h)
|
17 |
+
l = (0, int((0.5-ratio/2)*w), int((1.0 - ratio)*w))
|
18 |
+
r = (int(ratio*w), int((0.5+ratio/2)*w), w)
|
19 |
+
h, w = list(zip(t, b)), list(zip(l, r))
|
20 |
+
|
21 |
+
images = []
|
22 |
+
for s in itertools.product(h, w):
|
23 |
+
h, w = s
|
24 |
+
top, left = h[0], w[0]
|
25 |
+
height, width = h[1]-h[0], w[1]-w[0]
|
26 |
+
images.append(F.crop(image, top, left, height, width))
|
27 |
+
|
28 |
+
return images
|
29 |
+
|
30 |
+
|
31 |
+
def pre_caption(caption, max_words=None):
|
32 |
+
# Ref: https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py#L49-L68
|
33 |
+
caption = re.sub(
|
34 |
+
r"([.!\"()*#:;~])",
|
35 |
+
" ",
|
36 |
+
caption.lower(),
|
37 |
+
)
|
38 |
+
caption = re.sub(
|
39 |
+
r"\s{2,}",
|
40 |
+
" ",
|
41 |
+
caption,
|
42 |
+
)
|
43 |
+
caption = caption.rstrip("\n")
|
44 |
+
caption = caption.strip(" ")
|
45 |
+
|
46 |
+
# truncate caption
|
47 |
+
caption_words = caption.split(" ")
|
48 |
+
if max_words is not None and len(caption_words) > max_words:
|
49 |
+
caption = " ".join(caption_words[: max_words])
|
50 |
+
|
51 |
+
return caption
|
52 |
+
|
knowledge/utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
import math
|
5 |
+
import bisect
|
6 |
+
import hashlib
|
7 |
+
import faiss
|
8 |
+
from faiss import StandardGpuResources, index_cpu_to_gpu_multiple_py
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def file_hash(file):
|
13 |
+
# Ref: https://stackoverflow.com/a/59056837
|
14 |
+
with open(file, "rb") as f:
|
15 |
+
hash_fn = hashlib.blake2b()
|
16 |
+
chunk = f.read(8192)
|
17 |
+
while chunk:
|
18 |
+
hash_fn.update(chunk)
|
19 |
+
chunk = f.read(8192)
|
20 |
+
|
21 |
+
return hash_fn.hexdigest()
|
22 |
+
|
23 |
+
|
24 |
+
def build_faiss_index(x, gpus=None):
|
25 |
+
# Ref: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
|
26 |
+
# Ref: https://gist.github.com/mdouze/46d6bbbaabca0b9778fca37ed2bcccf6
|
27 |
+
|
28 |
+
N, dim = x.shape
|
29 |
+
secs = [2**i for i in range(1, 15)]
|
30 |
+
d = secs[bisect.bisect_right(secs, dim) - 1] // 2
|
31 |
+
m = d // 4
|
32 |
+
|
33 |
+
if N <= 60000:
|
34 |
+
index_factory = "Flat"
|
35 |
+
elif N <= 2555904:
|
36 |
+
index_factory = f"IVF{int(8*math.sqrt(N))},Flat"
|
37 |
+
elif N <= 10223616:
|
38 |
+
index_factory = f"OPQ{m}_{d},IVF65536_HNSW32,PQ{m}x4fsr"
|
39 |
+
elif N <= 1e8:
|
40 |
+
index_factory = f"OPQ{m}_{d},IVF262144_HNSW32,PQ{m}x4fsr"
|
41 |
+
else:
|
42 |
+
index_factory = f"OPQ{m}_{d},IVF1048576_HNSW32,PQ{m}x4fsr"
|
43 |
+
print(f"train {index_factory} index on {N:,} x {dim} data")
|
44 |
+
|
45 |
+
index = faiss.index_factory(dim, index_factory)
|
46 |
+
if gpus is not None and N > 60000:
|
47 |
+
index_ivf = faiss.extract_index_ivf(index)
|
48 |
+
res = []
|
49 |
+
for _ in gpus:
|
50 |
+
r = StandardGpuResources()
|
51 |
+
r.noTempMemory()
|
52 |
+
res.append(r)
|
53 |
+
clustering_index = index_cpu_to_gpu_multiple_py(
|
54 |
+
res, faiss.IndexFlatL2(index_ivf.d), None, gpus
|
55 |
+
)
|
56 |
+
index_ivf.clustering_index = clustering_index
|
57 |
+
|
58 |
+
print("train index...", end="\r")
|
59 |
+
ts = time.time()
|
60 |
+
# commented out for index_factory = "Flat"
|
61 |
+
# assert not index.is_trained
|
62 |
+
index.train(x)
|
63 |
+
assert index.is_trained
|
64 |
+
print(f"train index...done in {time.time() - ts:.2f} secs")
|
65 |
+
|
66 |
+
index.nprobe = 64
|
67 |
+
index.quantizer_efSearch = 32
|
68 |
+
|
69 |
+
return index
|
70 |
+
|
71 |
+
|
72 |
+
def nn_search(query, index, topk, bs=256, desc=None, disable_tqdm=True):
|
73 |
+
idx, dist = [], []
|
74 |
+
N = (len(query) - 1) // bs + 1
|
75 |
+
for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
|
76 |
+
D, I = index.search(query[i*bs: (i+1)*bs], topk)
|
77 |
+
idx.append(I)
|
78 |
+
dist.append(D)
|
79 |
+
idx = np.concatenate(idx)
|
80 |
+
dist = np.concatenate(dist)
|
81 |
+
|
82 |
+
return idx, dist
|
83 |
+
|
84 |
+
|
85 |
+
def radius_search(query, index, r, bs=256, desc=None, disable_tqdm=True):
|
86 |
+
idx, dist = [], []
|
87 |
+
N = (len(query) - 1) // bs + 1
|
88 |
+
for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
|
89 |
+
L, D, I = index.range_search(query[i*bs: (i+1)*bs], r)
|
90 |
+
idx.extend([I[L[j]:L[j+1]] for j in range(len(L)-1)])
|
91 |
+
dist.extend([D[L[j]:L[j+1]] for j in range(len(L)-1)])
|
92 |
+
|
93 |
+
return idx, dist
|
94 |
+
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def refine_cosine(Xa, Xq, I, device, k=None):
|
98 |
+
if k is not None:
|
99 |
+
assert k <= I.shape[1]
|
100 |
+
else:
|
101 |
+
k = I.shape[1]
|
102 |
+
|
103 |
+
Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d
|
104 |
+
Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d
|
105 |
+
Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d
|
106 |
+
|
107 |
+
sim = torch.sum(Xi * Xj, dim=-1) # bs x K
|
108 |
+
sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy()
|
109 |
+
I_refined, S_refined = [], []
|
110 |
+
for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx):
|
111 |
+
I_refined.append(idx_i[sort_i][:k])
|
112 |
+
S_refined.append(sim_i[sort_i][:k])
|
113 |
+
I_refined = np.stack(I_refined)
|
114 |
+
S_refined = np.stack(S_refined)
|
115 |
+
|
116 |
+
return S_refined, I_refined
|
117 |
+
|
118 |
+
|
119 |
+
def test_nn_search():
|
120 |
+
key = np.random.random((3000000, 512)).astype(np.float32)
|
121 |
+
key /= np.linalg.norm(key, keepdims=True, axis=1)
|
122 |
+
index = build_faiss_index(key, -1)
|
123 |
+
|
124 |
+
query = np.random.random((100000, 512)).astype(np.float32)
|
125 |
+
query /= np.linalg.norm(query, keepdims=True, axis=1)
|
126 |
+
idx_r = nn_search(query, index, r=0.5)
|
127 |
+
idx_k = nn_search(query, index, topk=10)
|
model/.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gptk import get_gptk_model, get_gptk_image_transform
|
model/ckpt/mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fab39af071b1e303f5976936a8662f75eb04952e03fa71bcb93291948892d2fd
|
3 |
+
size 31462530292
|
model/eva_vit.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
2 |
+
# https://github.com/baaivision/EVA
|
3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
5 |
+
# https://github.com/facebookresearch/deit/
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
import math
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
import sys
|
18 |
+
sys.path.append("./")
|
19 |
+
from model.utils import download_cached_file
|
20 |
+
|
21 |
+
|
22 |
+
class DropPath(nn.Module):
|
23 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
24 |
+
"""
|
25 |
+
def __init__(self, drop_prob=None):
|
26 |
+
super(DropPath, self).__init__()
|
27 |
+
self.drop_prob = drop_prob
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return drop_path(x, self.drop_prob, self.training)
|
31 |
+
|
32 |
+
def extra_repr(self) -> str:
|
33 |
+
return 'p={}'.format(self.drop_prob)
|
34 |
+
|
35 |
+
|
36 |
+
class Mlp(nn.Module):
|
37 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
38 |
+
super().__init__()
|
39 |
+
out_features = out_features or in_features
|
40 |
+
hidden_features = hidden_features or in_features
|
41 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
42 |
+
self.act = act_layer()
|
43 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
44 |
+
self.drop = nn.Dropout(drop)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.fc1(x)
|
48 |
+
x = self.act(x)
|
49 |
+
# x = self.drop(x)
|
50 |
+
# commit this for the orignal BERT implement
|
51 |
+
x = self.fc2(x)
|
52 |
+
x = self.drop(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Attention(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
59 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
if attn_head_dim is not None:
|
64 |
+
head_dim = attn_head_dim
|
65 |
+
all_head_dim = head_dim * self.num_heads
|
66 |
+
self.scale = qk_scale or head_dim ** -0.5
|
67 |
+
|
68 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
69 |
+
if qkv_bias:
|
70 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
71 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
72 |
+
else:
|
73 |
+
self.q_bias = None
|
74 |
+
self.v_bias = None
|
75 |
+
|
76 |
+
if window_size:
|
77 |
+
self.window_size = window_size
|
78 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
79 |
+
self.relative_position_bias_table = nn.Parameter(
|
80 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
81 |
+
# cls to token & token 2 cls & cls to cls
|
82 |
+
|
83 |
+
# get pair-wise relative position index for each token inside the window
|
84 |
+
coords_h = torch.arange(window_size[0])
|
85 |
+
coords_w = torch.arange(window_size[1])
|
86 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
87 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
88 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
89 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
90 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
91 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
92 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
93 |
+
relative_position_index = \
|
94 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
95 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
96 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
97 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
98 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
99 |
+
|
100 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
101 |
+
else:
|
102 |
+
self.window_size = None
|
103 |
+
self.relative_position_bias_table = None
|
104 |
+
self.relative_position_index = None
|
105 |
+
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
109 |
+
|
110 |
+
def forward(self, x, rel_pos_bias=None):
|
111 |
+
B, N, C = x.shape
|
112 |
+
qkv_bias = None
|
113 |
+
if self.q_bias is not None:
|
114 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
115 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
116 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
117 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
118 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
119 |
+
|
120 |
+
q = q * self.scale
|
121 |
+
attn = (q @ k.transpose(-2, -1))
|
122 |
+
|
123 |
+
if self.relative_position_bias_table is not None:
|
124 |
+
relative_position_bias = \
|
125 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
126 |
+
self.window_size[0] * self.window_size[1] + 1,
|
127 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
128 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
129 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
130 |
+
|
131 |
+
if rel_pos_bias is not None:
|
132 |
+
attn = attn + rel_pos_bias
|
133 |
+
|
134 |
+
attn = attn.softmax(dim=-1)
|
135 |
+
attn = self.attn_drop(attn)
|
136 |
+
|
137 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
138 |
+
x = self.proj(x)
|
139 |
+
x = self.proj_drop(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class Block(nn.Module):
|
144 |
+
|
145 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
146 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
147 |
+
window_size=None, attn_head_dim=None):
|
148 |
+
super().__init__()
|
149 |
+
self.norm1 = norm_layer(dim)
|
150 |
+
self.attn = Attention(
|
151 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
152 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
153 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
154 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
155 |
+
self.norm2 = norm_layer(dim)
|
156 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
157 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
158 |
+
|
159 |
+
if init_values is not None and init_values > 0:
|
160 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
161 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
162 |
+
else:
|
163 |
+
self.gamma_1, self.gamma_2 = None, None
|
164 |
+
|
165 |
+
def forward(self, x, rel_pos_bias=None):
|
166 |
+
if self.gamma_1 is None:
|
167 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
168 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
169 |
+
else:
|
170 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
171 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class PatchEmbed(nn.Module):
|
176 |
+
""" Image to Patch Embedding
|
177 |
+
"""
|
178 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
179 |
+
super().__init__()
|
180 |
+
img_size = to_2tuple(img_size)
|
181 |
+
patch_size = to_2tuple(patch_size)
|
182 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
183 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
184 |
+
self.img_size = img_size
|
185 |
+
self.patch_size = patch_size
|
186 |
+
self.num_patches = num_patches
|
187 |
+
|
188 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
189 |
+
|
190 |
+
def forward(self, x, **kwargs):
|
191 |
+
B, C, H, W = x.shape
|
192 |
+
# FIXME look at relaxing size constraints
|
193 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
194 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
195 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class RelativePositionBias(nn.Module):
|
200 |
+
|
201 |
+
def __init__(self, window_size, num_heads):
|
202 |
+
super().__init__()
|
203 |
+
self.window_size = window_size
|
204 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
205 |
+
self.relative_position_bias_table = nn.Parameter(
|
206 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
207 |
+
# cls to token & token 2 cls & cls to cls
|
208 |
+
|
209 |
+
# get pair-wise relative position index for each token inside the window
|
210 |
+
coords_h = torch.arange(window_size[0])
|
211 |
+
coords_w = torch.arange(window_size[1])
|
212 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
213 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
214 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
215 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
216 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
217 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
218 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
219 |
+
relative_position_index = \
|
220 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
221 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
222 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
223 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
224 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
225 |
+
|
226 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
227 |
+
|
228 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
229 |
+
|
230 |
+
def forward(self):
|
231 |
+
relative_position_bias = \
|
232 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
233 |
+
self.window_size[0] * self.window_size[1] + 1,
|
234 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
235 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
236 |
+
|
237 |
+
|
238 |
+
class VisionTransformer(nn.Module):
|
239 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
240 |
+
"""
|
241 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
242 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
243 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
244 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
245 |
+
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
246 |
+
super().__init__()
|
247 |
+
self.image_size = img_size
|
248 |
+
self.num_classes = num_classes
|
249 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
250 |
+
|
251 |
+
self.patch_embed = PatchEmbed(
|
252 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
253 |
+
num_patches = self.patch_embed.num_patches
|
254 |
+
|
255 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
256 |
+
if use_abs_pos_emb:
|
257 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
258 |
+
else:
|
259 |
+
self.pos_embed = None
|
260 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
261 |
+
|
262 |
+
if use_shared_rel_pos_bias:
|
263 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
264 |
+
else:
|
265 |
+
self.rel_pos_bias = None
|
266 |
+
self.use_checkpoint = use_checkpoint
|
267 |
+
|
268 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
269 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
270 |
+
self.blocks = nn.ModuleList([
|
271 |
+
Block(
|
272 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
273 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
274 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
275 |
+
for i in range(depth)])
|
276 |
+
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
277 |
+
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
278 |
+
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
279 |
+
|
280 |
+
if self.pos_embed is not None:
|
281 |
+
trunc_normal_(self.pos_embed, std=.02)
|
282 |
+
trunc_normal_(self.cls_token, std=.02)
|
283 |
+
# trunc_normal_(self.mask_token, std=.02)
|
284 |
+
# if isinstance(self.head, nn.Linear):
|
285 |
+
# trunc_normal_(self.head.weight, std=.02)
|
286 |
+
self.apply(self._init_weights)
|
287 |
+
self.fix_init_weight()
|
288 |
+
# if isinstance(self.head, nn.Linear):
|
289 |
+
# self.head.weight.data.mul_(init_scale)
|
290 |
+
# self.head.bias.data.mul_(init_scale)
|
291 |
+
|
292 |
+
def fix_init_weight(self):
|
293 |
+
def rescale(param, layer_id):
|
294 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
295 |
+
|
296 |
+
for layer_id, layer in enumerate(self.blocks):
|
297 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
298 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
299 |
+
|
300 |
+
def _init_weights(self, m):
|
301 |
+
if isinstance(m, nn.Linear):
|
302 |
+
trunc_normal_(m.weight, std=.02)
|
303 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
304 |
+
nn.init.constant_(m.bias, 0)
|
305 |
+
elif isinstance(m, nn.LayerNorm):
|
306 |
+
nn.init.constant_(m.bias, 0)
|
307 |
+
nn.init.constant_(m.weight, 1.0)
|
308 |
+
|
309 |
+
def get_classifier(self):
|
310 |
+
return self.head
|
311 |
+
|
312 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
313 |
+
self.num_classes = num_classes
|
314 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
315 |
+
|
316 |
+
def forward_features(self, x):
|
317 |
+
x = self.patch_embed(x)
|
318 |
+
batch_size, seq_len, _ = x.size()
|
319 |
+
|
320 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
321 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
322 |
+
if self.pos_embed is not None:
|
323 |
+
x = x + self.pos_embed
|
324 |
+
x = self.pos_drop(x)
|
325 |
+
|
326 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
327 |
+
for blk in self.blocks:
|
328 |
+
if self.use_checkpoint:
|
329 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
330 |
+
else:
|
331 |
+
x = blk(x, rel_pos_bias)
|
332 |
+
return x
|
333 |
+
# x = self.norm(x)
|
334 |
+
|
335 |
+
# if self.fc_norm is not None:
|
336 |
+
# t = x[:, 1:, :]
|
337 |
+
# return self.fc_norm(t.mean(1))
|
338 |
+
# else:
|
339 |
+
# return x[:, 0]
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
x = self.forward_features(x)
|
343 |
+
# x = self.head(x)
|
344 |
+
return x
|
345 |
+
|
346 |
+
def get_intermediate_layers(self, x):
|
347 |
+
x = self.patch_embed(x)
|
348 |
+
batch_size, seq_len, _ = x.size()
|
349 |
+
|
350 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
351 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
352 |
+
if self.pos_embed is not None:
|
353 |
+
x = x + self.pos_embed
|
354 |
+
x = self.pos_drop(x)
|
355 |
+
|
356 |
+
features = []
|
357 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
358 |
+
for blk in self.blocks:
|
359 |
+
x = blk(x, rel_pos_bias)
|
360 |
+
features.append(x)
|
361 |
+
|
362 |
+
return features
|
363 |
+
|
364 |
+
|
365 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
366 |
+
if 'pos_embed' in checkpoint_model:
|
367 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
368 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
369 |
+
num_patches = model.patch_embed.num_patches
|
370 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
371 |
+
# height (== width) for the checkpoint position embedding
|
372 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
373 |
+
# height (== width) for the new position embedding
|
374 |
+
new_size = int(num_patches ** 0.5)
|
375 |
+
# class_token and dist_token are kept unchanged
|
376 |
+
if orig_size != new_size:
|
377 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
378 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
379 |
+
# only the position tokens are interpolated
|
380 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
381 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
382 |
+
pos_tokens = torch.nn.functional.interpolate(
|
383 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
384 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
385 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
386 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
387 |
+
|
388 |
+
|
389 |
+
def convert_weights_to_fp16(model: nn.Module):
|
390 |
+
"""Convert applicable model parameters to fp16"""
|
391 |
+
|
392 |
+
def _convert_weights_to_fp16(l):
|
393 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
394 |
+
l.weight.data = l.weight.data.half()
|
395 |
+
if l.bias is not None:
|
396 |
+
l.bias.data = l.bias.data.half()
|
397 |
+
|
398 |
+
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
399 |
+
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
400 |
+
# tensor = getattr(l, attr)
|
401 |
+
# if tensor is not None:
|
402 |
+
# tensor.data = tensor.data.half()
|
403 |
+
|
404 |
+
model.apply(_convert_weights_to_fp16)
|
405 |
+
|
406 |
+
|
407 |
+
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
|
408 |
+
model = VisionTransformer(
|
409 |
+
img_size=img_size,
|
410 |
+
patch_size=14,
|
411 |
+
use_mean_pooling=False,
|
412 |
+
embed_dim=1408,
|
413 |
+
depth=39,
|
414 |
+
num_heads=1408//88,
|
415 |
+
mlp_ratio=4.3637,
|
416 |
+
qkv_bias=True,
|
417 |
+
drop_path_rate=drop_path_rate,
|
418 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
419 |
+
use_checkpoint=use_checkpoint,
|
420 |
+
)
|
421 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
422 |
+
cached_file = download_cached_file(
|
423 |
+
url, check_hash=False, progress=True
|
424 |
+
)
|
425 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
426 |
+
interpolate_pos_embed(model,state_dict)
|
427 |
+
|
428 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
429 |
+
# print(incompatible_keys)
|
430 |
+
|
431 |
+
if precision == "fp16":
|
432 |
+
# model.to("cuda")
|
433 |
+
convert_weights_to_fp16(model)
|
434 |
+
return model
|
model/gptk-7b.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
+
# All rights reserved.
|
3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
+
|
6 |
+
arch: instruct_vicuna7b
|
7 |
+
pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
|
8 |
+
|
9 |
+
# vit encoder
|
10 |
+
image_size: 224
|
11 |
+
drop_path_rate: 0
|
12 |
+
use_grad_checkpoint: False
|
13 |
+
vit_precision: "fp16"
|
14 |
+
freeze_vit: True
|
15 |
+
|
16 |
+
# Q-Former
|
17 |
+
num_query_token: 32
|
18 |
+
|
19 |
+
# path to Vicuna checkpoint
|
20 |
+
llm_model: "model/llm/vicuna-7b-v1.1"
|
21 |
+
# llm_model: "lmsys/vicuna-7b-v1.3"
|
22 |
+
# llm_model: "lmsys/vicuna-7b-v1.5"
|
23 |
+
|
24 |
+
# generation configs
|
25 |
+
prompt: ""
|