wcy1122 commited on
Commit
35153f6
1 Parent(s): 9066a31

update code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. minigemini/__init__.py +3 -0
  2. minigemini/constants.py +27 -0
  3. minigemini/conversation.py +460 -0
  4. minigemini/eval/MathVista/calculate_score.py +258 -0
  5. minigemini/eval/MathVista/extract_answer.py +160 -0
  6. minigemini/eval/MathVista/prompts/ext_ans.py +42 -0
  7. minigemini/eval/MathVista/utilities.py +200 -0
  8. minigemini/eval/eval_gpt_review.py +113 -0
  9. minigemini/eval/eval_gpt_review_bench.py +121 -0
  10. minigemini/eval/eval_gpt_review_visual.py +118 -0
  11. minigemini/eval/eval_pope.py +81 -0
  12. minigemini/eval/eval_science_qa.py +114 -0
  13. minigemini/eval/eval_science_qa_gpt4.py +104 -0
  14. minigemini/eval/eval_science_qa_gpt4_requery.py +149 -0
  15. minigemini/eval/eval_textvqa.py +65 -0
  16. minigemini/eval/generate_webpage_data_from_table.py +111 -0
  17. minigemini/eval/m4c_evaluator.py +334 -0
  18. minigemini/eval/model_math_vista.py +237 -0
  19. minigemini/eval/model_qa.py +64 -0
  20. minigemini/eval/model_vqa.py +154 -0
  21. minigemini/eval/model_vqa_loader.py +187 -0
  22. minigemini/eval/model_vqa_mmbench.py +212 -0
  23. minigemini/eval/model_vqa_qbench.py +122 -0
  24. minigemini/eval/model_vqa_science.py +162 -0
  25. minigemini/eval/qa_baseline_gpt35.py +74 -0
  26. minigemini/eval/run_llava.py +143 -0
  27. minigemini/eval/summarize_gpt_review.py +60 -0
  28. minigemini/mm_utils.py +105 -0
  29. minigemini/model/__init__.py +7 -0
  30. minigemini/model/builder.py +140 -0
  31. minigemini/model/consolidate.py +29 -0
  32. minigemini/model/language_model/mini_gemini_gemma.py +164 -0
  33. minigemini/model/language_model/mini_gemini_llama.py +203 -0
  34. minigemini/model/language_model/mini_gemini_mistral.py +162 -0
  35. minigemini/model/language_model/mini_gemini_mixtral.py +162 -0
  36. minigemini/model/llava_arch.py +299 -0
  37. minigemini/model/mini_gemini_arch.py +497 -0
  38. minigemini/model/multimodal_encoder/builder.py +34 -0
  39. minigemini/model/multimodal_encoder/clip_encoder.py +89 -0
  40. minigemini/model/multimodal_encoder/eva_encoder.py +551 -0
  41. minigemini/model/multimodal_encoder/openclip_encoder.py +225 -0
  42. minigemini/model/multimodal_projector/builder.py +50 -0
  43. minigemini/model/processor/video_processor.py +74 -0
  44. minigemini/serve/__init__.py +0 -0
  45. minigemini/serve/cli.py +237 -0
  46. minigemini/serve/controller.py +298 -0
  47. minigemini/serve/examples/extreme_ironing.jpg +3 -0
  48. minigemini/serve/examples/monday.jpg +3 -0
  49. minigemini/serve/examples/waterview.jpg +3 -0
  50. minigemini/serve/examples/woolen.png +3 -0
