Spaces:
Running
on
Zero
Running
on
Zero
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +74 -0
- bunny/constants.py +7 -0
- bunny/conversation.py +239 -0
- bunny/eval/m4c_evaluator.py +334 -0
- bunny/eval/model_vqa.py +111 -0
- bunny/eval/model_vqa_cmmmu.py +234 -0
- bunny/eval/model_vqa_loader.py +143 -0
- bunny/eval/model_vqa_mmbench.py +167 -0
- bunny/eval/model_vqa_mmmu.py +326 -0
- bunny/eval/model_vqa_science.py +119 -0
- bunny/model/__init__.py +6 -0
- bunny/model/builder.py +197 -0
- bunny/model/bunny_arch.py +230 -0
- bunny/model/language_model/bunny_llama.py +102 -0
- bunny/model/language_model/bunny_minicpm.py +103 -0
- bunny/model/language_model/bunny_phi.py +100 -0
- bunny/model/language_model/bunny_phi3.py +100 -0
- bunny/model/language_model/bunny_qwen.py +100 -0
- bunny/model/language_model/bunny_stablelm.py +100 -0
- bunny/model/language_model/llama/__init__.py +114 -0
- bunny/model/language_model/llama/configuration_llama.py +191 -0
- bunny/model/language_model/llama/modeling_llama.py +1844 -0
- bunny/model/language_model/llama/tokenization_llama.py +471 -0
- bunny/model/language_model/llama/tokenization_llama_fast.py +281 -0
- bunny/model/language_model/minicpm/configuration_minicpm.py +202 -0
- bunny/model/language_model/minicpm/modeling_minicpm.py +1456 -0
- bunny/model/language_model/phi/__init__.py +69 -0
- bunny/model/language_model/phi/configuration_phi.py +195 -0
- bunny/model/language_model/phi/modeling_phi.py +1374 -0
- bunny/model/language_model/phi3/__init__.py +69 -0
- bunny/model/language_model/phi3/configuration_phi3.py +213 -0
- bunny/model/language_model/phi3/modeling_phi3.py +1597 -0
- bunny/model/language_model/qwen2/__init__.py +80 -0
- bunny/model/language_model/qwen2/configuration_qwen2.py +144 -0
- bunny/model/language_model/qwen2/modeling_qwen2.py +1403 -0
- bunny/model/language_model/qwen2/tokenization_qwen2.py +345 -0
- bunny/model/language_model/qwen2/tokenization_qwen2_fast.py +143 -0
- bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py +113 -0
- bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py +917 -0
- bunny/model/multimodal_encoder/builder.py +29 -0
- bunny/model/multimodal_encoder/clip/clip_encoder.py +76 -0
- bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +63 -0
- bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py +68 -0
- bunny/model/multimodal_encoder/eva_clip/eva_vit.py +851 -0
- bunny/model/multimodal_encoder/siglip/siglip_encoder.py +129 -0
- bunny/model/multimodal_projector/builder.py +183 -0
- bunny/serve/cli.py +118 -0
- bunny/serve/controller.py +277 -0
- bunny/serve/examples/example_1.png +0 -0
- bunny/serve/examples/example_2.png +0 -0
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import argparse
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import bunny.serve.gradio_web_server as gws
|
8 |
+
|
9 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-e', '.'])
|
10 |
+
|
11 |
+
|
12 |
+
def start_controller():
|
13 |
+
controller_command = [
|
14 |
+
sys.executable, '-m', 'bunny.serve.controller',
|
15 |
+
'--host', '0.0.0.0',
|
16 |
+
'--port', '10000'
|
17 |
+
]
|
18 |
+
return subprocess.Popen(controller_command)
|
19 |
+
|
20 |
+
|
21 |
+
def start_worker(port: int, model_path: str, model_type: str):
|
22 |
+
worker_command = [
|
23 |
+
sys.executable, '-m', 'bunny.serve.model_worker',
|
24 |
+
'--host', '0.0.0.0',
|
25 |
+
'--controller', 'http://localhost:10000',
|
26 |
+
'--port', f'{port}',
|
27 |
+
'--worker', f'http://localhost:{port}',
|
28 |
+
'--model-path', model_path,
|
29 |
+
'--model-type', model_type
|
30 |
+
]
|
31 |
+
return subprocess.Popen(worker_command)
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
37 |
+
parser.add_argument("--port", type=int)
|
38 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
|
39 |
+
parser.add_argument("--concurrency-count", type=int, default=5)
|
40 |
+
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
|
41 |
+
parser.add_argument("--share", action="store_true")
|
42 |
+
parser.add_argument("--moderate", action="store_true")
|
43 |
+
parser.add_argument("--embed", action="store_true")
|
44 |
+
gws.args = parser.parse_args()
|
45 |
+
gws.models = []
|
46 |
+
|
47 |
+
controller_proc = start_controller()
|
48 |
+
|
49 |
+
worker_procs = []
|
50 |
+
|
51 |
+
worker_procs.append(start_worker(port=40000, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V', model_type='llama3-8b'))
|
52 |
+
worker_procs.append(start_worker(port=40001, model_path='BAAI/Bunny-v1_1-4B', model_type='phi-3'))
|
53 |
+
worker_procs.append(start_worker(port=40002, model_path='BAAI/Bunny-v1_0-3B', model_type='phi-2'))
|
54 |
+
|
55 |
+
time.sleep(60)
|
56 |
+
|
57 |
+
exit_status = 0
|
58 |
+
try:
|
59 |
+
demo = gws.build_demo(embed_mode=gws.args.embed)
|
60 |
+
demo.launch(
|
61 |
+
server_name=gws.args.host,
|
62 |
+
server_port=gws.args.port,
|
63 |
+
share=gws.args.share,
|
64 |
+
debug=True,
|
65 |
+
max_threads=10
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
print(e)
|
69 |
+
exit_status = 1
|
70 |
+
finally:
|
71 |
+
for worker_proc in worker_procs:
|
72 |
+
worker_proc.kill()
|
73 |
+
controller_proc.kill()
|
74 |
+
sys.exit(exit_status)
|
bunny/constants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Constants
|
2 |
+
IGNORE_INDEX = -100
|
3 |
+
IMAGE_TOKEN_INDEX = -200
|
4 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
5 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
6 |
+
LOGDIR = "gradio-logs"
|
7 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
bunny/conversation.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
TWO = auto()
|
9 |
+
PLAIN = auto()
|
10 |
+
|
11 |
+
|
12 |
+
@dataclasses.dataclass
|
13 |
+
class Conversation:
|
14 |
+
"""A class that keeps all conversation history."""
|
15 |
+
system: str
|
16 |
+
roles: List[str]
|
17 |
+
messages: List[List[str]]
|
18 |
+
offset: int
|
19 |
+
sep_style: SeparatorStyle
|
20 |
+
sep: str = "###"
|
21 |
+
sep2: str = None
|
22 |
+
version: str = "Unknown"
|
23 |
+
|
24 |
+
skip_next: bool = False
|
25 |
+
|
26 |
+
def get_prompt(self):
|
27 |
+
messages = self.messages
|
28 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
29 |
+
messages = self.messages.copy()
|
30 |
+
init_role, init_msg = messages[0].copy()
|
31 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
32 |
+
if 'mmtag' in self.version:
|
33 |
+
messages[0] = (init_role, init_msg)
|
34 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
35 |
+
messages.insert(1, (self.roles[1], "Received."))
|
36 |
+
else:
|
37 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
38 |
+
|
39 |
+
if self.sep_style == SeparatorStyle.TWO:
|
40 |
+
seps = [self.sep, self.sep2]
|
41 |
+
ret = self.system + seps[0]
|
42 |
+
for i, (role, message) in enumerate(messages):
|
43 |
+
if message:
|
44 |
+
if type(message) is tuple:
|
45 |
+
message, _, _ = message
|
46 |
+
ret += role + ": " + message + seps[i % 2]
|
47 |
+
else:
|
48 |
+
ret += role + ":"
|
49 |
+
|
50 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
51 |
+
seps = [self.sep, self.sep2]
|
52 |
+
ret = self.system
|
53 |
+
for i, (role, message) in enumerate(messages):
|
54 |
+
if message:
|
55 |
+
if type(message) is tuple:
|
56 |
+
message, _, _ = message
|
57 |
+
ret += message + seps[i % 2]
|
58 |
+
else:
|
59 |
+
ret += ""
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
62 |
+
|
63 |
+
return ret
|
64 |
+
|
65 |
+
def append_message(self, role, message):
|
66 |
+
self.messages.append([role, message])
|
67 |
+
|
68 |
+
def get_images(self, return_pil=False):
|
69 |
+
images = []
|
70 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
71 |
+
if i % 2 == 0:
|
72 |
+
if type(msg) is tuple:
|
73 |
+
import base64
|
74 |
+
from io import BytesIO
|
75 |
+
from PIL import Image
|
76 |
+
msg, image, image_process_mode = msg
|
77 |
+
if image_process_mode == "Pad":
|
78 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
79 |
+
width, height = pil_img.size
|
80 |
+
if width == height:
|
81 |
+
return pil_img
|
82 |
+
elif width > height:
|
83 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
84 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
85 |
+
return result
|
86 |
+
else:
|
87 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
88 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
89 |
+
return result
|
90 |
+
|
91 |
+
image = expand2square(image)
|
92 |
+
elif image_process_mode in ["Default", "Crop"]:
|
93 |
+
pass
|
94 |
+
elif image_process_mode == "Resize":
|
95 |
+
image = image.resize((336, 336))
|
96 |
+
else:
|
97 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
98 |
+
|
99 |
+
if return_pil:
|
100 |
+
images.append(image)
|
101 |
+
else:
|
102 |
+
buffered = BytesIO()
|
103 |
+
image.save(buffered, format="PNG")
|
104 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
105 |
+
images.append(img_b64_str)
|
106 |
+
return images
|
107 |
+
|
108 |
+
def to_gradio_chatbot(self):
|
109 |
+
ret = []
|
110 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
111 |
+
if i % 2 == 0:
|
112 |
+
if type(msg) is tuple:
|
113 |
+
import base64
|
114 |
+
from io import BytesIO
|
115 |
+
msg, image, image_process_mode = msg
|
116 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
117 |
+
aspect_ratio = max_hw / min_hw
|
118 |
+
max_len, min_len = 800, 400
|
119 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
120 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
121 |
+
W, H = image.size
|
122 |
+
if H > W:
|
123 |
+
H, W = longest_edge, shortest_edge
|
124 |
+
else:
|
125 |
+
H, W = shortest_edge, longest_edge
|
126 |
+
image = image.resize((W, H))
|
127 |
+
buffered = BytesIO()
|
128 |
+
image.save(buffered, format="JPEG")
|
129 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
130 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
131 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
132 |
+
ret.append([msg, None])
|
133 |
+
else:
|
134 |
+
ret.append([msg, None])
|
135 |
+
else:
|
136 |
+
ret[-1][-1] = msg
|
137 |
+
return ret
|
138 |
+
|
139 |
+
def copy(self):
|
140 |
+
return Conversation(
|
141 |
+
system=self.system,
|
142 |
+
roles=self.roles,
|
143 |
+
messages=[[x, y] for x, y in self.messages],
|
144 |
+
offset=self.offset,
|
145 |
+
sep_style=self.sep_style,
|
146 |
+
sep=self.sep,
|
147 |
+
sep2=self.sep2,
|
148 |
+
version=self.version)
|
149 |
+
|
150 |
+
def dict(self):
|
151 |
+
if len(self.get_images()) > 0:
|
152 |
+
return {
|
153 |
+
"system": self.system,
|
154 |
+
"roles": self.roles,
|
155 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
156 |
+
"offset": self.offset,
|
157 |
+
"sep": self.sep,
|
158 |
+
"sep2": self.sep2,
|
159 |
+
}
|
160 |
+
return {
|
161 |
+
"system": self.system,
|
162 |
+
"roles": self.roles,
|
163 |
+
"messages": self.messages,
|
164 |
+
"offset": self.offset,
|
165 |
+
"sep": self.sep,
|
166 |
+
"sep2": self.sep2,
|
167 |
+
}
|
168 |
+
|
169 |
+
|
170 |
+
conv_bunny = Conversation(
|
171 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
172 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
173 |
+
roles=("USER", "ASSISTANT"),
|
174 |
+
version="bunny",
|
175 |
+
messages=(),
|
176 |
+
offset=0,
|
177 |
+
sep_style=SeparatorStyle.TWO,
|
178 |
+
sep=" ",
|
179 |
+
sep2="<|endoftext|>",
|
180 |
+
)
|
181 |
+
|
182 |
+
conv_phi3 = Conversation(
|
183 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
184 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
185 |
+
roles=("USER", "ASSISTANT"),
|
186 |
+
version="phi3",
|
187 |
+
messages=(),
|
188 |
+
offset=0,
|
189 |
+
sep_style=SeparatorStyle.TWO,
|
190 |
+
sep=" ",
|
191 |
+
sep2="<|endoftext|>",
|
192 |
+
)
|
193 |
+
|
194 |
+
conv_minicpm = Conversation(
|
195 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
196 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
197 |
+
roles=("USER", "ASSISTANT"),
|
198 |
+
version="minicpm",
|
199 |
+
messages=(),
|
200 |
+
offset=0,
|
201 |
+
sep_style=SeparatorStyle.TWO,
|
202 |
+
sep=" ",
|
203 |
+
sep2="</s>",
|
204 |
+
)
|
205 |
+
|
206 |
+
conv_llama = Conversation(
|
207 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
208 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
209 |
+
roles=("USER", "ASSISTANT"),
|
210 |
+
version="llama",
|
211 |
+
messages=(),
|
212 |
+
offset=0,
|
213 |
+
sep_style=SeparatorStyle.TWO,
|
214 |
+
sep=" ",
|
215 |
+
sep2="<|end_of_text|>",
|
216 |
+
)
|
217 |
+
|
218 |
+
conv_plain = Conversation(
|
219 |
+
system="",
|
220 |
+
roles=("", ""),
|
221 |
+
messages=(
|
222 |
+
),
|
223 |
+
offset=0,
|
224 |
+
sep_style=SeparatorStyle.PLAIN,
|
225 |
+
sep="\n",
|
226 |
+
)
|
227 |
+
|
228 |
+
default_conversation = conv_bunny
|
229 |
+
conv_templates = {
|
230 |
+
"default": conv_bunny,
|
231 |
+
"bunny": conv_bunny,
|
232 |
+
"phi3": conv_phi3,
|
233 |
+
"plain": conv_plain,
|
234 |
+
'minicpm': conv_minicpm,
|
235 |
+
'llama': conv_llama
|
236 |
+
}
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
print(default_conversation.get_prompt())
|
bunny/eval/m4c_evaluator.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import re
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class EvalAIAnswerProcessor:
|
8 |
+
"""
|
9 |
+
Processes an answer similar to Eval AI
|
10 |
+
copied from
|
11 |
+
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
12 |
+
"""
|
13 |
+
|
14 |
+
CONTRACTIONS = {
|
15 |
+
"aint": "ain't",
|
16 |
+
"arent": "aren't",
|
17 |
+
"cant": "can't",
|
18 |
+
"couldve": "could've",
|
19 |
+
"couldnt": "couldn't",
|
20 |
+
"couldn'tve": "couldn't've",
|
21 |
+
"couldnt've": "couldn't've",
|
22 |
+
"didnt": "didn't",
|
23 |
+
"doesnt": "doesn't",
|
24 |
+
"dont": "don't",
|
25 |
+
"hadnt": "hadn't",
|
26 |
+
"hadnt've": "hadn't've",
|
27 |
+
"hadn'tve": "hadn't've",
|
28 |
+
"hasnt": "hasn't",
|
29 |
+
"havent": "haven't",
|
30 |
+
"hed": "he'd",
|
31 |
+
"hed've": "he'd've",
|
32 |
+
"he'dve": "he'd've",
|
33 |
+
"hes": "he's",
|
34 |
+
"howd": "how'd",
|
35 |
+
"howll": "how'll",
|
36 |
+
"hows": "how's",
|
37 |
+
"Id've": "I'd've",
|
38 |
+
"I'dve": "I'd've",
|
39 |
+
"Im": "I'm",
|
40 |
+
"Ive": "I've",
|
41 |
+
"isnt": "isn't",
|
42 |
+
"itd": "it'd",
|
43 |
+
"itd've": "it'd've",
|
44 |
+
"it'dve": "it'd've",
|
45 |
+
"itll": "it'll",
|
46 |
+
"let's": "let's",
|
47 |
+
"maam": "ma'am",
|
48 |
+
"mightnt": "mightn't",
|
49 |
+
"mightnt've": "mightn't've",
|
50 |
+
"mightn'tve": "mightn't've",
|
51 |
+
"mightve": "might've",
|
52 |
+
"mustnt": "mustn't",
|
53 |
+
"mustve": "must've",
|
54 |
+
"neednt": "needn't",
|
55 |
+
"notve": "not've",
|
56 |
+
"oclock": "o'clock",
|
57 |
+
"oughtnt": "oughtn't",
|
58 |
+
"ow's'at": "'ow's'at",
|
59 |
+
"'ows'at": "'ow's'at",
|
60 |
+
"'ow'sat": "'ow's'at",
|
61 |
+
"shant": "shan't",
|
62 |
+
"shed've": "she'd've",
|
63 |
+
"she'dve": "she'd've",
|
64 |
+
"she's": "she's",
|
65 |
+
"shouldve": "should've",
|
66 |
+
"shouldnt": "shouldn't",
|
67 |
+
"shouldnt've": "shouldn't've",
|
68 |
+
"shouldn'tve": "shouldn't've",
|
69 |
+
"somebody'd": "somebodyd",
|
70 |
+
"somebodyd've": "somebody'd've",
|
71 |
+
"somebody'dve": "somebody'd've",
|
72 |
+
"somebodyll": "somebody'll",
|
73 |
+
"somebodys": "somebody's",
|
74 |
+
"someoned": "someone'd",
|
75 |
+
"someoned've": "someone'd've",
|
76 |
+
"someone'dve": "someone'd've",
|
77 |
+
"someonell": "someone'll",
|
78 |
+
"someones": "someone's",
|
79 |
+
"somethingd": "something'd",
|
80 |
+
"somethingd've": "something'd've",
|
81 |
+
"something'dve": "something'd've",
|
82 |
+
"somethingll": "something'll",
|
83 |
+
"thats": "that's",
|
84 |
+
"thered": "there'd",
|
85 |
+
"thered've": "there'd've",
|
86 |
+
"there'dve": "there'd've",
|
87 |
+
"therere": "there're",
|
88 |
+
"theres": "there's",
|
89 |
+
"theyd": "they'd",
|
90 |
+
"theyd've": "they'd've",
|
91 |
+
"they'dve": "they'd've",
|
92 |
+
"theyll": "they'll",
|
93 |
+
"theyre": "they're",
|
94 |
+
"theyve": "they've",
|
95 |
+
"twas": "'twas",
|
96 |
+
"wasnt": "wasn't",
|
97 |
+
"wed've": "we'd've",
|
98 |
+
"we'dve": "we'd've",
|
99 |
+
"weve": "we've",
|
100 |
+
"werent": "weren't",
|
101 |
+
"whatll": "what'll",
|
102 |
+
"whatre": "what're",
|
103 |
+
"whats": "what's",
|
104 |
+
"whatve": "what've",
|
105 |
+
"whens": "when's",
|
106 |
+
"whered": "where'd",
|
107 |
+
"wheres": "where's",
|
108 |
+
"whereve": "where've",
|
109 |
+
"whod": "who'd",
|
110 |
+
"whod've": "who'd've",
|
111 |
+
"who'dve": "who'd've",
|
112 |
+
"wholl": "who'll",
|
113 |
+
"whos": "who's",
|
114 |
+
"whove": "who've",
|
115 |
+
"whyll": "why'll",
|
116 |
+
"whyre": "why're",
|
117 |
+
"whys": "why's",
|
118 |
+
"wont": "won't",
|
119 |
+
"wouldve": "would've",
|
120 |
+
"wouldnt": "wouldn't",
|
121 |
+
"wouldnt've": "wouldn't've",
|
122 |
+
"wouldn'tve": "wouldn't've",
|
123 |
+
"yall": "y'all",
|
124 |
+
"yall'll": "y'all'll",
|
125 |
+
"y'allll": "y'all'll",
|
126 |
+
"yall'd've": "y'all'd've",
|
127 |
+
"y'alld've": "y'all'd've",
|
128 |
+
"y'all'dve": "y'all'd've",
|
129 |
+
"youd": "you'd",
|
130 |
+
"youd've": "you'd've",
|
131 |
+
"you'dve": "you'd've",
|
132 |
+
"youll": "you'll",
|
133 |
+
"youre": "you're",
|
134 |
+
"youve": "you've",
|
135 |
+
}
|
136 |
+
|
137 |
+
NUMBER_MAP = {
|
138 |
+
"none": "0",
|
139 |
+
"zero": "0",
|
140 |
+
"one": "1",
|
141 |
+
"two": "2",
|
142 |
+
"three": "3",
|
143 |
+
"four": "4",
|
144 |
+
"five": "5",
|
145 |
+
"six": "6",
|
146 |
+
"seven": "7",
|
147 |
+
"eight": "8",
|
148 |
+
"nine": "9",
|
149 |
+
"ten": "10",
|
150 |
+
}
|
151 |
+
ARTICLES = ["a", "an", "the"]
|
152 |
+
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
153 |
+
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
154 |
+
PUNCTUATIONS = [
|
155 |
+
";",
|
156 |
+
r"/",
|
157 |
+
"[",
|
158 |
+
"]",
|
159 |
+
'"',
|
160 |
+
"{",
|
161 |
+
"}",
|
162 |
+
"(",
|
163 |
+
")",
|
164 |
+
"=",
|
165 |
+
"+",
|
166 |
+
"\\",
|
167 |
+
"_",
|
168 |
+
"-",
|
169 |
+
">",
|
170 |
+
"<",
|
171 |
+
"@",
|
172 |
+
"`",
|
173 |
+
",",
|
174 |
+
"?",
|
175 |
+
"!",
|
176 |
+
]
|
177 |
+
|
178 |
+
def __init__(self, *args, **kwargs):
|
179 |
+
pass
|
180 |
+
|
181 |
+
def word_tokenize(self, word):
|
182 |
+
word = word.lower()
|
183 |
+
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
184 |
+
return word.strip()
|
185 |
+
|
186 |
+
def process_punctuation(self, in_text):
|
187 |
+
out_text = in_text
|
188 |
+
for p in self.PUNCTUATIONS:
|
189 |
+
if (p + " " in in_text or " " + p in in_text) or (
|
190 |
+
re.search(self.COMMA_STRIP, in_text) is not None
|
191 |
+
):
|
192 |
+
out_text = out_text.replace(p, "")
|
193 |
+
else:
|
194 |
+
out_text = out_text.replace(p, " ")
|
195 |
+
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
196 |
+
return out_text
|
197 |
+
|
198 |
+
def process_digit_article(self, in_text):
|
199 |
+
out_text = []
|
200 |
+
temp_text = in_text.lower().split()
|
201 |
+
for word in temp_text:
|
202 |
+
word = self.NUMBER_MAP.setdefault(word, word)
|
203 |
+
if word not in self.ARTICLES:
|
204 |
+
out_text.append(word)
|
205 |
+
else:
|
206 |
+
pass
|
207 |
+
for word_id, word in enumerate(out_text):
|
208 |
+
if word in self.CONTRACTIONS:
|
209 |
+
out_text[word_id] = self.CONTRACTIONS[word]
|
210 |
+
out_text = " ".join(out_text)
|
211 |
+
return out_text
|
212 |
+
|
213 |
+
def __call__(self, item):
|
214 |
+
item = self.word_tokenize(item)
|
215 |
+
item = item.replace("\n", " ").replace("\t", " ").strip()
|
216 |
+
item = self.process_punctuation(item)
|
217 |
+
item = self.process_digit_article(item)
|
218 |
+
return item
|
219 |
+
|
220 |
+
|
221 |
+
class TextVQAAccuracyEvaluator:
|
222 |
+
def __init__(self):
|
223 |
+
self.answer_processor = EvalAIAnswerProcessor()
|
224 |
+
|
225 |
+
def _compute_answer_scores(self, raw_answers):
|
226 |
+
"""
|
227 |
+
compute the accuracy (soft score) of human answers
|
228 |
+
"""
|
229 |
+
answers = [self.answer_processor(a) for a in raw_answers]
|
230 |
+
assert len(answers) == 10
|
231 |
+
gt_answers = list(enumerate(answers))
|
232 |
+
unique_answers = set(answers)
|
233 |
+
unique_answer_scores = {}
|
234 |
+
|
235 |
+
for unique_answer in unique_answers:
|
236 |
+
accs = []
|
237 |
+
for gt_answer in gt_answers:
|
238 |
+
other_answers = [item for item in gt_answers if item != gt_answer]
|
239 |
+
matching_answers = [
|
240 |
+
item for item in other_answers if item[1] == unique_answer
|
241 |
+
]
|
242 |
+
acc = min(1, float(len(matching_answers)) / 3)
|
243 |
+
accs.append(acc)
|
244 |
+
unique_answer_scores[unique_answer] = sum(accs) / len(accs)
|
245 |
+
|
246 |
+
return unique_answer_scores
|
247 |
+
|
248 |
+
def eval_pred_list(self, pred_list):
|
249 |
+
pred_scores = []
|
250 |
+
for entry in tqdm(pred_list):
|
251 |
+
pred_answer = self.answer_processor(entry["pred_answer"])
|
252 |
+
unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
|
253 |
+
score = unique_answer_scores.get(pred_answer, 0.0)
|
254 |
+
pred_scores.append(score)
|
255 |
+
|
256 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
257 |
+
return accuracy
|
258 |
+
|
259 |
+
|
260 |
+
class STVQAAccuracyEvaluator:
|
261 |
+
def __init__(self):
|
262 |
+
self.answer_processor = EvalAIAnswerProcessor()
|
263 |
+
|
264 |
+
def eval_pred_list(self, pred_list):
|
265 |
+
pred_scores = []
|
266 |
+
for entry in pred_list:
|
267 |
+
pred_answer = self.answer_processor(entry["pred_answer"])
|
268 |
+
gts = [self.answer_processor(a) for a in entry["gt_answers"]]
|
269 |
+
score = 1.0 if pred_answer in gts else 0.0
|
270 |
+
pred_scores.append(score)
|
271 |
+
|
272 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
273 |
+
return accuracy
|
274 |
+
|
275 |
+
|
276 |
+
class STVQAANLSEvaluator:
|
277 |
+
def __init__(self):
|
278 |
+
import editdistance # install with `pip install editdistance`
|
279 |
+
|
280 |
+
self.get_edit_distance = editdistance.eval
|
281 |
+
|
282 |
+
def get_anls(self, s1, s2):
|
283 |
+
s1 = s1.lower().strip()
|
284 |
+
s2 = s2.lower().strip()
|
285 |
+
iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
|
286 |
+
anls = iou if iou >= 0.5 else 0.0
|
287 |
+
return anls
|
288 |
+
|
289 |
+
def eval_pred_list(self, pred_list):
|
290 |
+
pred_scores = []
|
291 |
+
for entry in pred_list:
|
292 |
+
anls = max(
|
293 |
+
self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
|
294 |
+
)
|
295 |
+
pred_scores.append(anls)
|
296 |
+
|
297 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
298 |
+
return accuracy
|
299 |
+
|
300 |
+
|
301 |
+
class TextCapsBleu4Evaluator:
|
302 |
+
def __init__(self):
|
303 |
+
# The following script requires Java 1.8.0 and pycocotools installed.
|
304 |
+
# The pycocoevalcap can be installed with pip as
|
305 |
+
# pip install git+https://github.com/ronghanghu/coco-caption.git@python23
|
306 |
+
# Original pycocoevalcap code is at https://github.com/tylin/coco-caption
|
307 |
+
# but has no python3 support yet.
|
308 |
+
try:
|
309 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
310 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
311 |
+
except ModuleNotFoundError:
|
312 |
+
print(
|
313 |
+
"Please install pycocoevalcap module using "
|
314 |
+
"pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
|
315 |
+
)
|
316 |
+
raise
|
317 |
+
|
318 |
+
self.tokenizer = PTBTokenizer()
|
319 |
+
self.scorer = Bleu(4)
|
320 |
+
|
321 |
+
def eval_pred_list(self, pred_list):
|
322 |
+
# Create reference and hypotheses captions.
|
323 |
+
gts = {}
|
324 |
+
res = {}
|
325 |
+
for idx, entry in enumerate(pred_list):
|
326 |
+
gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
|
327 |
+
res[idx] = [{"caption": entry["pred_answer"]}]
|
328 |
+
|
329 |
+
gts = self.tokenizer.tokenize(gts)
|
330 |
+
res = self.tokenizer.tokenize(res)
|
331 |
+
score, _ = self.scorer.compute_score(gts, res)
|
332 |
+
|
333 |
+
bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
|
334 |
+
return bleu4
|
bunny/eval/model_vqa.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
9 |
+
from bunny.conversation import conv_templates, SeparatorStyle
|
10 |
+
from bunny.model.builder import load_pretrained_model
|
11 |
+
from bunny.util.utils import disable_torch_init
|
12 |
+
from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
def split_list(lst, n):
|
19 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
20 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
21 |
+
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
22 |
+
|
23 |
+
|
24 |
+
def get_chunk(lst, n, k):
|
25 |
+
chunks = split_list(lst, n)
|
26 |
+
return chunks[k]
|
27 |
+
|
28 |
+
|
29 |
+
def eval_model(args):
|
30 |
+
# Model
|
31 |
+
disable_torch_init()
|
32 |
+
model_path = os.path.expanduser(args.model_path)
|
33 |
+
model_name = get_model_name_from_path(model_path)
|
34 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
|
35 |
+
args.model_type)
|
36 |
+
|
37 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
38 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
39 |
+
answers_file = os.path.expanduser(args.answers_file)
|
40 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
41 |
+
ans_file = open(answers_file, "w")
|
42 |
+
for line in tqdm(questions):
|
43 |
+
idx = line["question_id"]
|
44 |
+
image_file = line["image"]
|
45 |
+
qs = line["text"]
|
46 |
+
cur_prompt = qs
|
47 |
+
|
48 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
49 |
+
|
50 |
+
conv = conv_templates[args.conv_mode].copy()
|
51 |
+
conv.append_message(conv.roles[0], qs)
|
52 |
+
conv.append_message(conv.roles[1], None)
|
53 |
+
prompt = conv.get_prompt()
|
54 |
+
|
55 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
56 |
+
|
57 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
58 |
+
image_tensor = process_images([image], image_processor, model.config)[0]
|
59 |
+
|
60 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
61 |
+
|
62 |
+
with torch.inference_mode():
|
63 |
+
output_ids = model.generate(
|
64 |
+
input_ids,
|
65 |
+
images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
|
66 |
+
do_sample=True if args.temperature > 0 else False,
|
67 |
+
temperature=args.temperature,
|
68 |
+
top_p=args.top_p,
|
69 |
+
num_beams=args.num_beams,
|
70 |
+
# no_repeat_ngram_size=3,
|
71 |
+
max_new_tokens=1024,
|
72 |
+
use_cache=True)
|
73 |
+
|
74 |
+
input_token_len = input_ids.shape[1]
|
75 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
76 |
+
if n_diff_input_output > 0:
|
77 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
78 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
79 |
+
outputs = outputs.strip()
|
80 |
+
if outputs.endswith(stop_str):
|
81 |
+
outputs = outputs[:-len(stop_str)]
|
82 |
+
outputs = outputs.strip()
|
83 |
+
|
84 |
+
ans_id = shortuuid.uuid()
|
85 |
+
ans_file.write(json.dumps({"question_id": idx,
|
86 |
+
"prompt": cur_prompt,
|
87 |
+
"text": outputs,
|
88 |
+
"answer_id": ans_id,
|
89 |
+
"model_id": model_name,
|
90 |
+
"metadata": {}}) + "\n")
|
91 |
+
ans_file.flush()
|
92 |
+
ans_file.close()
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
parser.add_argument("--model-path", type=str, default=None)
|
98 |
+
parser.add_argument("--model-base", type=str, default=None)
|
99 |
+
parser.add_argument("--model-type", type=str, default=None)
|
100 |
+
parser.add_argument("--image-folder", type=str, default=None)
|
101 |
+
parser.add_argument("--question-file", type=str, default=None)
|
102 |
+
parser.add_argument("--answers-file", type=str, default=None)
|
103 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
104 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
105 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
106 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
107 |
+
parser.add_argument("--top_p", type=float, default=None)
|
108 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
109 |
+
args = parser.parse_args()
|
110 |
+
|
111 |
+
eval_model(args)
|
bunny/eval/model_vqa_cmmmu.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import yaml
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from datasets import load_dataset, concatenate_datasets
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
|
12 |
+
from bunny.model.builder import load_pretrained_model
|
13 |
+
from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
|
14 |
+
from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
15 |
+
from bunny.conversation import conv_templates
|
16 |
+
|
17 |
+
CAT_CN2EN = {'艺术与设计': 'art_and_design',
|
18 |
+
'商业': 'business',
|
19 |
+
'健康与医学': 'health_and_medicine',
|
20 |
+
'人文社会科学': 'humanities_and_social_sciences',
|
21 |
+
'科学': 'science',
|
22 |
+
'技术与工程': 'technology_and_engineering'}
|
23 |
+
|
24 |
+
|
25 |
+
def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
|
26 |
+
def deal_with_prompt(input_text):
|
27 |
+
qs = input_text
|
28 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
29 |
+
return qs
|
30 |
+
|
31 |
+
prompt = sample['final_input_prompt']
|
32 |
+
prompt = deal_with_prompt(prompt)
|
33 |
+
|
34 |
+
conv = conv_templates[args.conv_mode].copy()
|
35 |
+
conv.append_message(conv.roles[0], prompt)
|
36 |
+
conv.append_message(conv.roles[1], None)
|
37 |
+
prompt = conv.get_prompt()
|
38 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
39 |
+
|
40 |
+
image = sample['image_1']
|
41 |
+
if sample['image_2'] is not None: # multiple images actually
|
42 |
+
if sample['type'] == '选择':
|
43 |
+
all_choices = sample['all_choices']
|
44 |
+
response = random.choice(all_choices)
|
45 |
+
else:
|
46 |
+
response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
|
47 |
+
elif image is not None:
|
48 |
+
output_ids = model.generate(
|
49 |
+
input_ids,
|
50 |
+
images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
|
51 |
+
do_sample=False,
|
52 |
+
temperature=0,
|
53 |
+
top_p=None,
|
54 |
+
# num_beams=5,
|
55 |
+
max_new_tokens=128,
|
56 |
+
use_cache=True)
|
57 |
+
|
58 |
+
input_token_len = input_ids.shape[1]
|
59 |
+
# n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
60 |
+
# if n_diff_input_output > 0:
|
61 |
+
# print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
62 |
+
response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
63 |
+
|
64 |
+
return response
|
65 |
+
|
66 |
+
|
67 |
+
def load_yaml(file_path):
|
68 |
+
with open(file_path, 'r') as stream:
|
69 |
+
try:
|
70 |
+
yaml_dict = yaml.safe_load(stream)
|
71 |
+
except yaml.YAMLError as exc:
|
72 |
+
print(exc)
|
73 |
+
|
74 |
+
return yaml_dict
|
75 |
+
|
76 |
+
|
77 |
+
# DATA PROCESSING
|
78 |
+
def construct_prompt(sample, config):
|
79 |
+
question = sample['question']
|
80 |
+
options = []
|
81 |
+
for i in range(1, 5):
|
82 |
+
if sample[f'option{i}'] is None:
|
83 |
+
break
|
84 |
+
options.append(sample[f'option{i}'])
|
85 |
+
|
86 |
+
example = ""
|
87 |
+
if sample['type'] == '选择':
|
88 |
+
start_chr = 'A'
|
89 |
+
prediction_range = []
|
90 |
+
for option in options:
|
91 |
+
prediction_range.append(start_chr)
|
92 |
+
example += f"({start_chr}) {option}\n"
|
93 |
+
start_chr = chr(ord(start_chr) + 1)
|
94 |
+
empty_prompt_sample_structure = config['multi_choice_example_format']
|
95 |
+
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
96 |
+
res_dict = {}
|
97 |
+
res_dict['correct_choice'] = sample['answer']
|
98 |
+
res_dict['all_choices'] = prediction_range
|
99 |
+
res_dict['empty_prompt'] = empty_prompt
|
100 |
+
if config['task_instructions']:
|
101 |
+
res_dict['final_input_prompt'] = config['task_instructions'][0].strip() + '\n\n' + empty_prompt
|
102 |
+
else:
|
103 |
+
res_dict['final_input_prompt'] = empty_prompt
|
104 |
+
|
105 |
+
res_dict['gt_content'] = sample['answer']
|
106 |
+
elif sample['type'] == '判断':
|
107 |
+
empty_prompt_sample_structure = config['T/F_example_format']
|
108 |
+
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
109 |
+
res_dict = {}
|
110 |
+
res_dict['empty_prompt'] = empty_prompt
|
111 |
+
if config['task_instructions']:
|
112 |
+
res_dict['final_input_prompt'] = config['task_instructions'][1].strip() + '\n\n' + empty_prompt
|
113 |
+
else:
|
114 |
+
res_dict['final_input_prompt'] = empty_prompt
|
115 |
+
res_dict['gt_content'] = sample['answer']
|
116 |
+
else:
|
117 |
+
empty_prompt_sample_structure = config['short_ans_example_format']
|
118 |
+
empty_prompt = empty_prompt_sample_structure.format(question)
|
119 |
+
res_dict = {}
|
120 |
+
res_dict['empty_prompt'] = empty_prompt
|
121 |
+
if config['task_instructions']:
|
122 |
+
res_dict['final_input_prompt'] = config['task_instructions'][2].strip() + '\n\n' + empty_prompt
|
123 |
+
else:
|
124 |
+
res_dict['final_input_prompt'] = empty_prompt
|
125 |
+
res_dict['gt_content'] = sample['answer']
|
126 |
+
|
127 |
+
res_dict.update(sample)
|
128 |
+
return res_dict
|
129 |
+
|
130 |
+
|
131 |
+
def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
|
132 |
+
out_samples = []
|
133 |
+
with torch.no_grad():
|
134 |
+
for sample in tqdm(samples):
|
135 |
+
if args.small_gpu_usage:
|
136 |
+
sample['image_1'] = sample['image_1'].cuda()
|
137 |
+
response = call_model_engine_fn(args, sample, model, tokenizer, processor)
|
138 |
+
if args.small_gpu_usage:
|
139 |
+
sample['image_1'] = sample['image_1'].cpu()
|
140 |
+
|
141 |
+
out_sample = dict()
|
142 |
+
out_sample['id'] = sample['id']
|
143 |
+
out_sample['type'] = sample['type']
|
144 |
+
out_sample['response'] = response
|
145 |
+
out_samples.append(out_sample)
|
146 |
+
return out_samples
|
147 |
+
|
148 |
+
|
149 |
+
def set_seed(seed_value):
|
150 |
+
"""
|
151 |
+
Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
|
152 |
+
|
153 |
+
:param seed_value: An integer value to be used as the seed.
|
154 |
+
"""
|
155 |
+
torch.manual_seed(seed_value)
|
156 |
+
if torch.cuda.is_available():
|
157 |
+
torch.cuda.manual_seed(seed_value)
|
158 |
+
torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
|
159 |
+
random.seed(seed_value)
|
160 |
+
np.random.seed(seed_value)
|
161 |
+
torch.backends.cudnn.deterministic = True
|
162 |
+
torch.backends.cudnn.benchmark = False
|
163 |
+
|
164 |
+
|
165 |
+
def main():
|
166 |
+
parser = ArgumentParser()
|
167 |
+
parser.add_argument('--model-path', type=str, default=None)
|
168 |
+
parser.add_argument('--model-base', type=str, default=None)
|
169 |
+
parser.add_argument("--model-type", type=str, default=None)
|
170 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
171 |
+
parser.add_argument('--data-path', type=str, default=None)
|
172 |
+
parser.add_argument('--config-path', type=str, default=None)
|
173 |
+
parser.add_argument('--output-path', type=str, default=None)
|
174 |
+
parser.add_argument('--split', type=str, default='validation')
|
175 |
+
parser.add_argument('--seed', type=int, default=42)
|
176 |
+
parser.add_argument("--small-gpu-usage", action="store_true")
|
177 |
+
|
178 |
+
args = parser.parse_args()
|
179 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
180 |
+
set_seed(args.seed)
|
181 |
+
|
182 |
+
print('bunny_initializing...')
|
183 |
+
processor = None
|
184 |
+
call_model_engine = call_bunny_engine_df
|
185 |
+
|
186 |
+
# load config and process to one value
|
187 |
+
args.config = load_yaml(args.config_path)
|
188 |
+
for key, value in args.config.items():
|
189 |
+
if key == 'task_instructions':
|
190 |
+
args.config[key] = value
|
191 |
+
elif key != 'eval_params' and type(value) == list:
|
192 |
+
assert len(value) == 1, 'key {} has more than one value'.format(key)
|
193 |
+
args.config[key] = value[0]
|
194 |
+
|
195 |
+
# run for each subject
|
196 |
+
sub_dataset_list = []
|
197 |
+
for subject in CAT_CN2EN.values():
|
198 |
+
sub_dataset = load_dataset(args.data_path, subject, split=args.split)
|
199 |
+
sub_dataset_list.append(sub_dataset)
|
200 |
+
|
201 |
+
# merge all dataset
|
202 |
+
dataset = concatenate_datasets(sub_dataset_list)
|
203 |
+
|
204 |
+
# load model
|
205 |
+
model_path = os.path.expanduser(args.model_path)
|
206 |
+
model_name = get_model_name_from_path(model_path)
|
207 |
+
tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
|
208 |
+
args.model_type)
|
209 |
+
|
210 |
+
samples = []
|
211 |
+
print('Processing CMMMU dataset...')
|
212 |
+
for sample in tqdm(dataset):
|
213 |
+
|
214 |
+
sample = construct_prompt(sample, args.config)
|
215 |
+
if sample['image_1']:
|
216 |
+
sample['image_1'] = process_images([sample['image_1'].convert('RGB')], vis_processors, model.config)[0]
|
217 |
+
if not args.small_gpu_usage:
|
218 |
+
sample['image_1'] = sample['image_1'].to(device)
|
219 |
+
|
220 |
+
samples.append(sample)
|
221 |
+
|
222 |
+
print('Start to evaluate...')
|
223 |
+
# run ex
|
224 |
+
out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
|
225 |
+
|
226 |
+
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
227 |
+
|
228 |
+
with open(args.output_path, 'w') as f:
|
229 |
+
for out_sample in out_samples:
|
230 |
+
f.write(json.dumps(out_sample) + '\n')
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == '__main__':
|
234 |
+
main()
|
bunny/eval/model_vqa_loader.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
9 |
+
from bunny.conversation import conv_templates
|
10 |
+
from bunny.model.builder import load_pretrained_model
|
11 |
+
from bunny.util.utils import disable_torch_init
|
12 |
+
from bunny.util.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
import math
|
17 |
+
|
18 |
+
|
19 |
+
def split_list(lst, n):
|
20 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
+
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
+
|
24 |
+
|
25 |
+
def get_chunk(lst, n, k):
|
26 |
+
chunks = split_list(lst, n)
|
27 |
+
return chunks[k]
|
28 |
+
|
29 |
+
|
30 |
+
# Custom dataset class
|
31 |
+
class CustomDataset(Dataset):
|
32 |
+
def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
|
33 |
+
self.questions = questions
|
34 |
+
self.image_folder = image_folder
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.image_processor = image_processor
|
37 |
+
self.model_config = model_config
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
line = self.questions[index]
|
41 |
+
image_file = line["image"]
|
42 |
+
qs = line["text"]
|
43 |
+
|
44 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
45 |
+
|
46 |
+
conv = conv_templates[args.conv_mode].copy()
|
47 |
+
conv.append_message(conv.roles[0], qs)
|
48 |
+
conv.append_message(conv.roles[1], None)
|
49 |
+
prompt = conv.get_prompt()
|
50 |
+
|
51 |
+
image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
|
52 |
+
image_tensor = process_images([image], self.image_processor, self.model_config)[0]
|
53 |
+
|
54 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
55 |
+
|
56 |
+
return input_ids, image_tensor
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.questions)
|
60 |
+
|
61 |
+
|
62 |
+
# DataLoader
|
63 |
+
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
|
64 |
+
assert batch_size == 1, "batch_size must be 1"
|
65 |
+
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
|
66 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
67 |
+
return data_loader
|
68 |
+
|
69 |
+
|
70 |
+
def eval_model(args):
|
71 |
+
# Model
|
72 |
+
disable_torch_init()
|
73 |
+
model_path = os.path.expanduser(args.model_path)
|
74 |
+
model_name = get_model_name_from_path(model_path)
|
75 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
|
76 |
+
args.model_type)
|
77 |
+
|
78 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
79 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
80 |
+
answers_file = os.path.expanduser(args.answers_file)
|
81 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
82 |
+
ans_file = open(answers_file, "w")
|
83 |
+
|
84 |
+
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
85 |
+
args.conv_mode = args.conv_mode + '_mmtag'
|
86 |
+
print(
|
87 |
+
f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
88 |
+
|
89 |
+
data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
|
90 |
+
|
91 |
+
for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
92 |
+
idx = line["question_id"]
|
93 |
+
cur_prompt = line["text"]
|
94 |
+
|
95 |
+
input_ids = input_ids.to(device='cuda', non_blocking=True)
|
96 |
+
|
97 |
+
with torch.inference_mode():
|
98 |
+
output_ids = model.generate(
|
99 |
+
input_ids,
|
100 |
+
images=image_tensor.to(dtype=model.dtype, device='cuda', non_blocking=True),
|
101 |
+
do_sample=True if args.temperature > 0 else False,
|
102 |
+
temperature=args.temperature,
|
103 |
+
top_p=args.top_p,
|
104 |
+
num_beams=args.num_beams,
|
105 |
+
max_new_tokens=args.max_new_tokens,
|
106 |
+
use_cache=True)
|
107 |
+
|
108 |
+
input_token_len = input_ids.shape[1]
|
109 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
110 |
+
if n_diff_input_output > 0:
|
111 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
112 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
113 |
+
outputs = outputs.strip()
|
114 |
+
|
115 |
+
ans_id = shortuuid.uuid()
|
116 |
+
ans_file.write(json.dumps({"question_id": idx,
|
117 |
+
"prompt": cur_prompt,
|
118 |
+
"text": outputs,
|
119 |
+
"answer_id": ans_id,
|
120 |
+
"model_id": model_name,
|
121 |
+
"metadata": {}}) + "\n")
|
122 |
+
# ans_file.flush()
|
123 |
+
ans_file.close()
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument("--model-path", type=str, default=None)
|
129 |
+
parser.add_argument("--model-base", type=str, default=None)
|
130 |
+
parser.add_argument("--model-type", type=str, default=None)
|
131 |
+
parser.add_argument("--image-folder", type=str, default=None)
|
132 |
+
parser.add_argument("--question-file", type=str, default=None)
|
133 |
+
parser.add_argument("--answers-file", type=str, default=None)
|
134 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
135 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
136 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
137 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
138 |
+
parser.add_argument("--top_p", type=float, default=None)
|
139 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
140 |
+
parser.add_argument("--max_new_tokens", type=int, default=128)
|
141 |
+
args = parser.parse_args()
|
142 |
+
|
143 |
+
eval_model(args)
|
bunny/eval/model_vqa_mmbench.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
import shortuuid
|
8 |
+
|
9 |
+
from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
10 |
+
from bunny.conversation import conv_templates, SeparatorStyle
|
11 |
+
from bunny.model.builder import load_pretrained_model
|
12 |
+
from bunny.util.utils import disable_torch_init
|
13 |
+
from bunny.util.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, \
|
14 |
+
get_model_name_from_path
|
15 |
+
|
16 |
+
import math
|
17 |
+
|
18 |
+
all_options = ['A', 'B', 'C', 'D']
|
19 |
+
|
20 |
+
|
21 |
+
def split_list(lst, n):
|
22 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
23 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
24 |
+
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
25 |
+
|
26 |
+
|
27 |
+
def get_chunk(lst, n, k):
|
28 |
+
chunks = split_list(lst, n)
|
29 |
+
return chunks[k]
|
30 |
+
|
31 |
+
|
32 |
+
def is_none(value):
|
33 |
+
if value is None:
|
34 |
+
return True
|
35 |
+
if type(value) is float and math.isnan(value):
|
36 |
+
return True
|
37 |
+
if type(value) is str and value.lower() == 'nan':
|
38 |
+
return True
|
39 |
+
if type(value) is str and value.lower() == 'none':
|
40 |
+
return True
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
def get_options(row, options):
|
45 |
+
parsed_options = []
|
46 |
+
for option in options:
|
47 |
+
option_value = row[option]
|
48 |
+
if is_none(option_value):
|
49 |
+
break
|
50 |
+
parsed_options.append(option_value)
|
51 |
+
return parsed_options
|
52 |
+
|
53 |
+
|
54 |
+
def eval_model(args):
|
55 |
+
# Model
|
56 |
+
disable_torch_init()
|
57 |
+
model_path = os.path.expanduser(args.model_path)
|
58 |
+
model_name = get_model_name_from_path(model_path)
|
59 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
|
60 |
+
args.model_type)
|
61 |
+
|
62 |
+
questions = pd.read_table(os.path.expanduser(args.question_file))
|
63 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
64 |
+
answers_file = os.path.expanduser(args.answers_file)
|
65 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
66 |
+
ans_file = open(answers_file, "w")
|
67 |
+
|
68 |
+
for index, row in tqdm(questions.iterrows(), total=len(questions)):
|
69 |
+
options = get_options(row, all_options)
|
70 |
+
cur_option_char = all_options[:len(options)]
|
71 |
+
|
72 |
+
if args.all_rounds:
|
73 |
+
num_rounds = len(options)
|
74 |
+
else:
|
75 |
+
num_rounds = 1
|
76 |
+
|
77 |
+
for round_idx in range(num_rounds):
|
78 |
+
idx = row['index']
|
79 |
+
question = row['question']
|
80 |
+
hint = row['hint']
|
81 |
+
image = load_image_from_base64(row['image'])
|
82 |
+
if not is_none(hint):
|
83 |
+
question = hint + '\n' + question
|
84 |
+
for option_char, option in zip(all_options[:len(options)], options):
|
85 |
+
question = question + '\n' + option_char + '. ' + option
|
86 |
+
qs = cur_prompt = question
|
87 |
+
|
88 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
89 |
+
|
90 |
+
if args.single_pred_prompt:
|
91 |
+
if args.lang == 'cn':
|
92 |
+
qs = qs + '\n' + "请直接回答选项字母。"
|
93 |
+
else:
|
94 |
+
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
|
95 |
+
|
96 |
+
conv = conv_templates[args.conv_mode].copy()
|
97 |
+
conv.append_message(conv.roles[0], qs)
|
98 |
+
conv.append_message(conv.roles[1], None)
|
99 |
+
prompt = conv.get_prompt()
|
100 |
+
|
101 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
|
102 |
+
0).cuda()
|
103 |
+
|
104 |
+
image_tensor = process_images([image], image_processor, model.config)[0]
|
105 |
+
|
106 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
107 |
+
|
108 |
+
with torch.inference_mode():
|
109 |
+
output_ids = model.generate(
|
110 |
+
input_ids,
|
111 |
+
images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
|
112 |
+
do_sample=True if args.temperature > 0 else False,
|
113 |
+
temperature=args.temperature,
|
114 |
+
top_p=args.top_p,
|
115 |
+
num_beams=args.num_beams,
|
116 |
+
# no_repeat_ngram_size=3,
|
117 |
+
max_new_tokens=128,
|
118 |
+
use_cache=True)
|
119 |
+
|
120 |
+
input_token_len = input_ids.shape[1]
|
121 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
122 |
+
if n_diff_input_output > 0:
|
123 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
124 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
125 |
+
outputs = outputs.strip()
|
126 |
+
if outputs.endswith(stop_str):
|
127 |
+
outputs = outputs[:-len(stop_str)]
|
128 |
+
outputs = outputs.strip()
|
129 |
+
|
130 |
+
ans_id = shortuuid.uuid()
|
131 |
+
ans_file.write(json.dumps({"question_id": idx,
|
132 |
+
"round_id": round_idx,
|
133 |
+
"prompt": cur_prompt,
|
134 |
+
"text": outputs,
|
135 |
+
"options": options,
|
136 |
+
"option_char": cur_option_char,
|
137 |
+
"answer_id": ans_id,
|
138 |
+
"model_id": model_name,
|
139 |
+
"metadata": {}}) + "\n")
|
140 |
+
ans_file.flush()
|
141 |
+
|
142 |
+
# rotate options
|
143 |
+
options = options[1:] + options[:1]
|
144 |
+
cur_option_char = cur_option_char[1:] + cur_option_char[:1]
|
145 |
+
ans_file.close()
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
parser = argparse.ArgumentParser()
|
150 |
+
parser.add_argument("--model-path", type=str, default=None)
|
151 |
+
parser.add_argument("--model-base", type=str, default=None)
|
152 |
+
parser.add_argument("--model-type", type=str, default=None)
|
153 |
+
parser.add_argument("--image-folder", type=str, default=None)
|
154 |
+
parser.add_argument("--question-file", type=str, default=None)
|
155 |
+
parser.add_argument("--answers-file", type=str, default=None)
|
156 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
157 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
158 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
159 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
160 |
+
parser.add_argument("--top_p", type=float, default=None)
|
161 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
162 |
+
parser.add_argument("--all-rounds", action="store_true")
|
163 |
+
parser.add_argument("--single-pred-prompt", action="store_true")
|
164 |
+
parser.add_argument("--lang", type=str, default="en")
|
165 |
+
args = parser.parse_args()
|
166 |
+
|
167 |
+
eval_model(args)
|
bunny/eval/model_vqa_mmmu.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|