cwkuo commited on
Commit
7962ed0
·
1 Parent(s): fb92e97

implement gpt-k demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .vscode/settings.json +6 -0
  3. README.md +1 -1
  4. app.py +387 -0
  5. conversation.py +364 -0
  6. examples/diamond_head.jpg +3 -0
  7. examples/horseshoe_bend.jpg +3 -0
  8. examples/mona_lisa.jpg +3 -0
  9. examples/mona_lisa_dog.jpg +3 -0
  10. examples/titanic.jpg +3 -0
  11. knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
  12. knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
  13. knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
  14. knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
  15. knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
  16. knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
  17. knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
  18. knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
  19. knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index +3 -0
  20. knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt +1 -0
  21. knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 +3 -0
  22. knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy +3 -0
  23. knowledge/__init__.py +2 -0
  24. knowledge/__pycache__/__init__.cpython-37.pyc +0 -0
  25. knowledge/__pycache__/__init__.cpython-38.pyc +0 -0
  26. knowledge/__pycache__/cluster.cpython-38.pyc +0 -0
  27. knowledge/__pycache__/dbscan.cpython-37.pyc +0 -0
  28. knowledge/__pycache__/dbscan.cpython-38.pyc +0 -0
  29. knowledge/__pycache__/image_crops_idx.cpython-38.pyc +0 -0
  30. knowledge/__pycache__/image_tokens_idx.cpython-38.pyc +0 -0
  31. knowledge/__pycache__/revive.cpython-38.pyc +0 -0
  32. knowledge/__pycache__/sentence_db.cpython-37.pyc +0 -0
  33. knowledge/__pycache__/sentence_db.cpython-38.pyc +0 -0
  34. knowledge/__pycache__/sentence_idx.cpython-37.pyc +0 -0
  35. knowledge/__pycache__/sentence_idx.cpython-38.pyc +0 -0
  36. knowledge/__pycache__/text_db.cpython-38.pyc +0 -0
  37. knowledge/__pycache__/utils.cpython-37.pyc +0 -0
  38. knowledge/__pycache__/utils.cpython-38.pyc +0 -0
  39. knowledge/__pycache__/vis_vocab.cpython-37.pyc +0 -0
  40. knowledge/__pycache__/wordnet.cpython-37.pyc +0 -0
  41. knowledge/cluster.py +178 -0
  42. knowledge/retrieve.py +327 -0
  43. knowledge/text_db.py +197 -0
  44. knowledge/transforms.py +52 -0
  45. knowledge/utils.py +127 -0
  46. model/.gitattributes +2 -0
  47. model/__init__.py +1 -0
  48. model/ckpt/mp_rank_00_model_states.pt +3 -0
  49. model/eva_vit.py +434 -0
  50. model/gptk-7b.yaml +25 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .bin filter=lfs diff=lfs merge=lfs -text
37
+ .pt filter=lfs diff=lfs merge=lfs -text
38
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
39
+ *.index filter=lfs diff=lfs merge=lfs -text
40
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.autopep8"
4
+ },
5
+ "python.formatting.provider": "none"
6
+ }
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: K GPT
3
  emoji: 🚀
4
  colorFrom: green
5
  colorTo: red
 
1
  ---
2
+ title: GPT-K
3
  emoji: 🚀
4
  colorFrom: green
5
  colorTo: red
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ import numpy as np
9
+
10
+ import torch
11
+ import open_clip
12
+ import faiss
13
+ from transformers import TextIteratorStreamer
14
+ from threading import Thread
15
+
16
+ from conversation import default_conversation, conv_templates, Conversation
17
+ from knowledge import TextDB
18
+ from knowledge.transforms import five_crop, nine_crop
19
+ from knowledge.utils import refine_cosine
20
+ from model import get_gptk_model, get_gptk_image_transform
21
+
22
+
23
+ no_change_btn = gr.Button.update()
24
+ enable_btn = gr.Button.update(interactive=True)
25
+ disable_btn = gr.Button.update(interactive=False)
26
+ knwl_none = (None, ) * 30
27
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
28
+
29
+
30
+ def violates_moderation(text):
31
+ """
32
+ Check whether the text violates OpenAI moderation API.
33
+ """
34
+ url = "https://api.openai.com/v1/moderations"
35
+ headers = {
36
+ "Content-Type": "application/json",
37
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]
38
+ }
39
+ text = text.replace("\n", "")
40
+ data = "{" + '"input": ' + f'"{text}"' + "}"
41
+ data = data.encode("utf-8")
42
+ try:
43
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
44
+ flagged = ret.json()["results"][0]["flagged"]
45
+ except requests.exceptions.RequestException as e:
46
+ flagged = False
47
+ except KeyError as e:
48
+ flagged = False
49
+
50
+ return flagged
51
+
52
+
53
+ def load_demo():
54
+ state = default_conversation.copy()
55
+ return (state, )
56
+
57
+
58
+ def regenerate(state: Conversation):
59
+ state.messages[-1][-1] = None
60
+ prev_human_msg = state.messages[-2]
61
+ if type(prev_human_msg[1]) in (tuple, list):
62
+ prev_human_msg[1] = prev_human_msg[1][:2]
63
+ state.skip_next = False
64
+
65
+ return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn)
66
+
67
+
68
+ def clear_history():
69
+ state = default_conversation.copy()
70
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2 + knwl_none
71
+
72
+
73
+ def add_text(state: Conversation, text, image):
74
+ if len(text) <= 0 and image is None:
75
+ state.skip_next = True
76
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 2
77
+
78
+ if violates_moderation(text):
79
+ state.skip_next = True
80
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 2
81
+
82
+ text = (text, image)
83
+ if len(state.get_images(return_pil=True)) > 0:
84
+ state = default_conversation.copy()
85
+ state.append_message(state.roles[0], text)
86
+ state.append_message(state.roles[1], None)
87
+ state.skip_next = False
88
+
89
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
90
+
91
+
92
+ def search(image, pos, topk, knwl_db, knwl_idx):
93
+ with torch.cuda.amp.autocast():
94
+ image = query_trans(image).unsqueeze(0).to(device)
95
+ query = query_enc.encode_image(image, normalize=True)
96
+ query = query.cpu().numpy()
97
+
98
+ _, I = knwl_idx.search(query, 4*topk)
99
+ score, I = refine_cosine(knwl_db.feature, query, I, device, topk)
100
+ score, I = score.flatten(), I.flatten()
101
+ embd, text = knwl_db[I]
102
+ pos = np.full((topk, ), fill_value=pos)
103
+
104
+ query = torch.FloatTensor(query).unsqueeze(0).to(device)
105
+ embd = torch.FloatTensor(embd).unsqueeze(0).to(device)
106
+ pos = torch.LongTensor(pos).unsqueeze(0).to(device)
107
+ score = torch.FloatTensor(score).unsqueeze(0).to(device)
108
+
109
+ return query, embd, pos, score, text
110
+
111
+
112
+ def retrieve_knowledge(image):
113
+ knwl_embd = {}
114
+ knwl_text = {}
115
+ for query_type, topk_q in topk.items():
116
+ if topk_q == 0: continue
117
+
118
+ if query_type == "whole":
119
+ images = [image, ]
120
+ knwl_text[query_type] = {i: {} for i in range(1)}
121
+ elif query_type == "five":
122
+ images = five_crop(image)
123
+ knwl_text[query_type] = {i: {} for i in range(5)}
124
+ elif query_type == "nine":
125
+ images = nine_crop(image)
126
+ knwl_text[query_type] = {i: {} for i in range(9)}
127
+ else:
128
+ raise ValueError
129
+
130
+ knwl_embd[query_type] = {}
131
+ for knwl_type, (knwl_db_t, knwl_idx_t) in knwl_db.items():
132
+ query, embed, pos, score = [], [], [], []
133
+ for i, img in enumerate(images):
134
+ query_i, embed_i, pos_i, score_i, text_i = search(
135
+ img, i, topk_q, knwl_db_t, knwl_idx_t
136
+ )
137
+ query.append(query_i)
138
+ embed.append(embed_i)
139
+ pos.append(pos_i)
140
+ score.append(score_i)
141
+ knwl_text[query_type][i][knwl_type] = text_i
142
+
143
+ query = torch.cat(query, dim=1)
144
+ embed = torch.cat(embed, dim=1)
145
+ pos = torch.cat(pos, dim=1)
146
+ score = torch.cat(score, dim=1)
147
+
148
+ knwl_embd[query_type][knwl_type] = {
149
+ "embed": embed, "query": query, "pos": pos, "score": score
150
+ }
151
+
152
+ return knwl_embd, knwl_text
153
+
154
+
155
+ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
156
+ if state.skip_next: # This generate call is skipped due to invalid inputs
157
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2 + knwl_none
158
+ return
159
+
160
+ if len(state.messages) == state.offset + 2: # First round of conversation
161
+ new_state = conv_templates["gptk"].copy()
162
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
163
+ new_state.append_message(new_state.roles[1], None)
164
+ state = new_state
165
+
166
+ # retrieve and visualize knowledge
167
+ image = state.get_images(return_pil=True)[0]
168
+ if bool(add_knwl):
169
+ knwl_embd, knwl = retrieve_knowledge(image)
170
+ knwl_img, knwl_txt, idx = [None, ] * 15, ["", ] * 15, 0
171
+ for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
172
+ if query_type == "whole":
173
+ images = [image, ]
174
+ elif query_type == "five":
175
+ images = five_crop(image)
176
+ elif query_type == "nine":
177
+ images = nine_crop(image)
178
+
179
+ for pos in range(knwl_pos):
180
+ try:
181
+ txt = ""
182
+ for k, v in knwl[query_type][str(pos)].items():
183
+ v = ", ".join([vi.replace("_", " ") for vi in v])
184
+ txt += f"**[{k.upper()}]:** {v}\n\n"
185
+ knwl_txt[idx] += txt
186
+ knwl_img[idx] = images[pos]
187
+ except KeyError:
188
+ pass
189
+ idx += 1
190
+ knwl_vis = tuple(knwl_img + knwl_txt)
191
+ else:
192
+ knwl_embd = None
193
+ knwl_vis = knwl_none
194
+
195
+ # generate output
196
+ prompt = state.get_prompt()
197
+ prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
198
+ image_pt = image_trans(image).to(device).unsqueeze(0)
199
+ samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
200
+
201
+ if bool(do_beam_search):
202
+ new_text = gptk_model.generate(
203
+ samples=samples,
204
+ use_nucleus_sampling=bool(do_sampling),
205
+ max_length=min(int(max_new_tokens), 1024),
206
+ top_p=float(top_p),
207
+ temperature=float(temperature),
208
+ auto_cast=True
209
+ )[0]
210
+ streamer = [new_text, ]
211
+ else:
212
+ streamer = TextIteratorStreamer(
213
+ gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
214
+ )
215
+ thread = Thread(
216
+ target=gptk_model.generate,
217
+ kwargs=dict(
218
+ samples=samples,
219
+ use_nucleus_sampling=bool(do_sampling),
220
+ max_length=min(int(max_new_tokens), 1024),
221
+ top_p=float(top_p),
222
+ temperature=float(temperature),
223
+ streamer=streamer,
224
+ num_beams=1,
225
+ auto_cast=True
226
+ )
227
+ )
228
+ thread.start()
229
+
230
+ generated_text = ""
231
+ for new_text in streamer:
232
+ generated_text += new_text
233
+ state.messages[-1][-1] = generated_text + "▌"
234
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 + knwl_vis
235
+ time.sleep(0.03)
236
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
237
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2 + knwl_vis
238
+
239
+
240
+ title_markdown = ("""
241
+ # GPT-K: Knowledge Augmented Vision-and-Language Assistant
242
+ """)
243
+
244
+ tos_markdown = ("""
245
+ ### Terms of use
246
+ By using this service, users are required to agree to the following terms:
247
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
248
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
249
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
250
+ """)
251
+
252
+ learn_more_markdown = ("""
253
+ ### License
254
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
255
+ """)
256
+
257
+
258
+ def build_demo():
259
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
260
+ imagebox = gr.Image(type="pil")
261
+ state = gr.State()
262
+
263
+ with gr.Blocks(title="GPT-K", theme=gr.themes.Base()) as demo:
264
+ gr.Markdown(title_markdown)
265
+ with gr.Row():
266
+ with gr.Column(scale=3):
267
+ gr.Examples(examples=[
268
+ ["examples/mona_lisa.jpg", "Discuss the historical impact and the significance of this painting in the art world."],
269
+ ["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
270
+ ["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
271
+ ["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
272
+ ["examples/titanic.jpg", "What happen in the scene in this movie?"],
273
+ ], inputs=[imagebox, textbox])
274
+
275
+ imagebox.render()
276
+ textbox.render()
277
+ with gr.Column():
278
+ submit_btn = gr.Button(value="📝 Submit")
279
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
280
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
281
+
282
+ with gr.Accordion("Parameters", open=True):
283
+ with gr.Row():
284
+ add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
285
+ do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
286
+ do_beam_search = gr.Checkbox(value=False, interactive=True, label="Beam search")
287
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
288
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
289
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
290
+
291
+ with gr.Column(scale=6):
292
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
293
+
294
+ gr.Markdown("Retrieved Knowledge")
295
+ knwl_img, knwl_txt = [], []
296
+ for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
297
+ with gr.Tab(query_type):
298
+ for p in range(knwl_pos):
299
+ with gr.Tab(str(p)):
300
+ with gr.Row():
301
+ with gr.Column(scale=1):
302
+ knwl_img.append(gr.Image(type="pil", show_label=False, interactive=False))
303
+ with gr.Column(scale=7):
304
+ knwl_txt.append(gr.Markdown())
305
+ knwl_vis = knwl_img + knwl_txt
306
+
307
+ gr.Markdown(tos_markdown)
308
+ gr.Markdown(learn_more_markdown)
309
+
310
+ # Register listeners
311
+ btn_list = [regenerate_btn, clear_btn]
312
+ regenerate_btn.click(
313
+ regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
314
+ ).then(
315
+ generate,
316
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
317
+ [state, chatbot] + btn_list + knwl_vis
318
+ )
319
+
320
+ clear_btn.click(
321
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + knwl_vis
322
+ )
323
+
324
+ textbox.submit(
325
+ add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
326
+ ).then(
327
+ generate,
328
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
329
+ [state, chatbot] + btn_list + knwl_vis
330
+ )
331
+
332
+ submit_btn.click(
333
+ add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
334
+ ).then(
335
+ generate,
336
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
337
+ [state, chatbot] + btn_list + knwl_vis
338
+ )
339
+
340
+ demo.load(load_demo, None, [state, ])
341
+
342
+ return demo
343
+
344
+
345
+ def build_model():
346
+ if torch.cuda.is_available():
347
+ device = torch.device("cuda")
348
+ else:
349
+ device = torch.device("cpu")
350
+
351
+ query_enc, _, query_trans = open_clip.create_model_and_transforms(
352
+ "ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
353
+ )
354
+ query_enc = query_enc.to(device).eval()
355
+
356
+ def get_knwl(knowledge_db):
357
+ knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
358
+ knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
359
+ knwl_idx.add(knwl_db.feature)
360
+
361
+ return knwl_db, knwl_idx
362
+
363
+ knwl_db = {
364
+ "obj": get_knwl('knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
365
+ "act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
366
+ "attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
367
+ }
368
+ d_knwl = knwl_db["obj"][0].feature.shape[-1]
369
+
370
+ _, image_trans = get_gptk_image_transform()
371
+ topk = {"whole": 60, "five": 24, "nine": 16}
372
+ gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
373
+ gptk_ckpt = "model/ckpt/mp_rank_00_model_states.pt"
374
+ gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
375
+ gptk_ckpt = {
376
+ ".".join(k.split(".")[2:]): v
377
+ for k, v in gptk_ckpt["module"].items()
378
+ }
379
+ gptk_model.load_state_dict(gptk_ckpt)
380
+ gptk_model = gptk_model.to(device).eval()
381
+
382
+ return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
383
+
384
+
385
+ knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device = build_model()
386
+ demo = build_demo()
387
+ demo.queue().launch()
conversation.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
32
+ messages = self.messages.copy()
33
+ init_role, init_msg = messages[0].copy()
34
+ init_msg = init_msg[0].replace("<image>", "").strip()
35
+ if 'mmtag' in self.version:
36
+ messages[0] = (init_role, init_msg)
37
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
38
+ messages.insert(1, (self.roles[1], "Received."))
39
+ else:
40
+ messages[0] = (init_role, "<image>\n" + init_msg)
41
+
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in messages:
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ ret = self.system + seps[0]
54
+ for i, (role, message) in enumerate(messages):
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + ": " + message + seps[i % 2]
59
+ else:
60
+ ret += role + ":"
61
+ elif self.sep_style == SeparatorStyle.MPT:
62
+ ret = self.system + self.sep
63
+ for role, message in messages:
64
+ if message:
65
+ if type(message) is tuple:
66
+ message, _, _ = message
67
+ ret += role + message + self.sep
68
+ else:
69
+ ret += role
70
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
71
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
72
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
73
+ ret = ""
74
+
75
+ for i, (role, message) in enumerate(messages):
76
+ if i == 0:
77
+ assert message, "first message should not be none"
78
+ assert role == self.roles[0], "first message should come from user"
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i == 0: message = wrap_sys(self.system) + message
83
+ if i % 2 == 0:
84
+ message = wrap_inst(message)
85
+ ret += self.sep + message
86
+ else:
87
+ ret += " " + message + " " + self.sep2
88
+ else:
89
+ ret += ""
90
+ ret = ret.lstrip(self.sep)
91
+ elif self.sep_style == SeparatorStyle.PLAIN:
92
+ seps = [self.sep, self.sep2]
93
+ ret = self.system
94
+ for i, (role, message) in enumerate(messages):
95
+ if message:
96
+ if type(message) is tuple:
97
+ message, _, _ = message
98
+ ret += message + seps[i % 2]
99
+ else:
100
+ ret += ""
101
+ else:
102
+ raise ValueError(f"Invalid style: {self.sep_style}")
103
+
104
+ return ret
105
+
106
+ def append_message(self, role, message):
107
+ self.messages.append([role, message])
108
+
109
+ def get_images(self, return_pil=False):
110
+ images = []
111
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
112
+ if i % 2 == 0:
113
+ if type(msg) is tuple:
114
+ image = msg[1].convert('RGB')
115
+ if return_pil:
116
+ images.append(image)
117
+ else:
118
+ import base64
119
+ from io import BytesIO
120
+ buffered = BytesIO()
121
+ image.save(buffered, format="PNG")
122
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
123
+ images.append(img_b64_str)
124
+ return images
125
+
126
+ def to_gradio_chatbot(self):
127
+ ret = []
128
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
129
+ if i % 2 == 0:
130
+ if type(msg) is tuple:
131
+ import base64
132
+ from io import BytesIO
133
+ msg, image = msg
134
+ max_hw, min_hw = max(image.size), min(image.size)
135
+ aspect_ratio = max_hw / min_hw
136
+ max_len, min_len = 800, 400
137
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
138
+ longest_edge = int(shortest_edge * aspect_ratio)
139
+ W, H = image.size
140
+ if H > W:
141
+ H, W = longest_edge, shortest_edge
142
+ else:
143
+ H, W = shortest_edge, longest_edge
144
+ image = image.resize((W, H))
145
+ buffered = BytesIO()
146
+ image.save(buffered, format="JPEG")
147
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
148
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
149
+ ret.append([img_str, None])
150
+ msg = msg.replace('<image>', '').strip()
151
+ if len(msg) > 0:
152
+ ret.append([msg, None])
153
+ else:
154
+ ret.append([msg, None])
155
+ else:
156
+ ret[-1][-1] = msg
157
+ return ret
158
+
159
+ def copy(self):
160
+ return Conversation(
161
+ system=self.system,
162
+ roles=self.roles,
163
+ messages=[[x, y] for x, y in self.messages],
164
+ offset=self.offset,
165
+ sep_style=self.sep_style,
166
+ sep=self.sep,
167
+ sep2=self.sep2,
168
+ version=self.version)
169
+
170
+ def dict(self):
171
+ if len(self.get_images()) > 0:
172
+ return {
173
+ "system": self.system,
174
+ "roles": self.roles,
175
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
176
+ "offset": self.offset,
177
+ "sep": self.sep,
178
+ "sep2": self.sep2,
179
+ }
180
+ return {
181
+ "system": self.system,
182
+ "roles": self.roles,
183
+ "messages": self.messages,
184
+ "offset": self.offset,
185
+ "sep": self.sep,
186
+ "sep2": self.sep2,
187
+ }
188
+
189
+
190
+ conv_gptk = Conversation(
191
+ system="",
192
+ roles=("USER", "ASSISTANT"),
193
+ version="v1",
194
+ messages=(),
195
+ offset=0,
196
+ sep_style=SeparatorStyle.SINGLE,
197
+ sep=""
198
+ )
199
+
200
+
201
+ conv_vicuna_v0 = Conversation(
202
+ system="A chat between a curious human and an artificial intelligence assistant. "
203
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
204
+ roles=("Human", "Assistant"),
205
+ messages=(
206
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
207
+ ("Assistant",
208
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
209
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
210
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
211
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
212
+ "renewable and non-renewable energy sources:\n"
213
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
214
+ "energy sources are finite and will eventually run out.\n"
215
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
216
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
217
+ "and other negative effects.\n"
218
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
219
+ "have lower operational costs than non-renewable sources.\n"
220
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
221
+ "locations than non-renewable sources.\n"
222
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
223
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
224
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
225
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
226
+ ),
227
+ offset=2,
228
+ sep_style=SeparatorStyle.SINGLE,
229
+ sep="###",
230
+ )
231
+
232
+ conv_vicuna_v1 = Conversation(
233
+ system="A chat between a curious user and an artificial intelligence assistant. "
234
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
235
+ roles=("USER", "ASSISTANT"),
236
+ version="v1",
237
+ messages=(),
238
+ offset=0,
239
+ sep_style=SeparatorStyle.TWO,
240
+ sep=" ",
241
+ sep2="</s>",
242
+ )
243
+
244
+ conv_llama_2 = Conversation(
245
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
246
+
247
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
248
+ roles=("USER", "ASSISTANT"),
249
+ version="llama_v2",
250
+ messages=(),
251
+ offset=0,
252
+ sep_style=SeparatorStyle.LLAMA_2,
253
+ sep="<s>",
254
+ sep2="</s>",
255
+ )
256
+
257
+ conv_llava_llama_2 = Conversation(
258
+ system="You are a helpful language and vision assistant. "
259
+ "You are able to understand the visual content that the user provides, "
260
+ "and assist the user with a variety of tasks using natural language.",
261
+ roles=("USER", "ASSISTANT"),
262
+ version="llama_v2",
263
+ messages=(),
264
+ offset=0,
265
+ sep_style=SeparatorStyle.LLAMA_2,
266
+ sep="<s>",
267
+ sep2="</s>",
268
+ )
269
+
270
+ conv_mpt = Conversation(
271
+ system="""<|im_start|>system
272
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
273
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
274
+ version="mpt",
275
+ messages=(),
276
+ offset=0,
277
+ sep_style=SeparatorStyle.MPT,
278
+ sep="<|im_end|>",
279
+ )
280
+
281
+ conv_llava_plain = Conversation(
282
+ system="",
283
+ roles=("", ""),
284
+ messages=(
285
+ ),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.PLAIN,
288
+ sep="\n",
289
+ )
290
+
291
+ conv_llava_v0 = Conversation(
292
+ system="A chat between a curious human and an artificial intelligence assistant. "
293
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
294
+ roles=("Human", "Assistant"),
295
+ messages=(
296
+ ("Human", "Hi!"),
297
+ ("Assistant", "Hi there! How can I help you today?")
298
+ ),
299
+ offset=2,
300
+ sep_style=SeparatorStyle.SINGLE,
301
+ sep="###",
302
+ )
303
+
304
+ conv_llava_v0_mmtag = Conversation(
305
+ system="A chat between a curious user and an artificial intelligence assistant. "
306
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
307
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
308
+ roles=("Human", "Assistant"),
309
+ messages=(
310
+ ),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.SINGLE,
313
+ sep="###",
314
+ version="v0_mmtag",
315
+ )
316
+
317
+ conv_llava_v1 = Conversation(
318
+ system="A chat between a curious human and an artificial intelligence assistant. "
319
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
320
+ roles=("USER", "ASSISTANT"),
321
+ version="v1",
322
+ messages=(),
323
+ offset=0,
324
+ sep_style=SeparatorStyle.TWO,
325
+ sep=" ",
326
+ sep2="</s>",
327
+ )
328
+
329
+ conv_llava_v1_mmtag = Conversation(
330
+ system="A chat between a curious user and an artificial intelligence assistant. "
331
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
332
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
333
+ roles=("USER", "ASSISTANT"),
334
+ messages=(),
335
+ offset=0,
336
+ sep_style=SeparatorStyle.TWO,
337
+ sep=" ",
338
+ sep2="</s>",
339
+ version="v1_mmtag",
340
+ )
341
+
342
+ default_conversation = conv_vicuna_v0
343
+ conv_templates = {
344
+ "default": conv_vicuna_v0,
345
+ "v0": conv_vicuna_v0,
346
+ "v1": conv_vicuna_v1,
347
+ "vicuna_v1": conv_vicuna_v1,
348
+ "llama_2": conv_llama_2,
349
+ "gptk": conv_gptk,
350
+
351
+ "plain": conv_llava_plain,
352
+ "v0_plain": conv_llava_plain,
353
+ "llava_v0": conv_llava_v0,
354
+ "v0_mmtag": conv_llava_v0_mmtag,
355
+ "llava_v1": conv_llava_v1,
356
+ "v1_mmtag": conv_llava_v1_mmtag,
357
+ "llava_llama_2": conv_llava_llama_2,
358
+
359
+ "mpt": conv_mpt,
360
+ }
361
+
362
+
363
+ if __name__ == "__main__":
364
+ print(default_conversation.get_prompt())
examples/diamond_head.jpg ADDED

Git LFS Details

  • SHA256: 33d2f8ebdcde47a8a3cef6af8baa13cbbfc148a25dc869c081f0c4bc4d5522b1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
examples/horseshoe_bend.jpg ADDED

Git LFS Details

  • SHA256: 749c3c49813a870440d101c482ec374c3fe0481a0ac281be062f4610760d75e7
  • Pointer size: 130 Bytes
  • Size of remote file: 41.1 kB
examples/mona_lisa.jpg ADDED

Git LFS Details

  • SHA256: fc9c58de87644926b98da728d809ba9fb9453c93d58c64bff4049f784ea39623
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
examples/mona_lisa_dog.jpg ADDED

Git LFS Details

  • SHA256: 992bfdc88a772a7a273ddd00bb502dbf44ceb9c07ae7b54fc0e537a1c534f41b
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
examples/titanic.jpg ADDED

Git LFS Details

  • SHA256: e730a4a2d3efd7a99d5e120d22000cc51cf81176e32aa677fd2be1ea8dfb4a63
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb05eb3ab8b8e775c1e10ab21a4f8d409b77a47ffacbc606050c2055bd78549a
3
+ size 45
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ n_samples = 148,620; n_clusters = 43,296; noise_ratio = 0.000%
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6281557260322cacbfbe58d710e3dd537e823d6d6565da7c9fea27e30ced5e31
3
+ size 166074480
knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:881bf21972ffb9a9155d185282530a75a4ca4ffdb75c8a05d38dda901c0f366c
3
+ size 1189088
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1efe5c6accd575c85403aaeccaf24c6fb1cfff05bd6a0f1ecdbdbc0ce0a5befa
3
+ size 9093259
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ n_samples = 191,836; n_clusters = 77,073; noise_ratio = 0.000%
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51449f86c49a1651debdaf7ec1b4c1020db911785bd5f51e0766a4bfefe1897f
3
+ size 295832959
knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cf9a479e50595e52e593f961d4f3dcc822d9c0caf097fed3498a64c175f7e2c
3
+ size 1534816
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/faiss.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ece63b94bf3672252b77fbbf47a3070a378280ef3eafb682f99340fc74e1d096
3
+ size 18702475
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ n_samples = 770,808; n_clusters = 325,813; noise_ratio = 0.000%
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/knowledge_db.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dcca3e4560c724f42128b8d476dd28ad0305ad66125213050c7fec7715d6a8b
3
+ size 1251033850
knowledge/(dataset-object)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)/labels.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c79a747b0551e46056391ad988317604dd29a8905acb3167127550dcc6b90890
3
+ size 6166592
knowledge/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .text_db import TextDB
2
+ from .retrieve import *
knowledge/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (254 Bytes). View file
 
knowledge/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (254 Bytes). View file
 
knowledge/__pycache__/cluster.cpython-38.pyc ADDED
Binary file (5.12 kB). View file
 
knowledge/__pycache__/dbscan.cpython-37.pyc ADDED
Binary file (2.29 kB). View file
 
knowledge/__pycache__/dbscan.cpython-38.pyc ADDED
Binary file (2.32 kB). View file
 
knowledge/__pycache__/image_crops_idx.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
knowledge/__pycache__/image_tokens_idx.cpython-38.pyc ADDED
Binary file (7.7 kB). View file
 
knowledge/__pycache__/revive.cpython-38.pyc ADDED
Binary file (2.19 kB). View file
 
knowledge/__pycache__/sentence_db.cpython-37.pyc ADDED
Binary file (6.01 kB). View file
 
knowledge/__pycache__/sentence_db.cpython-38.pyc ADDED
Binary file (6.39 kB). View file
 
knowledge/__pycache__/sentence_idx.cpython-37.pyc ADDED
Binary file (9.12 kB). View file
 
knowledge/__pycache__/sentence_idx.cpython-38.pyc ADDED
Binary file (9.75 kB). View file
 
knowledge/__pycache__/text_db.cpython-38.pyc ADDED
Binary file (7.22 kB). View file
 
knowledge/__pycache__/utils.cpython-37.pyc ADDED
Binary file (3.05 kB). View file
 
knowledge/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.1 kB). View file
 
knowledge/__pycache__/vis_vocab.cpython-37.pyc ADDED
Binary file (8.46 kB). View file
 
knowledge/__pycache__/wordnet.cpython-37.pyc ADDED
Binary file (2.3 kB). View file
 
knowledge/cluster.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import h5py
6
+ import time
7
+
8
+ import faiss
9
+ import torch
10
+ from pytorch_lightning import seed_everything
11
+
12
+ import sys
13
+ sys.path.append('.')
14
+ from knowledge.text_db import TextDB
15
+ from knowledge.utils import nn_search, build_faiss_index, refine_cosine
16
+
17
+
18
+ UNSEEN = -2
19
+ NOISE = -1
20
+
21
+
22
+ def dbscan(X, faiss_index, device, eps=0.1, min_points=1, k=2048, bs=512):
23
+ neighbors = []
24
+ N = (len(X) - 1) // bs + 1
25
+ for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
26
+ Xi = X[i*bs: (i+1)*bs]
27
+ _, I = faiss_index.search(Xi, k*2)
28
+ S, I = refine_cosine(X, Xi, I, device, k)
29
+
30
+ for sim, idx in zip(S, I):
31
+ dist = 1. - sim
32
+ neighbors.append(idx[dist < eps])
33
+
34
+ cluster_id = 0
35
+ n_points = len(X)
36
+ labels = np.array([
37
+ NOISE if len(neighbors[i]) < min_points else UNSEEN
38
+ for i in range(n_points)
39
+ ])
40
+
41
+ with tqdm(total=n_points, dynamic_ncols=True, desc="DBSCAN clustering", mininterval=1.0) as pbar:
42
+ for i in range(n_points):
43
+ if labels[i] == UNSEEN:
44
+ seeds = np.array([i, ])
45
+ labels[seeds] = cluster_id
46
+
47
+ while len(seeds) > 0:
48
+ neighbor_seeds = set()
49
+ for s in seeds:
50
+ n = neighbors[s]
51
+ if len(n) > 0:
52
+ l = np.array(list(set(labels[n])))
53
+ l = l[np.logical_and(l >= 0, l != cluster_id)]
54
+ for li in l:
55
+ labels[labels == li] = cluster_id
56
+
57
+ n = n[labels[n] == UNSEEN]
58
+ neighbor_seeds.update(n)
59
+
60
+ seeds = np.array(list(neighbor_seeds))
61
+ if len(seeds) > 0:
62
+ assert np.all(labels[seeds] == UNSEEN)
63
+ labels[seeds] = cluster_id
64
+
65
+ cluster_id += 1
66
+
67
+ pbar.set_postfix(num_clusters=cluster_id)
68
+ pbar.update()
69
+
70
+ label_set = np.sort(list(set(labels)))
71
+ label_set = label_set[label_set >= 0]
72
+ labels_mapping = {l1: l2 for l2, l1 in enumerate(label_set)}
73
+ labels_mapping[-1] = -1
74
+ labels = np.array([labels_mapping[l] for l in labels])
75
+
76
+ return labels
77
+
78
+
79
+ def extract_clusters(feat, text, labels, faiss_index, device, k=128, bs=8192):
80
+ clusters = {}
81
+ for i, l in enumerate(tqdm(labels, dynamic_ncols=True, desc="Label each samples", mininterval=1.0)):
82
+ if l >= 0:
83
+ try:
84
+ clusters[l]["feat"] += feat[i].astype(np.float64)
85
+ clusters[l]["N"] += 1
86
+ except KeyError:
87
+ clusters[l] = {"feat": feat[i].astype(np.float64), "N": 1}
88
+
89
+ cc = []
90
+ for l in tqdm(list(clusters.keys()), dynamic_ncols=True, desc="Compute cluster centers", mininterval=1.0):
91
+ c = clusters[l]["feat"]/clusters[l]["N"]
92
+ cc.append(c.astype(np.float32))
93
+ cc = np.stack(cc)
94
+ cc /= np.linalg.norm(cc, keepdims=True, axis=-1)
95
+
96
+ idx = []
97
+ N = (len(cc) - 1) // bs + 1
98
+ for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
99
+ cc_i = cc[i*bs: (i+1)*bs]
100
+ _, I = faiss_index.search(cc_i, k)
101
+ _, I = refine_cosine(feat, cc_i, I, device, 1)
102
+ idx.append(I[:, 0])
103
+ idx = np.unique(np.concatenate(idx))
104
+ text = [text[i] for i in idx]
105
+ feat = np.stack([feat[i] for i in idx])
106
+
107
+ return feat, text
108
+
109
+
110
+ if __name__ == "__main__":
111
+ parser = argparse.ArgumentParser(description="Cluster knowledge database using DBSCAN")
112
+ parser.add_argument("--knowledge_db", type=str, required=True)
113
+ parser.add_argument("--seed", type=int, default=12345)
114
+ parser.add_argument("--eps", type=float, default=0.1)
115
+ parser.add_argument("--ms", type=int, default=1)
116
+ parser.add_argument("--ratio", type=float, default=None)
117
+ parser.add_argument("--device", type=int, default=None)
118
+ args = parser.parse_args()
119
+
120
+ # parse exp name
121
+ args.knowledge_db = Path(args.knowledge_db)
122
+ exp_name = args.knowledge_db.parent.name
123
+ exp_name += f"(dbscan)(eps-{args.eps})(ms-{args.ms})"
124
+ save_root = args.knowledge_db.parent.parent/exp_name
125
+ setattr(args, "save_root", save_root)
126
+ args.save_root.mkdir(parents=True, exist_ok=True)
127
+
128
+ args.device = torch.device("cuda", args.device) \
129
+ if args.device is not None else torch.device("cpu")
130
+
131
+ seed_everything(args.seed, workers=True)
132
+ print(args)
133
+
134
+ # load feature, text, and faiss index from knowledge db
135
+ knowledge_db = TextDB(args.knowledge_db)
136
+ feat = knowledge_db.feature.astype(np.float32)
137
+ text = knowledge_db.text
138
+ if args.ratio is not None:
139
+ N = int(len(feat) * args.ratio)
140
+ feat, text = feat[:N], text[:N]
141
+ faiss_index = faiss.read_index(str(args.knowledge_db.parent/"faiss.index"))
142
+ print("Add data to faiss index...", end="\r")
143
+ ts = time.time()
144
+ faiss_index.add(feat)
145
+ print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
146
+
147
+ # DBSCAN clustering
148
+ labels_file = args.save_root/"labels.npy"
149
+ if labels_file.exists():
150
+ labels = np.load(labels_file)
151
+ else:
152
+ labels = dbscan(feat, faiss_index, args.device, args.eps, args.ms)
153
+ with open(labels_file, 'wb') as f:
154
+ np.save(f, labels)
155
+
156
+ # extract clusters
157
+ feat, text = extract_clusters(feat, text, labels, faiss_index, args.device)
158
+ with h5py.File(args.save_root/f"knowledge_db.hdf5", "w") as f:
159
+ bs = 65536
160
+ N = (len(feat) - 1) // bs + 1
161
+ for i in tqdm(range(N), dynamic_ncols=True, desc="Saving clustered DB", mininterval=1.0):
162
+ g = f.create_group(str(i))
163
+ g.create_dataset("feature", data=feat[i*bs: (i+1)*bs], compression="gzip")
164
+ g.create_dataset("text", data=text[i*bs: (i+1)*bs], compression="gzip")
165
+
166
+ # build faiss index for the clustered DB
167
+ index = build_faiss_index(feat, gpus=[args.device.index, ])
168
+ faiss.write_index(index, str(args.save_root/"faiss.index"))
169
+
170
+ # some stats
171
+ noise_ratio = np.sum(labels == -1) / len(labels)
172
+ n_clusters, n_samples = len(text), len(labels)
173
+ msg = f"n_samples = {n_samples:,}; n_clusters = {n_clusters:,}; noise_ratio = {noise_ratio*100:.3f}%\n"
174
+ with open(save_root/"info.txt", "w") as f:
175
+ f.write(msg)
176
+ print(msg)
177
+
178
+
knowledge/retrieve.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import h5py
4
+ import time
5
+ import shutil
6
+ import numpy as np
7
+ import subprocess
8
+ import time
9
+ from tqdm import tqdm
10
+
11
+ import faiss
12
+ import open_clip
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.utils.data import DataLoader
16
+ from pytorch_lightning import callbacks
17
+ from pytorch_lightning import Trainer, LightningModule, seed_everything
18
+
19
+ import sys
20
+ sys.path.append('.')
21
+ from dataset import coco, cc, llava
22
+ from knowledge.utils import refine_cosine
23
+ from knowledge import text_db
24
+ from knowledge import TextDB
25
+ from train.utils import ExpName
26
+
27
+
28
+ class ImageCropsIdx:
29
+ def __init__(self, knowledge_idx, topk_w, topk_f, topk_n):
30
+ topk = {"whole": topk_w, "five": topk_f, "nine": topk_n}
31
+ self.topk = {k: v for k, v in topk.items() if v > 0}
32
+
33
+ self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk)
34
+
35
+ def load(self, knowledge_idx, topk):
36
+ with h5py.File(knowledge_idx, "r") as f:
37
+ fdim = f.attrs["fdim"]
38
+ file_hash = f.attrs["file_hash"]
39
+
40
+ knowledge_idx_ = {}
41
+ for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0):
42
+ knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]}
43
+ for k, v in topk.items():
44
+ knowledge_idx_[str(i)][k] = {
45
+ "index": f[f"{i}/{k}/index"][:, :, :v],
46
+ "score": f[f"{i}/{k}/score"][:, :, :v],
47
+ "query": f[f"{i}/{k}/query"][:]
48
+ }
49
+
50
+ knowledge_idx = {}
51
+ for i in knowledge_idx_.keys():
52
+ for j, id in enumerate(knowledge_idx_[i]["image_ids"]):
53
+ knowledge_idx[id] = {}
54
+ for k in topk.keys():
55
+ knowledge_idx[id][k] = {
56
+ "index": knowledge_idx_[i][k]["index"][j],
57
+ "score": knowledge_idx_[i][k]["score"][j],
58
+ "query": knowledge_idx_[i][k]["query"][j],
59
+ }
60
+
61
+ return knowledge_idx, fdim, file_hash
62
+
63
+ def __getitem__(self, image_id):
64
+ return self.knowledge_idx[image_id]
65
+
66
+
67
+ class KnowAugImageCrops:
68
+ def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False):
69
+ self.knowledge_db = knowledge_db
70
+ self.knowledge_idx = knowledge_idx
71
+ assert knowledge_db.file_hash == knowledge_idx.file_hash
72
+
73
+ self.ncrop = {"whole": 1, "five": 5, "nine": 9}
74
+ self.topk = knowledge_idx.topk
75
+ self.fdim = knowledge_idx.fdim
76
+
77
+ self.return_txt = return_txt
78
+
79
+ def __call__(self, image_id):
80
+ ret = {}
81
+ for k in self.topk.keys():
82
+ ki = self.knowledge_idx[image_id][k]["index"].flatten()
83
+ ke, kt = self.knowledge_db[ki]
84
+ kq = self.knowledge_idx[image_id][k]["query"]
85
+ kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten()
86
+ ks = self.knowledge_idx[image_id][k]["score"].flatten()
87
+
88
+ ke = torch.FloatTensor(ke)
89
+ kq = torch.FloatTensor(kq)
90
+ kp = torch.LongTensor(kp)
91
+ ks = torch.FloatTensor(ks)
92
+
93
+ ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks}
94
+ if self.return_txt:
95
+ ret[k]["text"] = kt
96
+
97
+ return ret
98
+
99
+
100
+ class KnowAugImageCropsCombined:
101
+ def __init__(
102
+ self,
103
+ knwl_aug_obj: KnowAugImageCrops,
104
+ knwl_aug_attr: KnowAugImageCrops,
105
+ knwl_aug_act: KnowAugImageCrops
106
+ ):
107
+ self.knwl_aug_obj = knwl_aug_obj
108
+ self.knwl_aug_act = knwl_aug_act
109
+ self.knwl_aug_attr = knwl_aug_attr
110
+ self.fdim = knwl_aug_obj.fdim
111
+
112
+ def __call__(self, image_id):
113
+ knwl_obj = self.knwl_aug_obj(image_id)
114
+ knwl_attr = self.knwl_aug_attr(image_id)
115
+ knwl_act = self.knwl_aug_act(image_id)
116
+
117
+ ret = {}
118
+ for k in knwl_obj.keys():
119
+ ret[k] = {
120
+ "obj": knwl_obj[k],
121
+ "attr": knwl_attr[k],
122
+ "act": knwl_act[k]
123
+ }
124
+
125
+ return ret
126
+
127
+
128
+ class ImageCropsIdxBuilder(LightningModule):
129
+ def __init__(self, args, model: open_clip.model.CLIP):
130
+ super().__init__()
131
+
132
+ self.args = args
133
+ self.save_root = args.save_root
134
+ self.k = args.k
135
+ self.model = model
136
+
137
+ def on_validation_epoch_start(self):
138
+ if self.global_rank == 0:
139
+ knowledge_db = TextDB(self.args.knowledge_db)
140
+ self.feature = knowledge_db.feature
141
+ self.text = knowledge_db.text
142
+
143
+ self.faiss_index = faiss.read_index(
144
+ str(Path(self.args.knowledge_db).parent/"faiss.index")
145
+ )
146
+ print("\nAdd data to faiss index...", end="\r")
147
+ ts = time.time()
148
+ self.faiss_index.add(self.feature)
149
+ print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
150
+
151
+ with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
152
+ f.attrs["fdim"] = self.feature.shape[-1]
153
+ f.attrs["file_hash"] = knowledge_db.file_hash
154
+
155
+ self.trainer.strategy.barrier()
156
+
157
+ def all_gather_object(self, data):
158
+ if self.trainer.world_size > 1:
159
+ gathered = [None for _ in range(self.trainer.world_size)]
160
+ dist.all_gather_object(gathered, data)
161
+ data = gathered
162
+ else:
163
+ data = [data, ]
164
+
165
+ return data
166
+
167
+ def broadcast_object(self, data, src_rank=0):
168
+ if self.trainer.world_size > 1:
169
+ if self.global_rank == src_rank:
170
+ data_list = [data, ] * self.trainer.world_size
171
+ else:
172
+ data_list = [None, ] * self.trainer.world_size
173
+
174
+ dist.broadcast_object_list(data_list, src=src_rank)
175
+ return data_list[0]
176
+ else:
177
+ return data
178
+
179
+ def search(self, images, topk):
180
+ query = self.model.encode_image(images, normalize=True)
181
+ query = query.cpu().numpy()
182
+ query = self.all_gather_object(query)
183
+ query = np.concatenate(query)
184
+
185
+ if self.global_rank == 0:
186
+ _, I = self.faiss_index.search(query, 4*topk)
187
+ S, I = refine_cosine(self.feature, query, I, self.device, topk)
188
+ else:
189
+ S = I = None
190
+
191
+ return S, I, query
192
+
193
+ def validation_step(self, batch, batch_idx):
194
+ orig_imgs, five_imgs, nine_imgs, ids = batch
195
+
196
+ ids = ids.cpu().numpy()
197
+ ids = np.concatenate(self.all_gather_object(ids))
198
+
199
+ S_w, I_w, Q_w = self.search(orig_imgs, topk=self.k)
200
+
201
+ S_f, I_f, Q_f = [], [], []
202
+ for i in range(five_imgs.shape[1]):
203
+ Si, Ii, Qi = self.search(five_imgs[:, i], topk=self.k)
204
+ S_f.append(Si)
205
+ I_f.append(Ii)
206
+ Q_f.append(Qi)
207
+
208
+ S_n, I_n, Q_n = [], [], []
209
+ for i in range(nine_imgs.shape[1]):
210
+ Si, Ii, Qi = self.search(nine_imgs[:, i], topk=self.k)
211
+ S_n.append(Si)
212
+ I_n.append(Ii)
213
+ Q_n.append(Qi)
214
+
215
+ if self.global_rank == 0:
216
+ S_w, I_w, Q_w = np.expand_dims(S_w, axis=1), np.expand_dims(I_w, axis=1), np.expand_dims(Q_w, axis=1)
217
+ S_f, I_f, Q_f = np.stack(S_f, axis=1), np.stack(I_f, axis=1), np.stack(Q_f, axis=1)
218
+ S_n, I_n, Q_n = np.stack(S_n, axis=1), np.stack(I_n, axis=1), np.stack(Q_n, axis=1)
219
+
220
+ with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
221
+ g = f.create_group(str(batch_idx))
222
+
223
+ g.create_dataset("image_ids", data=ids.astype(np.int32), compression="gzip")
224
+
225
+ gw = g.create_group("whole")
226
+ gw.create_dataset("index", data=I_w.astype(np.int32), compression="gzip")
227
+ gw.create_dataset("score", data=S_w.astype(np.float32), compression="gzip")
228
+ gw.create_dataset("query", data=Q_w.astype(np.float32), compression="gzip")
229
+
230
+ gf = g.create_group("five")
231
+ gf.create_dataset("index", data=I_f.astype(np.int32), compression="gzip")
232
+ gf.create_dataset("score", data=S_f.astype(np.float32), compression="gzip")
233
+ gf.create_dataset("query", data=Q_f.astype(np.float32), compression="gzip")
234
+
235
+ gn = g.create_group("nine")
236
+ gn.create_dataset("index", data=I_n.astype(np.int32), compression="gzip")
237
+ gn.create_dataset("score", data=S_n.astype(np.float32), compression="gzip")
238
+ gn.create_dataset("query", data=Q_n.astype(np.float32), compression="gzip")
239
+
240
+ def on_validation_epoch_end(self):
241
+ if self.args.azcopy and self.global_rank == 0:
242
+ with open("azcopy/sas_output", "r") as f:
243
+ sas = f.readline()
244
+ sas_base, sas_key = sas.split("?")
245
+ sas = f"{sas_base}/knowledge_idx?{sas_key}"
246
+
247
+ cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
248
+ print(f"start copying data with command {cmd}")
249
+ ts = time.time()
250
+ subprocess.run(cmd)
251
+ print(f"done copying data in {time.time() - ts:.2f} secs")
252
+
253
+
254
+ def main(args):
255
+ model, _, trans_img = open_clip.create_model_and_transforms(
256
+ args.clip_model, pretrained=text_db.CLIP_MODELS[args.clip_model]
257
+ )
258
+
259
+ print("load query dataset...")
260
+ if "coco" in args.query:
261
+ dset = coco.COCOImageCrops(Path(f"data/{args.query}"), trans=trans_img)
262
+ collate_crops = coco.collate_coco_crops
263
+ elif args.query == "cc3m":
264
+ dset = cc.CC3MImageCrops(Path("data/cc3m_instruct"), trans=trans_img)
265
+ collate_crops = cc.collate_cc_crops
266
+ elif args.query == "llava":
267
+ dset = llava.LLaVAImageCrops(Path("data/llava_bench"), trans=trans_img)
268
+ collate_crops = llava.collate_llava_crops
269
+ else:
270
+ raise ValueError
271
+ loader = DataLoader(
272
+ dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
273
+ drop_last=False, collate_fn=collate_crops
274
+ )
275
+
276
+ print("build model and trainer...")
277
+ pl_model = ImageCropsIdxBuilder(args, model)
278
+ model_summary = callbacks.RichModelSummary()
279
+ progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
280
+ trainer_callbacks = [model_summary, progress_bar]
281
+ trainer = Trainer(
282
+ sync_batchnorm=True,
283
+ precision=16,
284
+ accelerator='gpu',
285
+ devices=args.devices,
286
+ strategy="ddp",
287
+ default_root_dir=args.save_root,
288
+ callbacks=trainer_callbacks,
289
+ limit_val_batches=args.limit_val_batches
290
+ )
291
+
292
+ print("retrieve knowledge...")
293
+ trainer.validate(pl_model, dataloaders=loader)
294
+
295
+
296
+ if __name__ == "__main__":
297
+ parser = argparse.ArgumentParser(description='Knowledge retrieval using image crops')
298
+ parser = Trainer.add_argparse_args(parser)
299
+ parser.add_argument('--query', type=str, choices=["coco14", "coco17", "cc3m", "llava"], required=True)
300
+ parser.add_argument('--knowledge_db', type=str, required=True)
301
+ parser.add_argument('--k', type=int, default=128)
302
+ parser.add_argument("--bs", type=int, default=128)
303
+ parser.add_argument("--num_workers", type=int, default=7)
304
+ parser.add_argument("--seed", type=int, default=12345)
305
+ parser.add_argument("--refresh_rate", type=int, default=1)
306
+ parser.add_argument("--azcopy", action="store_true")
307
+ args = parser.parse_args()
308
+
309
+ # parse exp_name
310
+ exp_name = ExpName(f"(query-{args.query})")
311
+ exp_name += Path(args.knowledge_db).parent.name
312
+ if args.azcopy:
313
+ setattr(args, "save_root", Path("azcopy")/str(exp_name))
314
+ else:
315
+ setattr(args, "save_root", Path("output")/"knowledge_idx"/str(exp_name))
316
+ shutil.rmtree(args.save_root, ignore_errors=True)
317
+ args.save_root.mkdir(parents=True, exist_ok=True)
318
+
319
+ # parse model
320
+ model = exp_name.get("clip-model")[1:-1]
321
+ model = model[len("clip-model-"):]
322
+ assert model in text_db.CLIP_MODELS.keys()
323
+ setattr(args, "clip_model", model)
324
+
325
+ print(args)
326
+ seed_everything(args.seed, workers=True)
327
+ main(args)
knowledge/text_db.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ from pathlib import Path
4
+ import shutil
5
+ import h5py
6
+ import time
7
+ import subprocess
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import codecs
11
+
12
+ import open_clip
13
+ import faiss
14
+ import torch
15
+ import torch.distributed as dist
16
+ from torch.utils.data import DataLoader
17
+ from pytorch_lightning import callbacks
18
+ from pytorch_lightning import Trainer, LightningModule, seed_everything
19
+
20
+ import sys
21
+ sys.path.append("./")
22
+ from dataset import cc, words
23
+ from knowledge.utils import file_hash, build_faiss_index
24
+
25
+
26
+ class TextDB:
27
+ def __init__(self, text_db):
28
+ self.feature, self.text = self.load(text_db)
29
+ self.file_hash = file_hash(text_db)
30
+
31
+ def load(self, text_db):
32
+ with h5py.File(text_db, 'r') as f:
33
+ db_size = 0
34
+ for i in range(len(f)):
35
+ db_size += len(f[f"{i}/feature"])
36
+ _, d = f[f"0/feature"].shape
37
+
38
+ with h5py.File(text_db, 'r') as f:
39
+ feature = np.zeros((db_size, d), dtype=np.float32)
40
+ text = []
41
+ N = 0
42
+ for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
43
+ fi = f[f"{i}/feature"][:]
44
+ feature[N:N+len(fi)] = fi
45
+ N += len(fi)
46
+
47
+ text.extend(f[f"{i}/text"][:])
48
+ text = [codecs.decode(t) for t in text]
49
+
50
+ return feature, text
51
+
52
+ def __getitem__(self, idx):
53
+ f = self.feature[idx]
54
+
55
+ try:
56
+ t = [self.text[i] for i in idx]
57
+ except TypeError:
58
+ t = self.text[idx]
59
+
60
+ return f, t
61
+
62
+
63
+ class TextDBBuilder(LightningModule):
64
+ def __init__(self, args, model: open_clip.model.CLIP):
65
+ super().__init__()
66
+ self.args = args
67
+ self.model = model
68
+
69
+ def validation_step(self, batch, batch_idx):
70
+ token, text = batch
71
+ feat = self.model.encode_text(token, normalize=True)
72
+
73
+ if self.trainer.world_size > 1:
74
+ text_gathered = [None for _ in range(self.trainer.world_size)]
75
+ dist.all_gather_object(text_gathered, text)
76
+ text = list(itertools.chain.from_iterable(text_gathered))
77
+
78
+ feat_gathered = [None for _ in range(self.trainer.world_size)]
79
+ dist.all_gather_object(feat_gathered, feat)
80
+ feat = torch.cat([x.to(self.device) for x in feat_gathered])
81
+ feat = feat.cpu().numpy()
82
+
83
+ if self.global_rank == 0:
84
+ with h5py.File(self.args.save_root/"knowledge_db.hdf5", "a") as f:
85
+ g = f.create_group(str(batch_idx))
86
+ g.create_dataset("feature", data=feat, compression="gzip")
87
+ g.create_dataset("text", data=text, compression="gzip")
88
+
89
+ def validation_epoch_end(self, outputs):
90
+ if self.global_rank == 0:
91
+ knowledge_db = TextDB(self.args.save_root/"knowledge_db.hdf5")
92
+ feat = knowledge_db.feature
93
+
94
+ if self.args.devices == "-1":
95
+ num_devices = torch.cuda.device_count()
96
+ devices = list(range(num_devices))
97
+ else:
98
+ devices = [int(x) for x in args.devices.split(",") if x]
99
+ print(f"CUDA devices: {devices}")
100
+
101
+ index = build_faiss_index(feat, gpus=devices)
102
+ faiss.write_index(index, str(self.args.save_root/"faiss.index"))
103
+ self.trainer.strategy.barrier()
104
+
105
+ if self.args.azcopy and self.global_rank == 0:
106
+ with open("azcopy/sas_output", "r") as f:
107
+ sas = f.readline()
108
+ sas_base, sas_key = sas.split("?")
109
+ sas = f"{sas_base}/knowledge_db?{sas_key}"
110
+
111
+ cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
112
+ print(f"start copying data with command {cmd}")
113
+ ts = time.time()
114
+ subprocess.run(cmd)
115
+ print(f"done copying data in {time.time() - ts:.2f} secs")
116
+ self.trainer.strategy.barrier()
117
+
118
+
119
+ DATASETS = {
120
+ "object": words.ObjsDataset,
121
+ "attribute": words.AttrsDataset,
122
+ "action": words.ActsDataset,
123
+ "cc3m": cc.CC3MTextDataset,
124
+ "cc12m": cc.CC12MTextDataset
125
+ }
126
+
127
+
128
+ def main(args):
129
+ model, _, _ = open_clip.create_model_and_transforms(
130
+ args.clip_model, pretrained=CLIP_MODELS[args.clip_model]
131
+ )
132
+ trans_txt = open_clip.get_tokenizer(args.clip_model)
133
+
134
+ print("load dataset...")
135
+ dset = DATASETS[args.dataset](Path(args.data_root), trans_txt)
136
+ loader = DataLoader(
137
+ dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
138
+ drop_last=False, collate_fn=cc.collate_cc_txt
139
+ )
140
+
141
+ print("build model and trainer...")
142
+ pl_model = TextDBBuilder(args, model)
143
+ model_summary = callbacks.RichModelSummary()
144
+ progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
145
+ trainer_callbacks = [model_summary, progress_bar]
146
+ trainer = Trainer(
147
+ sync_batchnorm=True,
148
+ precision=16,
149
+ accelerator='gpu',
150
+ devices=args.devices,
151
+ strategy="ddp",
152
+ default_root_dir=args.save_root,
153
+ callbacks=trainer_callbacks,
154
+ limit_val_batches=args.limit_val_batches
155
+ )
156
+
157
+ print("compute textual features...")
158
+ trainer.validate(pl_model, dataloaders=loader)
159
+
160
+
161
+ CLIP_MODELS = {
162
+ 'ViT-B-32': 'openai',
163
+ 'ViT-B-16': 'openai',
164
+ 'ViT-L-14': 'openai',
165
+ 'ViT-g-14': 'laion2b_s34b_b88k',
166
+ 'ViT-bigG-14': 'laion2b_s39b_b160k',
167
+ 'convnext_xxlarge': 'laion2b_s34b_b82k_augreg_soup',
168
+ }
169
+
170
+
171
+ if __name__ == "__main__":
172
+ parser = argparse.ArgumentParser(description="Build knowledge database of words")
173
+ parser = Trainer.add_argparse_args(parser)
174
+ parser.add_argument(
175
+ "--dataset", type=str, required=True, choices=["object", "attribute", "action", "cc3m", "cc12m"]
176
+ )
177
+ parser.add_argument("--data_root", type=str, default="data/conceptnet/conceptnet-assertions-5.7.0.csv")
178
+ parser.add_argument("--clip_model", type=str, default="ViT-g-14", choices=CLIP_MODELS.keys())
179
+ parser.add_argument("--bs", type=int, default=2**10)
180
+ parser.add_argument("--num_workers", type=int, default=7)
181
+ parser.add_argument("--seed", type=int, default=12345)
182
+ parser.add_argument("--refresh_rate", type=int, default=1)
183
+ parser.add_argument("--azcopy", action="store_true")
184
+ args = parser.parse_args()
185
+
186
+ # feature dir
187
+ exp_name = f"(dataset-{args.dataset})(clip-model-{args.clip_model})"
188
+ if args.azcopy:
189
+ setattr(args, "save_root", Path("azcopy")/"knowledge_db"/exp_name)
190
+ else:
191
+ setattr(args, "save_root", Path("output")/"knowledge_db"/exp_name)
192
+ shutil.rmtree(args.save_root, ignore_errors=True)
193
+ args.save_root.mkdir(parents=True, exist_ok=True)
194
+
195
+ print(args)
196
+ seed_everything(args.seed, workers=True)
197
+ main(args)
knowledge/transforms.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from torchvision.transforms import functional as F
3
+ import re
4
+
5
+
6
+ def five_crop(image, ratio=0.6):
7
+ w, h = image.size
8
+ hw = (h*ratio, w*ratio)
9
+
10
+ return F.five_crop(image, hw)
11
+
12
+ def nine_crop(image, ratio=0.4):
13
+ w, h = image.size
14
+
15
+ t = (0, int((0.5-ratio/2)*h), int((1.0 - ratio)*h))
16
+ b = (int(ratio*h), int((0.5+ratio/2)*h), h)
17
+ l = (0, int((0.5-ratio/2)*w), int((1.0 - ratio)*w))
18
+ r = (int(ratio*w), int((0.5+ratio/2)*w), w)
19
+ h, w = list(zip(t, b)), list(zip(l, r))
20
+
21
+ images = []
22
+ for s in itertools.product(h, w):
23
+ h, w = s
24
+ top, left = h[0], w[0]
25
+ height, width = h[1]-h[0], w[1]-w[0]
26
+ images.append(F.crop(image, top, left, height, width))
27
+
28
+ return images
29
+
30
+
31
+ def pre_caption(caption, max_words=None):
32
+ # Ref: https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py#L49-L68
33
+ caption = re.sub(
34
+ r"([.!\"()*#:;~])",
35
+ " ",
36
+ caption.lower(),
37
+ )
38
+ caption = re.sub(
39
+ r"\s{2,}",
40
+ " ",
41
+ caption,
42
+ )
43
+ caption = caption.rstrip("\n")
44
+ caption = caption.strip(" ")
45
+
46
+ # truncate caption
47
+ caption_words = caption.split(" ")
48
+ if max_words is not None and len(caption_words) > max_words:
49
+ caption = " ".join(caption_words[: max_words])
50
+
51
+ return caption
52
+
knowledge/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import numpy as np
3
+ import time
4
+ import math
5
+ import bisect
6
+ import hashlib
7
+ import faiss
8
+ from faiss import StandardGpuResources, index_cpu_to_gpu_multiple_py
9
+ import torch
10
+
11
+
12
+ def file_hash(file):
13
+ # Ref: https://stackoverflow.com/a/59056837
14
+ with open(file, "rb") as f:
15
+ hash_fn = hashlib.blake2b()
16
+ chunk = f.read(8192)
17
+ while chunk:
18
+ hash_fn.update(chunk)
19
+ chunk = f.read(8192)
20
+
21
+ return hash_fn.hexdigest()
22
+
23
+
24
+ def build_faiss_index(x, gpus=None):
25
+ # Ref: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
26
+ # Ref: https://gist.github.com/mdouze/46d6bbbaabca0b9778fca37ed2bcccf6
27
+
28
+ N, dim = x.shape
29
+ secs = [2**i for i in range(1, 15)]
30
+ d = secs[bisect.bisect_right(secs, dim) - 1] // 2
31
+ m = d // 4
32
+
33
+ if N <= 60000:
34
+ index_factory = "Flat"
35
+ elif N <= 2555904:
36
+ index_factory = f"IVF{int(8*math.sqrt(N))},Flat"
37
+ elif N <= 10223616:
38
+ index_factory = f"OPQ{m}_{d},IVF65536_HNSW32,PQ{m}x4fsr"
39
+ elif N <= 1e8:
40
+ index_factory = f"OPQ{m}_{d},IVF262144_HNSW32,PQ{m}x4fsr"
41
+ else:
42
+ index_factory = f"OPQ{m}_{d},IVF1048576_HNSW32,PQ{m}x4fsr"
43
+ print(f"train {index_factory} index on {N:,} x {dim} data")
44
+
45
+ index = faiss.index_factory(dim, index_factory)
46
+ if gpus is not None and N > 60000:
47
+ index_ivf = faiss.extract_index_ivf(index)
48
+ res = []
49
+ for _ in gpus:
50
+ r = StandardGpuResources()
51
+ r.noTempMemory()
52
+ res.append(r)
53
+ clustering_index = index_cpu_to_gpu_multiple_py(
54
+ res, faiss.IndexFlatL2(index_ivf.d), None, gpus
55
+ )
56
+ index_ivf.clustering_index = clustering_index
57
+
58
+ print("train index...", end="\r")
59
+ ts = time.time()
60
+ # commented out for index_factory = "Flat"
61
+ # assert not index.is_trained
62
+ index.train(x)
63
+ assert index.is_trained
64
+ print(f"train index...done in {time.time() - ts:.2f} secs")
65
+
66
+ index.nprobe = 64
67
+ index.quantizer_efSearch = 32
68
+
69
+ return index
70
+
71
+
72
+ def nn_search(query, index, topk, bs=256, desc=None, disable_tqdm=True):
73
+ idx, dist = [], []
74
+ N = (len(query) - 1) // bs + 1
75
+ for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
76
+ D, I = index.search(query[i*bs: (i+1)*bs], topk)
77
+ idx.append(I)
78
+ dist.append(D)
79
+ idx = np.concatenate(idx)
80
+ dist = np.concatenate(dist)
81
+
82
+ return idx, dist
83
+
84
+
85
+ def radius_search(query, index, r, bs=256, desc=None, disable_tqdm=True):
86
+ idx, dist = [], []
87
+ N = (len(query) - 1) // bs + 1
88
+ for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
89
+ L, D, I = index.range_search(query[i*bs: (i+1)*bs], r)
90
+ idx.extend([I[L[j]:L[j+1]] for j in range(len(L)-1)])
91
+ dist.extend([D[L[j]:L[j+1]] for j in range(len(L)-1)])
92
+
93
+ return idx, dist
94
+
95
+
96
+ @torch.no_grad()
97
+ def refine_cosine(Xa, Xq, I, device, k=None):
98
+ if k is not None:
99
+ assert k <= I.shape[1]
100
+ else:
101
+ k = I.shape[1]
102
+
103
+ Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d
104
+ Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d
105
+ Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d
106
+
107
+ sim = torch.sum(Xi * Xj, dim=-1) # bs x K
108
+ sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy()
109
+ I_refined, S_refined = [], []
110
+ for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx):
111
+ I_refined.append(idx_i[sort_i][:k])
112
+ S_refined.append(sim_i[sort_i][:k])
113
+ I_refined = np.stack(I_refined)
114
+ S_refined = np.stack(S_refined)
115
+
116
+ return S_refined, I_refined
117
+
118
+
119
+ def test_nn_search():
120
+ key = np.random.random((3000000, 512)).astype(np.float32)
121
+ key /= np.linalg.norm(key, keepdims=True, axis=1)
122
+ index = build_faiss_index(key, -1)
123
+
124
+ query = np.random.random((100000, 512)).astype(np.float32)
125
+ query /= np.linalg.norm(query, keepdims=True, axis=1)
126
+ idx_r = nn_search(query, index, r=0.5)
127
+ idx_k = nn_search(query, index, topk=10)
model/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gptk import get_gptk_model, get_gptk_image_transform
model/ckpt/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fab39af071b1e303f5976936a8662f75eb04952e03fa71bcb93291948892d2fd
3
+ size 31462530292
model/eva_vit.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ import sys
18
+ sys.path.append("./")
19
+ from model.utils import download_cached_file
20
+
21
+
22
+ class DropPath(nn.Module):
23
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
24
+ """
25
+ def __init__(self, drop_prob=None):
26
+ super(DropPath, self).__init__()
27
+ self.drop_prob = drop_prob
28
+
29
+ def forward(self, x):
30
+ return drop_path(x, self.drop_prob, self.training)
31
+
32
+ def extra_repr(self) -> str:
33
+ return 'p={}'.format(self.drop_prob)
34
+
35
+
36
+ class Mlp(nn.Module):
37
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
38
+ super().__init__()
39
+ out_features = out_features or in_features
40
+ hidden_features = hidden_features or in_features
41
+ self.fc1 = nn.Linear(in_features, hidden_features)
42
+ self.act = act_layer()
43
+ self.fc2 = nn.Linear(hidden_features, out_features)
44
+ self.drop = nn.Dropout(drop)
45
+
46
+ def forward(self, x):
47
+ x = self.fc1(x)
48
+ x = self.act(x)
49
+ # x = self.drop(x)
50
+ # commit this for the orignal BERT implement
51
+ x = self.fc2(x)
52
+ x = self.drop(x)
53
+ return x
54
+
55
+
56
+ class Attention(nn.Module):
57
+ def __init__(
58
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
59
+ proj_drop=0., window_size=None, attn_head_dim=None):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ if attn_head_dim is not None:
64
+ head_dim = attn_head_dim
65
+ all_head_dim = head_dim * self.num_heads
66
+ self.scale = qk_scale or head_dim ** -0.5
67
+
68
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
69
+ if qkv_bias:
70
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
71
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
72
+ else:
73
+ self.q_bias = None
74
+ self.v_bias = None
75
+
76
+ if window_size:
77
+ self.window_size = window_size
78
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
79
+ self.relative_position_bias_table = nn.Parameter(
80
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
81
+ # cls to token & token 2 cls & cls to cls
82
+
83
+ # get pair-wise relative position index for each token inside the window
84
+ coords_h = torch.arange(window_size[0])
85
+ coords_w = torch.arange(window_size[1])
86
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
87
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
88
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
89
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
90
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
91
+ relative_coords[:, :, 1] += window_size[1] - 1
92
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
93
+ relative_position_index = \
94
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
95
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
96
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
97
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
98
+ relative_position_index[0, 0] = self.num_relative_distance - 1
99
+
100
+ self.register_buffer("relative_position_index", relative_position_index)
101
+ else:
102
+ self.window_size = None
103
+ self.relative_position_bias_table = None
104
+ self.relative_position_index = None
105
+
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(all_head_dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x, rel_pos_bias=None):
111
+ B, N, C = x.shape
112
+ qkv_bias = None
113
+ if self.q_bias is not None:
114
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
115
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
116
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
117
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
118
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
119
+
120
+ q = q * self.scale
121
+ attn = (q @ k.transpose(-2, -1))
122
+
123
+ if self.relative_position_bias_table is not None:
124
+ relative_position_bias = \
125
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
126
+ self.window_size[0] * self.window_size[1] + 1,
127
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129
+ attn = attn + relative_position_bias.unsqueeze(0)
130
+
131
+ if rel_pos_bias is not None:
132
+ attn = attn + rel_pos_bias
133
+
134
+ attn = attn.softmax(dim=-1)
135
+ attn = self.attn_drop(attn)
136
+
137
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+ return x
141
+
142
+
143
+ class Block(nn.Module):
144
+
145
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
146
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
147
+ window_size=None, attn_head_dim=None):
148
+ super().__init__()
149
+ self.norm1 = norm_layer(dim)
150
+ self.attn = Attention(
151
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
152
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
153
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+ self.norm2 = norm_layer(dim)
156
+ mlp_hidden_dim = int(dim * mlp_ratio)
157
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
158
+
159
+ if init_values is not None and init_values > 0:
160
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
161
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
162
+ else:
163
+ self.gamma_1, self.gamma_2 = None, None
164
+
165
+ def forward(self, x, rel_pos_bias=None):
166
+ if self.gamma_1 is None:
167
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
168
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
169
+ else:
170
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
171
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
172
+ return x
173
+
174
+
175
+ class PatchEmbed(nn.Module):
176
+ """ Image to Patch Embedding
177
+ """
178
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
179
+ super().__init__()
180
+ img_size = to_2tuple(img_size)
181
+ patch_size = to_2tuple(patch_size)
182
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
183
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
184
+ self.img_size = img_size
185
+ self.patch_size = patch_size
186
+ self.num_patches = num_patches
187
+
188
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
189
+
190
+ def forward(self, x, **kwargs):
191
+ B, C, H, W = x.shape
192
+ # FIXME look at relaxing size constraints
193
+ assert H == self.img_size[0] and W == self.img_size[1], \
194
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
195
+ x = self.proj(x).flatten(2).transpose(1, 2)
196
+ return x
197
+
198
+
199
+ class RelativePositionBias(nn.Module):
200
+
201
+ def __init__(self, window_size, num_heads):
202
+ super().__init__()
203
+ self.window_size = window_size
204
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
205
+ self.relative_position_bias_table = nn.Parameter(
206
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
207
+ # cls to token & token 2 cls & cls to cls
208
+
209
+ # get pair-wise relative position index for each token inside the window
210
+ coords_h = torch.arange(window_size[0])
211
+ coords_w = torch.arange(window_size[1])
212
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
213
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
214
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
215
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
216
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
217
+ relative_coords[:, :, 1] += window_size[1] - 1
218
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
219
+ relative_position_index = \
220
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
221
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
222
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
223
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
224
+ relative_position_index[0, 0] = self.num_relative_distance - 1
225
+
226
+ self.register_buffer("relative_position_index", relative_position_index)
227
+
228
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
229
+
230
+ def forward(self):
231
+ relative_position_bias = \
232
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
233
+ self.window_size[0] * self.window_size[1] + 1,
234
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
235
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
236
+
237
+
238
+ class VisionTransformer(nn.Module):
239
+ """ Vision Transformer with support for patch or hybrid CNN input stage
240
+ """
241
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
242
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
243
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
244
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
245
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
246
+ super().__init__()
247
+ self.image_size = img_size
248
+ self.num_classes = num_classes
249
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
250
+
251
+ self.patch_embed = PatchEmbed(
252
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
253
+ num_patches = self.patch_embed.num_patches
254
+
255
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
256
+ if use_abs_pos_emb:
257
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
+ else:
259
+ self.pos_embed = None
260
+ self.pos_drop = nn.Dropout(p=drop_rate)
261
+
262
+ if use_shared_rel_pos_bias:
263
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
264
+ else:
265
+ self.rel_pos_bias = None
266
+ self.use_checkpoint = use_checkpoint
267
+
268
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
269
+ self.use_rel_pos_bias = use_rel_pos_bias
270
+ self.blocks = nn.ModuleList([
271
+ Block(
272
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
274
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
275
+ for i in range(depth)])
276
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
277
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
278
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
279
+
280
+ if self.pos_embed is not None:
281
+ trunc_normal_(self.pos_embed, std=.02)
282
+ trunc_normal_(self.cls_token, std=.02)
283
+ # trunc_normal_(self.mask_token, std=.02)
284
+ # if isinstance(self.head, nn.Linear):
285
+ # trunc_normal_(self.head.weight, std=.02)
286
+ self.apply(self._init_weights)
287
+ self.fix_init_weight()
288
+ # if isinstance(self.head, nn.Linear):
289
+ # self.head.weight.data.mul_(init_scale)
290
+ # self.head.bias.data.mul_(init_scale)
291
+
292
+ def fix_init_weight(self):
293
+ def rescale(param, layer_id):
294
+ param.div_(math.sqrt(2.0 * layer_id))
295
+
296
+ for layer_id, layer in enumerate(self.blocks):
297
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
298
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
299
+
300
+ def _init_weights(self, m):
301
+ if isinstance(m, nn.Linear):
302
+ trunc_normal_(m.weight, std=.02)
303
+ if isinstance(m, nn.Linear) and m.bias is not None:
304
+ nn.init.constant_(m.bias, 0)
305
+ elif isinstance(m, nn.LayerNorm):
306
+ nn.init.constant_(m.bias, 0)
307
+ nn.init.constant_(m.weight, 1.0)
308
+
309
+ def get_classifier(self):
310
+ return self.head
311
+
312
+ def reset_classifier(self, num_classes, global_pool=''):
313
+ self.num_classes = num_classes
314
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
315
+
316
+ def forward_features(self, x):
317
+ x = self.patch_embed(x)
318
+ batch_size, seq_len, _ = x.size()
319
+
320
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
321
+ x = torch.cat((cls_tokens, x), dim=1)
322
+ if self.pos_embed is not None:
323
+ x = x + self.pos_embed
324
+ x = self.pos_drop(x)
325
+
326
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
327
+ for blk in self.blocks:
328
+ if self.use_checkpoint:
329
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
330
+ else:
331
+ x = blk(x, rel_pos_bias)
332
+ return x
333
+ # x = self.norm(x)
334
+
335
+ # if self.fc_norm is not None:
336
+ # t = x[:, 1:, :]
337
+ # return self.fc_norm(t.mean(1))
338
+ # else:
339
+ # return x[:, 0]
340
+
341
+ def forward(self, x):
342
+ x = self.forward_features(x)
343
+ # x = self.head(x)
344
+ return x
345
+
346
+ def get_intermediate_layers(self, x):
347
+ x = self.patch_embed(x)
348
+ batch_size, seq_len, _ = x.size()
349
+
350
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
351
+ x = torch.cat((cls_tokens, x), dim=1)
352
+ if self.pos_embed is not None:
353
+ x = x + self.pos_embed
354
+ x = self.pos_drop(x)
355
+
356
+ features = []
357
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
358
+ for blk in self.blocks:
359
+ x = blk(x, rel_pos_bias)
360
+ features.append(x)
361
+
362
+ return features
363
+
364
+
365
+ def interpolate_pos_embed(model, checkpoint_model):
366
+ if 'pos_embed' in checkpoint_model:
367
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
368
+ embedding_size = pos_embed_checkpoint.shape[-1]
369
+ num_patches = model.patch_embed.num_patches
370
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
371
+ # height (== width) for the checkpoint position embedding
372
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
373
+ # height (== width) for the new position embedding
374
+ new_size = int(num_patches ** 0.5)
375
+ # class_token and dist_token are kept unchanged
376
+ if orig_size != new_size:
377
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
378
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
379
+ # only the position tokens are interpolated
380
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
381
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
382
+ pos_tokens = torch.nn.functional.interpolate(
383
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
384
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
385
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
386
+ checkpoint_model['pos_embed'] = new_pos_embed
387
+
388
+
389
+ def convert_weights_to_fp16(model: nn.Module):
390
+ """Convert applicable model parameters to fp16"""
391
+
392
+ def _convert_weights_to_fp16(l):
393
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
394
+ l.weight.data = l.weight.data.half()
395
+ if l.bias is not None:
396
+ l.bias.data = l.bias.data.half()
397
+
398
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
399
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
400
+ # tensor = getattr(l, attr)
401
+ # if tensor is not None:
402
+ # tensor.data = tensor.data.half()
403
+
404
+ model.apply(_convert_weights_to_fp16)
405
+
406
+
407
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
408
+ model = VisionTransformer(
409
+ img_size=img_size,
410
+ patch_size=14,
411
+ use_mean_pooling=False,
412
+ embed_dim=1408,
413
+ depth=39,
414
+ num_heads=1408//88,
415
+ mlp_ratio=4.3637,
416
+ qkv_bias=True,
417
+ drop_path_rate=drop_path_rate,
418
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
419
+ use_checkpoint=use_checkpoint,
420
+ )
421
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
422
+ cached_file = download_cached_file(
423
+ url, check_hash=False, progress=True
424
+ )
425
+ state_dict = torch.load(cached_file, map_location="cpu")
426
+ interpolate_pos_embed(model,state_dict)
427
+
428
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
429
+ # print(incompatible_keys)
430
+
431
+ if precision == "fp16":
432
+ # model.to("cuda")
433
+ convert_weights_to_fp16(model)
434
+ return model
model/gptk-7b.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ arch: instruct_vicuna7b
7
+ pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
8
+
9
+ # vit encoder
10
+ image_size: 224
11
+ drop_path_rate: 0
12
+ use_grad_checkpoint: False
13
+ vit_precision: "fp16"
14
+ freeze_vit: True
15
+
16
+ # Q-Former
17
+ num_query_token: 32
18
+
19
+ # path to Vicuna checkpoint
20
+ llm_model: "model/llm/vicuna-7b-v1.1"
21
+ # llm_model: "lmsys/vicuna-7b-v1.3"
22
+ # llm_model: "lmsys/vicuna-7b-v1.5"
23
+
24
+ # generation configs
25
+ prompt: ""