minigemini/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import timm
2
+ import open_clip
3
+ from .model import MiniGeminiLlamaForCausalLM
minigemini/constants.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ PREDICT_TOKEN_INDEX = -300
10
+ DEFAULT_IMAGE_TOKEN = "<image>"
11
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
12
+ DEFAULT_IM_START_TOKEN = "<im_start>"
13
+ DEFAULT_IM_END_TOKEN = "<im_end>"
14
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
15
+ DEFAULT_PREDICT_TOKEN = "<predict>"
16
+
17
+ DESCRIPT_PROMPT = [
18
+ "Describe this image thoroughly.",
19
+ "Provide a detailed description in this picture.",
20
+ "Detail every aspect of what's in this picture.",
21
+ "Explain this image with precision and detail.",
22
+ "Give a comprehensive description of this visual.",
23
+ "Elaborate on the specifics within this image.",
24
+ "Offer a detailed account of this picture's contents.",
25
+ "Describe in detail what this image portrays.",
26
+ "Break down this image into detailed descriptions.",
27
+ "Provide a thorough description of the elements in this image."]
minigemini/conversation.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ GEMMA = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message = message[0]
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message = message[0]
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.GEMMA:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system + seps[0]
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message = message[0]
102
+ ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
103
+ else:
104
+ ret += "<start_of_turn>" + role + "\n"
105
+ elif self.sep_style == SeparatorStyle.PLAIN:
106
+ seps = [self.sep, self.sep2]
107
+ ret = self.system
108
+ for i, (role, message) in enumerate(messages):
109
+ if message:
110
+ if type(message) is tuple:
111
+ message = message[0]
112
+ ret += message + seps[i % 2]
113
+ else:
114
+ ret += ""
115
+ else:
116
+ raise ValueError(f"Invalid style: {self.sep_style}")
117
+
118
+ return ret
119
+
120
+ def append_message(self, role, message):
121
+ self.messages.append([role, message])
122
+
123
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
124
+ if image_process_mode == "Pad":
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(pil_img.mode, (width, width), background_color)
131
+ result.paste(pil_img, (0, (width - height) // 2))
132
+ return result
133
+ else:
134
+ result = Image.new(pil_img.mode, (height, height), background_color)
135
+ result.paste(pil_img, ((height - width) // 2, 0))
136
+ return result
137
+ image = expand2square(image)
138
+ elif image_process_mode in ["Default", "Crop"]:
139
+ pass
140
+ elif image_process_mode == "Resize":
141
+ image = image.resize((336, 336))
142
+ else:
143
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144
+ if max(image.size) > max_len:
145
+ max_hw, min_hw = max(image.size), min(image.size)
146
+ aspect_ratio = max_hw / min_hw
147
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148
+ longest_edge = int(shortest_edge * aspect_ratio)
149
+ W, H = image.size
150
+ if H > W:
151
+ H, W = longest_edge, shortest_edge
152
+ else:
153
+ H, W = shortest_edge, longest_edge
154
+ image = image.resize((W, H))
155
+ if return_pil:
156
+ return image
157
+ else:
158
+ buffered = BytesIO()
159
+ image.save(buffered, format=image_format)
160
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161
+ return img_b64_str
162
+
163
+ def get_images(self, return_pil=False):
164
+ images = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
170
+ images.append(image)
171
+ return images
172
+
173
+ def to_gradio_chatbot(self):
174
+ ret = []
175
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
176
+ if i % 2 == 0:
177
+ if type(msg) is tuple:
178
+ msg, image, image_process_mode = msg
179
+ img_b64_str = self.process_image(
180
+ image, "Default", return_pil=False,
181
+ image_format='JPEG')
182
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
183
+ msg = img_str + msg.replace('<image>', '').strip()
184
+ ret.append([msg, None])
185
+ else:
186
+ ret.append([msg, None])
187
+ else:
188
+ if type(msg) is tuple and len(msg) == 2:
189
+ msg, img_b64_str = msg
190
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg.strip() + img_str
192
+ ret[-1][-1] = msg
193
+ return ret
194
+
195
+ def copy(self):
196
+ return Conversation(
197
+ system=self.system,
198
+ roles=self.roles,
199
+ messages=[[x, y] for x, y in self.messages],
200
+ offset=self.offset,
201
+ sep_style=self.sep_style,
202
+ sep=self.sep,
203
+ sep2=self.sep2,
204
+ version=self.version)
205
+
206
+ def dict(self):
207
+ if len(self.get_images()) > 0:
208
+ return {
209
+ "system": self.system,
210
+ "roles": self.roles,
211
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
212
+ "offset": self.offset,
213
+ "sep": self.sep,
214
+ "sep2": self.sep2,
215
+ }
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": self.messages,
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+
225
+
226
+ conv_vicuna_v0 = Conversation(
227
+ system="A chat between a curious human and an artificial intelligence assistant. "
228
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
229
+ roles=("Human", "Assistant"),
230
+ messages=(
231
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
232
+ ("Assistant",
233
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
234
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
235
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
236
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
237
+ "renewable and non-renewable energy sources:\n"
238
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
239
+ "energy sources are finite and will eventually run out.\n"
240
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
241
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
242
+ "and other negative effects.\n"
243
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
244
+ "have lower operational costs than non-renewable sources.\n"
245
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
246
+ "locations than non-renewable sources.\n"
247
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
248
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
249
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
250
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
251
+ ),
252
+ offset=2,
253
+ sep_style=SeparatorStyle.SINGLE,
254
+ sep="###",
255
+ )
256
+
257
+ conv_vicuna_v1 = Conversation(
258
+ system="A chat between a curious user and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
260
+ roles=("USER", "ASSISTANT"),
261
+ version="v1",
262
+ messages=(),
263
+ offset=0,
264
+ sep_style=SeparatorStyle.TWO,
265
+ sep=" ",
266
+ sep2="</s>",
267
+ )
268
+
269
+ conv_llama_2 = Conversation(
270
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
271
+
272
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
273
+ roles=("USER", "ASSISTANT"),
274
+ version="llama_v2",
275
+ messages=(),
276
+ offset=0,
277
+ sep_style=SeparatorStyle.LLAMA_2,
278
+ sep="<s>",
279
+ sep2="</s>",
280
+ )
281
+
282
+ conv_llava_llama_2 = Conversation(
283
+ system="You are a helpful language and vision assistant. "
284
+ "You are able to understand the visual content that the user provides, "
285
+ "and assist the user with a variety of tasks using natural language.",
286
+ roles=("USER", "ASSISTANT"),
287
+ version="llama_v2",
288
+ messages=(),
289
+ offset=0,
290
+ sep_style=SeparatorStyle.LLAMA_2,
291
+ sep="<s>",
292
+ sep2="</s>",
293
+ )
294
+
295
+ conv_mpt = Conversation(
296
+ system="""<|im_start|>system
297
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
298
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
299
+ version="mpt",
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.MPT,
303
+ sep="<|im_end|>",
304
+ )
305
+
306
+ conv_llava_plain = Conversation(
307
+ system="",
308
+ roles=("", ""),
309
+ messages=(
310
+ ),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.PLAIN,
313
+ sep="\n",
314
+ )
315
+
316
+ conv_llava_v0 = Conversation(
317
+ system="A chat between a curious human and an artificial intelligence assistant. "
318
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
319
+ roles=("Human", "Assistant"),
320
+ messages=(
321
+ ),
322
+ offset=0,
323
+ sep_style=SeparatorStyle.SINGLE,
324
+ sep="###",
325
+ )
326
+
327
+ conv_llava_v0_mmtag = Conversation(
328
+ system="A chat between a curious user and an artificial intelligence assistant. "
329
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
330
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
331
+ roles=("Human", "Assistant"),
332
+ messages=(
333
+ ),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.SINGLE,
336
+ sep="###",
337
+ version="v0_mmtag",
338
+ )
339
+
340
+ conv_llava_v1 = Conversation(
341
+ system="A chat between a curious human and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_vicuna_imgsp_v1 = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="imgsp_v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+ conv_llava_plain_guided = Conversation(
365
+ system="",
366
+ roles=("", ""),
367
+ version="plain_guided",
368
+ messages=(
369
+ ),
370
+ offset=0,
371
+ sep_style=SeparatorStyle.PLAIN,
372
+ sep="\n",
373
+ )
374
+
375
+ conv_llava_v1_mmtag = Conversation(
376
+ system="A chat between a curious user and an artificial intelligence assistant. "
377
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
378
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
379
+ roles=("USER", "ASSISTANT"),
380
+ messages=(),
381
+ offset=0,
382
+ sep_style=SeparatorStyle.TWO,
383
+ sep=" ",
384
+ sep2="</s>",
385
+ version="v1_mmtag",
386
+ )
387
+
388
+ conv_phi_2 = Conversation(
389
+ system="A chat between a curious user and an artificial intelligence assistant. "
390
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
391
+ roles=("USER", "ASSISTANT"),
392
+ version="phi2",
393
+ messages=(),
394
+ offset=0,
395
+ sep_style=SeparatorStyle.TWO,
396
+ sep=" ",
397
+ sep2="<|endoftext|>",
398
+ )
399
+
400
+ conv_mistral_instruct = Conversation(
401
+ system="",
402
+ roles=("USER", "ASSISTANT"),
403
+ version="llama_v2",
404
+ messages=(),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.LLAMA_2,
407
+ sep="<s>",
408
+ sep2="</s>",
409
+ )
410
+
411
+ conv_gemma = Conversation(
412
+ system="",
413
+ roles=("user", "model"),
414
+ version="gemma",
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.GEMMA,
418
+ sep="",
419
+ sep2="<eos>",
420
+ )
421
+
422
+ conv_chatml_direct = Conversation(
423
+ system="""<|im_start|>system
424
+ Answer the questions.""",
425
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
426
+ version="mpt",
427
+ messages=(),
428
+ offset=0,
429
+ sep_style=SeparatorStyle.MPT,
430
+ sep="<|im_end|>",
431
+ )
432
+
433
+ default_conversation = conv_vicuna_v1
434
+ conv_templates = {
435
+ "default": conv_vicuna_v0,
436
+ "v0": conv_vicuna_v0,
437
+ "v1": conv_vicuna_v1,
438
+ "vicuna_v1": conv_vicuna_v1,
439
+ "phi_2": conv_phi_2,
440
+ "gemma": conv_gemma,
441
+ "llama_2": conv_llama_2,
442
+ "imgsp_v1": conv_vicuna_imgsp_v1,
443
+ "plain_guided": conv_llava_plain_guided,
444
+ "mistral_instruct": conv_mistral_instruct,
445
+ "chatml_direct": conv_chatml_direct,
446
+ "mistral_direct": conv_chatml_direct,
447
+ "plain": conv_llava_plain,
448
+ "v0_plain": conv_llava_plain,
449
+ "llava_v0": conv_llava_v0,
450
+ "v0_mmtag": conv_llava_v0_mmtag,
451
+ "llava_v1": conv_llava_v1,
452
+ "v1_mmtag": conv_llava_v1_mmtag,
453
+ "llava_llama_2": conv_llava_llama_2,
454
+
455
+ "mpt": conv_mpt,
456
+ }
457
+
458
+
459
+ if __name__ == "__main__":
460
+ print(default_conversation.get_prompt())
minigemini/eval/MathVista/calculate_score.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import pandas as pd
5
+
6
+ # !pip install python-Levenshtein
7
+ from Levenshtein import distance
8
+
9
+ import sys
10
+ sys.path.append('../')
11
+ from utilities import *
12
+
13
+
14
+ def get_most_similar(prediction, choices):
15
+ """
16
+ Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
17
+ """
18
+ distances = [distance(prediction, choice) for choice in choices]
19
+ ind = distances.index(min(distances))
20
+ return choices[ind]
21
+ # return min(choices, key=lambda choice: distance(prediction, choice))
22
+
23
+
24
+ def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision):
25
+ """
26
+ Normalize the extracted answer to match the answer type
27
+ """
28
+ if question_type == 'multi_choice':
29
+ # make sure the extraction is a string
30
+ if isinstance(extraction, str):
31
+ extraction = extraction.strip()
32
+ else:
33
+ try:
34
+ extraction = str(extraction)
35
+ except:
36
+ extraction = ""
37
+
38
+ # extract "A" from "(A) text"
39
+ letter = re.findall(r'\(([a-zA-Z])\)', extraction)
40
+ if len(letter) > 0:
41
+ extraction = letter[0].upper()
42
+
43
+ options = [chr(ord('A') + i) for i in range(len(choices))]
44
+
45
+ if extraction in options:
46
+ # convert option letter to text, e.g. "A" -> "text"
47
+ ind = options.index(extraction)
48
+ extraction = choices[ind]
49
+ else:
50
+ # select the most similar option
51
+ extraction = get_most_similar(extraction, choices)
52
+ assert extraction in choices
53
+
54
+ elif answer_type == 'integer':
55
+ try:
56
+ extraction = str(int(float(extraction)))
57
+ except:
58
+ extraction = None
59
+
60
+ elif answer_type == 'float':
61
+ try:
62
+ extraction = str(round(float(extraction), precision))
63
+ except:
64
+ extraction = None
65
+
66
+ elif answer_type == 'list':
67
+ try:
68
+ extraction = str(extraction)
69
+ except:
70
+ extraction = None
71
+
72
+ return extraction
73
+
74
+
75
+ def safe_equal(prediction, answer):
76
+ """
77
+ Check if the prediction is equal to the answer, even if they are of different types
78
+ """
79
+ try:
80
+ if prediction == answer:
81
+ return True
82
+ return False
83
+ except Exception as e:
84
+ print(e)
85
+ return False
86
+
87
+
88
+ def get_acc_with_contion(res_pd, key, value):
89
+ if key == 'skills':
90
+ # if value in res_pd[key]:
91
+ total_pd = res_pd[res_pd[key].apply(lambda x: value in x)]
92
+ else:
93
+ total_pd = res_pd[res_pd[key] == value]
94
+
95
+ correct_pd = total_pd[total_pd['true_false'] == True]
96
+ acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
97
+ return len(correct_pd), len(total_pd), acc
98
+
99
+ if __name__ == '__main__':
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument('--output_file', type=str, default='output.json')
102
+ parser.add_argument('--score_file', type=str, default='scores.json')
103
+ parser.add_argument('--gt_file', type=str, default='../data/testmini.json', help='ground truth file')
104
+ parser.add_argument('--number', type=int, default=-1, help='number of problems to run')
105
+ parser.add_argument('--rerun', action='store_true', help='rerun the evaluation')
106
+ parser.add_argument('--caculate_gain', action='store_true', help='caculate the socre gains over random guess')
107
+ parser.add_argument('--random_file', type=str, default='score_random_guess.json')
108
+ args = parser.parse_args()
109
+
110
+ # args
111
+ output_file = args.output_file
112
+
113
+ # # quick test
114
+ # output_file = '../results/llava-llama-2-13b/output_llava_llama_2_13b.json'
115
+
116
+ # read json
117
+ print(f"Reading {output_file}...")
118
+ results = read_json(output_file)
119
+
120
+ # read ground truth
121
+ print(f"Reading {args.gt_file}...")
122
+ gts = read_json(args.gt_file)
123
+
124
+ # full pids
125
+ full_pids = list(results.keys())
126
+ if args.number > 0:
127
+ full_pids = full_pids[:min(args.number, len(full_pids))]
128
+ print("Number of testing problems:", len(full_pids))
129
+
130
+ ## [1] Evaluate if the prediction is true or false
131
+ print("\nEvaluating the predictions...")
132
+ update_json_flag = False
133
+ for pid in full_pids:
134
+ problem = results[pid]
135
+ # print(problem)
136
+
137
+ if args.rerun:
138
+ if 'prediction' in problem:
139
+ del problem['prediction']
140
+ if 'true_false' in problem:
141
+ del problem['true_false']
142
+
143
+ choices = problem['choices']
144
+ question_type = problem['question_type']
145
+ answer_type = problem['answer_type']
146
+ precision = problem['precision']
147
+ extraction = problem['extraction']
148
+
149
+ if 'answer' in problem:
150
+ answer = problem['answer']
151
+ else:
152
+ answer = gts[pid]['answer']
153
+ problem['answer'] = answer
154
+
155
+ # normalize the extracted answer to match the answer type
156
+ prediction = normalize_extracted_answer(extraction, choices, question_type, answer_type, precision)
157
+
158
+ # verify the prediction is true or false
159
+ true_false = safe_equal(prediction, answer)
160
+
161
+ # update the problem
162
+ if "true_false" not in problem:
163
+ update_json_flag = True
164
+
165
+ elif true_false != problem['true_false']:
166
+ update_json_flag = True
167
+
168
+ if "prediction" not in problem:
169
+ update_json_flag = True
170
+
171
+ elif prediction != problem['prediction']:
172
+ update_json_flag = True
173
+
174
+ problem['prediction'] = prediction
175
+ problem['true_false'] = true_false
176
+
177
+ # save the updated json
178
+ if update_json_flag:
179
+ print("\n!!!Some problems are updated.!!!")
180
+ print(f"\nSaving {output_file}...")
181
+ save_json(results, output_file)
182
+
183
+ ## [2] Calculate the average accuracy
184
+ total = len(full_pids)
185
+ correct = 0
186
+ for pid in full_pids:
187
+ if results[pid]['true_false']:
188
+ correct += 1
189
+ accuracy = str(round(correct / total * 100, 2))
190
+ print(f"\nCorrect: {correct}, Total: {total}, Accuracy: {accuracy}%")
191
+
192
+ scores = {"average": {"accuracy": accuracy, "correct": correct, "total": total}}
193
+
194
+ ## [3] Calculate the fine-grained accuracy scores
195
+
196
+ # merge the 'metadata' attribute into the data
197
+ for pid in results:
198
+ results[pid].update(results[pid].pop('metadata'))
199
+
200
+ # convert the data to a pandas DataFrame
201
+ df = pd.DataFrame(results).T
202
+
203
+ print(len(df))
204
+ print("Number of test problems:", len(df))
205
+ # assert len(df) == 1000 # Important!!!
206
+
207
+ # asign the target keys for evaluation
208
+ target_keys = ['question_type', 'answer_type', 'language', 'source', 'category', 'task', 'context', 'grade', 'skills']
209
+
210
+ for key in target_keys:
211
+ print(f"\nType: [{key}]")
212
+ # get the unique values of the key
213
+ if key == 'skills':
214
+ # the value is a list
215
+ values = []
216
+ for i in range(len(df)):
217
+ values += df[key][i]
218
+ values = list(set(values))
219
+ else:
220
+ values = df[key].unique()
221
+ #print(values)
222
+
223
+ # calculate the accuracy for each value
224
+ scores[key] = {}
225
+ for value in values:
226
+ correct, total, acc = get_acc_with_contion(df, key, value)
227
+ if total > 0:
228
+ print(f"[{value}]: {acc}% ({correct}/{total})")
229
+ scores[key][value] = {"accuracy": acc, "correct": correct, "total": total}
230
+
231
+ # sort the scores by accuracy
232
+ scores[key] = dict(sorted(scores[key].items(), key=lambda item: float(item[1]['accuracy']), reverse=True))
233
+
234
+ # save the scores
235
+ scores_file = args.score_file
236
+ print(f"\nSaving {scores_file}...")
237
+ save_json(scores, scores_file)
238
+ print("\nDone!")
239
+
240
+ # [4] Calculate the score gains over random guess
241
+ if args.caculate_gain:
242
+ random_file = args.random_file
243
+ random_scores = json.load(open(random_file))
244
+
245
+ print("\nCalculating the score gains...")
246
+ for key in scores:
247
+ if key == 'average':
248
+ gain = round(float(scores[key]['accuracy']) - float(random_scores[key]['accuracy']), 2)
249
+ scores[key]['acc_gain'] = gain
250
+ else:
251
+ for sub_key in scores[key]:
252
+ gain = round(float(scores[key][sub_key]['accuracy']) - float(random_scores[key][sub_key]['accuracy']), 2)
253
+ scores[key][sub_key]['acc_gain'] = str(gain)
254
+
255
+ # save the score gains
256
+ print(f"\nSaving {scores_file}...")
257
+ save_json(scores, scores_file)
258
+ print("\nDone!")
minigemini/eval/MathVista/extract_answer.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import argparse
5
+
6
+ from tqdm import tqdm
7
+
8
+ import sys
9
+ sys.path.append('../')
10
+ from utilities import *
11
+
12
+ # OpenAI
13
+ import openai
14
+
15
+ # load demo prompt
16
+ from prompts.ext_ans import demo_prompt
17
+
18
+
19
+ def verify_extraction(extraction):
20
+ extraction = extraction.strip()
21
+ if extraction == "" or extraction == None:
22
+ return False
23
+ return True
24
+
25
+
26
+ def create_test_prompt(demo_prompt, query, response):
27
+ demo_prompt = demo_prompt.strip()
28
+ test_prompt = f"{query}\n\n{response}"
29
+ full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
30
+ return full_prompt
31
+
32
+
33
+ def extract_answer(response, problem, quick_extract=False):
34
+ question_type = problem['question_type']
35
+ answer_type = problem['answer_type']
36
+ choices = problem['choices']
37
+ query = problem['query']
38
+ pid = problem['pid']
39
+
40
+ if response == "":
41
+ return ""
42
+
43
+ if question_type == 'multi_choice' and response in choices:
44
+ return response
45
+
46
+ if answer_type == "integer":
47
+ try:
48
+ extraction = int(response)
49
+ return str(extraction)
50
+ except:
51
+ pass
52
+
53
+ if answer_type == "float":
54
+ try:
55
+ extraction = str(float(response))
56
+ return extraction
57
+ except:
58
+ pass
59
+
60
+ # quick extraction
61
+ if quick_extract:
62
+ print("Quickly extracting answer...")
63
+ # The answer is "text". -> "text"
64
+ try:
65
+ result = re.search(r'The answer is "(.*)"\.', response)
66
+ if result:
67
+ extraction = result.group(1)
68
+ return extraction
69
+ except:
70
+ pass
71
+
72
+ # general extraction
73
+ try:
74
+ full_prompt = create_test_prompt(demo_prompt, query, response)
75
+ extraction = get_chat_response(full_prompt, openai.api_key, openai.api_base, model=args.llm_engine)
76
+ return extraction
77
+ except Exception as e:
78
+ print(e)
79
+ print(f"Error in extracting answer for {pid}")
80
+
81
+ return ""
82
+
83
+
84
+ if __name__ == '__main__':
85
+ parser = argparse.ArgumentParser()
86
+ # input
87
+ parser.add_argument('--output_file', type=str, default='answer.json')
88
+ parser.add_argument('--response_label', type=str, default='response', help='response label for the input file')
89
+ # model
90
+ parser.add_argument('--llm_engine', type=str, default='gpt-4-0613', help='llm engine',
91
+ choices = ['gpt-3.5-turbo', 'gpt-3.5', 'gpt-4', 'gpt-4-0314', 'gpt-4-0613'])
92
+ parser.add_argument('--number', type=int, default=-1, help='number of problems to run')
93
+ parser.add_argument('--quick_extract', action='store_true', help='use rules to extract answer for some problems')
94
+ parser.add_argument('--rerun', action='store_true', help='rerun the answer extraction')
95
+ # openai
96
+ parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key")
97
+ parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base")
98
+ # output
99
+ parser.add_argument('--save_every', type=int, default=10, help='save every n problems')
100
+ parser.add_argument('--output_label', type=str, default='', help='label for the output file')
101
+ args = parser.parse_args()
102
+
103
+ # args
104
+ label = args.response_label
105
+ result_file = args.output_file
106
+ if args.output_label != '':
107
+ output_file = result_file.replace('.json', f'_{args.output_label}.json')
108
+ else:
109
+ output_file = result_file
110
+
111
+ # read results
112
+ print(f"Reading {result_file}...")
113
+ try:
114
+ results = read_json(output_file)
115
+ except:
116
+ samples = [json.loads(line) for line in open(result_file)]
117
+ results = {}
118
+ for sample in samples:
119
+ results[sample['pid']] = sample
120
+
121
+ # full pids
122
+ full_pids = list(results.keys())
123
+ if args.number > 0:
124
+ full_pids = full_pids[:min(args.number, len(full_pids))]
125
+ print("Number of testing problems:", len(full_pids))
126
+
127
+ # test pids
128
+ if args.rerun:
129
+ test_pids = full_pids
130
+ else:
131
+ test_pids = []
132
+ for pid in full_pids:
133
+ # print(pid)
134
+ if 'extraction' not in results[pid] or not verify_extraction(results[pid]['extraction']):
135
+ test_pids.append(pid)
136
+
137
+ test_num = len(test_pids)
138
+ print("Number of problems to run:", test_num)
139
+ # print(test_pids)
140
+
141
+ # openai api
142
+ openai.api_key = args.api_key # Your API key here
143
+ if args.api_base:
144
+ openai.api_base = args.api_base # Your API base here
145
+
146
+ # tqdm, enumerate results
147
+ for i, pid in enumerate(tqdm(test_pids)):
148
+ problem = results[pid]
149
+
150
+ assert label in problem
151
+ response = problem[label]
152
+
153
+
154
+ extraction = extract_answer(response, problem, args.quick_extract)
155
+ results[pid]['extraction'] = extraction
156
+
157
+ if i % args.save_every == 0 or i == test_num - 1:
158
+ print(f"Saving results to {output_file}...")
159
+ save_json(results, output_file)
160
+ print(f"Results saved.")
minigemini/eval/MathVista/prompts/ext_ans.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # pids = 852, 104, 824, 506, 540
4
+
5
+ demo_prompt = """
6
+ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.
7
+
8
+ Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
9
+ Question: Which number is missing?
10
+
11
+ Model response: The number missing in the sequence is 14.
12
+
13
+ Extracted answer: 14
14
+
15
+ Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
16
+ Question: What is the fraction of females facing the camera?
17
+
18
+ Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.
19
+
20
+ Extracted answer: 0.6
21
+
22
+ Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
23
+ Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)
24
+
25
+ Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.
26
+
27
+ Extracted answer: 1.45
28
+
29
+ Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
30
+ Question: Between which two years does the line graph saw its maximum peak?
31
+
32
+ Model response: The line graph saw its maximum peak between 2007 and 2008.
33
+
34
+ Extracted answer: [2007, 2008]
35
+
36
+ Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
37
+ Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5
38
+
39
+ Model response: The correct answer is (B) 8/11.
40
+
41
+ Extracted answer: B
42
+ """
minigemini/eval/MathVista/utilities.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import time
5
+ import pickle
6
+ import openai
7
+ import re
8
+ from word2number import w2n
9
+
10
+
11
+ def create_dir(output_dir):
12
+ if not os.path.exists(output_dir):
13
+ os.makedirs(output_dir)
14
+
15
+
16
+ def read_csv(file):
17
+ data = []
18
+ with open(file, 'r') as f:
19
+ for line in f:
20
+ data.append(line.strip())
21
+ return data
22
+
23
+
24
+ def read_pandas_csv(csv_path):
25
+ # read a pandas csv sheet
26
+ import pandas as pd
27
+ df = pd.read_csv(csv_path)
28
+ return df
29
+
30
+
31
+ def read_json(path):
32
+ with open(path, 'r', encoding='utf-8') as f:
33
+ return json.load(f)
34
+
35
+
36
+ def read_jsonl(file):
37
+ with open(file, 'r') as f:
38
+ data = [json.loads(line) for line in f]
39
+ return data
40
+
41
+
42
+ def read_pickle(path):
43
+ with open(path, 'rb') as f:
44
+ return pickle.load(f)
45
+
46
+
47
+ def save_json(data, path):
48
+ with open(path, 'w') as f:
49
+ json.dump(data, f, indent=4)
50
+
51
+
52
+ def save_array_img(path, image):
53
+ cv2.imwrite(path, image)
54
+
55
+
56
+ def contains_digit(text):
57
+ # check if text contains a digit
58
+ if any(char.isdigit() for char in text):
59
+ return True
60
+ return False
61
+
62
+ def contains_number_word(text):
63
+ # check if text contains a number word
64
+ ignore_words = ["a", "an", "point"]
65
+ words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text
66
+ for word in words:
67
+ if word in ignore_words:
68
+ continue
69
+ try:
70
+ w2n.word_to_num(word)
71
+ return True # If the word can be converted to a number, return True
72
+ except ValueError:
73
+ continue # If the word can't be converted to a number, continue with the next word
74
+
75
+ # check if text contains a digit
76
+ if any(char.isdigit() for char in text):
77
+ return True
78
+
79
+ return False # If none of the words could be converted to a number, return False
80
+
81
+
82
+ def contains_quantity_word(text, special_keep_words=[]):
83
+ # check if text contains a quantity word
84
+ quantity_words = ["most", "least", "fewest"
85
+ "more", "less", "fewer",
86
+ "largest", "smallest", "greatest",
87
+ "larger", "smaller", "greater",
88
+ "highest", "lowest", "higher", "lower",
89
+ "increase", "decrease",
90
+ "minimum", "maximum", "max", "min",
91
+ "mean", "average", "median",
92
+ "total", "sum", "add", "subtract",
93
+ "difference", "quotient", "gap",
94
+ "half", "double", "twice", "triple",
95
+ "square", "cube", "root",
96
+ "approximate", "approximation",
97
+ "triangle", "rectangle", "circle", "square", "cube", "sphere", "cylinder", "cone", "pyramid",
98
+ "multiply", "divide",
99
+ "percentage", "percent", "ratio", "proportion", "fraction", "rate",
100
+ ]
101
+
102
+ quantity_words += special_keep_words # dataset specific words
103
+
104
+ words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text
105
+ if any(word in quantity_words for word in words):
106
+ return True
107
+
108
+ return False # If none of the words could be converted to a number, return False
109
+
110
+
111
+ def is_bool_word(text):
112
+ if text in ["Yes", "No", "True", "False",
113
+ "yes", "no", "true", "false",
114
+ "YES", "NO", "TRUE", "FALSE"]:
115
+ return True
116
+ return False
117
+
118
+
119
+ def is_digit_string(text):
120
+ # remove ".0000"
121
+ text = text.strip()
122
+ text = re.sub(r'\.0+$', '', text)
123
+ try:
124
+ int(text)
125
+ return True
126
+ except ValueError:
127
+ return False
128
+
129
+
130
+ def is_float_string(text):
131
+ # text is a float string if it contains a "." and can be converted to a float
132
+ if "." in text:
133
+ try:
134
+ float(text)
135
+ return True
136
+ except ValueError:
137
+ return False
138
+ return False
139
+
140
+
141
+ def copy_image(image_path, output_image_path):
142
+ from shutil import copyfile
143
+ copyfile(image_path, output_image_path)
144
+
145
+
146
+ def copy_dir(src_dir, dst_dir):
147
+ from shutil import copytree
148
+ # copy the source directory to the target directory
149
+ copytree(src_dir, dst_dir)
150
+
151
+
152
+ import PIL.Image as Image
153
+ def get_image_size(img_path):
154
+ img = Image.open(img_path)
155
+ width, height = img.size
156
+ return width, height
157
+
158
+
159
+ def get_chat_response(promot, api_key, api_base, model="gpt-3.5-turbo", temperature=0, max_tokens=256, n=1, patience=10000000,
160
+ sleep_time=0):
161
+ messages = [
162
+ {"role": "user", "content": promot},
163
+ ]
164
+ # print("I am here")
165
+ while patience > 0:
166
+ patience -= 1
167
+ try:
168
+ response = openai.ChatCompletion.create(model=model,
169
+ messages=messages,
170
+ api_key=api_key,
171
+ api_base=api_base,
172
+ temperature=temperature,
173
+ max_tokens=max_tokens,
174
+ n=n)
175
+ if n == 1:
176
+ prediction = response['choices'][0]['message']['content'].strip()
177
+ if prediction != "" and prediction != None:
178
+ return prediction
179
+ else:
180
+ prediction = [choice['message']['content'].strip() for choice in response['choices']]
181
+ if prediction[0] != "" and prediction[0] != None:
182
+ return prediction
183
+
184
+ except Exception as e:
185
+ if "Rate limit" not in str(e):
186
+ print(e)
187
+
188
+ if "Please reduce the length of the messages" in str(e):
189
+ print("!!Reduce promot size")
190
+ # reduce input prompt and keep the tail
191
+ new_size = int(len(promot) * 0.9)
192
+ new_start = len(promot) - new_size
193
+ promot = promot[new_start:]
194
+ messages = [
195
+ {"role": "user", "content": promot},
196
+ ]
197
+
198
+ if sleep_time > 0:
199
+ time.sleep(sleep_time)
200
+ return ""
minigemini/eval/eval_gpt_review.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import tqdm
7
+ import ray
8
+ import time
9
+
10
+ NUM_SECONDS_TO_SLEEP = 3
11
+
12
+ @ray.remote(num_cpus=4)
13
+ def get_eval(content: str, max_tokens: int):
14
+ while True:
15
+ try:
16
+ response = openai.ChatCompletion.create(
17
+ model='gpt-4',
18
+ messages=[{
19
+ 'role': 'system',
20
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21
+ }, {
22
+ 'role': 'user',
23
+ 'content': content,
24
+ }],
25
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
26
+ max_tokens=max_tokens,
27
+ )
28
+ break
29
+ except openai.error.RateLimitError:
30
+ pass
31
+ except Exception as e:
32
+ print(e)
33
+ time.sleep(NUM_SECONDS_TO_SLEEP)
34
+
35
+ print('success!')
36
+ return response['choices'][0]['message']['content']
37
+
38
+
39
+ def parse_score(review):
40
+ try:
41
+ score_pair = review.split('\n')[0]
42
+ score_pair = score_pair.replace(',', ' ')
43
+ sp = score_pair.split(' ')
44
+ if len(sp) == 2:
45
+ return [float(sp[0]), float(sp[1])]
46
+ else:
47
+ print('error', review)
48
+ return [-1, -1]
49
+ except Exception as e:
50
+ print(e)
51
+ print('error', review)
52
+ return [-1, -1]
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57
+ parser.add_argument('-q', '--question')
58
+ # parser.add_argument('-a', '--answer')
59
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60
+ parser.add_argument('-r', '--rule')
61
+ parser.add_argument('-o', '--output')
62
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63
+ args = parser.parse_args()
64
+
65
+ ray.init()
66
+
67
+ f_q = open(os.path.expanduser(args.question))
68
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71
+
72
+ review_file = open(f'{args.output}', 'w')
73
+
74
+ js_list = []
75
+ handles = []
76
+ idx = 0
77
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78
+ # if idx == 1:
79
+ # break
80
+
81
+ ques = json.loads(ques_js)
82
+ ans1 = json.loads(ans1_js)
83
+ ans2 = json.loads(ans2_js)
84
+
85
+ category = json.loads(ques_js)['category']
86
+ if category in rule_dict:
87
+ rule = rule_dict[category]
88
+ else:
89
+ rule = rule_dict['default']
90
+ prompt = rule['prompt']
91
+ role = rule['role']
92
+ content = (f'[Question]\n{ques["text"]}\n\n'
93
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95
+ f'[System]\n{prompt}\n\n')
96
+ js_list.append({
97
+ 'id': idx+1,
98
+ 'question_id': ques['question_id'],
99
+ 'answer1_id': ans1['answer_id'],
100
+ 'answer2_id': ans2['answer_id'],
101
+ 'category': category})
102
+ idx += 1
103
+ handles.append(get_eval.remote(content, args.max_tokens))
104
+ # To avoid the rate limit set by OpenAI
105
+ time.sleep(NUM_SECONDS_TO_SLEEP)
106
+
107
+ reviews = ray.get(handles)
108
+ for idx, review in enumerate(reviews):
109
+ scores = parse_score(review)
110
+ js_list[idx]['content'] = review
111
+ js_list[idx]['tuple'] = scores
112
+ review_file.write(json.dumps(js_list[idx]) + '\n')
113
+ review_file.close()
minigemini/eval/eval_gpt_review_bench.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4-0314',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+
86
+ if isinstance(inst['caption'], list):
87
+ cap_str = '\n'.join(inst['caption'])
88
+ else:
89
+ cap_str = inst['caption']
90
+
91
+ category = 'llava_bench_' + json.loads(ques_js)['category']
92
+ if category in rule_dict:
93
+ rule = rule_dict[category]
94
+ else:
95
+ assert False, f"Visual QA category not found in rule file: {category}."
96
+ prompt = rule['prompt']
97
+ role = rule['role']
98
+ content = (f'[Context]\n{cap_str}\n\n'
99
+ f'[Question]\n{ques["text"]}\n\n'
100
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102
+ f'[System]\n{prompt}\n\n')
103
+ cur_js = {
104
+ 'id': idx+1,
105
+ 'question_id': ques['question_id'],
106
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108
+ 'category': category
109
+ }
110
+ if idx >= len(cur_reviews):
111
+ review = get_eval(content, args.max_tokens)
112
+ scores = parse_score(review)
113
+ cur_js['content'] = review
114
+ cur_js['tuple'] = scores
115
+ review_file.write(json.dumps(cur_js) + '\n')
116
+ review_file.flush()
117
+ else:
118
+ print(f'Skipping {idx} as we already have it.')
119
+ idx += 1
120
+ print(idx)
121
+ review_file.close()
minigemini/eval/eval_gpt_review_visual.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4-0314',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+ cap_str = '\n'.join(inst['captions'])
86
+ box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87
+
88
+ category = json.loads(ques_js)['category']
89
+ if category in rule_dict:
90
+ rule = rule_dict[category]
91
+ else:
92
+ assert False, f"Visual QA category not found in rule file: {category}."
93
+ prompt = rule['prompt']
94
+ role = rule['role']
95
+ content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96
+ f'[Question]\n{ques["text"]}\n\n'
97
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99
+ f'[System]\n{prompt}\n\n')
100
+ cur_js = {
101
+ 'id': idx+1,
102
+ 'question_id': ques['question_id'],
103
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105
+ 'category': category
106
+ }
107
+ if idx >= len(cur_reviews):
108
+ review = get_eval(content, args.max_tokens)
109
+ scores = parse_score(review)
110
+ cur_js['content'] = review
111
+ cur_js['tuple'] = scores
112
+ review_file.write(json.dumps(cur_js) + '\n')
113
+ review_file.flush()
114
+ else:
115
+ print(f'Skipping {idx} as we already have it.')
116
+ idx += 1
117
+ print(idx)
118
+ review_file.close()
minigemini/eval/eval_pope.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def eval_pope(answers, label_file):
6
+ label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7
+
8
+ for answer in answers:
9
+ text = answer['text']
10
+
11
+ # Only keep the first sentence
12
+ if text.find('.') != -1:
13
+ text = text.split('.')[0]
14
+
15
+ text = text.replace(',', '')
16
+ words = text.split(' ')
17
+ if 'No' in words or 'not' in words or 'no' in words:
18
+ answer['text'] = 'no'
19
+ else:
20
+ answer['text'] = 'yes'
21
+
22
+ for i in range(len(label_list)):
23
+ if label_list[i] == 'no':
24
+ label_list[i] = 0
25
+ else:
26
+ label_list[i] = 1
27
+
28
+ pred_list = []
29
+ for answer in answers:
30
+ if answer['text'] == 'no':
31
+ pred_list.append(0)
32
+ else:
33
+ pred_list.append(1)
34
+
35
+ pos = 1
36
+ neg = 0
37
+ yes_ratio = pred_list.count(1) / len(pred_list)
38
+
39
+ TP, TN, FP, FN = 0, 0, 0, 0
40
+ for pred, label in zip(pred_list, label_list):
41
+ if pred == pos and label == pos:
42
+ TP += 1
43
+ elif pred == pos and label == neg:
44
+ FP += 1
45
+ elif pred == neg and label == neg:
46
+ TN += 1
47
+ elif pred == neg and label == pos:
48
+ FN += 1
49
+
50
+ print('TP\tFP\tTN\tFN\t')
51
+ print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52
+
53
+ precision = float(TP) / float(TP + FP)
54
+ recall = float(TP) / float(TP + FN)
55
+ f1 = 2*precision*recall / (precision + recall)
56
+ acc = (TP + TN) / (TP + TN + FP + FN)
57
+ print('Accuracy: {}'.format(acc))
58
+ print('Precision: {}'.format(precision))
59
+ print('Recall: {}'.format(recall))
60
+ print('F1 score: {}'.format(f1))
61
+ print('Yes ratio: {}'.format(yes_ratio))
62
+ print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--annotation-dir", type=str)
67
+ parser.add_argument("--question-file", type=str)
68
+ parser.add_argument("--result-file", type=str)
69
+ args = parser.parse_args()
70
+
71
+ questions = [json.loads(line) for line in open(args.question_file)]
72
+ questions = {question['question_id']: question for question in questions}
73
+ answers = [json.loads(q) for q in open(args.result_file)]
74
+ for file in os.listdir(args.annotation_dir):
75
+ assert file.startswith('coco_pope_')
76
+ assert file.endswith('.json')
77
+ category = file[10:-5]
78
+ cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79
+ print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80
+ eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81
+ print("====================================")
minigemini/eval/eval_science_qa.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+
7
+
8
+ def get_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--base-dir', type=str)
11
+ parser.add_argument('--result-file', type=str)
12
+ parser.add_argument('--output-file', type=str)
13
+ parser.add_argument('--output-result', type=str)
14
+ parser.add_argument('--split', type=str, default='test')
15
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ return parser.parse_args()
17
+
18
+
19
+ def convert_caps(results):
20
+ fakecaps = []
21
+ for result in results:
22
+ image_id = result['question_id']
23
+ caption = result['text']
24
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
25
+ return fakecaps
26
+
27
+
28
+ def get_pred_idx(prediction, choices, options):
29
+ """
30
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
31
+ """
32
+ if prediction in options[:len(choices)]:
33
+ return options.index(prediction)
34
+ else:
35
+ return -1
36
+ return random.choice(range(len(choices)))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ args = get_args()
41
+
42
+ base_dir = args.base_dir
43
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
45
+ predictions = [json.loads(line) for line in open(args.result_file)]
46
+ predictions = {pred['question_id']: pred for pred in predictions}
47
+ split_problems = {idx: problems[idx] for idx in split_indices}
48
+
49
+ results = {'correct': [], 'incorrect': []}
50
+ sqa_results = {}
51
+ sqa_results['acc'] = None
52
+ sqa_results['correct'] = None
53
+ sqa_results['count'] = None
54
+ sqa_results['results'] = {}
55
+ sqa_results['outputs'] = {}
56
+
57
+ for prob_id, prob in split_problems.items():
58
+ if prob_id not in predictions:
59
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60
+ pred_text = 'FAILED'
61
+ else:
62
+ pred = predictions[prob_id]
63
+ pred_text = pred['text']
64
+
65
+ if pred_text in args.options:
66
+ answer = pred_text
67
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68
+ answer = pred_text[0]
69
+ else:
70
+ pattern = re.compile(r'The answer is ([A-Z]).')
71
+ res = pattern.findall(pred_text)
72
+ if len(res) == 1:
73
+ answer = res[0] # 'A', 'B', ...
74
+ else:
75
+ answer = "FAILED"
76
+
77
+ pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78
+
79
+ analysis = {
80
+ 'question_id': prob_id,
81
+ 'parsed_ans': answer,
82
+ 'ground_truth': args.options[prob['answer']],
83
+ 'question': pred['prompt'],
84
+ 'pred': pred_text,
85
+ 'is_multimodal': '<image>' in pred['prompt'],
86
+ }
87
+
88
+ sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89
+ sqa_results['outputs'][prob_id] = pred_text
90
+
91
+ if pred_idx == prob['answer']:
92
+ results['correct'].append(analysis)
93
+ else:
94
+ results['incorrect'].append(analysis)
95
+
96
+ correct = len(results['correct'])
97
+ total = len(results['correct']) + len(results['incorrect'])
98
+
99
+ ###### IMG ######
100
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102
+ multimodal_total = multimodal_correct + multimodal_incorrect
103
+ ###### IMG ######
104
+
105
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106
+
107
+ sqa_results['acc'] = correct / total * 100
108
+ sqa_results['correct'] = correct
109
+ sqa_results['count'] = total
110
+
111
+ with open(args.output_file, 'w') as f:
112
+ json.dump(results, f, indent=2)
113
+ with open(args.output_result, 'w') as f:
114
+ json.dump(sqa_results, f, indent=2)
minigemini/eval/eval_science_qa_gpt4.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--base-dir', type=str)
12
+ parser.add_argument('--gpt4-result', type=str)
13
+ parser.add_argument('--our-result', type=str)
14
+ parser.add_argument('--split', type=str, default='test')
15
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ return parser.parse_args()
17
+
18
+
19
+ def convert_caps(results):
20
+ fakecaps = []
21
+ for result in results:
22
+ image_id = result['question_id']
23
+ caption = result['text']
24
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
25
+ return fakecaps
26
+
27
+
28
+ def get_pred_idx(prediction, choices, options):
29
+ """
30
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
31
+ """
32
+ if prediction in options[:len(choices)]:
33
+ return options.index(prediction)
34
+ else:
35
+ return random.choice(range(len(choices)))
36
+
37
+
38
+ if __name__ == "__main__":
39
+ args = get_args()
40
+
41
+ base_dir = args.base_dir
42
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
45
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
46
+ split_problems = {idx: problems[idx] for idx in split_indices}
47
+
48
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49
+
50
+ results = defaultdict(lambda: 0)
51
+
52
+ for prob_id, prob in split_problems.items():
53
+ if prob_id not in our_predictions:
54
+ continue
55
+ if prob_id not in gpt4_predictions:
56
+ continue
57
+ our_pred = our_predictions[prob_id]['text']
58
+ gpt4_pred = gpt4_predictions[prob_id]
59
+
60
+ pattern = re.compile(r'The answer is ([A-Z]).')
61
+ our_res = pattern.findall(our_pred)
62
+ if len(our_res) == 1:
63
+ our_answer = our_res[0] # 'A', 'B', ...
64
+ else:
65
+ our_answer = "FAILED"
66
+ gpt4_res = pattern.findall(gpt4_pred)
67
+ if len(gpt4_res) == 1:
68
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69
+ else:
70
+ gpt4_answer = "FAILED"
71
+
72
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74
+
75
+ if gpt4_answer == 'FAILED':
76
+ results['gpt4_failed'] += 1
77
+ # continue
78
+ gpt4_pred_idx = our_pred_idx
79
+ # if our_pred_idx != prob['answer']:
80
+ # print(our_predictions[prob_id]['prompt'])
81
+ # print('-----------------')
82
+ # print(f'LECTURE: {prob["lecture"]}')
83
+ # print(f'SOLUTION: {prob["solution"]}')
84
+ # print('=====================')
85
+ else:
86
+ # continue
87
+ pass
88
+ # gpt4_pred_idx = our_pred_idx
89
+
90
+ if gpt4_pred_idx == prob['answer']:
91
+ results['correct'] += 1
92
+ else:
93
+ results['incorrect'] += 1
94
+
95
+
96
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97
+ results['correct_upperbound'] += 1
98
+
99
+ correct = results['correct']
100
+ total = results['correct'] + results['incorrect']
101
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102
+ print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104
+
minigemini/eval/eval_science_qa_gpt4_requery.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--base-dir', type=str)
12
+ parser.add_argument('--gpt4-result', type=str)
13
+ parser.add_argument('--requery-result', type=str)
14
+ parser.add_argument('--our-result', type=str)
15
+ parser.add_argument('--output-result', type=str)
16
+ parser.add_argument('--split', type=str, default='test')
17
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18
+ return parser.parse_args()
19
+
20
+
21
+ def convert_caps(results):
22
+ fakecaps = []
23
+ for result in results:
24
+ image_id = result['question_id']
25
+ caption = result['text']
26
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
27
+ return fakecaps
28
+
29
+
30
+ def get_pred_idx(prediction, choices, options):
31
+ """
32
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
33
+ """
34
+ if prediction in options[:len(choices)]:
35
+ return options.index(prediction)
36
+ else:
37
+ return random.choice(range(len(choices)))
38
+
39
+
40
+ if __name__ == "__main__":
41
+ args = get_args()
42
+
43
+ base_dir = args.base_dir
44
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
47
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
48
+ split_problems = {idx: problems[idx] for idx in split_indices}
49
+
50
+ requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51
+ requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52
+
53
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54
+
55
+ results = defaultdict(lambda: 0)
56
+
57
+ sqa_results = {}
58
+ sqa_results['acc'] = None
59
+ sqa_results['correct'] = None
60
+ sqa_results['count'] = None
61
+ sqa_results['results'] = {}
62
+ sqa_results['outputs'] = {}
63
+
64
+ for prob_id, prob in split_problems.items():
65
+ if prob_id not in our_predictions:
66
+ assert False
67
+ if prob_id not in gpt4_predictions:
68
+ assert False
69
+ our_pred = our_predictions[prob_id]['text']
70
+ gpt4_pred = gpt4_predictions[prob_id]
71
+ if prob_id not in requery_predictions:
72
+ results['missing_requery'] += 1
73
+ requery_pred = "MISSING"
74
+ else:
75
+ requery_pred = requery_predictions[prob_id]['text']
76
+
77
+ pattern = re.compile(r'The answer is ([A-Z]).')
78
+ our_res = pattern.findall(our_pred)
79
+ if len(our_res) == 1:
80
+ our_answer = our_res[0] # 'A', 'B', ...
81
+ else:
82
+ our_answer = "FAILED"
83
+
84
+ requery_res = pattern.findall(requery_pred)
85
+ if len(requery_res) == 1:
86
+ requery_answer = requery_res[0] # 'A', 'B', ...
87
+ else:
88
+ requery_answer = "FAILED"
89
+
90
+ gpt4_res = pattern.findall(gpt4_pred)
91
+ if len(gpt4_res) == 1:
92
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93
+ else:
94
+ gpt4_answer = "FAILED"
95
+
96
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98
+ requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99
+
100
+ results['total'] += 1
101
+
102
+ if gpt4_answer == 'FAILED':
103
+ results['gpt4_failed'] += 1
104
+ if gpt4_pred_idx == prob['answer']:
105
+ results['gpt4_correct'] += 1
106
+ if our_pred_idx == prob['answer']:
107
+ results['gpt4_ourvisual_correct'] += 1
108
+ elif gpt4_pred_idx == prob['answer']:
109
+ results['gpt4_correct'] += 1
110
+ results['gpt4_ourvisual_correct'] += 1
111
+
112
+ if our_pred_idx == prob['answer']:
113
+ results['our_correct'] += 1
114
+
115
+ if requery_answer == 'FAILED':
116
+ sqa_results['results'][prob_id] = our_pred_idx
117
+ if our_pred_idx == prob['answer']:
118
+ results['requery_correct'] += 1
119
+ else:
120
+ sqa_results['results'][prob_id] = requery_pred_idx
121
+ if requery_pred_idx == prob['answer']:
122
+ results['requery_correct'] += 1
123
+ else:
124
+ print(f"""
125
+ Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126
+ Our ({our_answer}): {our_pred}
127
+ GPT-4 ({gpt4_answer}): {gpt4_pred}
128
+ Requery ({requery_answer}): {requery_pred}
129
+ print("=====================================")
130
+ """)
131
+
132
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133
+ results['correct_upperbound'] += 1
134
+
135
+ total = results['total']
136
+ print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137
+ print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139
+ print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140
+ print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141
+ print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142
+
143
+ sqa_results['acc'] = results["requery_correct"] / total * 100
144
+ sqa_results['correct'] = results["requery_correct"]
145
+ sqa_results['count'] = total
146
+
147
+ with open(args.output_result, 'w') as f:
148
+ json.dump(sqa_results, f, indent=2)
149
+
minigemini/eval/eval_textvqa.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import re
5
+
6
+ from minigemini.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--annotation-file', type=str)
12
+ parser.add_argument('--result-file', type=str)
13
+ parser.add_argument('--result-dir', type=str)
14
+ return parser.parse_args()
15
+
16
+
17
+ def prompt_processor(prompt):
18
+ if prompt.startswith('OCR tokens: '):
19
+ pattern = r"Question: (.*?) Short answer:"
20
+ match = re.search(pattern, prompt, re.DOTALL)
21
+ question = match.group(1)
22
+ elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23
+ if prompt.startswith('Reference OCR token:'):
24
+ question = prompt.split('\n')[1]
25
+ else:
26
+ question = prompt.split('\n')[0]
27
+ elif len(prompt.split('\n')) == 2:
28
+ question = prompt.split('\n')[0]
29
+ else:
30
+ assert False
31
+
32
+ return question.lower()
33
+
34
+
35
+ def eval_single(annotation_file, result_file):
36
+ experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37
+ print(experiment_name)
38
+ annotations = json.load(open(annotation_file))['data']
39
+ annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40
+ results = [json.loads(line) for line in open(result_file)]
41
+
42
+ pred_list = []
43
+ for result in results:
44
+ annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45
+ pred_list.append({
46
+ "pred_answer": result['text'],
47
+ "gt_answers": annotation['answers'],
48
+ })
49
+
50
+ evaluator = TextVQAAccuracyEvaluator()
51
+ print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52
+
53
+
54
+ if __name__ == "__main__":
55
+ args = get_args()
56
+
57
+ if args.result_file is not None:
58
+ eval_single(args.annotation_file, args.result_file)
59
+
60
+ if args.result_dir is not None:
61
+ for result_file in sorted(os.listdir(args.result_dir)):
62
+ if not result_file.endswith('.jsonl'):
63
+ print(f'Skipping {result_file}')
64
+ continue
65
+ eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
minigemini/eval/generate_webpage_data_from_table.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate json file for webpage."""
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
+ models = ['vicuna']
8
+
9
+
10
+ def read_jsonl(path: str, key: str=None):
11
+ data = []
12
+ with open(os.path.expanduser(path)) as f:
13
+ for line in f:
14
+ if not line:
15
+ continue
16
+ data.append(json.loads(line))
17
+ if key is not None:
18
+ data.sort(key=lambda x: x[key])
19
+ data = {item[key]: item for item in data}
20
+ return data
21
+
22
+
23
+ def trim_hanging_lines(s: str, n: int) -> str:
24
+ s = s.strip()
25
+ for _ in range(n):
26
+ s = s.split('\n', 1)[1].strip()
27
+ return s
28
+
29
+
30
+ if __name__ == '__main__':
31
+ questions = read_jsonl('table/question.jsonl', key='question_id')
32
+
33
+ # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
+ # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
+ # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
+ # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
+ vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38
+ ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39
+
40
+ review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41
+ # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42
+ # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43
+ # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44
+ # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45
+
46
+ records = []
47
+ for qid in questions.keys():
48
+ r = {
49
+ 'id': qid,
50
+ 'category': questions[qid]['category'],
51
+ 'question': questions[qid]['text'],
52
+ 'answers': {
53
+ # 'alpaca': alpaca_answers[qid]['text'],
54
+ # 'llama': llama_answers[qid]['text'],
55
+ # 'bard': bard_answers[qid]['text'],
56
+ # 'gpt35': gpt35_answers[qid]['text'],
57
+ 'vicuna': vicuna_answers[qid]['text'],
58
+ 'ours': ours_answers[qid]['text'],
59
+ },
60
+ 'evaluations': {
61
+ # 'alpaca': review_alpaca[qid]['text'],
62
+ # 'llama': review_llama[qid]['text'],
63
+ # 'bard': review_bard[qid]['text'],
64
+ 'vicuna': review_vicuna[qid]['content'],
65
+ # 'gpt35': review_gpt35[qid]['text'],
66
+ },
67
+ 'scores': {
68
+ 'vicuna': review_vicuna[qid]['tuple'],
69
+ # 'alpaca': review_alpaca[qid]['score'],
70
+ # 'llama': review_llama[qid]['score'],
71
+ # 'bard': review_bard[qid]['score'],
72
+ # 'gpt35': review_gpt35[qid]['score'],
73
+ },
74
+ }
75
+
76
+ # cleanup data
77
+ cleaned_evals = {}
78
+ for k, v in r['evaluations'].items():
79
+ v = v.strip()
80
+ lines = v.split('\n')
81
+ # trim the first line if it's a pair of numbers
82
+ if re.match(r'\d+[, ]+\d+', lines[0]):
83
+ lines = lines[1:]
84
+ v = '\n'.join(lines)
85
+ cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86
+
87
+ r['evaluations'] = cleaned_evals
88
+ records.append(r)
89
+
90
+ # Reorder the records, this is optional
91
+ for r in records:
92
+ if r['id'] <= 20:
93
+ r['id'] += 60
94
+ else:
95
+ r['id'] -= 20
96
+ for r in records:
97
+ if r['id'] <= 50:
98
+ r['id'] += 10
99
+ elif 50 < r['id'] <= 60:
100
+ r['id'] -= 50
101
+ for r in records:
102
+ if r['id'] == 7:
103
+ r['id'] = 1
104
+ elif r['id'] < 7:
105
+ r['id'] += 1
106
+
107
+ records.sort(key=lambda x: x['id'])
108
+
109
+ # Write to file
110
+ with open('webpage/data.json', 'w') as f:
111
+ json.dump({'questions': records, 'models': models}, f, indent=2)
minigemini/eval/m4c_evaluator.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ class EvalAIAnswerProcessor:
8
+ """
9
+ Processes an answer similar to Eval AI
10
+ copied from
11
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
+ """
13
+
14
+ CONTRACTIONS = {
15
+ "aint": "ain't",
16
+ "arent": "aren't",
17
+ "cant": "can't",
18
+ "couldve": "could've",
19
+ "couldnt": "couldn't",
20
+ "couldn'tve": "couldn't've",
21
+ "couldnt've": "couldn't've",
22
+ "didnt": "didn't",
23
+ "doesnt": "doesn't",
24
+ "dont": "don't",
25
+ "hadnt": "hadn't",
26
+ "hadnt've": "hadn't've",
27
+ "hadn'tve": "hadn't've",
28
+ "hasnt": "hasn't",
29
+ "havent": "haven't",
30
+ "hed": "he'd",
31
+ "hed've": "he'd've",
32
+ "he'dve": "he'd've",
33
+ "hes": "he's",
34
+ "howd": "how'd",
35
+ "howll": "how'll",
36
+ "hows": "how's",
37
+ "Id've": "I'd've",
38
+ "I'dve": "I'd've",
39
+ "Im": "I'm",
40
+ "Ive": "I've",
41
+ "isnt": "isn't",
42
+ "itd": "it'd",
43
+ "itd've": "it'd've",
44
+ "it'dve": "it'd've",
45
+ "itll": "it'll",
46
+ "let's": "let's",
47
+ "maam": "ma'am",
48
+ "mightnt": "mightn't",
49
+ "mightnt've": "mightn't've",
50
+ "mightn'tve": "mightn't've",
51
+ "mightve": "might've",
52
+ "mustnt": "mustn't",
53
+ "mustve": "must've",
54
+ "neednt": "needn't",
55
+ "notve": "not've",
56
+ "oclock": "o'clock",
57
+ "oughtnt": "oughtn't",
58
+ "ow's'at": "'ow's'at",
59
+ "'ows'at": "'ow's'at",
60
+ "'ow'sat": "'ow's'at",
61
+ "shant": "shan't",
62
+ "shed've": "she'd've",
63
+ "she'dve": "she'd've",
64
+ "she's": "she's",
65
+ "shouldve": "should've",
66
+ "shouldnt": "shouldn't",
67
+ "shouldnt've": "shouldn't've",
68
+ "shouldn'tve": "shouldn't've",
69
+ "somebody'd": "somebodyd",
70
+ "somebodyd've": "somebody'd've",
71
+ "somebody'dve": "somebody'd've",
72
+ "somebodyll": "somebody'll",
73
+ "somebodys": "somebody's",
74
+ "someoned": "someone'd",
75
+ "someoned've": "someone'd've",
76
+ "someone'dve": "someone'd've",
77
+ "someonell": "someone'll",
78
+ "someones": "someone's",
79
+ "somethingd": "something'd",
80
+ "somethingd've": "something'd've",
81
+ "something'dve": "something'd've",
82
+ "somethingll": "something'll",
83
+ "thats": "that's",
84
+ "thered": "there'd",
85
+ "thered've": "there'd've",
86
+ "there'dve": "there'd've",
87
+ "therere": "there're",
88
+ "theres": "there's",
89
+ "theyd": "they'd",
90
+ "theyd've": "they'd've",
91
+ "they'dve": "they'd've",
92
+ "theyll": "they'll",
93
+ "theyre": "they're",
94
+ "theyve": "they've",
95
+ "twas": "'twas",
96
+ "wasnt": "wasn't",
97
+ "wed've": "we'd've",
98
+ "we'dve": "we'd've",
99
+ "weve": "we've",
100
+ "werent": "weren't",
101
+ "whatll": "what'll",
102
+ "whatre": "what're",
103
+ "whats": "what's",
104
+ "whatve": "what've",
105
+ "whens": "when's",
106
+ "whered": "where'd",
107
+ "wheres": "where's",
108
+ "whereve": "where've",
109
+ "whod": "who'd",
110
+ "whod've": "who'd've",
111
+ "who'dve": "who'd've",
112
+ "wholl": "who'll",
113
+ "whos": "who's",
114
+ "whove": "who've",
115
+ "whyll": "why'll",
116
+ "whyre": "why're",
117
+ "whys": "why's",
118
+ "wont": "won't",
119
+ "wouldve": "would've",
120
+ "wouldnt": "wouldn't",
121
+ "wouldnt've": "wouldn't've",
122
+ "wouldn'tve": "wouldn't've",
123
+ "yall": "y'all",
124
+ "yall'll": "y'all'll",
125
+ "y'allll": "y'all'll",
126
+ "yall'd've": "y'all'd've",
127
+ "y'alld've": "y'all'd've",
128
+ "y'all'dve": "y'all'd've",
129
+ "youd": "you'd",
130
+ "youd've": "you'd've",
131
+ "you'dve": "you'd've",
132
+ "youll": "you'll",
133
+ "youre": "you're",
134
+ "youve": "you've",
135
+ }
136
+
137
+ NUMBER_MAP = {
138
+ "none": "0",
139
+ "zero": "0",
140
+ "one": "1",
141
+ "two": "2",
142
+ "three": "3",
143
+ "four": "4",
144
+ "five": "5",
145
+ "six": "6",
146
+ "seven": "7",
147
+ "eight": "8",
148
+ "nine": "9",
149
+ "ten": "10",
150
+ }
151
+ ARTICLES = ["a", "an", "the"]
152
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
+ PUNCTUATIONS = [
155
+ ";",
156
+ r"/",
157
+ "[",
158
+ "]",
159
+ '"',
160
+ "{",
161
+ "}",
162
+ "(",
163
+ ")",
164
+ "=",
165
+ "+",
166
+ "\\",
167
+ "_",
168
+ "-",
169
+ ">",
170
+ "<",
171
+ "@",
172
+ "`",
173
+ ",",
174
+ "?",
175
+ "!",
176
+ ]
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ pass
180
+
181
+ def word_tokenize(self, word):
182
+ word = word.lower()
183
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
+ return word.strip()
185
+
186
+ def process_punctuation(self, in_text):
187
+ out_text = in_text
188
+ for p in self.PUNCTUATIONS:
189
+ if (p + " " in in_text or " " + p in in_text) or (
190
+ re.search(self.COMMA_STRIP, in_text) is not None
191
+ ):
192
+ out_text = out_text.replace(p, "")
193
+ else:
194
+ out_text = out_text.replace(p, " ")
195
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
+ return out_text
197
+
198
+ def process_digit_article(self, in_text):
199
+ out_text = []
200
+ temp_text = in_text.lower().split()
201
+ for word in temp_text:
202
+ word = self.NUMBER_MAP.setdefault(word, word)
203
+ if word not in self.ARTICLES:
204
+ out_text.append(word)
205
+ else:
206
+ pass
207
+ for word_id, word in enumerate(out_text):
208
+ if word in self.CONTRACTIONS:
209
+ out_text[word_id] = self.CONTRACTIONS[word]
210
+ out_text = " ".join(out_text)
211
+ return out_text
212
+
213
+ def __call__(self, item):
214
+ item = self.word_tokenize(item)
215
+ item = item.replace("\n", " ").replace("\t", " ").strip()
216
+ item = self.process_punctuation(item)
217
+ item = self.process_digit_article(item)
218
+ return item
219
+
220
+
221
+ class TextVQAAccuracyEvaluator:
222
+ def __init__(self):
223
+ self.answer_processor = EvalAIAnswerProcessor()
224
+
225
+ def _compute_answer_scores(self, raw_answers):
226
+ """
227
+ compute the accuracy (soft score) of human answers
228
+ """
229
+ answers = [self.answer_processor(a) for a in raw_answers]
230
+ assert len(answers) == 10
231
+ gt_answers = list(enumerate(answers))
232
+ unique_answers = set(answers)
233
+ unique_answer_scores = {}
234
+
235
+ for unique_answer in unique_answers:
236
+ accs = []
237
+ for gt_answer in gt_answers:
238
+ other_answers = [item for item in gt_answers if item != gt_answer]
239
+ matching_answers = [
240
+ item for item in other_answers if item[1] == unique_answer
241
+ ]
242
+ acc = min(1, float(len(matching_answers)) / 3)
243
+ accs.append(acc)
244
+ unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
+
246
+ return unique_answer_scores
247
+
248
+ def eval_pred_list(self, pred_list):
249
+ pred_scores = []
250
+ for entry in tqdm(pred_list):
251
+ pred_answer = self.answer_processor(entry["pred_answer"])
252
+ unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
+ score = unique_answer_scores.get(pred_answer, 0.0)
254
+ pred_scores.append(score)
255
+
256
+ accuracy = sum(pred_scores) / len(pred_scores)
257
+ return accuracy
258
+
259
+
260
+ class STVQAAccuracyEvaluator:
261
+ def __init__(self):
262
+ self.answer_processor = EvalAIAnswerProcessor()
263
+
264
+ def eval_pred_list(self, pred_list):
265
+ pred_scores = []
266
+ for entry in pred_list:
267
+ pred_answer = self.answer_processor(entry["pred_answer"])
268
+ gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
+ score = 1.0 if pred_answer in gts else 0.0
270
+ pred_scores.append(score)
271
+
272
+ accuracy = sum(pred_scores) / len(pred_scores)
273
+ return accuracy
274
+
275
+
276
+ class STVQAANLSEvaluator:
277
+ def __init__(self):
278
+ import editdistance # install with `pip install editdistance`
279
+
280
+ self.get_edit_distance = editdistance.eval
281
+
282
+ def get_anls(self, s1, s2):
283
+ s1 = s1.lower().strip()
284
+ s2 = s2.lower().strip()
285
+ iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
+ anls = iou if iou >= 0.5 else 0.0
287
+ return anls
288
+
289
+ def eval_pred_list(self, pred_list):
290
+ pred_scores = []
291
+ for entry in pred_list:
292
+ anls = max(
293
+ self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
+ )
295
+ pred_scores.append(anls)
296
+
297
+ accuracy = sum(pred_scores) / len(pred_scores)
298
+ return accuracy
299
+
300
+
301
+ class TextCapsBleu4Evaluator:
302
+ def __init__(self):
303
+ # The following script requires Java 1.8.0 and pycocotools installed.
304
+ # The pycocoevalcap can be installed with pip as
305
+ # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
+ # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
+ # but has no python3 support yet.
308
+ try:
309
+ from pycocoevalcap.bleu.bleu import Bleu
310
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
+ except ModuleNotFoundError:
312
+ print(
313
+ "Please install pycocoevalcap module using "
314
+ "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
+ )
316
+ raise
317
+
318
+ self.tokenizer = PTBTokenizer()
319
+ self.scorer = Bleu(4)
320
+
321
+ def eval_pred_list(self, pred_list):
322
+ # Create reference and hypotheses captions.
323
+ gts = {}
324
+ res = {}
325
+ for idx, entry in enumerate(pred_list):
326
+ gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
+ res[idx] = [{"caption": entry["pred_answer"]}]
328
+
329
+ gts = self.tokenizer.tokenize(gts)
330
+ res = self.tokenizer.tokenize(res)
331
+ score, _ = self.scorer.compute_score(gts, res)
332
+
333
+ bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
+ return bleu4
minigemini/eval/model_math_vista.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+
7
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
+ from minigemini.conversation import conv_templates, SeparatorStyle
9
+ from minigemini.model.builder import load_pretrained_model
10
+ from minigemini.utils import disable_torch_init
11
+ from minigemini.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
12
+
13
+ from PIL import Image
14
+ import math
15
+
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def create_one_query(problem, shot_num, shot_type, use_caption):
28
+
29
+
30
+ ### [1] Demo prompt
31
+ demo_prompt = ""
32
+
33
+ ### [2] Test query
34
+ # problem info
35
+ question = problem['question']
36
+ unit = problem['unit']
37
+ choices = problem['choices']
38
+ # caption = problem['caption']
39
+ precision = problem['precision']
40
+ question_type = problem['question_type']
41
+ answer_type = problem['answer_type']
42
+
43
+ # hint
44
+ if shot_type == 'solution':
45
+ if question_type == "multi_choice":
46
+ assert answer_type == "text"
47
+ hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end."
48
+ else:
49
+ assert answer_type in ["integer", "float", "list"]
50
+ if answer_type == "integer":
51
+ hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end."
52
+
53
+ elif answer_type == "float" and precision == 1:
54
+ hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end."
55
+
56
+ elif answer_type == "float" and precision == 2:
57
+ hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end."
58
+
59
+ elif answer_type == "list":
60
+ hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end."
61
+ else:
62
+ assert shot_type == 'code'
63
+ hint_text = "Hint: Please generate a python code to solve the problem"
64
+
65
+ # question
66
+ question_text = f"Question: {question}"
67
+ if unit:
68
+ question_text += f" (Unit: {unit})"
69
+
70
+ # choices
71
+ if choices:
72
+ # choices: (A) 1.2 (B) 1.3 (C) 1.4 (D) 1.5
73
+ texts = ["Choices:"]
74
+ for i, choice in enumerate(choices):
75
+ texts.append(f"({chr(ord('A')+i)}) {choice}")
76
+ choices_text = "\n".join(texts)
77
+ else:
78
+ choices_text = ""
79
+
80
+ # prompt
81
+ if shot_type == 'solution':
82
+ prompt = "Solution: "
83
+ else:
84
+ assert shot_type == 'code'
85
+ prompt = "Python code: "
86
+
87
+ elements = [hint_text, question_text, choices_text]
88
+ test_query = "\n".join([e for e in elements if e != ""])
89
+
90
+ ### [3] Final query
91
+ query = demo_prompt + "\n\n" + test_query
92
+ query = query.strip()
93
+ return query
94
+
95
+
96
+ def eval_model(args):
97
+ # Model
98
+ disable_torch_init()
99
+ model_path = os.path.expanduser(args.model_path)
100
+ model_name = get_model_name_from_path(model_path)
101
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
102
+ load_8bit=args.load_8bit)
103
+
104
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
105
+ questions = [dict(pid=pid, info=qs) for pid, qs in questions.items()]
106
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
107
+
108
+ answers_file = os.path.expanduser(args.answers_file)
109
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
110
+
111
+ if os.path.exists(answers_file):
112
+ file = open(answers_file, "r")
113
+ pred_contents = [json.loads(line) for line in file]
114
+ done_pid = [sample['pid'] for sample in pred_contents]
115
+ else:
116
+ done_pid = []
117
+ ans_file = open(answers_file, "a")
118
+
119
+ for i, line in enumerate(tqdm(questions)):
120
+ idx = line['pid']
121
+ info = line['info']
122
+ if idx in done_pid:
123
+ continue
124
+
125
+ qs = create_one_query(
126
+ problem = info,
127
+ shot_num = 0,
128
+ shot_type = 'solution',
129
+ use_caption = False,
130
+ )
131
+ query = qs
132
+
133
+ if 'image' in info:
134
+ image_file = info["image"]
135
+ image = Image.open(os.path.join(args.image_folder, image_file))
136
+
137
+ if hasattr(model.config, 'image_size_aux'):
138
+ if not hasattr(image_processor, 'image_size_raw'):
139
+ image_processor.image_size_raw = image_processor.crop_size.copy()
140
+ image_processor.crop_size['height'] = model.config.image_size_aux
141
+ image_processor.crop_size['width'] = model.config.image_size_aux
142
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
143
+
144
+ image_tensor = process_images([image], image_processor, model.config)[0]
145
+
146
+ image_grid = getattr(model.config, 'image_grid', 1)
147
+ if hasattr(model.config, 'image_size_aux'):
148
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
149
+ image_processor.image_size_raw['width'] * image_grid]
150
+ image_tensor_aux = image_tensor
151
+ image_tensor = torch.nn.functional.interpolate(image_tensor[None],
152
+ size=raw_shape,
153
+ mode='bilinear',
154
+ align_corners=False)[0]
155
+ else:
156
+ image_tensor_aux = []
157
+
158
+ if image_grid >= 2:
159
+ raw_image = image_tensor.reshape(3,
160
+ image_grid,
161
+ image_processor.image_size_raw['height'],
162
+ image_grid,
163
+ image_processor.image_size_raw['width'])
164
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
165
+ raw_image = raw_image.reshape(-1, 3,
166
+ image_processor.image_size_raw['height'],
167
+ image_processor.image_size_raw['width'])
168
+
169
+ if getattr(model.config, 'image_global', False):
170
+ global_image = image_tensor
171
+ if len(global_image.shape) == 3:
172
+ global_image = global_image[None]
173
+ global_image = torch.nn.functional.interpolate(global_image,
174
+ size=[image_processor.image_size_raw['height'],
175
+ image_processor.image_size_raw['width']],
176
+ mode='bilinear',
177
+ align_corners=False)
178
+ # [image_crops, image_global]
179
+ raw_image = torch.cat([raw_image, global_image], dim=0)
180
+ image_tensor = raw_image.contiguous()
181
+
182
+ images = image_tensor[None].to(dtype=model.dtype, device='cuda', non_blocking=True)
183
+ images_aux = image_tensor_aux[None].to(dtype=model.dtype, device='cuda', non_blocking=True) if len(image_tensor_aux)>0 else None
184
+ if getattr(model.config, 'mm_use_im_start_end', False):
185
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
186
+ else:
187
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
188
+ else:
189
+ images = None
190
+ images_aux = None
191
+
192
+ conv = conv_templates[args.conv_mode].copy()
193
+ conv.append_message(conv.roles[0], qs)
194
+ conv.append_message(conv.roles[1], None)
195
+ prompt = conv.get_prompt()
196
+
197
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
198
+
199
+ with torch.inference_mode():
200
+ output_ids = model.generate(
201
+ input_ids,
202
+ images=images,
203
+ images_aux=images_aux,
204
+ do_sample=True if args.temperature > 0 else False,
205
+ temperature=args.temperature,
206
+ max_new_tokens=1024,
207
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
208
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
209
+ pad_token_id=tokenizer.pad_token_id, # Pad token
210
+ use_cache=True,
211
+ )
212
+
213
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
214
+
215
+ info['query'] = query
216
+ info['response'] = outputs
217
+ ans_file.write(json.dumps(info) + "\n")
218
+ ans_file.flush()
219
+ ans_file.close()
220
+
221
+ if __name__ == "__main__":
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
224
+ parser.add_argument("--model-base", type=str, default=None)
225
+ parser.add_argument("--image-folder", type=str, default="")
226
+ parser.add_argument("--question-file", type=str, default="tables/question.json")
227
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
228
+ parser.add_argument("--conv-mode", type=str, default="llava_v0")
229
+ parser.add_argument("--num-chunks", type=int, default=1)
230
+ parser.add_argument("--chunk-idx", type=int, default=0)
231
+ parser.add_argument("--temperature", type=float, default=0.2)
232
+ parser.add_argument("--answer-prompter", action="store_true")
233
+ parser.add_argument('--load_8bit', type=bool, default=False)
234
+ parser.add_argument("--single-pred-prompt", action="store_true")
235
+ args = parser.parse_args()
236
+
237
+ eval_model(args)
minigemini/eval/model_qa.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3
+ import torch
4
+ import os
5
+ import json
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from minigemini.conversation import default_conversation
10
+ from minigemini.utils import disable_torch_init
11
+
12
+
13
+ @torch.inference_mode()
14
+ def eval_model(model_name, questions_file, answers_file):
15
+ # Model
16
+ disable_torch_init()
17
+ model_name = os.path.expanduser(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19
+ model = AutoModelForCausalLM.from_pretrained(model_name,
20
+ torch_dtype=torch.float16).cuda()
21
+
22
+
23
+ ques_file = open(os.path.expanduser(questions_file), "r")
24
+ ans_file = open(os.path.expanduser(answers_file), "w")
25
+ for i, line in enumerate(tqdm(ques_file)):
26
+ idx = json.loads(line)["question_id"]
27
+ qs = json.loads(line)["text"]
28
+ cat = json.loads(line)["category"]
29
+ conv = default_conversation.copy()
30
+ conv.append_message(conv.roles[0], qs)
31
+ prompt = conv.get_prompt()
32
+ inputs = tokenizer([prompt])
33
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
34
+ output_ids = model.generate(
35
+ input_ids,
36
+ do_sample=True,
37
+ use_cache=True,
38
+ temperature=0.7,
39
+ max_new_tokens=1024,)
40
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41
+ try:
42
+ index = outputs.index(conv.sep, len(prompt))
43
+ except ValueError:
44
+ outputs += conv.sep
45
+ index = outputs.index(conv.sep, len(prompt))
46
+
47
+ outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48
+ ans_id = shortuuid.uuid()
49
+ ans_file.write(json.dumps({"question_id": idx,
50
+ "text": outputs,
51
+ "answer_id": ans_id,
52
+ "model_id": model_name,
53
+ "metadata": {}}) + "\n")
54
+ ans_file.flush()
55
+ ans_file.close()
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
60
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
61
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
62
+ args = parser.parse_args()
63
+
64
+ eval_model(args.model_name, args.question_file, args.answers_file)
minigemini/eval/model_vqa.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from minigemini.conversation import conv_templates, SeparatorStyle
10
+ from minigemini.model.builder import load_pretrained_model
11
+ from minigemini.utils import disable_torch_init
12
+ from minigemini.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
+
36
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
+ answers_file = os.path.expanduser(args.answers_file)
39
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
+ ans_file = open(answers_file, "w")
41
+ for line in tqdm(questions):
42
+ idx = line["question_id"]
43
+ image_file = line["image"]
44
+ qs = line["text"]
45
+ cur_prompt = qs
46
+
47
+ if hasattr(model, "update_prompt"):
48
+ model.update_prompt([[cur_prompt]])
49
+
50
+ if model.config.mm_use_im_start_end:
51
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
52
+ else:
53
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
54
+
55
+ conv = conv_templates[args.conv_mode].copy()
56
+ conv.append_message(conv.roles[0], qs)
57
+ conv.append_message(conv.roles[1], None)
58
+ prompt = conv.get_prompt()
59
+
60
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
61
+
62
+ image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
63
+
64
+ if hasattr(model.config, 'image_size_aux'):
65
+ if not hasattr(image_processor, 'image_size_raw'):
66
+ image_processor.image_size_raw = image_processor.crop_size.copy()
67
+ image_processor.crop_size['height'] = model.config.image_size_aux
68
+ image_processor.crop_size['width'] = model.config.image_size_aux
69
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
70
+
71
+ image_tensor = process_images([image], image_processor, model.config)[0]
72
+
73
+ image_grid = getattr(model.config, 'image_grid', 1)
74
+ if hasattr(model.config, 'image_size_aux'):
75
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
76
+ image_processor.image_size_raw['width'] * image_grid]
77
+ image_tensor_aux = image_tensor
78
+ image_tensor = torch.nn.functional.interpolate(image_tensor[None],
79
+ size=raw_shape,
80
+ mode='bilinear',
81
+ align_corners=False)[0]
82
+ else:
83
+ image_tensor_aux = []
84
+
85
+ if image_grid >= 2:
86
+ raw_image = image_tensor.reshape(3,
87
+ image_grid,
88
+ image_processor.image_size_raw['height'],
89
+ image_grid,
90
+ image_processor.image_size_raw['width'])
91
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
92
+ raw_image = raw_image.reshape(-1, 3,
93
+ image_processor.image_size_raw['height'],
94
+ image_processor.image_size_raw['width'])
95
+
96
+ if getattr(model.config, 'image_global', False):
97
+ global_image = image_tensor
98
+ if len(global_image.shape) == 3:
99
+ global_image = global_image[None]
100
+ global_image = torch.nn.functional.interpolate(global_image,
101
+ size=[image_processor.image_size_raw['height'],
102
+ image_processor.image_size_raw['width']],
103
+ mode='bilinear',
104
+ align_corners=False)
105
+ # [image_crops, image_global]
106
+ raw_image = torch.cat([raw_image, global_image], dim=0)
107
+ image_tensor = raw_image.contiguous()
108
+
109
+ images = image_tensor[None].to(dtype=model.dtype, device='cuda', non_blocking=True)
110
+ images_aux = image_tensor_aux[None].to(dtype=model.dtype, device='cuda', non_blocking=True) if len(image_tensor_aux)>0 else None
111
+
112
+ with torch.inference_mode():
113
+ output_ids = model.generate(
114
+ input_ids,
115
+ images=images,
116
+ images_aux=images_aux,
117
+ do_sample=True if args.temperature > 0 else False,
118
+ temperature=args.temperature,
119
+ top_p=args.top_p,
120
+ num_beams=args.num_beams,
121
+ max_new_tokens=1024,
122
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
123
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
124
+ pad_token_id=tokenizer.pad_token_id, # Pad token
125
+ use_cache=True)
126
+
127
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
128
+
129
+ ans_id = shortuuid.uuid()
130
+ ans_file.write(json.dumps({"question_id": idx,
131
+ "prompt": cur_prompt,
132
+ "text": outputs,
133
+ "answer_id": ans_id,
134
+ "model_id": model_name,
135
+ "metadata": {}}) + "\n")
136
+ ans_file.flush()
137
+ ans_file.close()
138
+
139
+ if __name__ == "__main__":
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
142
+ parser.add_argument("--model-base", type=str, default=None)
143
+ parser.add_argument("--image-folder", type=str, default="")
144
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
145
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
146
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
147
+ parser.add_argument("--num-chunks", type=int, default=1)
148
+ parser.add_argument("--chunk-idx", type=int, default=0)
149
+ parser.add_argument("--temperature", type=float, default=0.2)
150
+ parser.add_argument("--top_p", type=float, default=None)
151
+ parser.add_argument("--num_beams", type=int, default=1)
152
+ args = parser.parse_args()
153
+
154
+ eval_model(args)
minigemini/eval/model_vqa_loader.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from minigemini.conversation import conv_templates, SeparatorStyle
10
+ from minigemini.model.builder import load_pretrained_model
11
+ from minigemini.utils import disable_torch_init
12
+ from minigemini.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ # Custom dataset class
30
+ class CustomDataset(Dataset):
31
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
32
+ self.questions = questions
33
+ self.image_folder = image_folder
34
+ self.tokenizer = tokenizer
35
+ self.image_processor = image_processor
36
+ self.model_config = model_config
37
+
38
+ def __getitem__(self, index):
39
+ line = self.questions[index]
40
+ image_file = line["image"]
41
+ qs = line["text"]
42
+
43
+ if self.model_config.mm_use_im_start_end:
44
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45
+ else:
46
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47
+
48
+ conv = conv_templates[args.conv_mode].copy()
49
+ conv.append_message(conv.roles[0], qs)
50
+ conv.append_message(conv.roles[1], None)
51
+ prompt = conv.get_prompt()
52
+
53
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54
+
55
+ if hasattr(self.model_config, 'image_size_aux'):
56
+ if not hasattr(self.image_processor, 'image_size_raw'):
57
+ self.image_processor.image_size_raw = self.image_processor.crop_size.copy()
58
+ self.image_processor.crop_size['height'] = self.model_config.image_size_aux
59
+ self.image_processor.crop_size['width'] = self.model_config.image_size_aux
60
+ self.image_processor.size['shortest_edge'] = self.model_config.image_size_aux
61
+
62
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
63
+
64
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
65
+
66
+ image_grid = getattr(self.model_config, 'image_grid', 1)
67
+ if hasattr(self.model_config, 'image_size_aux'):
68
+ raw_shape = [self.image_processor.image_size_raw['height'] * image_grid,
69
+ self.image_processor.image_size_raw['width'] * image_grid]
70
+ image_tensor_aux = image_tensor
71
+ image_tensor = torch.nn.functional.interpolate(image_tensor[None],
72
+ size=raw_shape,
73
+ mode='bilinear',
74
+ align_corners=False)[0]
75
+ else:
76
+ image_tensor_aux = []
77
+
78
+ if image_grid >= 2:
79
+ raw_image = image_tensor.reshape(3,
80
+ image_grid,
81
+ self.image_processor.image_size_raw['height'],
82
+ image_grid,
83
+ self.image_processor.image_size_raw['width'])
84
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
85
+ raw_image = raw_image.reshape(-1, 3,
86
+ self.image_processor.image_size_raw['height'],
87
+ self.image_processor.image_size_raw['width'])
88
+
89
+ if getattr(self.model_config, 'image_global', False):
90
+ global_image = image_tensor
91
+ if len(global_image.shape) == 3:
92
+ global_image = global_image[None]
93
+ global_image = torch.nn.functional.interpolate(global_image,
94
+ size=[self.image_processor.image_size_raw['height'],
95
+ self.image_processor.image_size_raw['width']],
96
+ mode='bilinear',
97
+ align_corners=False)
98
+ # [image_crops, image_global]
99
+ raw_image = torch.cat([raw_image, global_image], dim=0)
100
+ image_tensor = raw_image.contiguous()
101
+
102
+ return input_ids, image_tensor, image_tensor_aux
103
+
104
+ def __len__(self):
105
+ return len(self.questions)
106
+
107
+
108
+ # DataLoader
109
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
110
+ assert batch_size == 1, "batch_size must be 1"
111
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
112
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
113
+ return data_loader
114
+
115
+
116
+ def eval_model(args):
117
+ # Model
118
+ disable_torch_init()
119
+ model_path = os.path.expanduser(args.model_path)
120
+ model_name = get_model_name_from_path(model_path)
121
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, load_8bit=args.load_8bit)
122
+
123
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
124
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
125
+ answers_file = os.path.expanduser(args.answers_file)
126
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
127
+ ans_file = open(answers_file, "w")
128
+
129
+ if 'plain' in args.conv_mode and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
130
+ args.conv_mode = args.conv_mode + '_mmtag'
131
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
132
+
133
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
134
+
135
+ for (input_ids, image_tensor, image_tensor_aux), line in tqdm(zip(data_loader, questions), total=len(questions)):
136
+ idx = line["question_id"]
137
+ cur_prompt = line["text"]
138
+
139
+ input_ids = input_ids.to(device=model.device, non_blocking=True)
140
+ if hasattr(model, "update_prompt"):
141
+ model.update_prompt([[cur_prompt]])
142
+
143
+ with torch.inference_mode():
144
+ output_ids = model.generate(
145
+ input_ids,
146
+ images=image_tensor.to(dtype=model.dtype, device=model.device, non_blocking=True),
147
+ images_aux=image_tensor_aux.to(dtype=model.dtype, device=model.device, non_blocking=True) if len(image_tensor_aux)>0 else None,
148
+ do_sample=True if args.temperature > 0 else False,
149
+ temperature=args.temperature,
150
+ top_p=args.top_p,
151
+ num_beams=args.num_beams,
152
+ max_new_tokens=args.max_new_tokens,
153
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
154
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
155
+ pad_token_id=tokenizer.pad_token_id, # Pad token
156
+ use_cache=True)
157
+
158
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
159
+
160
+ ans_id = shortuuid.uuid()
161
+ ans_file.write(json.dumps({"question_id": idx,
162
+ "prompt": cur_prompt,
163
+ "text": outputs,
164
+ "answer_id": ans_id,
165
+ "model_id": model_name,
166
+ "metadata": {}}) + "\n")
167
+ # ans_file.flush()
168
+ ans_file.close()
169
+
170
+ if __name__ == "__main__":
171
+ parser = argparse.ArgumentParser()
172
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
173
+ parser.add_argument("--model-base", type=str, default=None)
174
+ parser.add_argument("--image-folder", type=str, default="")
175
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
176
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
177
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
178
+ parser.add_argument("--num-chunks", type=int, default=1)
179
+ parser.add_argument("--chunk-idx", type=int, default=0)
180
+ parser.add_argument("--temperature", type=float, default=0.2)
181
+ parser.add_argument("--top_p", type=float, default=None)
182
+ parser.add_argument("--num_beams", type=int, default=1)
183
+ parser.add_argument('--load_8bit', type=bool, default=False)
184
+ parser.add_argument("--max_new_tokens", type=int, default=128)
185
+ args = parser.parse_args()
186
+
187
+ eval_model(args)
minigemini/eval/model_vqa_mmbench.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10
+ from minigemini.conversation import conv_templates, SeparatorStyle
11
+ from minigemini.model.builder import load_pretrained_model
12
+ from minigemini.utils import disable_torch_init
13
+ from minigemini.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ all_options = ['A', 'B', 'C', 'D']
20
+
21
+
22
+ def split_list(lst, n):
23
+ """Split a list into n (roughly) equal-sized chunks"""
24
+ chunk_size = math.ceil(len(lst) / n) # integer division
25
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
26
+
27
+
28
+ def get_chunk(lst, n, k):
29
+ chunks = split_list(lst, n)
30
+ return chunks[k]
31
+
32
+
33
+ def is_none(value):
34
+ if value is None:
35
+ return True
36
+ if type(value) is float and math.isnan(value):
37
+ return True
38
+ if type(value) is str and value.lower() == 'nan':
39
+ return True
40
+ if type(value) is str and value.lower() == 'none':
41
+ return True
42
+ return False
43
+
44
+ def get_options(row, options):
45
+ parsed_options = []
46
+ for option in options:
47
+ option_value = row[option]
48
+ if is_none(option_value):
49
+ break
50
+ parsed_options.append(option_value)
51
+ return parsed_options
52
+
53
+
54
+ def eval_model(args):
55
+ # Model
56
+ disable_torch_init()
57
+ model_path = os.path.expanduser(args.model_path)
58
+ model_name = get_model_name_from_path(model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
60
+
61
+ questions = pd.read_table(os.path.expanduser(args.question_file))
62
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
63
+ answers_file = os.path.expanduser(args.answers_file)
64
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
65
+ ans_file = open(answers_file, "w")
66
+
67
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
68
+ args.conv_mode = args.conv_mode + '_mmtag'
69
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
70
+
71
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
72
+ options = get_options(row, all_options)
73
+ cur_option_char = all_options[:len(options)]
74
+
75
+ if args.all_rounds:
76
+ num_rounds = len(options)
77
+ else:
78
+ num_rounds = 1
79
+
80
+ for round_idx in range(num_rounds):
81
+ idx = row['index']
82
+ question = row['question']
83
+ hint = row['hint']
84
+ image = load_image_from_base64(row['image'])
85
+ if not is_none(hint):
86
+ question = hint + '\n' + question
87
+ for option_char, option in zip(all_options[:len(options)], options):
88
+ question = question + '\n' + option_char + '. ' + option
89
+ qs = cur_prompt = question
90
+
91
+ if hasattr(model, "update_prompt"):
92
+ model.update_prompt([[cur_prompt]])
93
+
94
+ if model.config.mm_use_im_start_end:
95
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
96
+ else:
97
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
98
+
99
+ if args.single_pred_prompt:
100
+ if args.lang == 'cn':
101
+ qs = qs + '\n' + "请直接回答选项字母。"
102
+ else:
103
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
104
+
105
+ conv = conv_templates[args.conv_mode].copy()
106
+ conv.append_message(conv.roles[0], qs)
107
+ conv.append_message(conv.roles[1], None)
108
+ prompt = conv.get_prompt()
109
+
110
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
111
+
112
+ if hasattr(model.config, 'image_size_aux'):
113
+ if not hasattr(image_processor, 'image_size_raw'):
114
+ image_processor.image_size_raw = image_processor.crop_size.copy()
115
+ image_processor.crop_size['height'] = model.config.image_size_aux
116
+ image_processor.crop_size['width'] = model.config.image_size_aux
117
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
118
+
119
+ image_tensor = process_images([image], image_processor, model.config)[0]
120
+ image_grid = getattr(model.config, 'image_grid', 1)
121
+ if hasattr(model.config, 'image_size_aux'):
122
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
123
+ image_processor.image_size_raw['width'] * image_grid]
124
+ image_tensor_aux = image_tensor
125
+ image_tensor = torch.nn.functional.interpolate(image_tensor[None],
126
+ size=raw_shape,
127
+ mode='bilinear',
128
+ align_corners=False)[0]
129
+ else:
130
+ image_tensor_aux = []
131
+
132
+ if image_grid >= 2:
133
+ raw_image = image_tensor.reshape(3,
134
+ image_grid,
135
+ image_processor.image_size_raw['height'],
136
+ image_grid,
137
+ image_processor.image_size_raw['width'])
138
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
139
+ raw_image = raw_image.reshape(-1, 3,
140
+ image_processor.image_size_raw['height'],
141
+ image_processor.image_size_raw['width'])
142
+
143
+ if getattr(model.config, 'image_global', False):
144
+ global_image = image_tensor
145
+ if len(global_image.shape) == 3:
146
+ global_image = global_image[None]
147
+ global_image = torch.nn.functional.interpolate(global_image,
148
+ size=[image_processor.image_size_raw['height'],
149
+ image_processor.image_size_raw['width']],
150
+ mode='bilinear',
151
+ align_corners=False)
152
+ # [image_crops, image_global]
153
+ raw_image = torch.cat([raw_image, global_image], dim=0)
154
+ image_tensor = raw_image.contiguous()
155
+
156
+ images = image_tensor[None].to(dtype=model.dtype, device='cuda', non_blocking=True)
157
+ images_aux = image_tensor_aux[None].to(dtype=model.dtype, device='cuda', non_blocking=True) if len(image_tensor_aux)>0 else None
158
+
159
+ with torch.inference_mode():
160
+ output_ids = model.generate(
161
+ input_ids,
162
+ images=images,
163
+ images_aux=images_aux,
164
+ do_sample=True if args.temperature > 0 else False,
165
+ temperature=args.temperature,
166
+ top_p=args.top_p,
167
+ num_beams=args.num_beams,
168
+ # no_repeat_ngram_size=3,
169
+ max_new_tokens=1024,
170
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
171
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
172
+ pad_token_id=tokenizer.pad_token_id, # Pad token
173
+ use_cache=True)
174
+
175
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
176
+
177
+ ans_id = shortuuid.uuid()
178
+ ans_file.write(json.dumps({"question_id": idx,
179
+ "round_id": round_idx,
180
+ "prompt": cur_prompt,
181
+ "text": outputs,
182
+ "options": options,
183
+ "option_char": cur_option_char,
184
+ "answer_id": ans_id,
185
+ "model_id": model_name,
186
+ "metadata": {}}) + "\n")
187
+ ans_file.flush()
188
+
189
+ # rotate options
190
+ options = options[1:] + options[:1]
191
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
192
+ ans_file.close()
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
197
+ parser.add_argument("--model-base", type=str, default=None)
198
+ parser.add_argument("--image-folder", type=str, default="")
199
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
200
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
201
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
202
+ parser.add_argument("--num-chunks", type=int, default=1)
203
+ parser.add_argument("--chunk-idx", type=int, default=0)
204
+ parser.add_argument("--temperature", type=float, default=0.2)
205
+ parser.add_argument("--top_p", type=float, default=None)
206
+ parser.add_argument("--num_beams", type=int, default=1)
207
+ parser.add_argument("--all-rounds", action="store_true")
208
+ parser.add_argument("--single-pred-prompt", action="store_true")
209
+ parser.add_argument("--lang", type=str, default="en")
210
+ args = parser.parse_args()
211
+
212
+ eval_model(args)
minigemini/eval/model_vqa_qbench.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from tqdm import tqdm
4
+ import json
5
+
6
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7
+ from minigemini.conversation import conv_templates, SeparatorStyle
8
+ from minigemini.model.builder import load_pretrained_model
9
+ from minigemini.utils import disable_torch_init
10
+ from minigemini.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
11
+
12
+ from PIL import Image
13
+
14
+ import requests
15
+ from PIL import Image
16
+ from io import BytesIO
17
+
18
+
19
+ def load_image(image_file):
20
+ if image_file.startswith('http') or image_file.startswith('https'):
21
+ response = requests.get(image_file)
22
+ image = Image.open(BytesIO(response.content)).convert('RGB')
23
+ else:
24
+ image = Image.open(image_file).convert('RGB')
25
+ return image
26
+
27
+
28
+ def eval_model(args):
29
+ # Model
30
+ disable_torch_init()
31
+
32
+ model_name = get_model_name_from_path(args.model_path)
33
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True)
34
+
35
+
36
+
37
+
38
+ with open(args.questions_file) as f:
39
+ llvqa_data = json.load(f)
40
+
41
+ for i, llddata in enumerate(tqdm(llvqa_data)):
42
+ filename = llddata["img_path"]
43
+ if args.lang == "en":
44
+ message = llddata["question"] + "\nChoose between one of the options as follows:\n"
45
+ elif args.lang == "zh":
46
+ message = llddata["question"] + "\在下列选项中选择一个:\n"
47
+ else:
48
+ raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
49
+ for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
50
+ message += f"{choice} {ans}\n"
51
+ qs = message
52
+
53
+ if model.config.mm_use_im_start_end:
54
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
55
+ else:
56
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
57
+
58
+ if 'llama-2' in model_name.lower():
59
+ conv_mode = "llava_llama_2"
60
+ elif "v1" in model_name.lower():
61
+ conv_mode = "llava_v1"
62
+ elif "mpt" in model_name.lower():
63
+ conv_mode = "mpt"
64
+ else:
65
+ conv_mode = "llava_v0"
66
+
67
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
68
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
69
+ else:
70
+ args.conv_mode = conv_mode
71
+
72
+ conv = conv_templates[args.conv_mode].copy()
73
+ conv.append_message(conv.roles[0], qs)
74
+ conv.append_message(conv.roles[1], None)
75
+ prompt = conv.get_prompt()
76
+
77
+ image = load_image(args.image_folder + filename)
78
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
79
+
80
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
81
+
82
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
83
+ keywords = [stop_str]
84
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
85
+
86
+
87
+ with torch.inference_mode():
88
+ output_ids = model.generate(
89
+ input_ids,
90
+ images=image_tensor,
91
+ num_beams=1,
92
+ do_sample=False,
93
+ temperature=0,
94
+ max_new_tokens=1024,
95
+ use_cache=True,
96
+ stopping_criteria=[stopping_criteria])
97
+
98
+ input_token_len = input_ids.shape[1]
99
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
100
+ if n_diff_input_output > 0:
101
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
102
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
103
+ outputs = outputs.strip()
104
+ if outputs.endswith(stop_str):
105
+ outputs = outputs[:-len(stop_str)]
106
+ outputs = outputs.strip()
107
+ llddata["response"] = outputs
108
+ with open(args.answers_file, "a") as wf:
109
+ json.dump(llddata, wf)
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--model-path", type=str, default="llava-v1.5")
114
+ parser.add_argument("--model-base", type=str, default=None)
115
+ parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa")
116
+ parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json")
117
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
118
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
119
+ parser.add_argument("--lang", type=str, default="en")
120
+ args = parser.parse_args()
121
+
122
+ eval_model(args)
minigemini/eval/model_vqa_science.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from minigemini.conversation import conv_templates, SeparatorStyle
10
+ from minigemini.model.builder import load_pretrained_model
11
+ from minigemini.utils import disable_torch_init
12
+ from minigemini.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+ def split_list(lst, n):
18
+ """Split a list into n (roughly) equal-sized chunks"""
19
+ chunk_size = math.ceil(len(lst) / n) # integer division
20
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
21
+
22
+
23
+ def get_chunk(lst, n, k):
24
+ chunks = split_list(lst, n)
25
+ return chunks[k]
26
+
27
+
28
+ def eval_model(args):
29
+ # Model
30
+ disable_torch_init()
31
+ model_path = os.path.expanduser(args.model_path)
32
+ model_name = get_model_name_from_path(model_path)
33
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
34
+
35
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
36
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
37
+ answers_file = os.path.expanduser(args.answers_file)
38
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
39
+ ans_file = open(answers_file, "w")
40
+
41
+ for i, line in enumerate(tqdm(questions)):
42
+ idx = line["id"]
43
+ question = line['conversations'][0]
44
+ qs = question['value'].replace('<image>', '').strip()
45
+ cur_prompt = qs
46
+
47
+ if 'image' in line:
48
+ image_file = line["image"]
49
+ image = Image.open(os.path.join(args.image_folder, image_file))
50
+
51
+ if hasattr(model.config, 'image_size_aux'):
52
+ if not hasattr(image_processor, 'image_size_raw'):
53
+ image_processor.image_size_raw = image_processor.crop_size.copy()
54
+ image_processor.crop_size['height'] = model.config.image_size_aux
55
+ image_processor.crop_size['width'] = model.config.image_size_aux
56
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
57
+
58
+ image_tensor = process_images([image], image_processor, model.config)[0]
59
+
60
+ image_grid = getattr(model.config, 'image_grid', 1)
61
+ if hasattr(model.config, 'image_size_aux'):
62
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
63
+ image_processor.image_size_raw['width'] * image_grid]
64
+ image_tensor_aux = image_tensor
65
+ image_tensor = torch.nn.functional.interpolate(image_tensor[None],
66
+ size=raw_shape,
67
+ mode='bilinear',
68
+ align_corners=False)[0]
69
+ else:
70
+ image_tensor_aux = []
71
+
72
+ if image_grid >= 2:
73
+ raw_image = image_tensor.reshape(3,
74
+ image_grid,
75
+ image_processor.image_size_raw['height'],
76
+ image_grid,
77
+ image_processor.image_size_raw['width'])
78
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
79
+ raw_image = raw_image.reshape(-1, 3,
80
+ image_processor.image_size_raw['height'],
81
+ image_processor.image_size_raw['width'])
82
+
83
+ if getattr(model.config, 'image_global', False):
84
+ global_image = image_tensor
85
+ if len(global_image.shape) == 3:
86
+ global_image = global_image[None]
87
+ global_image = torch.nn.functional.interpolate(global_image,
88
+ size=[image_processor.image_size_raw['height'],
89
+ image_processor.image_size_raw['width']],
90
+ mode='bilinear',
91
+ align_corners=False)
92
+ # [image_crops, image_global]
93
+ raw_image = torch.cat([raw_image, global_image], dim=0)
94
+ image_tensor = raw_image.contiguous()
95
+
96
+ images = image_tensor[None].to(dtype=model.dtype, device='cuda', non_blocking=True)
97
+ images_aux = image_tensor_aux[None].to(dtype=model.dtype, device='cuda', non_blocking=True) if len(image_tensor_aux)>0 else None
98
+ if getattr(model.config, 'mm_use_im_start_end', False):
99
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
100
+ else:
101
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
102
+ cur_prompt = '<image>' + '\n' + cur_prompt
103
+ else:
104
+ images = None
105
+ images_aux = None
106
+
107
+ if args.single_pred_prompt:
108
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
109
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
110
+
111
+ conv = conv_templates[args.conv_mode].copy()
112
+ conv.append_message(conv.roles[0], qs)
113
+ conv.append_message(conv.roles[1], None)
114
+ prompt = conv.get_prompt()
115
+
116
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
117
+
118
+ if hasattr(model, "update_prompt"):
119
+ model.update_prompt([[cur_prompt]])
120
+
121
+ with torch.inference_mode():
122
+ output_ids = model.generate(
123
+ input_ids,
124
+ images=images,
125
+ images_aux=images_aux,
126
+ do_sample=True if args.temperature > 0 else False,
127
+ temperature=args.temperature,
128
+ max_new_tokens=1024,
129
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
130
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
131
+ pad_token_id=tokenizer.pad_token_id, # Pad token
132
+ use_cache=True,
133
+ )
134
+
135
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
136
+
137
+ ans_id = shortuuid.uuid()
138
+ ans_file.write(json.dumps({"question_id": idx,
139
+ "prompt": cur_prompt,
140
+ "text": outputs,
141
+ "answer_id": ans_id,
142
+ "model_id": model_name,
143
+ "metadata": {}}) + "\n")
144
+ ans_file.flush()
145
+ ans_file.close()
146
+
147
+ if __name__ == "__main__":
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
150
+ parser.add_argument("--model-base", type=str, default=None)
151
+ parser.add_argument("--image-folder", type=str, default="")
152
+ parser.add_argument("--question-file", type=str, default="tables/question.json")
153
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
154
+ parser.add_argument("--conv-mode", type=str, default="llava_v0")
155
+ parser.add_argument("--num-chunks", type=int, default=1)
156
+ parser.add_argument("--chunk-idx", type=int, default=0)
157
+ parser.add_argument("--temperature", type=float, default=0.2)
158
+ parser.add_argument("--answer-prompter", action="store_true")
159
+ parser.add_argument("--single-pred-prompt", action="store_true")
160
+ args = parser.parse_args()
161
+
162
+ eval_model(args)
minigemini/eval/qa_baseline_gpt35.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate answers with GPT-3.5"""
2
+ # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3
+ import argparse
4
+ import json
5
+ import os
6
+ import time
7
+ import concurrent.futures
8
+
9
+ import openai
10
+ import tqdm
11
+ import shortuuid
12
+
13
+ MODEL = 'gpt-3.5-turbo'
14
+ MODEL_ID = 'gpt-3.5-turbo:20230327'
15
+
16
+ def get_answer(question_id: int, question: str, max_tokens: int):
17
+ ans = {
18
+ 'answer_id': shortuuid.uuid(),
19
+ 'question_id': question_id,
20
+ 'model_id': MODEL_ID,
21
+ }
22
+ for _ in range(3):
23
+ try:
24
+ response = openai.ChatCompletion.create(
25
+ model=MODEL,
26
+ messages=[{
27
+ 'role': 'system',
28
+ 'content': 'You are a helpful assistant.'
29
+ }, {
30
+ 'role': 'user',
31
+ 'content': question,
32
+ }],
33
+ max_tokens=max_tokens,
34
+ )
35
+ ans['text'] = response['choices'][0]['message']['content']
36
+ return ans
37
+ except Exception as e:
38
+ print('[ERROR]', e)
39
+ ans['text'] = '#ERROR#'
40
+ time.sleep(1)
41
+ return ans
42
+
43
+
44
+ if __name__ == '__main__':
45
+ parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46
+ parser.add_argument('-q', '--question')
47
+ parser.add_argument('-o', '--output')
48
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49
+ args = parser.parse_args()
50
+
51
+ questions_dict = {}
52
+ with open(os.path.expanduser(args.question)) as f:
53
+ for line in f:
54
+ if not line:
55
+ continue
56
+ q = json.loads(line)
57
+ questions_dict[q['question_id']] = q['text']
58
+
59
+ answers = []
60
+
61
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62
+ futures = []
63
+ for qid, question in questions_dict.items():
64
+ future = executor.submit(get_answer, qid, question, args.max_tokens)
65
+ futures.append(future)
66
+
67
+ for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68
+ answers.append(future.result())
69
+
70
+ answers.sort(key=lambda x: x['question_id'])
71
+
72
+ with open(os.path.expanduser(args.output), 'w') as f:
73
+ table = [json.dumps(ans) for ans in answers]
74
+ f.write('\n'.join(table))
minigemini/eval/run_llava.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from minigemini.constants import (
5
+ IMAGE_TOKEN_INDEX,
6
+ DEFAULT_IMAGE_TOKEN,
7
+ DEFAULT_IM_START_TOKEN,
8
+ DEFAULT_IM_END_TOKEN,
9
+ IMAGE_PLACEHOLDER,
10
+ )
11
+ from minigemini.conversation import conv_templates, SeparatorStyle
12
+ from minigemini.model.builder import load_pretrained_model
13
+ from minigemini.utils import disable_torch_init
14
+ from minigemini.mm_utils import (
15
+ process_images,
16
+ tokenizer_image_token,
17
+ get_model_name_from_path,
18
+ )
19
+
20
+ from PIL import Image
21
+
22
+ import requests
23
+ from PIL import Image
24
+ from io import BytesIO
25
+ import re
26
+
27
+
28
+ def image_parser(args):
29
+ out = args.image_file.split(args.sep)
30
+ return out
31
+
32
+
33
+ def load_image(image_file):
34
+ if image_file.startswith("http") or image_file.startswith("https"):
35
+ response = requests.get(image_file)
36
+ image = Image.open(BytesIO(response.content)).convert("RGB")
37
+ else:
38
+ image = Image.open(image_file).convert("RGB")
39
+ return image
40
+
41
+
42
+ def load_images(image_files):
43
+ out = []
44
+ for image_file in image_files:
45
+ image = load_image(image_file)
46
+ out.append(image)
47
+ return out
48
+
49
+
50
+ def eval_model(args):
51
+ # Model
52
+ disable_torch_init()
53
+
54
+ model_name = get_model_name_from_path(args.model_path)
55
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
56
+ args.model_path, args.model_base, model_name
57
+ )
58
+
59
+ qs = args.query
60
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
61
+ if IMAGE_PLACEHOLDER in qs:
62
+ if model.config.mm_use_im_start_end:
63
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
64
+ else:
65
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
66
+ else:
67
+ if model.config.mm_use_im_start_end:
68
+ qs = image_token_se + "\n" + qs
69
+ else:
70
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
71
+
72
+ if "llama-2" in model_name.lower():
73
+ conv_mode = "llava_llama_2"
74
+ elif "mistral" in model_name.lower():
75
+ conv_mode = "mistral_instruct"
76
+ elif "v1.6-34b" in model_name.lower():
77
+ conv_mode = "chatml_direct"
78
+ elif "v1" in model_name.lower():
79
+ conv_mode = "llava_v1"
80
+ elif "mpt" in model_name.lower():
81
+ conv_mode = "mpt"
82
+ else:
83
+ conv_mode = "llava_v0"
84
+
85
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
86
+ print(
87
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
88
+ conv_mode, args.conv_mode, args.conv_mode
89
+ )
90
+ )
91
+ else:
92
+ args.conv_mode = conv_mode
93
+
94
+ conv = conv_templates[args.conv_mode].copy()
95
+ conv.append_message(conv.roles[0], qs)
96
+ conv.append_message(conv.roles[1], None)
97
+ prompt = conv.get_prompt()
98
+
99
+ image_files = image_parser(args)
100
+ images = load_images(image_files)
101
+ images_tensor = process_images(
102
+ images,
103
+ image_processor,
104
+ model.config
105
+ ).to(model.device, dtype=torch.float16)
106
+
107
+ input_ids = (
108
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
109
+ .unsqueeze(0)
110
+ .cuda()
111
+ )
112
+
113
+ with torch.inference_mode():
114
+ output_ids = model.generate(
115
+ input_ids,
116
+ images=images_tensor,
117
+ do_sample=True if args.temperature > 0 else False,
118
+ temperature=args.temperature,
119
+ top_p=args.top_p,
120
+ num_beams=args.num_beams,
121
+ max_new_tokens=args.max_new_tokens,
122
+ use_cache=True,
123
+ )
124
+
125
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
126
+ print(outputs)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ parser = argparse.ArgumentParser()
131
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
132
+ parser.add_argument("--model-base", type=str, default=None)
133
+ parser.add_argument("--image-file", type=str, required=True)
134
+ parser.add_argument("--query", type=str, required=True)
135
+ parser.add_argument("--conv-mode", type=str, default=None)
136
+ parser.add_argument("--sep", type=str, default=",")
137
+ parser.add_argument("--temperature", type=float, default=0.2)
138
+ parser.add_argument("--top_p", type=float, default=None)
139
+ parser.add_argument("--num_beams", type=int, default=1)
140
+ parser.add_argument("--max_new_tokens", type=int, default=512)
141
+ args = parser.parse_args()
142
+
143
+ eval_model(args)
minigemini/eval/summarize_gpt_review.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import defaultdict
4
+
5
+ import numpy as np
6
+
7
+ import argparse
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11
+ parser.add_argument('-d', '--dir', default=None)
12
+ parser.add_argument('-v', '--version', default=None)
13
+ parser.add_argument('-s', '--select', nargs='*', default=None)
14
+ parser.add_argument('-f', '--files', nargs='*', default=[])
15
+ parser.add_argument('-i', '--ignore', nargs='*', default=[])
16
+ return parser.parse_args()
17
+
18
+
19
+ if __name__ == '__main__':
20
+ args = parse_args()
21
+
22
+ if args.ignore is not None:
23
+ args.ignore = [int(x) for x in args.ignore]
24
+
25
+ if len(args.files) > 0:
26
+ review_files = args.files
27
+ else:
28
+ review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29
+
30
+ for review_file in sorted(review_files):
31
+ config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32
+ if args.select is not None and any(x not in config for x in args.select):
33
+ continue
34
+ if '0613' in config:
35
+ version = '0613'
36
+ else:
37
+ version = '0314'
38
+ if args.version is not None and args.version != version:
39
+ continue
40
+ scores = defaultdict(list)
41
+ print(config)
42
+ with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43
+ for review_str in f:
44
+ review = json.loads(review_str)
45
+ if review['question_id'] in args.ignore:
46
+ continue
47
+ if 'category' in review:
48
+ scores[review['category']].append(review['tuple'])
49
+ scores['all'].append(review['tuple'])
50
+ else:
51
+ if 'tuple' in review:
52
+ scores['all'].append(review['tuple'])
53
+ else:
54
+ scores['all'].append(review['score'])
55
+ for k, v in sorted(scores.items()):
56
+ stats = np.asarray(v).mean(0).tolist()
57
+ stats = [round(x, 3) for x in stats]
58
+ # print(k, stats, round(stats[1]/stats[0]*100, 1))
59
+ print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60
+ print('=================================')
minigemini/mm_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from minigemini.constants import IMAGE_TOKEN_INDEX
8
+
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+
14
+ def expand2square(pil_img, background_color):
15
+ width, height = pil_img.size
16
+ if width == height:
17
+ return pil_img
18
+ elif width > height:
19
+ result = Image.new(pil_img.mode, (width, width), background_color)
20
+ result.paste(pil_img, (0, (width - height) // 2))
21
+ return result
22
+ else:
23
+ result = Image.new(pil_img.mode, (height, height), background_color)
24
+ result.paste(pil_img, ((height - width) // 2, 0))
25
+ return result
26
+
27
+
28
+ def process_images(images, image_processor, model_cfg):
29
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30
+ new_images = []
31
+ if image_aspect_ratio == 'pad':
32
+ for image in images:
33
+ image = expand2square(image.convert('RGB'), tuple(int(x*255) for x in image_processor.image_mean))
34
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35
+ new_images.append(image)
36
+ else:
37
+ return image_processor(images, return_tensors='pt')['pixel_values']
38
+ if all(x.shape == new_images[0].shape for x in new_images):
39
+ new_images = torch.stack(new_images, dim=0)
40
+ return new_images
41
+
42
+
43
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
45
+
46
+ def insert_separator(X, sep):
47
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48
+
49
+ input_ids = []
50
+ offset = 0
51
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52
+ offset = 1
53
+ input_ids.append(prompt_chunks[0][0])
54
+
55
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56
+ input_ids.extend(x[offset:])
57
+
58
+ if return_tensors is not None:
59
+ if return_tensors == 'pt':
60
+ return torch.tensor(input_ids, dtype=torch.long)
61
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
62
+ return input_ids
63
+
64
+
65
+ def get_model_name_from_path(model_path):
66
+ model_path = model_path.strip("/")
67
+ model_paths = model_path.split("/")
68
+ if model_paths[-1].startswith('checkpoint-'):
69
+ return model_paths[-2] + "_" + model_paths[-1]
70
+ else:
71
+ return model_paths[-1]
72
+
73
+ class KeywordsStoppingCriteria(StoppingCriteria):
74
+ def __init__(self, keywords, tokenizer, input_ids):
75
+ self.keywords = keywords
76
+ self.keyword_ids = []
77
+ self.max_keyword_len = 0
78
+ for keyword in keywords:
79
+ cur_keyword_ids = tokenizer(keyword).input_ids
80
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
81
+ cur_keyword_ids = cur_keyword_ids[1:]
82
+ if len(cur_keyword_ids) > self.max_keyword_len:
83
+ self.max_keyword_len = len(cur_keyword_ids)
84
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85
+ self.tokenizer = tokenizer
86
+ self.start_len = input_ids.shape[1]
87
+
88
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
89
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
90
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
91
+ for keyword_id in self.keyword_ids:
92
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
93
+ if torch.equal(truncated_output_ids, keyword_id):
94
+ return True
95
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
96
+ for keyword in self.keywords:
97
+ if keyword in outputs:
98
+ return True
99
+ return False
100
+
101
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ outputs = []
103
+ for i in range(output_ids.shape[0]):
104
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
105
+ return all(outputs)
minigemini/model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .language_model.mini_gemini_llama import MiniGeminiLlamaForCausalLM
2
+ try:
3
+ from .language_model.mini_gemini_mistral import MiniGeminiMistralForCausalLM
4
+ from .language_model.mini_gemini_mixtral import MiniGeminiMixtralForCausalLM
5
+ from .language_model.mini_gemini_gemma import MiniGeminiGemmaForCausalLM
6
+ except:
7
+ ImportWarning("New model not imported. Try to update Transformers.")
minigemini/model/builder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ import os
20
+ import warnings
21
+ import logging
22
+
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
24
+ import torch
25
+ from minigemini.model import *
26
+ from minigemini.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+
28
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
29
+ kwargs = {"device_map": device_map, **kwargs}
30
+
31
+ if device != "cuda":
32
+ kwargs['device_map'] = {"": device}
33
+
34
+ if load_8bit:
35
+ kwargs['load_in_8bit'] = True
36
+ elif load_4bit:
37
+ kwargs['load_in_4bit'] = True
38
+ kwargs['quantization_config'] = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.float16,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type='nf4'
43
+ )
44
+ else:
45
+ kwargs['torch_dtype'] = torch.float16
46
+
47
+ if use_flash_attn:
48
+ kwargs['attn_implementation'] = 'flash_attention_2'
49
+
50
+ logging.getLogger("transformers").setLevel(logging.ERROR)
51
+
52
+ if 'mini-gemini' in model_name.lower():
53
+ # Load MiniGemini model
54
+ if model_base is not None:
55
+ # this may be mm projector only
56
+ print('Loading MiniGemini from base model...')
57
+
58
+ if "8x7b" in model_name.lower():
59
+ tokenizer = AutoTokenizer.from_pretrained(model_base)
60
+ model = MiniGeminiMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
61
+ elif "2b" in model_name.lower():
62
+ tokenizer = AutoTokenizer.from_pretrained(model_base)
63
+ model = MiniGeminiGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
64
+ else:
65
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
66
+ model = MiniGeminiLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
67
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
68
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
69
+ model.load_state_dict(mm_projector_weights, strict=False)
70
+ else:
71
+ if "8x7b" in model_name.lower():
72
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
73
+ model = MiniGeminiMixtralForCausalLM.from_pretrained(model_path, **kwargs)
74
+ elif "2b" in model_name.lower():
75
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
76
+ model = MiniGeminiGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
77
+ else:
78
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
79
+ model = MiniGeminiLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
80
+
81
+ else:
82
+ # Load language model
83
+ if model_base is not None:
84
+ # PEFT model
85
+ from peft import PeftModel
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
87
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
88
+ print(f"Loading LoRA weights from {model_path}")
89
+ model = PeftModel.from_pretrained(model, model_path)
90
+ print(f"Merging weights")
91
+ model = model.merge_and_unload()
92
+ print('Convert to FP16...')
93
+ model.to(torch.float16)
94
+ else:
95
+ if 'mpt' in model_name.lower():
96
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
97
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
98
+ else:
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
100
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+
102
+ image_processor = None
103
+
104
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
105
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
106
+ if mm_use_im_patch_token:
107
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
108
+ if mm_use_im_start_end:
109
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
110
+
111
+ model.resize_token_embeddings(len(tokenizer))
112
+
113
+ vision_tower = model.get_vision_tower()
114
+ if not vision_tower.is_loaded:
115
+ vision_tower.load_model()
116
+ vision_tower.to(device=device, dtype=torch.float16)
117
+ image_processor = vision_tower.image_processor
118
+
119
+ if 'mini-gemini' in model_name.lower():
120
+ vision_tower_aux = model.get_vision_tower_aux()
121
+ if not vision_tower_aux.is_loaded:
122
+ vision_tower_aux.load_model()
123
+ vision_tower_aux.to(device=device, dtype=torch.float16)
124
+
125
+ # initialize attention modules
126
+ model.config.model_path = model_path
127
+ model.get_model().initialize_uni_modules(model.config, for_eval=True)
128
+
129
+ model.get_model().vlm_uni_query_projector.to(device=device)
130
+ model.get_model().vlm_uni_aux_projector.to(device=device)
131
+ model.get_model().vlm_uni_val_projector.to(device=device)
132
+
133
+ if hasattr(model.config, "max_sequence_length"):
134
+ context_len = model.config.max_sequence_length
135
+ else:
136
+ context_len = 2048
137
+
138
+ logging.getLogger("transformers").setLevel(logging.WARNING)
139
+
140
+ return tokenizer, model, image_processor, context_len
minigemini/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m minigemini.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from minigemini.model import *
10
+ from minigemini.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
minigemini/model/language_model/mini_gemini_gemma.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ try:
25
+ from transformers import AutoConfig, AutoModelForCausalLM, \
26
+ GemmaConfig, GemmaModel, GemmaForCausalLM
27
+ except:
28
+ print("New model not imported. Try to update Transformers to 4.38.0 or later.")
29
+ from transformers.modeling_outputs import CausalLMOutputWithPast
30
+ from transformers.generation.utils import GenerateOutput
31
+ from transformers.generation.utils import logging
32
+
33
+ from ..mini_gemini_arch import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ class MiniGeminiConfig(GemmaConfig):
38
+ model_type = "mini_gemini_gemma"
39
+
40
+
41
+ class MiniGeminiGemmaModel(MiniGeminiMetaModel, GemmaModel):
42
+ config_class = MiniGeminiConfig
43
+
44
+ def __init__(self, config: GemmaConfig):
45
+ super(MiniGeminiGemmaModel, self).__init__(config)
46
+
47
+ class MiniGeminiGemmaForCausalLM(GemmaForCausalLM, MiniGeminiMetaForCausalLM):
48
+ config_class = MiniGeminiConfig
49
+
50
+ def __init__(self, config):
51
+ super(GemmaForCausalLM, self).__init__(config)
52
+ self.model = MiniGeminiGemmaModel(config)
53
+ self.vocab_size = config.vocab_size
54
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
55
+
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ cache_position: Optional[torch.LongTensor] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ images_aux: Optional[torch.FloatTensor] = None,
76
+ return_dict: Optional[bool] = None,
77
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
78
+
79
+ if inputs_embeds is None:
80
+ (
81
+ input_ids,
82
+ position_ids,
83
+ attention_mask,
84
+ past_key_values,
85
+ inputs_embeds,
86
+ labels,
87
+ ) = self.prepare_inputs_labels_for_multimodal(
88
+ input_ids,
89
+ position_ids,
90
+ attention_mask,
91
+ past_key_values,
92
+ labels,
93
+ images,
94
+ images_aux
95
+ )
96
+
97
+ return super().forward(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ labels=labels,
104
+ use_cache=use_cache,
105
+ cache_position=cache_position,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ return_dict=return_dict
109
+ )
110
+
111
+ @torch.no_grad()
112
+ def generate(
113
+ self,
114
+ inputs: Optional[torch.Tensor] = None,
115
+ images: Optional[torch.Tensor] = None,
116
+ images_aux: Optional[torch.FloatTensor] = None,
117
+ **kwargs,
118
+ ) -> Union[GenerateOutput, torch.LongTensor]:
119
+ position_ids = kwargs.pop("position_ids", None)
120
+ attention_mask = kwargs.pop("attention_mask", None)
121
+ if "inputs_embeds" in kwargs:
122
+ raise NotImplementedError("`inputs_embeds` is not supported")
123
+
124
+ if images is not None:
125
+ (
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ _,
130
+ inputs_embeds,
131
+ _
132
+ ) = self.prepare_inputs_labels_for_multimodal(
133
+ inputs,
134
+ position_ids,
135
+ attention_mask,
136
+ None,
137
+ None,
138
+ images,
139
+ images_aux
140
+ )
141
+ else:
142
+ inputs_embeds = self.get_model().embed_tokens(inputs)
143
+
144
+ return super().generate(
145
+ position_ids=position_ids,
146
+ attention_mask=attention_mask,
147
+ inputs_embeds=inputs_embeds,
148
+ **kwargs
149
+ )
150
+
151
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
152
+ images = kwargs.pop("images", None)
153
+ images_aux = kwargs.pop("images_aux", None)
154
+ _inputs = super().prepare_inputs_for_generation(
155
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
156
+ )
157
+ if images is not None:
158
+ _inputs['images'] = images
159
+ if images_aux is not None:
160
+ _inputs['images_aux'] = images_aux
161
+ return _inputs
162
+
163
+ AutoConfig.register("mini_gemini_gemma", MiniGeminiConfig)
164
+ AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiGemmaForCausalLM)
minigemini/model/language_model/mini_gemini_llama.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import AutoConfig, AutoModelForCausalLM, \
25
+ LlamaConfig, LlamaModel, LlamaForCausalLM
26
+
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from transformers.utils import logging
29
+ from transformers.generation.utils import GenerateOutput
30
+
31
+ from minigemini.model.mini_gemini_arch import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
32
+ from torch.nn import CrossEntropyLoss
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ class MiniGeminiConfig(LlamaConfig):
38
+ model_type = "mini_gemini"
39
+
40
+ class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel):
41
+ config_class = MiniGeminiConfig
42
+
43
+ def __init__(self, config: LlamaConfig):
44
+ super(MiniGeminiLlamaModel, self).__init__(config)
45
+
46
+
47
+ class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM):
48
+ config_class = MiniGeminiConfig
49
+
50
+ def __init__(self, config):
51
+ super(LlamaForCausalLM, self).__init__(config)
52
+ self.model = MiniGeminiLlamaModel(config)
53
+ self.pretraining_tp = config.pretraining_tp
54
+ self.vocab_size = config.vocab_size
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ images_aux: Optional[torch.FloatTensor] = None,
76
+ return_dict: Optional[bool] = None,
77
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
78
+
79
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
80
+ output_hidden_states = (
81
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
+ )
83
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
84
+
85
+ if inputs_embeds is None:
86
+ (
87
+ input_ids,
88
+ position_ids,
89
+ attention_mask,
90
+ past_key_values,
91
+ inputs_embeds,
92
+ labels
93
+ ) = self.prepare_inputs_labels_for_multimodal(
94
+ input_ids,
95
+ position_ids,
96
+ attention_mask,
97
+ past_key_values,
98
+ labels,
99
+ images,
100
+ images_aux
101
+ )
102
+
103
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
104
+ outputs = self.model(
105
+ input_ids=input_ids,
106
+ attention_mask=attention_mask,
107
+ position_ids=position_ids,
108
+ past_key_values=past_key_values,
109
+ inputs_embeds=inputs_embeds,
110
+ use_cache=use_cache,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict,
114
+ )
115
+
116
+ hidden_states = outputs[0]
117
+ if self.pretraining_tp > 1:
118
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
119
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
120
+ logits = torch.cat(logits, dim=-1)
121
+ else:
122
+ logits = self.lm_head(hidden_states)
123
+ logits = logits.float()
124
+
125
+ loss = None
126
+ if labels is not None:
127
+ # Shift so that tokens < n predict n
128
+ shift_logits = logits[..., :-1, :].contiguous()
129
+ shift_labels = labels[..., 1:].contiguous()
130
+ # Flatten the tokens
131
+ loss_fct = CrossEntropyLoss()
132
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
133
+ shift_labels = shift_labels.view(-1)
134
+ # Enable model parallelism
135
+ shift_labels = shift_labels.to(shift_logits.device)
136
+ loss = loss_fct(shift_logits, shift_labels)
137
+
138
+ if not return_dict:
139
+ output = (logits,) + outputs[1:]
140
+ return (loss,) + output if loss is not None else output
141
+
142
+ return CausalLMOutputWithPast(
143
+ loss=loss,
144
+ logits=logits,
145
+ past_key_values=outputs.past_key_values,
146
+ hidden_states=outputs.hidden_states,
147
+ attentions=outputs.attentions,
148
+ )
149
+
150
+ @torch.no_grad()
151
+ def generate(
152
+ self,
153
+ inputs: Optional[torch.Tensor] = None,
154
+ images: Optional[torch.Tensor] = None,
155
+ images_aux: Optional[torch.FloatTensor] = None,
156
+ **kwargs,
157
+ ) -> Union[GenerateOutput, torch.LongTensor]:
158
+ position_ids = kwargs.pop("position_ids", None)
159
+ attention_mask = kwargs.pop("attention_mask", None)
160
+ if "inputs_embeds" in kwargs:
161
+ raise NotImplementedError("`inputs_embeds` is not supported")
162
+
163
+ if images is not None:
164
+ (
165
+ inputs,
166
+ position_ids,
167
+ attention_mask,
168
+ _,
169
+ inputs_embeds,
170
+ _
171
+ ) = self.prepare_inputs_labels_for_multimodal(
172
+ inputs,
173
+ position_ids,
174
+ attention_mask,
175
+ None,
176
+ None,
177
+ images,
178
+ images_aux
179
+ )
180
+ else:
181
+ inputs_embeds = self.get_model().embed_tokens(inputs)
182
+
183
+ return super().generate(
184
+ position_ids=position_ids,
185
+ attention_mask=attention_mask,
186
+ inputs_embeds=inputs_embeds,
187
+ **kwargs
188
+ )
189
+
190
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
191
+ images = kwargs.pop("images", None)
192
+ images_aux = kwargs.pop("images_aux", None)
193
+ _inputs = super().prepare_inputs_for_generation(
194
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
195
+ )
196
+ if images is not None:
197
+ _inputs['images'] = images
198
+ if images_aux is not None:
199
+ _inputs['images_aux'] = images_aux
200
+ return _inputs
201
+
202
+ AutoConfig.register("mini_gemini", MiniGeminiConfig)
203
+ AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM)
minigemini/model/language_model/mini_gemini_mistral.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from transformers import AutoConfig, AutoModelForCausalLM, \
25
+ MistralConfig, MistralModel, MistralForCausalLM
26
+
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from transformers.generation.utils import GenerateOutput
29
+ from transformers.generation.utils import logging
30
+
31
+ from ..mini_gemini_arch import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ class MiniGeminiConfig(MistralConfig):
36
+ model_type = "mini_gemini_mistral"
37
+
38
+
39
+ class MiniGeminiMistralModel(MiniGeminiMetaModel, MistralModel):
40
+ config_class = MiniGeminiConfig
41
+
42
+ def __init__(self, config: MistralConfig):
43
+ super(MiniGeminiMistralModel, self).__init__(config)
44
+ # self.max_pos_idx = 0
45
+
46
+ class MiniGeminiMistralForCausalLM(MistralForCausalLM, MiniGeminiMetaForCausalLM):
47
+ config_class = MiniGeminiConfig
48
+
49
+ def __init__(self, config):
50
+ super(MistralForCausalLM, self).__init__(config)
51
+ self.model = MiniGeminiMistralModel(config)
52
+ # self.pretraining_tp = config.pretraining_tp
53
+ self.vocab_size = config.vocab_size
54
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
55
+
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ images_aux: Optional[torch.FloatTensor] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
77
+
78
+ if inputs_embeds is None:
79
+ (
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ inputs_embeds,
85
+ labels
86
+ ) = self.prepare_inputs_labels_for_multimodal(
87
+ input_ids,
88
+ position_ids,
89
+ attention_mask,
90
+ past_key_values,
91
+ labels,
92
+ images,
93
+ images_aux
94
+ )
95
+
96
+ return super().forward(
97
+ input_ids=input_ids,
98
+ attention_mask=attention_mask,
99
+ position_ids=position_ids,
100
+ past_key_values=past_key_values,
101
+ inputs_embeds=inputs_embeds,
102
+ labels=labels,
103
+ use_cache=use_cache,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict
107
+ )
108
+
109
+ @torch.no_grad()
110
+ def generate(
111
+ self,
112
+ inputs: Optional[torch.Tensor] = None,
113
+ images: Optional[torch.Tensor] = None,
114
+ images_aux: Optional[torch.FloatTensor] = None,
115
+ **kwargs,
116
+ ) -> Union[GenerateOutput, torch.LongTensor]:
117
+ position_ids = kwargs.pop("position_ids", None)
118
+ attention_mask = kwargs.pop("attention_mask", None)
119
+ if "inputs_embeds" in kwargs:
120
+ raise NotImplementedError("`inputs_embeds` is not supported")
121
+
122
+ if images is not None:
123
+ (
124
+ inputs,
125
+ position_ids,
126
+ attention_mask,
127
+ _,
128
+ inputs_embeds,
129
+ _
130
+ ) = self.prepare_inputs_labels_for_multimodal(
131
+ inputs,
132
+ position_ids,
133
+ attention_mask,
134
+ None,
135
+ None,
136
+ images,
137
+ images_aux
138
+ )
139
+ else:
140
+ inputs_embeds = self.get_model().embed_tokens(inputs)
141
+
142
+ return super().generate(
143
+ position_ids=position_ids,
144
+ attention_mask=attention_mask,
145
+ inputs_embeds=inputs_embeds,
146
+ **kwargs
147
+ )
148
+
149
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
150
+ images = kwargs.pop("images", None)
151
+ images_aux = kwargs.pop("images_aux", None)
152
+ _inputs = super().prepare_inputs_for_generation(
153
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
154
+ )
155
+ if images is not None:
156
+ _inputs['images'] = images
157
+ if images_aux is not None:
158
+ _inputs['images_aux'] = images_aux
159
+ return _inputs
160
+
161
+ AutoConfig.register("mini_gemini_mistral", MiniGeminiConfig)
162
+ AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiMistralForCausalLM)
minigemini/model/language_model/mini_gemini_mixtral.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from transformers import AutoConfig, AutoModelForCausalLM, \
25
+ MixtralConfig, MixtralModel, MixtralForCausalLM
26
+
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from transformers.generation.utils import GenerateOutput
29
+ from transformers.generation.utils import logging
30
+
31
+ from ..mini_gemini_arch import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ class MiniGeminiConfig(MixtralConfig):
36
+ model_type = "mini_gemini_mixtral"
37
+
38
+
39
+ class MiniGeminiMixtralModel(MiniGeminiMetaModel, MixtralModel):
40
+ config_class = MiniGeminiConfig
41
+
42
+ def __init__(self, config: MixtralConfig):
43
+ super(MiniGeminiMixtralModel, self).__init__(config)
44
+ # self.max_pos_idx = 0
45
+
46
+ class MiniGeminiMixtralForCausalLM(MixtralForCausalLM, MiniGeminiMetaForCausalLM):
47
+ config_class = MiniGeminiConfig
48
+
49
+ def __init__(self, config):
50
+ super(MixtralForCausalLM, self).__init__(config)
51
+ self.model = MiniGeminiMixtralModel(config)
52
+ # self.pretraining_tp = config.pretraining_tp
53
+ self.vocab_size = config.vocab_size
54
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
55
+
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ images_aux: Optional[torch.FloatTensor] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
77
+
78
+ if inputs_embeds is None:
79
+ (
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ inputs_embeds,
85
+ labels
86
+ ) = self.prepare_inputs_labels_for_multimodal(
87
+ input_ids,
88
+ position_ids,
89
+ attention_mask,
90
+ past_key_values,
91
+ labels,
92
+ images,
93
+ images_aux
94
+ )
95
+
96
+ return super().forward(
97
+ input_ids=input_ids,
98
+ attention_mask=attention_mask,
99
+ position_ids=position_ids,
100
+ past_key_values=past_key_values,
101
+ inputs_embeds=inputs_embeds,
102
+ labels=labels,
103
+ use_cache=use_cache,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict
107
+ )
108
+
109
+ @torch.no_grad()
110
+ def generate(
111
+ self,
112
+ inputs: Optional[torch.Tensor] = None,
113
+ images: Optional[torch.Tensor] = None,
114
+ images_aux: Optional[torch.FloatTensor] = None,
115
+ **kwargs,
116
+ ) -> Union[GenerateOutput, torch.LongTensor]:
117
+ position_ids = kwargs.pop("position_ids", None)
118
+ attention_mask = kwargs.pop("attention_mask", None)
119
+ if "inputs_embeds" in kwargs:
120
+ raise NotImplementedError("`inputs_embeds` is not supported")
121
+
122
+ if images is not None:
123
+ (
124
+ inputs,
125
+ position_ids,
126
+ attention_mask,
127
+ _,
128
+ inputs_embeds,
129
+ _
130
+ ) = self.prepare_inputs_labels_for_multimodal(
131
+ inputs,
132
+ position_ids,
133
+ attention_mask,
134
+ None,
135
+ None,
136
+ images,
137
+ images_aux
138
+ )
139
+ else:
140
+ inputs_embeds = self.get_model().embed_tokens(inputs)
141
+
142
+ return super().generate(
143
+ position_ids=position_ids,
144
+ attention_mask=attention_mask,
145
+ inputs_embeds=inputs_embeds,
146
+ **kwargs
147
+ )
148
+
149
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
150
+ images = kwargs.pop("images", None)
151
+ images_aux = kwargs.pop("images_aux", None)
152
+ _inputs = super().prepare_inputs_for_generation(
153
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
154
+ )
155
+ if images is not None:
156
+ _inputs['images'] = images
157
+ if images_aux is not None:
158
+ _inputs['images_aux'] = images_aux
159
+ return _inputs
160
+
161
+ AutoConfig.register("mini_gemini_mixtral", MiniGeminiConfig)
162
+ AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiMixtralForCausalLM)
minigemini/model/llava_arch.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from minigemini.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+
27
+ class LlavaMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(LlavaMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ def get_vision_tower(self):
37
+ vision_tower = getattr(self, 'vision_tower', None)
38
+ if type(vision_tower) is list:
39
+ vision_tower = vision_tower[0]
40
+ return vision_tower
41
+
42
+ def initialize_vision_modules(self, model_args, fsdp=None):
43
+ vision_tower = model_args.vision_tower
44
+ mm_vision_select_layer = model_args.mm_vision_select_layer
45
+ mm_vision_select_feature = model_args.mm_vision_select_feature
46
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47
+
48
+ self.config.mm_vision_tower = vision_tower
49
+
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
+
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
+ else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = vision_tower.hidden_size
67
+ self.config.mm_vision_select_layer = mm_vision_select_layer
68
+ self.config.mm_vision_select_feature = mm_vision_select_feature
69
+
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
+ else:
73
+ # In case it is frozen by LoRA
74
+ for p in self.mm_projector.parameters():
75
+ p.requires_grad = True
76
+
77
+ if pretrain_mm_mlp_adapter is not None:
78
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
79
+ def get_w(weights, keyword):
80
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
81
+
82
+ if 'model' in mm_projector_weights.keys():
83
+ mm_projector_weights = mm_projector_weights['model']
84
+ status = self.mm_projector.load_state_dict(mm_projector_weights, strict=False)
85
+ print('missing_keys:', status.missing_keys)
86
+ else:
87
+ status = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
88
+ print('missing_keys:', status.missing_keys)
89
+
90
+ # class_embedding_weights = get_w(mm_projector_weights, 'model.vision_tower.vision_tower.vision_model.embeddings')
91
+ # if len(class_embedding_weights) > 0:
92
+ # self.vision_tower.vision_tower.vision_model.embeddings.load_state_dict(class_embedding_weights, strict=False)
93
+
94
+
95
+ class LlavaMetaForCausalLM(ABC):
96
+
97
+ @abstractmethod
98
+ def get_model(self):
99
+ pass
100
+
101
+ def get_vision_tower(self):
102
+ return self.get_model().get_vision_tower()
103
+
104
+ def encode_images(self, images=None, points=None):
105
+ if images is not None:
106
+ image_features = self.get_model().get_vision_tower()(images)
107
+ image_features = self.get_model().mm_projector(image_features)
108
+ if points is not None:
109
+ # use pre-computed features here
110
+ point_features = [self.get_model().mm_projector(_point).squeeze() for _point in points]
111
+ return image_features
112
+
113
+ def prepare_inputs_labels_for_multimodal(
114
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images=None, points=None
115
+ ):
116
+ vision_tower = self.get_vision_tower()
117
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
118
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
119
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
120
+ attention_mask = torch.cat((attention_mask, torch.ones(
121
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
122
+ dtype=attention_mask.dtype,
123
+ device=attention_mask.device
124
+ )), dim=1)
125
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
126
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
127
+
128
+ if type(images) is list or images.ndim == 5:
129
+ concat_images = torch.cat([image for image in images], dim=0)
130
+ image_features = self.encode_images(concat_images)
131
+ split_sizes = [image.shape[0] for image in images]
132
+ image_features = torch.split(image_features, split_sizes, dim=0)
133
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
134
+ else:
135
+ image_features = self.encode_images(images).to(self.device)
136
+
137
+ # TODO: image start / end is not implemented here to support pretraining.
138
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
139
+ raise NotImplementedError
140
+
141
+ # Let's just add dummy tensors if they do not exist,
142
+ # it is a headache to deal with None all the time.
143
+ # But it is not ideal, and if you have a better idea,
144
+ # please open an issue / submit a PR, thanks.
145
+ _labels = labels
146
+ _position_ids = position_ids
147
+ _attention_mask = attention_mask
148
+ if attention_mask is None:
149
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
150
+ else:
151
+ attention_mask = attention_mask.bool()
152
+ if position_ids is None:
153
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
154
+ if labels is None:
155
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
156
+
157
+ # remove the padding using attention_mask -- TODO: double check
158
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
159
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
160
+
161
+ new_input_embeds = []
162
+ new_labels = []
163
+ cur_image_idx = 0
164
+ for batch_idx, cur_input_ids in enumerate(input_ids):
165
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
166
+ if num_images == 0:
167
+ cur_image_features = image_features[cur_image_idx]
168
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
169
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
170
+ new_input_embeds.append(cur_input_embeds)
171
+ new_labels.append(labels[batch_idx])
172
+ cur_image_idx += 1
173
+ continue
174
+
175
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
176
+ cur_input_ids_noim = []
177
+ cur_labels = labels[batch_idx]
178
+ cur_labels_noim = []
179
+ for i in range(len(image_token_indices) - 1):
180
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
181
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
182
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
183
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
184
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
185
+ cur_new_input_embeds = []
186
+ cur_new_labels = []
187
+
188
+ for i in range(num_images + 1):
189
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
190
+ cur_new_labels.append(cur_labels_noim[i])
191
+ if i < num_images:
192
+ cur_image_features = image_features[cur_image_idx]
193
+ cur_image_idx += 1
194
+ cur_new_input_embeds.append(cur_image_features)
195
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
196
+
197
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
198
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
199
+ cur_new_labels = torch.cat(cur_new_labels)
200
+
201
+ new_input_embeds.append(cur_new_input_embeds)
202
+ new_labels.append(cur_new_labels)
203
+
204
+ # Truncate sequences to max length as image embeddings can make the sequence longer
205
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
206
+ if tokenizer_model_max_length is not None:
207
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
208
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
209
+
210
+ # Combine them
211
+ max_len = max(x.shape[0] for x in new_input_embeds)
212
+ batch_size = len(new_input_embeds)
213
+
214
+ new_input_embeds_padded = []
215
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
216
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
217
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
218
+
219
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
220
+ cur_len = cur_new_embed.shape[0]
221
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
222
+ new_input_embeds_padded.append(torch.cat((
223
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
224
+ cur_new_embed
225
+ ), dim=0))
226
+ if cur_len > 0:
227
+ new_labels_padded[i, -cur_len:] = cur_new_labels
228
+ attention_mask[i, -cur_len:] = True
229
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
230
+ else:
231
+ new_input_embeds_padded.append(torch.cat((
232
+ cur_new_embed,
233
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
234
+ ), dim=0))
235
+ if cur_len > 0:
236
+ new_labels_padded[i, :cur_len] = cur_new_labels
237
+ attention_mask[i, :cur_len] = True
238
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
239
+
240
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
241
+
242
+ if _labels is None:
243
+ new_labels = None
244
+ else:
245
+ new_labels = new_labels_padded
246
+
247
+ if _attention_mask is None:
248
+ attention_mask = None
249
+ else:
250
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
251
+
252
+ if _position_ids is None:
253
+ position_ids = None
254
+
255
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
256
+
257
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
258
+ if model_args.mm_use_im_patch_token:
259
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
260
+ self.resize_token_embeddings(len(tokenizer))
261
+
262
+ if model_args.mm_use_im_start_end:
263
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
264
+ self.resize_token_embeddings(len(tokenizer))
265
+
266
+ if num_new_tokens > 0:
267
+ input_embeddings = self.get_input_embeddings().weight.data
268
+ output_embeddings = self.get_output_embeddings().weight.data
269
+
270
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
271
+ dim=0, keepdim=True)
272
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
273
+ dim=0, keepdim=True)
274
+
275
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
276
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
277
+
278
+ if model_args.tune_mm_mlp_adapter:
279
+ for p in self.get_input_embeddings().parameters():
280
+ p.requires_grad = True
281
+ for p in self.get_output_embeddings().parameters():
282
+ p.requires_grad = False
283
+
284
+ if model_args.pretrain_mm_mlp_adapter:
285
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
286
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
287
+ assert num_new_tokens == 2
288
+ if input_embeddings.shape == embed_tokens_weight.shape:
289
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
290
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
291
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
292
+ else:
293
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
294
+ elif model_args.mm_use_im_patch_token:
295
+ if model_args.tune_mm_mlp_adapter:
296
+ for p in self.get_input_embeddings().parameters():
297
+ p.requires_grad = False
298
+ for p in self.get_output_embeddings().parameters():
299
+ p.requires_grad = False
minigemini/model/mini_gemini_arch.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from abc import ABC, abstractmethod
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import json
25
+ import os
26
+ import transformers
27
+ import safetensors
28
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
29
+ import deepspeed
30
+
31
+ from .multimodal_encoder.builder import build_vision_tower, build_vision_tower_aux
32
+ from .multimodal_projector.builder import build_vision_projector
33
+
34
+ from minigemini.constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN,
35
+ DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
36
+
37
+ IS_NEW_TRANSFORMERS = transformers.__version__ >= "4.34.0"
38
+
39
+ class MiniGeminiMetaModel:
40
+
41
+ def __init__(self, config):
42
+ super(MiniGeminiMetaModel, self).__init__(config)
43
+
44
+ if hasattr(config, "mm_vision_tower"):
45
+ self.vision_tower = build_vision_tower(config, delay_load=True)
46
+ self.mm_projector = build_vision_projector(config)
47
+
48
+ if hasattr(config, "mm_vision_tower_aux"):
49
+ self.vision_tower_aux = build_vision_tower_aux(config, delay_load=True)
50
+
51
+ def get_vision_tower(self):
52
+ vision_tower = getattr(self, 'vision_tower', None)
53
+ if type(vision_tower) is list:
54
+ vision_tower = vision_tower[0]
55
+ return vision_tower
56
+
57
+ def get_vision_tower_aux(self):
58
+ vision_tower_aux = getattr(self, 'vision_tower_aux', None)
59
+ if type(vision_tower_aux) is list:
60
+ vision_tower_aux = vision_tower_aux[0]
61
+ return vision_tower_aux
62
+
63
+ def initialize_vision_modules(self, model_args, fsdp=None):
64
+ vision_tower = model_args.vision_tower
65
+ vision_tower_aux = model_args.vision_tower_aux
66
+ mm_vision_select_layer = model_args.mm_vision_select_layer
67
+ mm_vision_select_feature = model_args.mm_vision_select_feature
68
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
69
+
70
+ self.config.mm_vision_tower = vision_tower
71
+ self.config.mm_vision_tower_aux = vision_tower_aux
72
+
73
+ if self.get_vision_tower() is None:
74
+ vision_tower = build_vision_tower(model_args)
75
+
76
+ if fsdp is not None and len(fsdp) > 0:
77
+ self.vision_tower = [vision_tower]
78
+ else:
79
+ self.vision_tower = vision_tower
80
+ else:
81
+ if fsdp is not None and len(fsdp) > 0:
82
+ vision_tower = self.vision_tower[0]
83
+ else:
84
+ vision_tower = self.vision_tower
85
+ vision_tower.load_model()
86
+
87
+ if vision_tower_aux is not None:
88
+ if self.get_vision_tower_aux() is None:
89
+ vision_tower_aux = build_vision_tower_aux(model_args)
90
+
91
+ if fsdp is not None and len(fsdp) > 0:
92
+ self.vision_tower_aux = [vision_tower_aux]
93
+ else:
94
+ self.vision_tower_aux = vision_tower_aux
95
+ else:
96
+ if fsdp is not None and len(fsdp) > 0:
97
+ vision_tower_aux = self.vision_tower_aux[0]
98
+ else:
99
+ vision_tower_aux = self.vision_tower_aux
100
+ vision_tower_aux.load_model()
101
+ self.config.mm_hidden_size_aux = vision_tower_aux.hidden_size
102
+
103
+ self.config.use_mm_proj = True
104
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
105
+ self.config.mm_hidden_size = vision_tower.hidden_size
106
+ self.config.mm_vision_select_layer = mm_vision_select_layer
107
+ self.config.mm_vision_select_feature = mm_vision_select_feature
108
+
109
+ if getattr(self, 'mm_projector', None) is None:
110
+ self.mm_projector = build_vision_projector(self.config)
111
+ else:
112
+ # In case it is frozen by LoRA
113
+ for p in self.mm_projector.parameters():
114
+ p.requires_grad = True
115
+
116
+ if pretrain_mm_mlp_adapter is not None:
117
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
118
+ def get_w(weights, keyword):
119
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
120
+
121
+ if 'model' in mm_projector_weights.keys():
122
+ mm_projector_weights = mm_projector_weights['model']
123
+ if is_deepspeed_zero3_enabled():
124
+ if len(mm_projector_weights) > 0:
125
+ with deepspeed.zero.GatheredParameters(mm_projector_weights, modifier_rank=0):
126
+ if torch.distributed.get_rank() == 0:
127
+ self.mm_projector.load_state_dict(mm_projector_weights)
128
+ else:
129
+ status = self.mm_projector.load_state_dict(mm_projector_weights, strict=False)
130
+ print('missing_keys:', status.missing_keys)
131
+ else:
132
+ if is_deepspeed_zero3_enabled():
133
+ named_parameters = get_w(mm_projector_weights, 'mm_projector')
134
+ if len(named_parameters) > 0:
135
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
136
+ if torch.distributed.get_rank() == 0:
137
+ self.mm_projector.load_state_dict(named_parameters)
138
+ else:
139
+ status = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
140
+ print('missing_keys:', status.missing_keys)
141
+ self.mm_projector = self.mm_projector.to(device=self.device)
142
+
143
+ def initialize_uni_modules(self, model_args, for_eval=False):
144
+ pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
145
+ self.config.image_size_aux = getattr(model_args, 'image_size_aux', 320)
146
+ self.config.optimize_vision_tower = getattr(model_args, 'optimize_vision_tower', False)
147
+ self.config.optimize_vision_tower_aux = getattr(model_args, 'optimize_vision_tower_aux', False)
148
+
149
+ self.vlm_uni_query_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size),
150
+ nn.Linear(self.config.mm_hidden_size, self.config.mm_hidden_size))
151
+ self.vlm_uni_aux_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
152
+ nn.Linear(self.config.mm_hidden_size_aux, self.config.mm_hidden_size))
153
+ self.vlm_uni_val_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
154
+ nn.Linear(self.config.mm_hidden_size_aux, self.config.mm_hidden_size))
155
+
156
+ if pretrain_mm_mlp_adapter is not None:
157
+ projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
158
+ else:
159
+ trainable_module = ['vlm_uni', 'vision_fpn', 'vision_stages']
160
+ if hasattr(model_args, 'model_name_or_path'):
161
+ model_save_path = model_args.model_name_or_path
162
+ else:
163
+ model_save_path = model_args.model_path
164
+ model_idx_path = getattr(model_args, 'model_path', model_save_path)
165
+ if IS_NEW_TRANSFORMERS:
166
+ try:
167
+ weight_file = json.load(open(os.path.join(model_idx_path, 'model.safetensors.index.json'), 'r'))['weight_map']
168
+ except:
169
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
170
+ else:
171
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
172
+ model_path = set([weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
173
+ projector_weights = {}
174
+ for _model in model_path:
175
+ if not IS_NEW_TRANSFORMERS:
176
+ projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
177
+ else:
178
+ with safetensors.safe_open(os.path.join(model_idx_path, _model), framework="pt", device='cpu') as f:
179
+ for _key in f.keys():
180
+ projector_weights.update({_key: f.get_tensor(_key)})
181
+ if len(projector_weights) == 0:
182
+ return
183
+
184
+ def get_w(weights, keyword, main_module, sub_module):
185
+ if getattr(main_module, sub_module, None) is None:
186
+ return
187
+
188
+ pretrain_weight = {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
189
+ if len(pretrain_weight) == 0:
190
+ return
191
+ if is_deepspeed_zero3_enabled():
192
+ named_parameters = [v for k, v in getattr(main_module, sub_module).named_parameters()]
193
+ if len(named_parameters) > 0:
194
+ # because zero3 puts placeholders in model params, this context
195
+ # manager gathers (unpartitions) the params of the current layer, then loads from
196
+ # the state dict and then re-partitions them again
197
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
198
+ if torch.distributed.get_rank() == 0:
199
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
200
+ with deepspeed.zero.GatheredParameters(self.mm_projector[0].weight, modifier_rank=None):
201
+ weight_type = self.mm_projector[0].weight.dtype
202
+ device_type = self.mm_projector[0].weight.device
203
+ else:
204
+ weight_type = self.mm_projector[0].weight.dtype
205
+ device_type = self.mm_projector[0].weight.device
206
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
207
+ if weight_type == torch.uint8 or weight_type == torch.int8 or weight_type == torch.int16:
208
+ weight_type = torch.float16
209
+
210
+ getattr(main_module, sub_module).to(dtype=weight_type)
211
+ print(f"Loading {sub_module} weights...")
212
+
213
+ # load pretrained weights
214
+ get_w(projector_weights, 'vision_tower.vision_tower', self.vision_tower, 'vision_tower')
215
+
216
+ # load pretrained weights
217
+ if self.config.optimize_vision_tower_aux:
218
+ # not optimize vision stem, just used to check
219
+ get_w(projector_weights, 'vision_tower_aux.vision_stem', self.vision_tower_aux, 'vision_stem')
220
+ get_w(projector_weights, 'vision_tower_aux.vision_stages', self.vision_tower_aux, 'vision_stages')
221
+ get_w(projector_weights, 'vlm_uni_query_projector', self, 'vlm_uni_query_projector')
222
+ get_w(projector_weights, 'vlm_uni_aux_projector', self, 'vlm_uni_aux_projector')
223
+ get_w(projector_weights, 'vlm_uni_val_projector', self, 'vlm_uni_val_projector')
224
+
225
+ class MiniGeminiMetaForCausalLM(ABC):
226
+
227
+ @abstractmethod
228
+ def get_model(self):
229
+ pass
230
+
231
+ def get_vision_tower(self):
232
+ return self.get_model().get_vision_tower()
233
+
234
+ def get_vision_tower_aux(self):
235
+ return self.get_model().get_vision_tower_aux()
236
+
237
+ def encode_images(self, images, images_aux=None, is_video=False):
238
+ image_grid = getattr(self.config, 'image_grid', 1)
239
+ image_global = getattr(self.config, 'image_global', False)
240
+ if image_grid > 1:
241
+ batch_size = images.shape[0]
242
+ if image_global:
243
+ global_images = images[:, -1:].flatten(0,1).contiguous()
244
+ grid_images = images[:, :-1].flatten(0,1).contiguous()
245
+ images = torch.cat([grid_images, global_images], dim=0)
246
+ else:
247
+ images = images.flatten(0,1).contiguous()
248
+
249
+ image_features = self.get_model().get_vision_tower()(images)
250
+
251
+ if image_global:
252
+ image_feat_global = image_features[-len(global_images):]
253
+ image_features = image_features[:len(grid_images)]
254
+
255
+ if images_aux is not None:
256
+ image_aux_features_raw = self.get_model().get_vision_tower_aux()(images_aux).to(
257
+ dtype=image_features.dtype, device=image_features.device)
258
+
259
+ if image_global:
260
+ image_aux_features_global = F.interpolate(image_aux_features_raw.float(),
261
+ scale_factor=1/image_grid,
262
+ mode='bilinear',
263
+ align_corners=False).to(dtype=image_aux_features_raw.dtype)
264
+ image_feat_global, image_aux_feat_global = self.unified_resampler(image_feat_global, image_aux_features_global)
265
+
266
+ if image_grid > 1:
267
+ image_aux_features_raw = image_aux_features_raw.reshape(*image_aux_features_raw.shape[:2],
268
+ image_grid,
269
+ image_aux_features_raw.shape[-2]//image_grid,
270
+ image_grid,
271
+ image_aux_features_raw.shape[-1]//image_grid)
272
+ image_aux_features_raw = image_aux_features_raw.permute(0, 2, 4, 1, 3, 5).flatten(1,2).flatten(0,1).contiguous()
273
+ image_features, image_aux_features = self.unified_resampler(image_features, image_aux_features_raw)
274
+
275
+ if image_grid > 1:
276
+ image_features = image_features.reshape(batch_size, image_grid**2, *image_features.shape[1:])
277
+ image_features = image_features.flatten(1,2).contiguous()
278
+ image_aux_features = image_aux_features.reshape(batch_size, image_grid**2, *image_aux_features.shape[1:])
279
+ image_aux_features = image_aux_features.flatten(1,2).contiguous()
280
+
281
+ # add global features, [global, local]
282
+ if image_global:
283
+ image_features = torch.cat([image_feat_global, image_features], dim=1)
284
+ image_aux_features = torch.cat([image_aux_feat_global, image_aux_features], dim=1)
285
+
286
+ # token generation
287
+ image_features = image_features + image_aux_features
288
+
289
+ # process image features after token generation
290
+ image_features = self.get_model().mm_projector(image_features)
291
+
292
+ return image_features
293
+
294
+ def unified_resampler(self, images, images_aux):
295
+ # patchwise with square images
296
+ patch_num = int(images.shape[1]**0.5)
297
+ patch_size = images_aux.shape[-1]//patch_num
298
+ # within patch attention
299
+ images_aux = images_aux.permute(0,2,3,1)
300
+ images_aux = images_aux.reshape(len(images_aux), patch_num, patch_size, patch_num, patch_size, images_aux.shape[-1])
301
+ images_aux = images_aux.permute(0,1,3,2,4,5)
302
+ images_aux = images_aux.reshape(len(images_aux), patch_num**2, patch_size**2, images_aux.shape[-1]).contiguous()
303
+
304
+ # token
305
+ print(self.get_model().vlm_uni_query_projector[0].weight.device)
306
+ embed_query = self.get_model().vlm_uni_query_projector(images)
307
+ embed_aux = self.get_model().vlm_uni_aux_projector(images_aux)
308
+ embed_value = self.get_model().vlm_uni_val_projector(images_aux)
309
+ embed_att = embed_query[:,:,None] @ (embed_aux.transpose(-1,-2) / (embed_aux.shape[-1]**0.5))
310
+ embed_att = embed_att.nan_to_num()
311
+ embed_feat = (embed_att.softmax(-1) @ embed_value).mean(2)
312
+
313
+ return images, embed_feat
314
+
315
+ def prepare_inputs_labels_for_multimodal(
316
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images=None, images_aux=None,
317
+ ):
318
+ vision_tower = self.get_vision_tower()
319
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
320
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
321
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
322
+ attention_mask = torch.cat((attention_mask, torch.ones(
323
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
324
+ dtype=attention_mask.dtype,
325
+ device=attention_mask.device
326
+ )), dim=1)
327
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
328
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
329
+
330
+ image_features = self.encode_images(images, images_aux)
331
+
332
+ # TODO: image start / end is not implemented here to support pretraining.
333
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
334
+ raise NotImplementedError
335
+
336
+ # Let's just add dummy tensors if they do not exist,
337
+ # it is a headache to deal with None all the time.
338
+ # But it is not ideal, and if you have a better idea,
339
+ # please open an issue / submit a PR, thanks.
340
+ _labels = labels
341
+ _position_ids = position_ids
342
+ _attention_mask = attention_mask
343
+ if attention_mask is None:
344
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
345
+ else:
346
+ attention_mask = attention_mask.bool()
347
+ if position_ids is None:
348
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
349
+ if labels is None:
350
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
351
+
352
+ # remove the padding using attention_mask -- TODO: double check
353
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
354
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
355
+
356
+ new_input_embeds = []
357
+ new_labels = []
358
+ cur_image_idx = 0
359
+ for batch_idx, cur_input_ids in enumerate(input_ids):
360
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
361
+ if num_images == 0:
362
+ cur_image_features = image_features[cur_image_idx]
363
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
364
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
365
+ new_input_embeds.append(cur_input_embeds)
366
+ new_labels.append(labels[batch_idx])
367
+ cur_image_idx += 1
368
+ continue
369
+
370
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
371
+ cur_input_ids_noim = []
372
+ cur_labels = labels[batch_idx]
373
+ cur_labels_noim = []
374
+ for i in range(len(image_token_indices) - 1):
375
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
376
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
377
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
378
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
379
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
380
+ cur_new_input_embeds = []
381
+ cur_new_labels = []
382
+
383
+ max_pos_id = 0
384
+ for i in range(num_images + 1):
385
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
386
+ cur_new_labels.append(cur_labels_noim[i])
387
+ max_pos_id += cur_input_embeds_no_im[i].shape[0]
388
+ if i < num_images:
389
+ cur_image_features = image_features[cur_image_idx]
390
+ cur_image_idx += 1
391
+ cur_new_input_embeds.append(cur_image_features)
392
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
393
+ max_pos_id += cur_image_features.shape[0]
394
+
395
+ cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
396
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
397
+ cur_new_labels = torch.cat(cur_new_labels)
398
+
399
+ new_input_embeds.append(cur_new_input_embeds)
400
+ new_labels.append(cur_new_labels)
401
+
402
+ # Truncate sequences to max length as image embeddings can make the sequence longer
403
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
404
+ if tokenizer_model_max_length is not None:
405
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
406
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
407
+
408
+ # Combine them
409
+ max_len = max(x.shape[0] for x in new_input_embeds)
410
+ batch_size = len(new_input_embeds)
411
+
412
+ new_input_embeds_padded = []
413
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
414
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
415
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
416
+
417
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
418
+ cur_len = cur_new_embed.shape[0]
419
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
420
+ new_input_embeds_padded.append(torch.cat((
421
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
422
+ cur_new_embed
423
+ ), dim=0))
424
+ if cur_len > 0:
425
+ new_labels_padded[i, -cur_len:] = cur_new_labels
426
+ attention_mask[i, -cur_len:] = True
427
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
428
+ else:
429
+ new_input_embeds_padded.append(torch.cat((
430
+ cur_new_embed,
431
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
432
+ ), dim=0))
433
+ if cur_len > 0:
434
+ new_labels_padded[i, :cur_len] = cur_new_labels
435
+ attention_mask[i, :cur_len] = True
436
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
437
+
438
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
439
+
440
+ if _labels is None:
441
+ new_labels = None
442
+ else:
443
+ new_labels = new_labels_padded
444
+
445
+ if _attention_mask is None:
446
+ attention_mask = None
447
+ else:
448
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
449
+
450
+ if _position_ids is None:
451
+ position_ids = None
452
+
453
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
454
+
455
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
456
+ if model_args.mm_use_im_patch_token:
457
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
458
+ self.resize_token_embeddings(len(tokenizer))
459
+
460
+ if model_args.mm_use_im_start_end:
461
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
462
+ self.resize_token_embeddings(len(tokenizer))
463
+
464
+ if num_new_tokens > 0:
465
+ input_embeddings = self.get_input_embeddings().weight.data
466
+ output_embeddings = self.get_output_embeddings().weight.data
467
+
468
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
469
+ dim=0, keepdim=True)
470
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
471
+ dim=0, keepdim=True)
472
+
473
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
474
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
475
+
476
+ if model_args.tune_mm_mlp_adapter:
477
+ for p in self.get_input_embeddings().parameters():
478
+ p.requires_grad = True
479
+ for p in self.get_output_embeddings().parameters():
480
+ p.requires_grad = False
481
+
482
+ if model_args.pretrain_mm_mlp_adapter:
483
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
484
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
485
+ assert num_new_tokens == 2
486
+ if input_embeddings.shape == embed_tokens_weight.shape:
487
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
488
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
489
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
490
+ else:
491
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
492
+ elif model_args.mm_use_im_patch_token:
493
+ if model_args.tune_mm_mlp_adapter:
494
+ for p in self.get_input_embeddings().parameters():
495
+ p.requires_grad = False
496
+ for p in self.get_output_embeddings().parameters():
497
+ p.requires_grad = False
minigemini/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .openclip_encoder import OpenCLIPVisionTower
4
+
5
+
6
+ def build_vision_tower(vision_tower_cfg, **kwargs):
7
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8
+ image_processor = getattr(vision_tower_cfg, 'image_processor', getattr(vision_tower_cfg, 'image_processor', "../processor/clip-patch14-224"))
9
+
10
+ # if not os.path.exists(vision_tower):
11
+ # raise ValueError(f'Not find vision tower: {vision_tower}')
12
+
13
+ if "openai" in vision_tower.lower() or "ShareGPT4V" in vision_tower:
14
+ vision_tower = 'openai/clip-vit-large-patch14-336'
15
+ # vision_tower = '/dataset/chengyaowang/official/MiniGemini/model_zoo/OpenAI/clip-vit-large-patch14-336'
16
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
17
+ else:
18
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
19
+
20
+
21
+ def build_vision_tower_aux(vision_tower_cfg, **kwargs):
22
+ vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', getattr(vision_tower_cfg, 'vision_tower_aux', None))
23
+
24
+ # if not os.path.exists(vision_tower_aux):
25
+ # raise ValueError(f'Not find vision tower: {vision_tower_aux}')
26
+
27
+ if "openclip" in vision_tower_aux.lower():
28
+ vision_tower_aux = './checkpoints/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup'
29
+ # vision_tower_aux = '/dataset/chengyaowang/official/MiniGemini/model_zoo/OpenAI/openclip-convnext-large-d-320-laion2B-s29B-b131K-ft-soup'
30
+ return OpenCLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
31
+ elif "openai" in vision_tower_aux.lower():
32
+ return CLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
33
+ else:
34
+ raise ValueError(f'Unknown vision tower: {vision_tower_aux}')
minigemini/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+ from ..processor.video_processor import VideoFramesProcessor
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+ self.is_optimize = getattr(args, 'optimize_vision_tower', False)
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
21
+ self.load_model()
22
+ else:
23
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
24
+
25
+ def load_model(self):
26
+ self.image_processor = VideoFramesProcessor.from_pretrained(self.vision_tower_name)
27
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
28
+ self.vision_tower.requires_grad_(False)
29
+
30
+ self.is_loaded = True
31
+
32
+ def feature_select(self, image_forward_outs):
33
+ image_features = image_forward_outs.hidden_states[self.select_layer]
34
+ if self.select_feature == 'patch':
35
+ image_features = image_features[:, 1:]
36
+ elif self.select_feature == 'cls_patch':
37
+ image_features = image_features
38
+ else:
39
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
40
+ return image_features
41
+
42
+ def image_forward(self, images):
43
+ if type(images) is list:
44
+ image_features = []
45
+ for image in images:
46
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
47
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
48
+ image_features.append(image_feature)
49
+ else:
50
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
51
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
52
+
53
+ return image_features
54
+
55
+ def forward(self, images):
56
+ if not self.is_optimize:
57
+ with torch.no_grad():
58
+ image_features = self.image_forward(images)
59
+ else:
60
+ image_features = self.image_forward(images)
61
+
62
+ return image_features
63
+
64
+ @property
65
+ def dummy_feature(self):
66
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
+
68
+ @property
69
+ def dtype(self):
70
+ return self.vision_tower.dtype
71
+
72
+ @property
73
+ def device(self):
74
+ return self.vision_tower.device
75
+
76
+ @property
77
+ def config(self):
78
+ if self.is_loaded:
79
+ return self.vision_tower.config
80
+ else:
81
+ return self.cfg_only
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.config.hidden_size
86
+
87
+ @property
88
+ def num_patches(self):
89
+ return (self.config.image_size // self.config.patch_size) ** 2
minigemini/model/multimodal_encoder/eva_encoder.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+ from transformers import CLIPImageProcessor, CLIPVisionConfig
18
+ from ..processor.video_processor import VideoFramesProcessor
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31
+ """
32
+ def __init__(self, drop_prob=None):
33
+ super(DropPath, self).__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
38
+
39
+ def extra_repr(self) -> str:
40
+ return 'p={}'.format(self.drop_prob)
41
+
42
+
43
+ class Mlp(nn.Module):
44
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
45
+ super().__init__()
46
+ out_features = out_features or in_features
47
+ hidden_features = hidden_features or in_features
48
+ self.fc1 = nn.Linear(in_features, hidden_features)
49
+ self.act = act_layer()
50
+ self.fc2 = nn.Linear(hidden_features, out_features)
51
+ self.drop = nn.Dropout(drop)
52
+
53
+ def forward(self, x):
54
+ x = self.fc1(x)
55
+ x = self.act(x)
56
+ # x = self.drop(x)
57
+ # commit this for the orignal BERT implement
58
+ x = self.fc2(x)
59
+ x = self.drop(x)
60
+ return x
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(
65
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
66
+ proj_drop=0., window_size=None, attn_head_dim=None):
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ head_dim = dim // num_heads
70
+ if attn_head_dim is not None:
71
+ head_dim = attn_head_dim
72
+ all_head_dim = head_dim * self.num_heads
73
+ self.scale = qk_scale or head_dim ** -0.5
74
+
75
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
76
+ if qkv_bias:
77
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
78
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ else:
80
+ self.q_bias = None
81
+ self.v_bias = None
82
+
83
+ if window_size:
84
+ self.window_size = window_size
85
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+ # cls to token & token 2 cls & cls to cls
89
+
90
+ # get pair-wise relative position index for each token inside the window
91
+ coords_h = torch.arange(window_size[0])
92
+ coords_w = torch.arange(window_size[1])
93
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
94
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
95
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
96
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
97
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
98
+ relative_coords[:, :, 1] += window_size[1] - 1
99
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
100
+ relative_position_index = \
101
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
102
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
104
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
105
+ relative_position_index[0, 0] = self.num_relative_distance - 1
106
+
107
+ self.register_buffer("relative_position_index", relative_position_index)
108
+ else:
109
+ self.window_size = None
110
+ self.relative_position_bias_table = None
111
+ self.relative_position_index = None
112
+
113
+ self.attn_drop = nn.Dropout(attn_drop)
114
+ self.proj = nn.Linear(all_head_dim, dim)
115
+ self.proj_drop = nn.Dropout(proj_drop)
116
+
117
+ def forward(self, x, rel_pos_bias=None):
118
+ B, N, C = x.shape
119
+ qkv_bias = None
120
+ if self.q_bias is not None:
121
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
122
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
124
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
125
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
126
+
127
+ q = q * self.scale
128
+ attn = (q @ k.transpose(-2, -1))
129
+
130
+ if self.relative_position_bias_table is not None:
131
+ relative_position_bias = \
132
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133
+ self.window_size[0] * self.window_size[1] + 1,
134
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
135
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
136
+ attn = attn + relative_position_bias.unsqueeze(0)
137
+
138
+ if rel_pos_bias is not None:
139
+ attn = attn + rel_pos_bias
140
+
141
+ attn = attn.softmax(dim=-1)
142
+ attn = self.attn_drop(attn)
143
+
144
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
145
+ x = self.proj(x)
146
+ x = self.proj_drop(x)
147
+ return x
148
+
149
+
150
+ class Block(nn.Module):
151
+
152
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
153
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
154
+ window_size=None, attn_head_dim=None):
155
+ super().__init__()
156
+ self.norm1 = norm_layer(dim)
157
+ self.attn = Attention(
158
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
159
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
160
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
161
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
162
+ self.norm2 = norm_layer(dim)
163
+ mlp_hidden_dim = int(dim * mlp_ratio)
164
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
165
+
166
+ if init_values is not None and init_values > 0:
167
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
168
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ else:
170
+ self.gamma_1, self.gamma_2 = None, None
171
+
172
+ def forward(self, x, rel_pos_bias=None):
173
+ if self.gamma_1 is None:
174
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
175
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
176
+ else:
177
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
178
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
179
+ return x
180
+
181
+
182
+ class PatchEmbed(nn.Module):
183
+ """ Image to Patch Embedding
184
+ """
185
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
186
+ super().__init__()
187
+ img_size = to_2tuple(img_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
190
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
191
+ self.img_size = img_size
192
+ self.patch_size = patch_size
193
+ self.num_patches = num_patches
194
+
195
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
196
+
197
+ def forward(self, x, **kwargs):
198
+ B, C, H, W = x.shape
199
+ # FIXME look at relaxing size constraints
200
+ assert H == self.img_size[0] and W == self.img_size[1], \
201
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
202
+ x = self.proj(x).flatten(2).transpose(1, 2)
203
+ return x
204
+
205
+
206
+ class RelativePositionBias(nn.Module):
207
+
208
+ def __init__(self, window_size, num_heads):
209
+ super().__init__()
210
+ self.window_size = window_size
211
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
212
+ self.relative_position_bias_table = nn.Parameter(
213
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
214
+ # cls to token & token 2 cls & cls to cls
215
+
216
+ # get pair-wise relative position index for each token inside the window
217
+ coords_h = torch.arange(window_size[0])
218
+ coords_w = torch.arange(window_size[1])
219
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
220
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
221
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
222
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
223
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
224
+ relative_coords[:, :, 1] += window_size[1] - 1
225
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
226
+ relative_position_index = \
227
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
228
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
229
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
230
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
231
+ relative_position_index[0, 0] = self.num_relative_distance - 1
232
+
233
+ self.register_buffer("relative_position_index", relative_position_index)
234
+
235
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
236
+
237
+ def forward(self):
238
+ relative_position_bias = \
239
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240
+ self.window_size[0] * self.window_size[1] + 1,
241
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
242
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
250
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
251
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
252
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
253
+ super().__init__()
254
+ self.image_size = img_size
255
+ self.num_classes = num_classes
256
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
257
+
258
+ self.patch_embed = PatchEmbed(
259
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
260
+ num_patches = self.patch_embed.num_patches
261
+
262
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
263
+ if use_abs_pos_emb:
264
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
265
+ else:
266
+ self.pos_embed = None
267
+ self.pos_drop = nn.Dropout(p=drop_rate)
268
+
269
+ if use_shared_rel_pos_bias:
270
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
271
+ else:
272
+ self.rel_pos_bias = None
273
+ self.use_checkpoint = use_checkpoint
274
+
275
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
276
+ self.use_rel_pos_bias = use_rel_pos_bias
277
+ self.blocks = nn.ModuleList([
278
+ Block(
279
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
280
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
281
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
282
+ for i in range(depth)])
283
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
284
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
285
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
286
+
287
+ if self.pos_embed is not None:
288
+ trunc_normal_(self.pos_embed, std=.02)
289
+ trunc_normal_(self.cls_token, std=.02)
290
+ # trunc_normal_(self.mask_token, std=.02)
291
+ # if isinstance(self.head, nn.Linear):
292
+ # trunc_normal_(self.head.weight, std=.02)
293
+ self.apply(self._init_weights)
294
+ self.fix_init_weight()
295
+ # if isinstance(self.head, nn.Linear):
296
+ # self.head.weight.data.mul_(init_scale)
297
+ # self.head.bias.data.mul_(init_scale)
298
+
299
+ def fix_init_weight(self):
300
+ def rescale(param, layer_id):
301
+ param.div_(math.sqrt(2.0 * layer_id))
302
+
303
+ for layer_id, layer in enumerate(self.blocks):
304
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
305
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
306
+
307
+ def _init_weights(self, m):
308
+ if isinstance(m, nn.Linear):
309
+ trunc_normal_(m.weight, std=.02)
310
+ if isinstance(m, nn.Linear) and m.bias is not None:
311
+ nn.init.constant_(m.bias, 0)
312
+ elif isinstance(m, nn.LayerNorm):
313
+ nn.init.constant_(m.bias, 0)
314
+ nn.init.constant_(m.weight, 1.0)
315
+
316
+ def get_classifier(self):
317
+ return self.head
318
+
319
+ def reset_classifier(self, num_classes, global_pool=''):
320
+ self.num_classes = num_classes
321
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
322
+
323
+ def forward_features(self, x):
324
+ x = self.patch_embed(x)
325
+ batch_size, seq_len, _ = x.size()
326
+
327
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
328
+ x = torch.cat((cls_tokens, x), dim=1)
329
+ if self.pos_embed is not None:
330
+ x = x + self.pos_embed
331
+ x = self.pos_drop(x)
332
+
333
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
334
+ for blk in self.blocks:
335
+ if self.use_checkpoint:
336
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
337
+ else:
338
+ x = blk(x, rel_pos_bias)
339
+ return x
340
+ # x = self.norm(x)
341
+
342
+ # if self.fc_norm is not None:
343
+ # t = x[:, 1:, :]
344
+ # return self.fc_norm(t.mean(1))
345
+ # else:
346
+ # return x[:, 0]
347
+
348
+ def forward(self, x):
349
+ x = self.forward_features(x)
350
+ # x = self.head(x)
351
+ return x
352
+
353
+ def get_intermediate_layers(self, x):
354
+ x = self.patch_embed(x)
355
+ batch_size, seq_len, _ = x.size()
356
+
357
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
358
+ x = torch.cat((cls_tokens, x), dim=1)
359
+ if self.pos_embed is not None:
360
+ x = x + self.pos_embed
361
+ x = self.pos_drop(x)
362
+
363
+ features = []
364
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
365
+ for blk in self.blocks:
366
+ x = blk(x, rel_pos_bias)
367
+ features.append(x)
368
+
369
+ return features
370
+
371
+ @property
372
+ def dtype(self):
373
+ return self.cls_token.dtype
374
+
375
+ @property
376
+ def device(self):
377
+ return self.cls_token.device
378
+
379
+ def get_num_layer(self, var_name=""):
380
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
381
+ return 0
382
+ elif var_name.startswith("patch_embed"):
383
+ return 0
384
+ elif var_name.startswith("rel_pos_bias"):
385
+ return len(self.blocks) - 1
386
+ elif var_name.startswith("blocks"):
387
+ layer_id = int(var_name.split('.')[1])
388
+ return layer_id + 1
389
+ else:
390
+ return len(self.blocks)
391
+
392
+
393
+ def interpolate_pos_embed(model, checkpoint_model):
394
+ if 'pos_embed' in checkpoint_model:
395
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
396
+ embedding_size = pos_embed_checkpoint.shape[-1]
397
+ num_patches = model.patch_embed.num_patches
398
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
399
+ # height (== width) for the checkpoint position embedding
400
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
401
+ # height (== width) for the new position embedding
402
+ new_size = int(num_patches ** 0.5)
403
+ # class_token and dist_token are kept unchanged
404
+ if orig_size != new_size:
405
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
406
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
407
+ # only the position tokens are interpolated
408
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
409
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
410
+ pos_tokens = torch.nn.functional.interpolate(
411
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
412
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
413
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
414
+ checkpoint_model['pos_embed'] = new_pos_embed
415
+
416
+
417
+ def convert_weights_to_fp16(model: nn.Module):
418
+ """Convert applicable model parameters to fp16"""
419
+
420
+ def _convert_weights_to_fp16(l):
421
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
422
+ l.weight.data = l.weight.data.half()
423
+ if l.bias is not None:
424
+ l.bias.data = l.bias.data.half()
425
+
426
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
427
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
428
+ # tensor = getattr(l, attr)
429
+ # if tensor is not None:
430
+ # tensor.data = tensor.data.half()
431
+
432
+ model.apply(_convert_weights_to_fp16)
433
+
434
+ class EVAVisionTower(nn.Module):
435
+ def __init__(self, vision_tower, image_processor, args, use_checkpoint=False, drop_path_rate=0.0, delay_load=False, dtype=torch.float32):
436
+ super().__init__()
437
+
438
+ self.is_loaded = False
439
+ self.use_checkpoint = use_checkpoint
440
+ self.vision_tower_name = vision_tower
441
+ self.image_processor_name = image_processor
442
+ self.drop_path_rate = drop_path_rate
443
+ self.patch_size = 14
444
+ self.out_channel = 1408
445
+ if not delay_load:
446
+ self.load_model()
447
+
448
+ self.vision_config = CLIPVisionConfig.from_pretrained(image_processor)
449
+
450
+ def load_model(self):
451
+ # self.image_processor = CLIPImageProcessor.from_pretrained(self.image_processor_name)
452
+ self.image_processor = VideoFramesProcessor.from_pretrained(self.image_processor_name)
453
+ self.vision_tower = VisionTransformer(
454
+ img_size=self.image_processor.size['shortest_edge'],
455
+ patch_size=self.patch_size,
456
+ use_mean_pooling=False,
457
+ embed_dim=self.out_channel,
458
+ depth=39,
459
+ num_heads=self.out_channel//88,
460
+ mlp_ratio=4.3637,
461
+ qkv_bias=True,
462
+ drop_path_rate=self.drop_path_rate,
463
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
464
+ use_checkpoint=self.use_checkpoint,
465
+ )
466
+
467
+ state_dict = torch.load(self.vision_tower_name, map_location="cpu")
468
+ interpolate_pos_embed(self.vision_tower, state_dict)
469
+ incompatible_keys = self.vision_tower.load_state_dict(state_dict, strict=False)
470
+ print(incompatible_keys)
471
+ self.vision_tower.requires_grad_(False)
472
+
473
+ self.is_loaded = True
474
+
475
+ @torch.no_grad()
476
+ def forward(self, images):
477
+ if type(images) is list:
478
+ image_features = []
479
+ for image in images:
480
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
481
+ image_feature = image_forward_out.to(image.dtype)
482
+ image_features.append(image_feature)
483
+ else:
484
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
485
+ image_features = image_forward_outs.to(images.dtype)
486
+
487
+ return image_features
488
+
489
+ def feature_select(self, image_features):
490
+ # image_features = image_features.hidden_states[self.select_layer]
491
+ if self.select_feature == 'patch':
492
+ image_features = image_features[:, 1:]
493
+ elif self.select_feature == 'cls_patch':
494
+ image_features = image_features
495
+ else:
496
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
497
+ return image_features
498
+
499
+ @property
500
+ def dummy_feature(self):
501
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
502
+
503
+ @property
504
+ def dtype(self):
505
+ return self.vision_tower.dtype
506
+
507
+ @property
508
+ def device(self):
509
+ return self.vision_tower.device
510
+
511
+ @property
512
+ def config(self):
513
+ return self.vision_config
514
+
515
+ @property
516
+ def hidden_size(self):
517
+ return self.out_channel
518
+
519
+ @property
520
+ def num_patches(self):
521
+ return (self.image_processor.size['shortest_edge'] // self.patch_size) ** 2
522
+
523
+
524
+
525
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,model_path=None,precision="fp16"):
526
+ model = VisionTransformer(
527
+ img_size=img_size,
528
+ patch_size=14,
529
+ use_mean_pooling=False,
530
+ embed_dim=1408,
531
+ depth=39,
532
+ num_heads=1408//88,
533
+ mlp_ratio=4.3637,
534
+ qkv_bias=True,
535
+ drop_path_rate=drop_path_rate,
536
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
537
+ use_checkpoint=use_checkpoint,
538
+ )
539
+ # url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
540
+ # cached_file = download_cached_file(
541
+ # url, check_hash=False, progress=True
542
+ # )
543
+ state_dict = torch.load(model_path, map_location="cpu")
544
+ interpolate_pos_embed(model,state_dict)
545
+
546
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
547
+ print(incompatible_keys)
548
+
549
+ if precision == "fp16":
550
+ convert_weights_to_fp16(model)
551
+ return model
minigemini/model/multimodal_encoder/openclip_encoder.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import json
6
+ import logging
7
+ import deepspeed
8
+ from pathlib import Path
9
+ from open_clip.factory import load_state_dict, get_model_config
10
+ from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
11
+ from typing import Dict, Optional
12
+ from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
13
+
14
+ open_clip_config = {
15
+ "model_cfg": {
16
+ "embed_dim": 768,
17
+ "vision_cfg": {
18
+ "timm_model_name": "convnext_large",
19
+ "timm_model_pretrained": False,
20
+ "timm_pool": "",
21
+ "timm_proj": "mlp",
22
+ "timm_drop": 0.0,
23
+ "timm_drop_path": 0.1,
24
+ "image_size": 320
25
+ },
26
+ "text_cfg": {
27
+ "context_length": 77,
28
+ "vocab_size": 49408,
29
+ "width": 768,
30
+ "heads": 12,
31
+ "layers": 16
32
+ }
33
+ },
34
+ "preprocess_cfg": {
35
+ "mean": [
36
+ 0.48145466,
37
+ 0.4578275,
38
+ 0.40821073
39
+ ],
40
+ "std": [
41
+ 0.26862954,
42
+ 0.26130258,
43
+ 0.27577711
44
+ ]
45
+ }
46
+ }
47
+
48
+ # xxx
49
+ class OpenCLIPVisionTower(nn.Module):
50
+ def __init__(self, vision_tower, args, delay_load=False):
51
+ super().__init__()
52
+
53
+ self.is_loaded = False
54
+ self.vision_tower_name = vision_tower
55
+ self.vision_config = open_clip_config
56
+ # json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
57
+ self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
58
+
59
+ if not delay_load:
60
+ self.load_model()
61
+
62
+ def load_model(self):
63
+ # print(self.vision_tower_name)
64
+
65
+ ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
66
+ if 'convnext' in self.vision_tower_name:
67
+ if 'large' in self.vision_tower_name and 'd_320' in self.vision_tower_name:
68
+ self.model_type = 'convnext_large_d_320'
69
+ self.model_channel = [192, 384, 768, 1536] # stage 0-3
70
+ elif 'base' in self.vision_tower_name and 'w_320' in self.vision_tower_name:
71
+ self.model_type = 'convnext_base_w_320'
72
+ self.model_channel = [128, 256, 512, 1024]
73
+ elif 'xxlarge' in self.vision_tower_name:
74
+ self.model_type = 'convnext_xxlarge'
75
+ self.model_channel = [384, 768, 1536, 3072]
76
+
77
+ clip_model = CLIP(**get_model_config(self.model_type))
78
+ clip_model.visual.trunk.norm_pre = None
79
+ clip_model.visual.trunk.head = None
80
+ clip_model.visual.head = None
81
+ print(f'Loading pretrained weights ({self.model_type}).')
82
+ load_checkpoint(clip_model, ckpt_path, strict=False)
83
+
84
+ self.is_loaded = True
85
+ # decompose stem and stages blocks in vision tower
86
+ self.vision_stem = clip_model.visual.trunk.stem
87
+ self.vision_stages = clip_model.visual.trunk.stages
88
+ self.vision_stem.requires_grad_(False)
89
+ self.vision_stages.requires_grad_(False)
90
+
91
+ def forward(self, images):
92
+ if type(images) is list:
93
+ image_features = []
94
+ for image in images:
95
+ image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
96
+ image_features.append(image_feature)
97
+ else:
98
+ image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
99
+
100
+ return image_features
101
+
102
+ def backbone(self, images):
103
+ if not self.is_optimize:
104
+ with torch.no_grad():
105
+ results = self.basic_forward(images)
106
+ else:
107
+ results = self.basic_forward(images)
108
+
109
+ target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
110
+ result_cat = []
111
+ for _stage in results:
112
+ if _stage == 'stage_0':
113
+ result_cat.append(results[_stage].contiguous())
114
+ else:
115
+ result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
116
+ size=target_size,
117
+ mode='bilinear',
118
+ align_corners=False).to(dtype=results[_stage].dtype))
119
+ result_cat = torch.cat(result_cat, dim=1)
120
+
121
+ return result_cat.contiguous()
122
+
123
+ def basic_forward(self, images):
124
+ results = {}
125
+ x = self.vision_stem(images)
126
+ for _idx in range(len(self.vision_stages)):
127
+ x = self.vision_stages[_idx](x)
128
+ results[f'stage_{_idx}'] = x
129
+ return results
130
+
131
+ @property
132
+ def dummy_feature(self):
133
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
134
+
135
+ @property
136
+ def dtype(self):
137
+ return self.vision_stem[0].weight.dtype
138
+
139
+ @property
140
+ def device(self):
141
+ return self.vision_stem[0].weight.device
142
+
143
+ @property
144
+ def config(self):
145
+ return self.vision_config
146
+
147
+ @property
148
+ def hidden_size(self):
149
+ return sum(self.model_channel)
150
+
151
+ # modified function from open_clip to support zero3 stage
152
+ def load_checkpoint(model, checkpoint_path, strict=True):
153
+ if Path(checkpoint_path).suffix in ('.npz', '.npy'):
154
+ from open_clip.big_vision import load_big_vision_weights
155
+ load_big_vision_weights(model, checkpoint_path)
156
+ return {}
157
+
158
+ state_dict = load_state_dict(checkpoint_path)
159
+ # detect old format and make compatible with new format
160
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
161
+ state_dict = convert_to_custom_text_state_dict(state_dict)
162
+ # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
163
+ # if 'logit_bias' not in state_dict and model.logit_bias is not None:
164
+ # state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
165
+ # Certain text transformers no longer expect position_ids after transformers==4.31
166
+ position_id_key = 'text.transformer.embeddings.position_ids'
167
+ if position_id_key in state_dict and not hasattr(model, position_id_key):
168
+ del state_dict[position_id_key]
169
+ resize_pos_embed(state_dict, model)
170
+ # resize_text_pos_embed(state_dict, model)
171
+ #incompatible_keys = model.load_state_dict(state_dict, strict=strict)
172
+ if is_deepspeed_zero3_enabled():
173
+
174
+ error_msgs = []
175
+
176
+ def load(module: nn.Module, state_dict, prefix=""):
177
+ metadata = None
178
+
179
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
180
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
181
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
182
+ # state_dict
183
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
184
+ if is_deepspeed_zero3_enabled():
185
+ # In sharded models, each shard has only part of the full state_dict, so only gather
186
+ # parameters that are in the current state_dict.
187
+ named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
188
+ params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
189
+ if len(params_to_gather) > 0:
190
+ # because zero3 puts placeholders in model params, this context
191
+ # manager gathers (unpartitions) the params of the current layer, then loads from
192
+ # the state dict and then re-partitions them again
193
+ with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
194
+ if torch.distributed.get_rank() == 0:
195
+ module._load_from_state_dict(*args)
196
+ else:
197
+ module._load_from_state_dict(*args)
198
+
199
+ for name, child in module._modules.items():
200
+ if child is not None:
201
+ load(child, state_dict, prefix + name + ".")
202
+
203
+ load(model, state_dict)
204
+ incompatible_keys = []
205
+ else:
206
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
207
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
208
+ return incompatible_keys
209
+
210
+ class CLIP(nn.Module):
211
+ output_dict: torch.jit.Final[bool]
212
+
213
+ def __init__(
214
+ self,
215
+ embed_dim: int,
216
+ vision_cfg: CLIPVisionCfg,
217
+ text_cfg: CLIPTextCfg,
218
+ quick_gelu: bool = False,
219
+ cast_dtype: Optional[torch.dtype] = None,
220
+ output_dict: bool = False,
221
+ ):
222
+ super().__init__()
223
+ self.output_dict = output_dict
224
+
225
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
minigemini/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ class IdentityMap(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_projector_type": 'identity'}
15
+
16
+
17
+ class SimpleResBlock(nn.Module):
18
+ def __init__(self, channels):
19
+ super().__init__()
20
+ self.pre_norm = nn.LayerNorm(channels)
21
+
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(channels, channels),
24
+ nn.GELU(),
25
+ nn.Linear(channels, channels)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
34
+
35
+ if projector_type == 'linear':
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == 'identity':
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f'Unknown projector type: {projector_type}')
minigemini/model/processor/video_processor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPImageProcessor
2
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
3
+ from transformers.image_transforms import get_resize_output_image_size
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import numpy as np
9
+
10
+
11
+ class VideoFramesProcessor(CLIPImageProcessor):
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+
16
+ def preprocess(self, images, **kwargs):
17
+ if not isinstance(images, np.ndarray):
18
+ return super().preprocess(images=images, **kwargs)
19
+
20
+ do_resize = kwargs.get('do_resize', self.do_resize)
21
+ size = kwargs.get('size', self.size)
22
+ size = get_size_dict(size, param_name="size", default_to_square=False)
23
+ do_center_crop = kwargs.get('do_center_crop', self.do_center_crop)
24
+ crop_size = kwargs.get('crop_size', self.crop_size)
25
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
26
+ do_rescale = kwargs.get('do_rescale', self.do_rescale)
27
+ rescale_factor = kwargs.get('rescale_factor', self.rescale_factor)
28
+ do_normalize = kwargs.get('do_normalize', self.do_normalize)
29
+ image_mean = kwargs.get('image_mean', self.image_mean)
30
+ image_std = kwargs.get('image_std', self.image_std)
31
+ return_tensors = kwargs.get('return_tensors', None)
32
+
33
+ def resize(images, output_size):
34
+ images = images.permute((0, 3, 1, 2))
35
+ images = F.interpolate(images, size=output_size, mode='bicubic')
36
+ images = images.permute((0, 2, 3, 1))
37
+ return images
38
+
39
+ def center_crop(images, crop_size):
40
+ crop_width, crop_height = crop_size["width"], crop_size["height"]
41
+ img_width, img_height = images.shape[1:3]
42
+ x = (img_width - crop_width) // 2
43
+ y = (img_height - crop_height) // 2
44
+ images = images[:, x:x+crop_width, y:y+crop_height]
45
+ return images
46
+
47
+ def rescale(images, rescale_factor):
48
+ images = images * rescale_factor
49
+ return images
50
+
51
+ def normalize(images, mean, std):
52
+ mean = torch.tensor(mean)
53
+ std = torch.tensor(std)
54
+ images = (images - mean) / std
55
+ return images
56
+
57
+ images = torch.from_numpy(images).float()
58
+
59
+ if do_resize:
60
+ output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False)
61
+ images = resize(images, output_size)
62
+
63
+ if do_center_crop:
64
+ images = center_crop(images, crop_size)
65
+
66
+ if do_rescale:
67
+ images = rescale(images, rescale_factor)
68
+
69
+ if do_normalize:
70
+ images = normalize(images, image_mean, image_std)
71
+
72
+ images = images.permute((0, 3, 1, 2))
73
+ data = {"pixel_values": images}
74
+ return BatchFeature(data=data, tensor_type=return_tensors)
minigemini/serve/__init__.py ADDED
File without changes
minigemini/serve/cli.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from minigemini.conversation import conv_templates, SeparatorStyle
6
+ from minigemini.model.builder import load_pretrained_model
7
+ from minigemini.utils import disable_torch_init
8
+ from minigemini.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+ try:
17
+ from diffusers import StableDiffusionXLPipeline
18
+ except:
19
+ print('please install diffusers==0.26.3')
20
+
21
+ try:
22
+ from paddleocr import PaddleOCR
23
+ except:
24
+ print('please install paddleocr following https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/README_en.md')
25
+
26
+
27
+ def load_image(image_file):
28
+ if image_file.startswith('http://') or image_file.startswith('https://'):
29
+ response = requests.get(image_file)
30
+ image = Image.open(BytesIO(response.content)).convert('RGB')
31
+ else:
32
+ image = Image.open(image_file).convert('RGB')
33
+ return image
34
+
35
+
36
+ def main(args):
37
+ # Model
38
+ disable_torch_init()
39
+
40
+ if args.ocr and args.image_file is not None:
41
+ ocr = PaddleOCR(use_angle_cls=True, use_gpu=True, lang="ch")
42
+ result = ocr.ocr(args.image_file)
43
+ str_in_image = ''
44
+ if result[0] is not None:
45
+ result = [res[1][0] for res in result[0] if res[1][1] > 0.1]
46
+ if len(result) > 0:
47
+ str_in_image = ', '.join(result)
48
+ print('OCR Token: ' + str_in_image)
49
+
50
+ if args.gen:
51
+ pipe = StableDiffusionXLPipeline.from_pretrained(
52
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
53
+ ).to("cuda")
54
+
55
+ model_name = get_model_name_from_path(args.model_path)
56
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
57
+
58
+ if '8x7b' in model_name.lower():
59
+ conv_mode = "mistral_instruct"
60
+ elif '34b' in model_name.lower():
61
+ conv_mode = "chatml_direct"
62
+ elif '2b' in model_name.lower():
63
+ conv_mode = "gemma"
64
+ else:
65
+ conv_mode = "vicuna_v1"
66
+
67
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
68
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
69
+ else:
70
+ args.conv_mode = conv_mode
71
+
72
+ conv = conv_templates[args.conv_mode].copy()
73
+ if "mpt" in model_name.lower():
74
+ roles = ('user', 'assistant')
75
+ else:
76
+ roles = conv.roles
77
+
78
+ if args.image_file is not None:
79
+ images = []
80
+ if ',' in args.image_file:
81
+ images = args.image_file.split(',')
82
+ else:
83
+ images = [args.image_file]
84
+
85
+ image_convert = []
86
+ for _image in images:
87
+ image_convert.append(load_image(_image))
88
+
89
+ if hasattr(model.config, 'image_size_aux'):
90
+ if not hasattr(image_processor, 'image_size_raw'):
91
+ image_processor.image_size_raw = image_processor.crop_size.copy()
92
+ image_processor.crop_size['height'] = model.config.image_size_aux
93
+ image_processor.crop_size['width'] = model.config.image_size_aux
94
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
95
+
96
+ # Similar operation in model_worker.py
97
+ image_tensor = process_images(image_convert, image_processor, model.config)
98
+
99
+ image_grid = getattr(model.config, 'image_grid', 1)
100
+ if hasattr(model.config, 'image_size_aux'):
101
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
102
+ image_processor.image_size_raw['width'] * image_grid]
103
+ image_tensor_aux = image_tensor
104
+ image_tensor = torch.nn.functional.interpolate(image_tensor,
105
+ size=raw_shape,
106
+ mode='bilinear',
107
+ align_corners=False)
108
+ else:
109
+ image_tensor_aux = []
110
+
111
+ if image_grid >= 2:
112
+ raw_image = image_tensor.reshape(3,
113
+ image_grid,
114
+ image_processor.image_size_raw['height'],
115
+ image_grid,
116
+ image_processor.image_size_raw['width'])
117
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
118
+ raw_image = raw_image.reshape(-1, 3,
119
+ image_processor.image_size_raw['height'],
120
+ image_processor.image_size_raw['width'])
121
+
122
+ if getattr(model.config, 'image_global', False):
123
+ global_image = image_tensor
124
+ if len(global_image.shape) == 3:
125
+ global_image = global_image[None]
126
+ global_image = torch.nn.functional.interpolate(global_image,
127
+ size=[image_processor.image_size_raw['height'],
128
+ image_processor.image_size_raw['width']],
129
+ mode='bilinear',
130
+ align_corners=False)
131
+ # [image_crops, image_global]
132
+ raw_image = torch.cat([raw_image, global_image], dim=0)
133
+ image_tensor = raw_image.contiguous()
134
+ image_tensor = image_tensor.unsqueeze(0)
135
+
136
+ if type(image_tensor) is list:
137
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
138
+ image_tensor_aux = [image.to(model.device, dtype=torch.float16) for image in image_tensor_aux]
139
+ else:
140
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
141
+ image_tensor_aux = image_tensor_aux.to(model.device, dtype=torch.float16)
142
+ else:
143
+ images = None
144
+ image_tensor = None
145
+ image_tensor_aux = []
146
+
147
+
148
+ while True:
149
+ try:
150
+ inp = input(f"{roles[0]}: ")
151
+ except EOFError:
152
+ inp = ""
153
+ if not inp:
154
+ print("exit...")
155
+ break
156
+
157
+ print(f"{roles[1]}: ", end="")
158
+
159
+ if args.ocr and len(str_in_image) > 0:
160
+ inp = inp + '\nReference OCR Token: ' + str_in_image + '\n'
161
+ if args.gen:
162
+ inp = inp + ' <GEN>'
163
+ # print(inp, '====')
164
+
165
+ if images is not None:
166
+ # first message
167
+ if model.config.mm_use_im_start_end:
168
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
169
+ else:
170
+ inp = (DEFAULT_IMAGE_TOKEN + '\n')*len(images) + inp
171
+ conv.append_message(conv.roles[0], inp)
172
+ images = None
173
+ else:
174
+ # later messages
175
+ conv.append_message(conv.roles[0], inp)
176
+ conv.append_message(conv.roles[1], None)
177
+ prompt = conv.get_prompt()
178
+
179
+ # add image split string
180
+ if prompt.count(DEFAULT_IMAGE_TOKEN) >= 2:
181
+ final_str = ''
182
+ sent_split = prompt.split(DEFAULT_IMAGE_TOKEN)
183
+ for _idx, _sub_sent in enumerate(sent_split):
184
+ if _idx == len(sent_split) - 1:
185
+ final_str = final_str + _sub_sent
186
+ else:
187
+ final_str = final_str + _sub_sent + f'Image {_idx+1}:' + DEFAULT_IMAGE_TOKEN
188
+ prompt = final_str
189
+
190
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
191
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
192
+
193
+ with torch.inference_mode():
194
+ output_ids = model.generate(
195
+ input_ids,
196
+ images=image_tensor,
197
+ images_aux=image_tensor_aux if len(image_tensor_aux)>0 else None,
198
+ do_sample=True if args.temperature > 0 else False,
199
+ temperature=args.temperature,
200
+ max_new_tokens=args.max_new_tokens,
201
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
202
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
203
+ pad_token_id=tokenizer.pad_token_id, # Pad token
204
+ streamer=streamer,
205
+ use_cache=True)
206
+
207
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
208
+ conv.messages[-1][-1] = outputs
209
+
210
+ if args.gen and '<h>' in outputs and '</h>' in outputs:
211
+ common_neg_prompt = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
212
+ prompt = outputs.split("</h>")[-2].split("<h>")[-1]
213
+ output_img = pipe(prompt, negative_prompt=common_neg_prompt).images[0]
214
+ output_img.save(args.output_file)
215
+ print(f'Generate an image, save at {args.output_file}')
216
+
217
+ if args.debug:
218
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
219
+
220
+
221
+ if __name__ == "__main__":
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
224
+ parser.add_argument("--model-base", type=str, default=None)
225
+ parser.add_argument("--image-file", type=str, default=None) # file_0.jpg,file_1.jpg for multi image
226
+ parser.add_argument("--device", type=str, default="cuda")
227
+ parser.add_argument("--conv-mode", type=str, default=None)
228
+ parser.add_argument("--temperature", type=float, default=0.2)
229
+ parser.add_argument("--max-new-tokens", type=int, default=512)
230
+ parser.add_argument("--load-8bit", action="store_true")
231
+ parser.add_argument("--load-4bit", action="store_true")
232
+ parser.add_argument("--ocr", action="store_true")
233
+ parser.add_argument("--gen", action="store_true")
234
+ parser.add_argument("--output-file", type=str, default='generate.png')
235
+ parser.add_argument("--debug", action="store_true")
236
+ args = parser.parse_args()
237
+ main(args)
minigemini/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from minigemini.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from minigemini.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,))
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
minigemini/serve/examples/extreme_ironing.jpg ADDED

Git LFS Details

  • SHA256: a54caa21bc513ed25c8ca7f5747555c05dfd4e33f6a3cf5c08b3d9138a4da1d9
  • Pointer size: 130 Bytes
  • Size of remote file: 62.6 kB
minigemini/serve/examples/monday.jpg ADDED

Git LFS Details

  • SHA256: f516b74860919074ea7bd855c2073a565283cf0f888139841f60512655996066
  • Pointer size: 129 Bytes
  • Size of remote file: 7.14 kB
minigemini/serve/examples/waterview.jpg ADDED

Git LFS Details

  • SHA256: d092764cc9f21b9bc535ff5284b5add4d8256148bab1bc2f5b5ab3fd32759a36
  • Pointer size: 130 Bytes
  • Size of remote file: 95.5 kB
minigemini/serve/examples/woolen.png ADDED

Git LFS Details

  • SHA256: fb303bfbac3cdb104972daf87e4d0515d08f370897d361f63a1407f302f98a9f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB