Isaachh commited on
Commit
114dd13
1 Parent(s): 162e139
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +74 -0
  2. bunny/constants.py +7 -0
  3. bunny/conversation.py +239 -0
  4. bunny/eval/m4c_evaluator.py +334 -0
  5. bunny/eval/model_vqa.py +111 -0
  6. bunny/eval/model_vqa_cmmmu.py +234 -0
  7. bunny/eval/model_vqa_loader.py +143 -0
  8. bunny/eval/model_vqa_mmbench.py +167 -0
  9. bunny/eval/model_vqa_mmmu.py +326 -0
  10. bunny/eval/model_vqa_science.py +119 -0
  11. bunny/model/__init__.py +6 -0
  12. bunny/model/builder.py +197 -0
  13. bunny/model/bunny_arch.py +230 -0
  14. bunny/model/language_model/bunny_llama.py +102 -0
  15. bunny/model/language_model/bunny_minicpm.py +103 -0
  16. bunny/model/language_model/bunny_phi.py +100 -0
  17. bunny/model/language_model/bunny_phi3.py +100 -0
  18. bunny/model/language_model/bunny_qwen.py +100 -0
  19. bunny/model/language_model/bunny_stablelm.py +100 -0
  20. bunny/model/language_model/llama/__init__.py +114 -0
  21. bunny/model/language_model/llama/configuration_llama.py +191 -0
  22. bunny/model/language_model/llama/modeling_llama.py +1844 -0
  23. bunny/model/language_model/llama/tokenization_llama.py +471 -0
  24. bunny/model/language_model/llama/tokenization_llama_fast.py +281 -0
  25. bunny/model/language_model/minicpm/configuration_minicpm.py +202 -0
  26. bunny/model/language_model/minicpm/modeling_minicpm.py +1456 -0
  27. bunny/model/language_model/phi/__init__.py +69 -0
  28. bunny/model/language_model/phi/configuration_phi.py +195 -0
  29. bunny/model/language_model/phi/modeling_phi.py +1374 -0
  30. bunny/model/language_model/phi3/__init__.py +69 -0
  31. bunny/model/language_model/phi3/configuration_phi3.py +213 -0
  32. bunny/model/language_model/phi3/modeling_phi3.py +1597 -0
  33. bunny/model/language_model/qwen2/__init__.py +80 -0
  34. bunny/model/language_model/qwen2/configuration_qwen2.py +144 -0
  35. bunny/model/language_model/qwen2/modeling_qwen2.py +1403 -0
  36. bunny/model/language_model/qwen2/tokenization_qwen2.py +345 -0
  37. bunny/model/language_model/qwen2/tokenization_qwen2_fast.py +143 -0
  38. bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py +113 -0
  39. bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py +917 -0
  40. bunny/model/multimodal_encoder/builder.py +29 -0
  41. bunny/model/multimodal_encoder/clip/clip_encoder.py +76 -0
  42. bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +63 -0
  43. bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py +68 -0
  44. bunny/model/multimodal_encoder/eva_clip/eva_vit.py +851 -0
  45. bunny/model/multimodal_encoder/siglip/siglip_encoder.py +129 -0
  46. bunny/model/multimodal_projector/builder.py +183 -0
  47. bunny/serve/cli.py +118 -0
  48. bunny/serve/controller.py +277 -0
  49. bunny/serve/examples/example_1.png +0 -0
  50. bunny/serve/examples/example_2.png +0 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import argparse
5
+ import subprocess
6
+
7
+ import bunny.serve.gradio_web_server as gws
8
+
9
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-e', '.'])
10
+
11
+
12
+ def start_controller():
13
+ controller_command = [
14
+ sys.executable, '-m', 'bunny.serve.controller',
15
+ '--host', '0.0.0.0',
16
+ '--port', '10000'
17
+ ]
18
+ return subprocess.Popen(controller_command)
19
+
20
+
21
+ def start_worker(port: int, model_path: str, model_type: str):
22
+ worker_command = [
23
+ sys.executable, '-m', 'bunny.serve.model_worker',
24
+ '--host', '0.0.0.0',
25
+ '--controller', 'http://localhost:10000',
26
+ '--port', f'{port}',
27
+ '--worker', f'http://localhost:{port}',
28
+ '--model-path', model_path,
29
+ '--model-type', model_type
30
+ ]
31
+ return subprocess.Popen(worker_command)
32
+
33
+
34
+ if __name__ == '__main__':
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--host", type=str, default="0.0.0.0")
37
+ parser.add_argument("--port", type=int)
38
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
39
+ parser.add_argument("--concurrency-count", type=int, default=5)
40
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
41
+ parser.add_argument("--share", action="store_true")
42
+ parser.add_argument("--moderate", action="store_true")
43
+ parser.add_argument("--embed", action="store_true")
44
+ gws.args = parser.parse_args()
45
+ gws.models = []
46
+
47
+ controller_proc = start_controller()
48
+
49
+ worker_procs = []
50
+
51
+ worker_procs.append(start_worker(port=40000, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V', model_type='llama3-8b'))
52
+ worker_procs.append(start_worker(port=40001, model_path='BAAI/Bunny-v1_1-4B', model_type='phi-3'))
53
+ worker_procs.append(start_worker(port=40002, model_path='BAAI/Bunny-v1_0-3B', model_type='phi-2'))
54
+
55
+ time.sleep(60)
56
+
57
+ exit_status = 0
58
+ try:
59
+ demo = gws.build_demo(embed_mode=gws.args.embed)
60
+ demo.launch(
61
+ server_name=gws.args.host,
62
+ server_port=gws.args.port,
63
+ share=gws.args.share,
64
+ debug=True,
65
+ max_threads=10
66
+ )
67
+ except Exception as e:
68
+ print(e)
69
+ exit_status = 1
70
+ finally:
71
+ for worker_proc in worker_procs:
72
+ worker_proc.kill()
73
+ controller_proc.kill()
74
+ sys.exit(exit_status)
bunny/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_INDEX = -100
3
+ IMAGE_TOKEN_INDEX = -200
4
+ DEFAULT_IMAGE_TOKEN = "<image>"
5
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
+ LOGDIR = "gradio-logs"
7
+ WORKER_HEART_BEAT_INTERVAL = 15
bunny/conversation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ TWO = auto()
9
+ PLAIN = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle
20
+ sep: str = "###"
21
+ sep2: str = None
22
+ version: str = "Unknown"
23
+
24
+ skip_next: bool = False
25
+
26
+ def get_prompt(self):
27
+ messages = self.messages
28
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
29
+ messages = self.messages.copy()
30
+ init_role, init_msg = messages[0].copy()
31
+ init_msg = init_msg[0].replace("<image>", "").strip()
32
+ if 'mmtag' in self.version:
33
+ messages[0] = (init_role, init_msg)
34
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
35
+ messages.insert(1, (self.roles[1], "Received."))
36
+ else:
37
+ messages[0] = (init_role, "<image>\n" + init_msg)
38
+
39
+ if self.sep_style == SeparatorStyle.TWO:
40
+ seps = [self.sep, self.sep2]
41
+ ret = self.system + seps[0]
42
+ for i, (role, message) in enumerate(messages):
43
+ if message:
44
+ if type(message) is tuple:
45
+ message, _, _ = message
46
+ ret += role + ": " + message + seps[i % 2]
47
+ else:
48
+ ret += role + ":"
49
+
50
+ elif self.sep_style == SeparatorStyle.PLAIN:
51
+ seps = [self.sep, self.sep2]
52
+ ret = self.system
53
+ for i, (role, message) in enumerate(messages):
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += message + seps[i % 2]
58
+ else:
59
+ ret += ""
60
+ else:
61
+ raise ValueError(f"Invalid style: {self.sep_style}")
62
+
63
+ return ret
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def get_images(self, return_pil=False):
69
+ images = []
70
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
+ if i % 2 == 0:
72
+ if type(msg) is tuple:
73
+ import base64
74
+ from io import BytesIO
75
+ from PIL import Image
76
+ msg, image, image_process_mode = msg
77
+ if image_process_mode == "Pad":
78
+ def expand2square(pil_img, background_color=(122, 116, 104)):
79
+ width, height = pil_img.size
80
+ if width == height:
81
+ return pil_img
82
+ elif width > height:
83
+ result = Image.new(pil_img.mode, (width, width), background_color)
84
+ result.paste(pil_img, (0, (width - height) // 2))
85
+ return result
86
+ else:
87
+ result = Image.new(pil_img.mode, (height, height), background_color)
88
+ result.paste(pil_img, ((height - width) // 2, 0))
89
+ return result
90
+
91
+ image = expand2square(image)
92
+ elif image_process_mode in ["Default", "Crop"]:
93
+ pass
94
+ elif image_process_mode == "Resize":
95
+ image = image.resize((336, 336))
96
+ else:
97
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
98
+
99
+ if return_pil:
100
+ images.append(image)
101
+ else:
102
+ buffered = BytesIO()
103
+ image.save(buffered, format="PNG")
104
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
105
+ images.append(img_b64_str)
106
+ return images
107
+
108
+ def to_gradio_chatbot(self):
109
+ ret = []
110
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
111
+ if i % 2 == 0:
112
+ if type(msg) is tuple:
113
+ import base64
114
+ from io import BytesIO
115
+ msg, image, image_process_mode = msg
116
+ max_hw, min_hw = max(image.size), min(image.size)
117
+ aspect_ratio = max_hw / min_hw
118
+ max_len, min_len = 800, 400
119
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
120
+ longest_edge = int(shortest_edge * aspect_ratio)
121
+ W, H = image.size
122
+ if H > W:
123
+ H, W = longest_edge, shortest_edge
124
+ else:
125
+ H, W = shortest_edge, longest_edge
126
+ image = image.resize((W, H))
127
+ buffered = BytesIO()
128
+ image.save(buffered, format="JPEG")
129
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
130
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
131
+ msg = img_str + msg.replace('<image>', '').strip()
132
+ ret.append([msg, None])
133
+ else:
134
+ ret.append([msg, None])
135
+ else:
136
+ ret[-1][-1] = msg
137
+ return ret
138
+
139
+ def copy(self):
140
+ return Conversation(
141
+ system=self.system,
142
+ roles=self.roles,
143
+ messages=[[x, y] for x, y in self.messages],
144
+ offset=self.offset,
145
+ sep_style=self.sep_style,
146
+ sep=self.sep,
147
+ sep2=self.sep2,
148
+ version=self.version)
149
+
150
+ def dict(self):
151
+ if len(self.get_images()) > 0:
152
+ return {
153
+ "system": self.system,
154
+ "roles": self.roles,
155
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
156
+ "offset": self.offset,
157
+ "sep": self.sep,
158
+ "sep2": self.sep2,
159
+ }
160
+ return {
161
+ "system": self.system,
162
+ "roles": self.roles,
163
+ "messages": self.messages,
164
+ "offset": self.offset,
165
+ "sep": self.sep,
166
+ "sep2": self.sep2,
167
+ }
168
+
169
+
170
+ conv_bunny = Conversation(
171
+ system="A chat between a curious user and an artificial intelligence assistant. "
172
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
173
+ roles=("USER", "ASSISTANT"),
174
+ version="bunny",
175
+ messages=(),
176
+ offset=0,
177
+ sep_style=SeparatorStyle.TWO,
178
+ sep=" ",
179
+ sep2="<|endoftext|>",
180
+ )
181
+
182
+ conv_phi3 = Conversation(
183
+ system="A chat between a curious user and an artificial intelligence assistant. "
184
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
185
+ roles=("USER", "ASSISTANT"),
186
+ version="phi3",
187
+ messages=(),
188
+ offset=0,
189
+ sep_style=SeparatorStyle.TWO,
190
+ sep=" ",
191
+ sep2="<|endoftext|>",
192
+ )
193
+
194
+ conv_minicpm = Conversation(
195
+ system="A chat between a curious user and an artificial intelligence assistant. "
196
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
197
+ roles=("USER", "ASSISTANT"),
198
+ version="minicpm",
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.TWO,
202
+ sep=" ",
203
+ sep2="</s>",
204
+ )
205
+
206
+ conv_llama = Conversation(
207
+ system="A chat between a curious user and an artificial intelligence assistant. "
208
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
209
+ roles=("USER", "ASSISTANT"),
210
+ version="llama",
211
+ messages=(),
212
+ offset=0,
213
+ sep_style=SeparatorStyle.TWO,
214
+ sep=" ",
215
+ sep2="<|end_of_text|>",
216
+ )
217
+
218
+ conv_plain = Conversation(
219
+ system="",
220
+ roles=("", ""),
221
+ messages=(
222
+ ),
223
+ offset=0,
224
+ sep_style=SeparatorStyle.PLAIN,
225
+ sep="\n",
226
+ )
227
+
228
+ default_conversation = conv_bunny
229
+ conv_templates = {
230
+ "default": conv_bunny,
231
+ "bunny": conv_bunny,
232
+ "phi3": conv_phi3,
233
+ "plain": conv_plain,
234
+ 'minicpm': conv_minicpm,
235
+ 'llama': conv_llama
236
+ }
237
+
238
+ if __name__ == "__main__":
239
+ print(default_conversation.get_prompt())
bunny/eval/m4c_evaluator.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ class EvalAIAnswerProcessor:
8
+ """
9
+ Processes an answer similar to Eval AI
10
+ copied from
11
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
+ """
13
+
14
+ CONTRACTIONS = {
15
+ "aint": "ain't",
16
+ "arent": "aren't",
17
+ "cant": "can't",
18
+ "couldve": "could've",
19
+ "couldnt": "couldn't",
20
+ "couldn'tve": "couldn't've",
21
+ "couldnt've": "couldn't've",
22
+ "didnt": "didn't",
23
+ "doesnt": "doesn't",
24
+ "dont": "don't",
25
+ "hadnt": "hadn't",
26
+ "hadnt've": "hadn't've",
27
+ "hadn'tve": "hadn't've",
28
+ "hasnt": "hasn't",
29
+ "havent": "haven't",
30
+ "hed": "he'd",
31
+ "hed've": "he'd've",
32
+ "he'dve": "he'd've",
33
+ "hes": "he's",
34
+ "howd": "how'd",
35
+ "howll": "how'll",
36
+ "hows": "how's",
37
+ "Id've": "I'd've",
38
+ "I'dve": "I'd've",
39
+ "Im": "I'm",
40
+ "Ive": "I've",
41
+ "isnt": "isn't",
42
+ "itd": "it'd",
43
+ "itd've": "it'd've",
44
+ "it'dve": "it'd've",
45
+ "itll": "it'll",
46
+ "let's": "let's",
47
+ "maam": "ma'am",
48
+ "mightnt": "mightn't",
49
+ "mightnt've": "mightn't've",
50
+ "mightn'tve": "mightn't've",
51
+ "mightve": "might've",
52
+ "mustnt": "mustn't",
53
+ "mustve": "must've",
54
+ "neednt": "needn't",
55
+ "notve": "not've",
56
+ "oclock": "o'clock",
57
+ "oughtnt": "oughtn't",
58
+ "ow's'at": "'ow's'at",
59
+ "'ows'at": "'ow's'at",
60
+ "'ow'sat": "'ow's'at",
61
+ "shant": "shan't",
62
+ "shed've": "she'd've",
63
+ "she'dve": "she'd've",
64
+ "she's": "she's",
65
+ "shouldve": "should've",
66
+ "shouldnt": "shouldn't",
67
+ "shouldnt've": "shouldn't've",
68
+ "shouldn'tve": "shouldn't've",
69
+ "somebody'd": "somebodyd",
70
+ "somebodyd've": "somebody'd've",
71
+ "somebody'dve": "somebody'd've",
72
+ "somebodyll": "somebody'll",
73
+ "somebodys": "somebody's",
74
+ "someoned": "someone'd",
75
+ "someoned've": "someone'd've",
76
+ "someone'dve": "someone'd've",
77
+ "someonell": "someone'll",
78
+ "someones": "someone's",
79
+ "somethingd": "something'd",
80
+ "somethingd've": "something'd've",
81
+ "something'dve": "something'd've",
82
+ "somethingll": "something'll",
83
+ "thats": "that's",
84
+ "thered": "there'd",
85
+ "thered've": "there'd've",
86
+ "there'dve": "there'd've",
87
+ "therere": "there're",
88
+ "theres": "there's",
89
+ "theyd": "they'd",
90
+ "theyd've": "they'd've",
91
+ "they'dve": "they'd've",
92
+ "theyll": "they'll",
93
+ "theyre": "they're",
94
+ "theyve": "they've",
95
+ "twas": "'twas",
96
+ "wasnt": "wasn't",
97
+ "wed've": "we'd've",
98
+ "we'dve": "we'd've",
99
+ "weve": "we've",
100
+ "werent": "weren't",
101
+ "whatll": "what'll",
102
+ "whatre": "what're",
103
+ "whats": "what's",
104
+ "whatve": "what've",
105
+ "whens": "when's",
106
+ "whered": "where'd",
107
+ "wheres": "where's",
108
+ "whereve": "where've",
109
+ "whod": "who'd",
110
+ "whod've": "who'd've",
111
+ "who'dve": "who'd've",
112
+ "wholl": "who'll",
113
+ "whos": "who's",
114
+ "whove": "who've",
115
+ "whyll": "why'll",
116
+ "whyre": "why're",
117
+ "whys": "why's",
118
+ "wont": "won't",
119
+ "wouldve": "would've",
120
+ "wouldnt": "wouldn't",
121
+ "wouldnt've": "wouldn't've",
122
+ "wouldn'tve": "wouldn't've",
123
+ "yall": "y'all",
124
+ "yall'll": "y'all'll",
125
+ "y'allll": "y'all'll",
126
+ "yall'd've": "y'all'd've",
127
+ "y'alld've": "y'all'd've",
128
+ "y'all'dve": "y'all'd've",
129
+ "youd": "you'd",
130
+ "youd've": "you'd've",
131
+ "you'dve": "you'd've",
132
+ "youll": "you'll",
133
+ "youre": "you're",
134
+ "youve": "you've",
135
+ }
136
+
137
+ NUMBER_MAP = {
138
+ "none": "0",
139
+ "zero": "0",
140
+ "one": "1",
141
+ "two": "2",
142
+ "three": "3",
143
+ "four": "4",
144
+ "five": "5",
145
+ "six": "6",
146
+ "seven": "7",
147
+ "eight": "8",
148
+ "nine": "9",
149
+ "ten": "10",
150
+ }
151
+ ARTICLES = ["a", "an", "the"]
152
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
+ PUNCTUATIONS = [
155
+ ";",
156
+ r"/",
157
+ "[",
158
+ "]",
159
+ '"',
160
+ "{",
161
+ "}",
162
+ "(",
163
+ ")",
164
+ "=",
165
+ "+",
166
+ "\\",
167
+ "_",
168
+ "-",
169
+ ">",
170
+ "<",
171
+ "@",
172
+ "`",
173
+ ",",
174
+ "?",
175
+ "!",
176
+ ]
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ pass
180
+
181
+ def word_tokenize(self, word):
182
+ word = word.lower()
183
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
+ return word.strip()
185
+
186
+ def process_punctuation(self, in_text):
187
+ out_text = in_text
188
+ for p in self.PUNCTUATIONS:
189
+ if (p + " " in in_text or " " + p in in_text) or (
190
+ re.search(self.COMMA_STRIP, in_text) is not None
191
+ ):
192
+ out_text = out_text.replace(p, "")
193
+ else:
194
+ out_text = out_text.replace(p, " ")
195
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
+ return out_text
197
+
198
+ def process_digit_article(self, in_text):
199
+ out_text = []
200
+ temp_text = in_text.lower().split()
201
+ for word in temp_text:
202
+ word = self.NUMBER_MAP.setdefault(word, word)
203
+ if word not in self.ARTICLES:
204
+ out_text.append(word)
205
+ else:
206
+ pass
207
+ for word_id, word in enumerate(out_text):
208
+ if word in self.CONTRACTIONS:
209
+ out_text[word_id] = self.CONTRACTIONS[word]
210
+ out_text = " ".join(out_text)
211
+ return out_text
212
+
213
+ def __call__(self, item):
214
+ item = self.word_tokenize(item)
215
+ item = item.replace("\n", " ").replace("\t", " ").strip()
216
+ item = self.process_punctuation(item)
217
+ item = self.process_digit_article(item)
218
+ return item
219
+
220
+
221
+ class TextVQAAccuracyEvaluator:
222
+ def __init__(self):
223
+ self.answer_processor = EvalAIAnswerProcessor()
224
+
225
+ def _compute_answer_scores(self, raw_answers):
226
+ """
227
+ compute the accuracy (soft score) of human answers
228
+ """
229
+ answers = [self.answer_processor(a) for a in raw_answers]
230
+ assert len(answers) == 10
231
+ gt_answers = list(enumerate(answers))
232
+ unique_answers = set(answers)
233
+ unique_answer_scores = {}
234
+
235
+ for unique_answer in unique_answers:
236
+ accs = []
237
+ for gt_answer in gt_answers:
238
+ other_answers = [item for item in gt_answers if item != gt_answer]
239
+ matching_answers = [
240
+ item for item in other_answers if item[1] == unique_answer
241
+ ]
242
+ acc = min(1, float(len(matching_answers)) / 3)
243
+ accs.append(acc)
244
+ unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
+
246
+ return unique_answer_scores
247
+
248
+ def eval_pred_list(self, pred_list):
249
+ pred_scores = []
250
+ for entry in tqdm(pred_list):
251
+ pred_answer = self.answer_processor(entry["pred_answer"])
252
+ unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
+ score = unique_answer_scores.get(pred_answer, 0.0)
254
+ pred_scores.append(score)
255
+
256
+ accuracy = sum(pred_scores) / len(pred_scores)
257
+ return accuracy
258
+
259
+
260
+ class STVQAAccuracyEvaluator:
261
+ def __init__(self):
262
+ self.answer_processor = EvalAIAnswerProcessor()
263
+
264
+ def eval_pred_list(self, pred_list):
265
+ pred_scores = []
266
+ for entry in pred_list:
267
+ pred_answer = self.answer_processor(entry["pred_answer"])
268
+ gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
+ score = 1.0 if pred_answer in gts else 0.0
270
+ pred_scores.append(score)
271
+
272
+ accuracy = sum(pred_scores) / len(pred_scores)
273
+ return accuracy
274
+
275
+
276
+ class STVQAANLSEvaluator:
277
+ def __init__(self):
278
+ import editdistance # install with `pip install editdistance`
279
+
280
+ self.get_edit_distance = editdistance.eval
281
+
282
+ def get_anls(self, s1, s2):
283
+ s1 = s1.lower().strip()
284
+ s2 = s2.lower().strip()
285
+ iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
+ anls = iou if iou >= 0.5 else 0.0
287
+ return anls
288
+
289
+ def eval_pred_list(self, pred_list):
290
+ pred_scores = []
291
+ for entry in pred_list:
292
+ anls = max(
293
+ self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
+ )
295
+ pred_scores.append(anls)
296
+
297
+ accuracy = sum(pred_scores) / len(pred_scores)
298
+ return accuracy
299
+
300
+
301
+ class TextCapsBleu4Evaluator:
302
+ def __init__(self):
303
+ # The following script requires Java 1.8.0 and pycocotools installed.
304
+ # The pycocoevalcap can be installed with pip as
305
+ # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
+ # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
+ # but has no python3 support yet.
308
+ try:
309
+ from pycocoevalcap.bleu.bleu import Bleu
310
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
+ except ModuleNotFoundError:
312
+ print(
313
+ "Please install pycocoevalcap module using "
314
+ "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
+ )
316
+ raise
317
+
318
+ self.tokenizer = PTBTokenizer()
319
+ self.scorer = Bleu(4)
320
+
321
+ def eval_pred_list(self, pred_list):
322
+ # Create reference and hypotheses captions.
323
+ gts = {}
324
+ res = {}
325
+ for idx, entry in enumerate(pred_list):
326
+ gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
+ res[idx] = [{"caption": entry["pred_answer"]}]
328
+
329
+ gts = self.tokenizer.tokenize(gts)
330
+ res = self.tokenizer.tokenize(res)
331
+ score, _ = self.scorer.compute_score(gts, res)
332
+
333
+ bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
+ return bleu4
bunny/eval/model_vqa.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from bunny.conversation import conv_templates, SeparatorStyle
10
+ from bunny.model.builder import load_pretrained_model
11
+ from bunny.util.utils import disable_torch_init
12
+ from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
35
+ args.model_type)
36
+
37
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+ ans_file = open(answers_file, "w")
42
+ for line in tqdm(questions):
43
+ idx = line["question_id"]
44
+ image_file = line["image"]
45
+ qs = line["text"]
46
+ cur_prompt = qs
47
+
48
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
49
+
50
+ conv = conv_templates[args.conv_mode].copy()
51
+ conv.append_message(conv.roles[0], qs)
52
+ conv.append_message(conv.roles[1], None)
53
+ prompt = conv.get_prompt()
54
+
55
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
56
+
57
+ image = Image.open(os.path.join(args.image_folder, image_file))
58
+ image_tensor = process_images([image], image_processor, model.config)[0]
59
+
60
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
61
+
62
+ with torch.inference_mode():
63
+ output_ids = model.generate(
64
+ input_ids,
65
+ images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
66
+ do_sample=True if args.temperature > 0 else False,
67
+ temperature=args.temperature,
68
+ top_p=args.top_p,
69
+ num_beams=args.num_beams,
70
+ # no_repeat_ngram_size=3,
71
+ max_new_tokens=1024,
72
+ use_cache=True)
73
+
74
+ input_token_len = input_ids.shape[1]
75
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
76
+ if n_diff_input_output > 0:
77
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
78
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
79
+ outputs = outputs.strip()
80
+ if outputs.endswith(stop_str):
81
+ outputs = outputs[:-len(stop_str)]
82
+ outputs = outputs.strip()
83
+
84
+ ans_id = shortuuid.uuid()
85
+ ans_file.write(json.dumps({"question_id": idx,
86
+ "prompt": cur_prompt,
87
+ "text": outputs,
88
+ "answer_id": ans_id,
89
+ "model_id": model_name,
90
+ "metadata": {}}) + "\n")
91
+ ans_file.flush()
92
+ ans_file.close()
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--model-path", type=str, default=None)
98
+ parser.add_argument("--model-base", type=str, default=None)
99
+ parser.add_argument("--model-type", type=str, default=None)
100
+ parser.add_argument("--image-folder", type=str, default=None)
101
+ parser.add_argument("--question-file", type=str, default=None)
102
+ parser.add_argument("--answers-file", type=str, default=None)
103
+ parser.add_argument("--conv-mode", type=str, default=None)
104
+ parser.add_argument("--num-chunks", type=int, default=1)
105
+ parser.add_argument("--chunk-idx", type=int, default=0)
106
+ parser.add_argument("--temperature", type=float, default=0.2)
107
+ parser.add_argument("--top_p", type=float, default=None)
108
+ parser.add_argument("--num_beams", type=int, default=1)
109
+ args = parser.parse_args()
110
+
111
+ eval_model(args)
bunny/eval/model_vqa_cmmmu.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import os
4
+ import json
5
+ import yaml
6
+ import torch
7
+
8
+ from tqdm import tqdm
9
+ from datasets import load_dataset, concatenate_datasets
10
+ from argparse import ArgumentParser
11
+
12
+ from bunny.model.builder import load_pretrained_model
13
+ from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
14
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15
+ from bunny.conversation import conv_templates
16
+
17
+ CAT_CN2EN = {'艺术与设计': 'art_and_design',
18
+ '商业': 'business',
19
+ '健康与医学': 'health_and_medicine',
20
+ '人文社会科学': 'humanities_and_social_sciences',
21
+ '科学': 'science',
22
+ '技术与工程': 'technology_and_engineering'}
23
+
24
+
25
+ def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
26
+ def deal_with_prompt(input_text):
27
+ qs = input_text
28
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
29
+ return qs
30
+
31
+ prompt = sample['final_input_prompt']
32
+ prompt = deal_with_prompt(prompt)
33
+
34
+ conv = conv_templates[args.conv_mode].copy()
35
+ conv.append_message(conv.roles[0], prompt)
36
+ conv.append_message(conv.roles[1], None)
37
+ prompt = conv.get_prompt()
38
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
39
+
40
+ image = sample['image_1']
41
+ if sample['image_2'] is not None: # multiple images actually
42
+ if sample['type'] == '选择':
43
+ all_choices = sample['all_choices']
44
+ response = random.choice(all_choices)
45
+ else:
46
+ response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
47
+ elif image is not None:
48
+ output_ids = model.generate(
49
+ input_ids,
50
+ images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
51
+ do_sample=False,
52
+ temperature=0,
53
+ top_p=None,
54
+ # num_beams=5,
55
+ max_new_tokens=128,
56
+ use_cache=True)
57
+
58
+ input_token_len = input_ids.shape[1]
59
+ # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
60
+ # if n_diff_input_output > 0:
61
+ # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
62
+ response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
63
+
64
+ return response
65
+
66
+
67
+ def load_yaml(file_path):
68
+ with open(file_path, 'r') as stream:
69
+ try:
70
+ yaml_dict = yaml.safe_load(stream)
71
+ except yaml.YAMLError as exc:
72
+ print(exc)
73
+
74
+ return yaml_dict
75
+
76
+
77
+ # DATA PROCESSING
78
+ def construct_prompt(sample, config):
79
+ question = sample['question']
80
+ options = []
81
+ for i in range(1, 5):
82
+ if sample[f'option{i}'] is None:
83
+ break
84
+ options.append(sample[f'option{i}'])
85
+
86
+ example = ""
87
+ if sample['type'] == '选择':
88
+ start_chr = 'A'
89
+ prediction_range = []
90
+ for option in options:
91
+ prediction_range.append(start_chr)
92
+ example += f"({start_chr}) {option}\n"
93
+ start_chr = chr(ord(start_chr) + 1)
94
+ empty_prompt_sample_structure = config['multi_choice_example_format']
95
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
96
+ res_dict = {}
97
+ res_dict['correct_choice'] = sample['answer']
98
+ res_dict['all_choices'] = prediction_range
99
+ res_dict['empty_prompt'] = empty_prompt
100
+ if config['task_instructions']:
101
+ res_dict['final_input_prompt'] = config['task_instructions'][0].strip() + '\n\n' + empty_prompt
102
+ else:
103
+ res_dict['final_input_prompt'] = empty_prompt
104
+
105
+ res_dict['gt_content'] = sample['answer']
106
+ elif sample['type'] == '判断':
107
+ empty_prompt_sample_structure = config['T/F_example_format']
108
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
109
+ res_dict = {}
110
+ res_dict['empty_prompt'] = empty_prompt
111
+ if config['task_instructions']:
112
+ res_dict['final_input_prompt'] = config['task_instructions'][1].strip() + '\n\n' + empty_prompt
113
+ else:
114
+ res_dict['final_input_prompt'] = empty_prompt
115
+ res_dict['gt_content'] = sample['answer']
116
+ else:
117
+ empty_prompt_sample_structure = config['short_ans_example_format']
118
+ empty_prompt = empty_prompt_sample_structure.format(question)
119
+ res_dict = {}
120
+ res_dict['empty_prompt'] = empty_prompt
121
+ if config['task_instructions']:
122
+ res_dict['final_input_prompt'] = config['task_instructions'][2].strip() + '\n\n' + empty_prompt
123
+ else:
124
+ res_dict['final_input_prompt'] = empty_prompt
125
+ res_dict['gt_content'] = sample['answer']
126
+
127
+ res_dict.update(sample)
128
+ return res_dict
129
+
130
+
131
+ def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
132
+ out_samples = []
133
+ with torch.no_grad():
134
+ for sample in tqdm(samples):
135
+ if args.small_gpu_usage:
136
+ sample['image_1'] = sample['image_1'].cuda()
137
+ response = call_model_engine_fn(args, sample, model, tokenizer, processor)
138
+ if args.small_gpu_usage:
139
+ sample['image_1'] = sample['image_1'].cpu()
140
+
141
+ out_sample = dict()
142
+ out_sample['id'] = sample['id']
143
+ out_sample['type'] = sample['type']
144
+ out_sample['response'] = response
145
+ out_samples.append(out_sample)
146
+ return out_samples
147
+
148
+
149
+ def set_seed(seed_value):
150
+ """
151
+ Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
152
+
153
+ :param seed_value: An integer value to be used as the seed.
154
+ """
155
+ torch.manual_seed(seed_value)
156
+ if torch.cuda.is_available():
157
+ torch.cuda.manual_seed(seed_value)
158
+ torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
159
+ random.seed(seed_value)
160
+ np.random.seed(seed_value)
161
+ torch.backends.cudnn.deterministic = True
162
+ torch.backends.cudnn.benchmark = False
163
+
164
+
165
+ def main():
166
+ parser = ArgumentParser()
167
+ parser.add_argument('--model-path', type=str, default=None)
168
+ parser.add_argument('--model-base', type=str, default=None)
169
+ parser.add_argument("--model-type", type=str, default=None)
170
+ parser.add_argument("--conv-mode", type=str, default=None)
171
+ parser.add_argument('--data-path', type=str, default=None)
172
+ parser.add_argument('--config-path', type=str, default=None)
173
+ parser.add_argument('--output-path', type=str, default=None)
174
+ parser.add_argument('--split', type=str, default='validation')
175
+ parser.add_argument('--seed', type=int, default=42)
176
+ parser.add_argument("--small-gpu-usage", action="store_true")
177
+
178
+ args = parser.parse_args()
179
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
180
+ set_seed(args.seed)
181
+
182
+ print('bunny_initializing...')
183
+ processor = None
184
+ call_model_engine = call_bunny_engine_df
185
+
186
+ # load config and process to one value
187
+ args.config = load_yaml(args.config_path)
188
+ for key, value in args.config.items():
189
+ if key == 'task_instructions':
190
+ args.config[key] = value
191
+ elif key != 'eval_params' and type(value) == list:
192
+ assert len(value) == 1, 'key {} has more than one value'.format(key)
193
+ args.config[key] = value[0]
194
+
195
+ # run for each subject
196
+ sub_dataset_list = []
197
+ for subject in CAT_CN2EN.values():
198
+ sub_dataset = load_dataset(args.data_path, subject, split=args.split)
199
+ sub_dataset_list.append(sub_dataset)
200
+
201
+ # merge all dataset
202
+ dataset = concatenate_datasets(sub_dataset_list)
203
+
204
+ # load model
205
+ model_path = os.path.expanduser(args.model_path)
206
+ model_name = get_model_name_from_path(model_path)
207
+ tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
208
+ args.model_type)
209
+
210
+ samples = []
211
+ print('Processing CMMMU dataset...')
212
+ for sample in tqdm(dataset):
213
+
214
+ sample = construct_prompt(sample, args.config)
215
+ if sample['image_1']:
216
+ sample['image_1'] = process_images([sample['image_1'].convert('RGB')], vis_processors, model.config)[0]
217
+ if not args.small_gpu_usage:
218
+ sample['image_1'] = sample['image_1'].to(device)
219
+
220
+ samples.append(sample)
221
+
222
+ print('Start to evaluate...')
223
+ # run ex
224
+ out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
225
+
226
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
227
+
228
+ with open(args.output_path, 'w') as f:
229
+ for out_sample in out_samples:
230
+ f.write(json.dumps(out_sample) + '\n')
231
+
232
+
233
+ if __name__ == '__main__':
234
+ main()
bunny/eval/model_vqa_loader.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from bunny.conversation import conv_templates
10
+ from bunny.model.builder import load_pretrained_model
11
+ from bunny.util.utils import disable_torch_init
12
+ from bunny.util.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ def split_list(lst, n):
20
+ """Split a list into n (roughly) equal-sized chunks"""
21
+ chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
23
+
24
+
25
+ def get_chunk(lst, n, k):
26
+ chunks = split_list(lst, n)
27
+ return chunks[k]
28
+
29
+
30
+ # Custom dataset class
31
+ class CustomDataset(Dataset):
32
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33
+ self.questions = questions
34
+ self.image_folder = image_folder
35
+ self.tokenizer = tokenizer
36
+ self.image_processor = image_processor
37
+ self.model_config = model_config
38
+
39
+ def __getitem__(self, index):
40
+ line = self.questions[index]
41
+ image_file = line["image"]
42
+ qs = line["text"]
43
+
44
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
45
+
46
+ conv = conv_templates[args.conv_mode].copy()
47
+ conv.append_message(conv.roles[0], qs)
48
+ conv.append_message(conv.roles[1], None)
49
+ prompt = conv.get_prompt()
50
+
51
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
52
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
53
+
54
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
55
+
56
+ return input_ids, image_tensor
57
+
58
+ def __len__(self):
59
+ return len(self.questions)
60
+
61
+
62
+ # DataLoader
63
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
64
+ assert batch_size == 1, "batch_size must be 1"
65
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
66
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
67
+ return data_loader
68
+
69
+
70
+ def eval_model(args):
71
+ # Model
72
+ disable_torch_init()
73
+ model_path = os.path.expanduser(args.model_path)
74
+ model_name = get_model_name_from_path(model_path)
75
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
76
+ args.model_type)
77
+
78
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
79
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
80
+ answers_file = os.path.expanduser(args.answers_file)
81
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
82
+ ans_file = open(answers_file, "w")
83
+
84
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
85
+ args.conv_mode = args.conv_mode + '_mmtag'
86
+ print(
87
+ f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88
+
89
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90
+
91
+ for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92
+ idx = line["question_id"]
93
+ cur_prompt = line["text"]
94
+
95
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
96
+
97
+ with torch.inference_mode():
98
+ output_ids = model.generate(
99
+ input_ids,
100
+ images=image_tensor.to(dtype=model.dtype, device='cuda', non_blocking=True),
101
+ do_sample=True if args.temperature > 0 else False,
102
+ temperature=args.temperature,
103
+ top_p=args.top_p,
104
+ num_beams=args.num_beams,
105
+ max_new_tokens=args.max_new_tokens,
106
+ use_cache=True)
107
+
108
+ input_token_len = input_ids.shape[1]
109
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110
+ if n_diff_input_output > 0:
111
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113
+ outputs = outputs.strip()
114
+
115
+ ans_id = shortuuid.uuid()
116
+ ans_file.write(json.dumps({"question_id": idx,
117
+ "prompt": cur_prompt,
118
+ "text": outputs,
119
+ "answer_id": ans_id,
120
+ "model_id": model_name,
121
+ "metadata": {}}) + "\n")
122
+ # ans_file.flush()
123
+ ans_file.close()
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument("--model-path", type=str, default=None)
129
+ parser.add_argument("--model-base", type=str, default=None)
130
+ parser.add_argument("--model-type", type=str, default=None)
131
+ parser.add_argument("--image-folder", type=str, default=None)
132
+ parser.add_argument("--question-file", type=str, default=None)
133
+ parser.add_argument("--answers-file", type=str, default=None)
134
+ parser.add_argument("--conv-mode", type=str, default=None)
135
+ parser.add_argument("--num-chunks", type=int, default=1)
136
+ parser.add_argument("--chunk-idx", type=int, default=0)
137
+ parser.add_argument("--temperature", type=float, default=0.2)
138
+ parser.add_argument("--top_p", type=float, default=None)
139
+ parser.add_argument("--num_beams", type=int, default=1)
140
+ parser.add_argument("--max_new_tokens", type=int, default=128)
141
+ args = parser.parse_args()
142
+
143
+ eval_model(args)
bunny/eval/model_vqa_mmbench.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10
+ from bunny.conversation import conv_templates, SeparatorStyle
11
+ from bunny.model.builder import load_pretrained_model
12
+ from bunny.util.utils import disable_torch_init
13
+ from bunny.util.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, \
14
+ get_model_name_from_path
15
+
16
+ import math
17
+
18
+ all_options = ['A', 'B', 'C', 'D']
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+
32
+ def is_none(value):
33
+ if value is None:
34
+ return True
35
+ if type(value) is float and math.isnan(value):
36
+ return True
37
+ if type(value) is str and value.lower() == 'nan':
38
+ return True
39
+ if type(value) is str and value.lower() == 'none':
40
+ return True
41
+ return False
42
+
43
+
44
+ def get_options(row, options):
45
+ parsed_options = []
46
+ for option in options:
47
+ option_value = row[option]
48
+ if is_none(option_value):
49
+ break
50
+ parsed_options.append(option_value)
51
+ return parsed_options
52
+
53
+
54
+ def eval_model(args):
55
+ # Model
56
+ disable_torch_init()
57
+ model_path = os.path.expanduser(args.model_path)
58
+ model_name = get_model_name_from_path(model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
60
+ args.model_type)
61
+
62
+ questions = pd.read_table(os.path.expanduser(args.question_file))
63
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
64
+ answers_file = os.path.expanduser(args.answers_file)
65
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
66
+ ans_file = open(answers_file, "w")
67
+
68
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
69
+ options = get_options(row, all_options)
70
+ cur_option_char = all_options[:len(options)]
71
+
72
+ if args.all_rounds:
73
+ num_rounds = len(options)
74
+ else:
75
+ num_rounds = 1
76
+
77
+ for round_idx in range(num_rounds):
78
+ idx = row['index']
79
+ question = row['question']
80
+ hint = row['hint']
81
+ image = load_image_from_base64(row['image'])
82
+ if not is_none(hint):
83
+ question = hint + '\n' + question
84
+ for option_char, option in zip(all_options[:len(options)], options):
85
+ question = question + '\n' + option_char + '. ' + option
86
+ qs = cur_prompt = question
87
+
88
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
89
+
90
+ if args.single_pred_prompt:
91
+ if args.lang == 'cn':
92
+ qs = qs + '\n' + "请直接回答选项字母。"
93
+ else:
94
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
95
+
96
+ conv = conv_templates[args.conv_mode].copy()
97
+ conv.append_message(conv.roles[0], qs)
98
+ conv.append_message(conv.roles[1], None)
99
+ prompt = conv.get_prompt()
100
+
101
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
102
+ 0).cuda()
103
+
104
+ image_tensor = process_images([image], image_processor, model.config)[0]
105
+
106
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
107
+
108
+ with torch.inference_mode():
109
+ output_ids = model.generate(
110
+ input_ids,
111
+ images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
112
+ do_sample=True if args.temperature > 0 else False,
113
+ temperature=args.temperature,
114
+ top_p=args.top_p,
115
+ num_beams=args.num_beams,
116
+ # no_repeat_ngram_size=3,
117
+ max_new_tokens=128,
118
+ use_cache=True)
119
+
120
+ input_token_len = input_ids.shape[1]
121
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
122
+ if n_diff_input_output > 0:
123
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
124
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
125
+ outputs = outputs.strip()
126
+ if outputs.endswith(stop_str):
127
+ outputs = outputs[:-len(stop_str)]
128
+ outputs = outputs.strip()
129
+
130
+ ans_id = shortuuid.uuid()
131
+ ans_file.write(json.dumps({"question_id": idx,
132
+ "round_id": round_idx,
133
+ "prompt": cur_prompt,
134
+ "text": outputs,
135
+ "options": options,
136
+ "option_char": cur_option_char,
137
+ "answer_id": ans_id,
138
+ "model_id": model_name,
139
+ "metadata": {}}) + "\n")
140
+ ans_file.flush()
141
+
142
+ # rotate options
143
+ options = options[1:] + options[:1]
144
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
145
+ ans_file.close()
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--model-path", type=str, default=None)
151
+ parser.add_argument("--model-base", type=str, default=None)
152
+ parser.add_argument("--model-type", type=str, default=None)
153
+ parser.add_argument("--image-folder", type=str, default=None)
154
+ parser.add_argument("--question-file", type=str, default=None)
155
+ parser.add_argument("--answers-file", type=str, default=None)
156
+ parser.add_argument("--conv-mode", type=str, default=None)
157
+ parser.add_argument("--num-chunks", type=int, default=1)
158
+ parser.add_argument("--chunk-idx", type=int, default=0)
159
+ parser.add_argument("--temperature", type=float, default=0.2)
160
+ parser.add_argument("--top_p", type=float, default=None)
161
+ parser.add_argument("--num_beams", type=int, default=1)
162
+ parser.add_argument("--all-rounds", action="store_true")
163
+ parser.add_argument("--single-pred-prompt", action="store_true")
164
+ parser.add_argument("--lang", type=str, default="en")
165
+ args = parser.parse_args()
166
+
167
+ eval_model(args)
bunny/eval/model_vqa_mmmu.py ADDED
@@ -0,0 +1,326 @@