LLMBB-Agent / benchmark /metrics /visualization.py
vlff李飞飞
update md
2319518
raw
history blame
No virus
6.59 kB
import logging
import os
import re
import base64
import torch
from config import get_model, get_react_parser
from utils.data_utils import load_jsonl, save_jsonl
torch.manual_seed(1234)
EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
[问题]:{query}
"""
EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
[Question]: {query}
"""
visualization_code_correctness = {
'visualization-hard': None,
'visualization-easy': None,
}
def encode_image(image_path):
with open(image_path, "rb") as image_file:
a = base64.b64encode(image_file.read()).decode('utf-8')
return a
def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
output = ""
if judger_model_name == 'gpt-4-vision-preview':
logging.warning("This is an example of `gpt-4-vision-preview`. "
"Please set the API key and use according to your actual situation.")
from openai import OpenAI
client = OpenAI()
content_list = []
content_list.append({"type": "text", "text": prompt})
input_images = []
for img in imgs:
if 'http' not in img:
base64_image = encode_image(img)
img = f"data:image/jpeg;base64,{base64_image}"
input_images.append({"type": "image_url", 'image_url': img})
content_list.extend(input_images)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": content_list,
}
],
max_tokens=300,
)
output = response.choices[0]
elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
inputs = []
for img in imgs:
if 'http' not in img and judger_model_name == 'qwen-vl-plus':
img = "file://" + img
inputs.append({'image': img})
inputs.append({'text': prompt})
logging.info('Eval'.center(60, '-'))
logging.info(inputs)
output = judger_model.generate(inputs)
logging.info(output)
logging.info('=' * 60)
return output
def extract_images(text):
regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
results = re.findall(regex, text)
images = []
for res in results:
assert len(res) == 2
if os.path.exists(res[1]):
images.append(res[1])
return images
def check_images_observation(text, images, model_name):
start_flag = get_react_parser(model_name).observation
for image in images:
logging.info('Image'.center(60, '-'))
logging.info(image)
end_idx = text.find(image)
tmp_text = text[:end_idx + len(image)]
start_idx = tmp_text.rfind(start_flag)
check_text = tmp_text[start_idx + len(start_flag):]
logging.info('Observation'.center(60, '-'))
logging.info(check_text)
# As long as there exists correctly executed observation, we consider `True`
if 'error:' not in check_text and 'Traceback' not in check_text:
return True
return False
eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}
def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
if judger_model_name == 'gpt-4-vision-preview':
judger_model = None
elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
if judger_model_name == 'qwen-vl-chat':
logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
'evaluation model for `Visualization` task.. If you insist on using it, '
'the evaluation results might differ from the official results.')
judger_model = get_model(judger_model_name)
else:
raise Exception("Not supported judger model.")
one_action, one_action_right = 0, 0
zero_action, zero_action_right = 0, 0
data_list = load_jsonl(output_fname)
for item in data_list:
if 'visualization' not in item['tags']:
continue
item['vis_acc'] = False
if '<|im_end|>' in item['query']:
one_action += 1
prompt = item['query'].split('<|im_end|>')[0]
else:
zero_action += 1
prompt = item['query']
images = extract_images(item['gen'])
if images and check_images_observation(item['gen'], images,
model_name):
input_prompt = eval_visual_prompt[item.get('lang', 'en')]
format_prompt = input_prompt.format(query=prompt)
output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
if 'right' in output.lower():
item['vis_acc'] = True
if '<|im_end|>' in item['query']:
one_action_right += 1
else:
zero_action_right += 1
logging.info('*' * 60)
logging.info('{:^60}'.format('Visualization Acc.'))
logging.info('*' * 60)
logging.info(
'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
.format(zero_action, zero_action_right,
zero_action_right / zero_action * 100))
logging.info(
'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
.format(one_action, one_action_right,
one_action_right / one_action * 100))
logging.info('all count={}, all right={}, all acc={:.2f}'.format(
zero_action + one_action, zero_action_right + one_action_right,
(zero_action_right + one_action_right) / (zero_action + one_action) *
100))
visualization_code_correctness[
'visualization-hard'] = zero_action_right / zero_action * 100
visualization_code_correctness[
'visualization-easy'] = one_action_right / one_action * 100
error_data_list = [
item for item in data_list
if 'visualization' in item['tags'] and not item['vis_acc']
]
error_data_output_fname = os.path.splitext(
output_fname)[0] + '_vis_error.jsonl'
save_jsonl(error_data_list, error_data_output_fname)
return visualization_code_correctness