LLMBB-Agent / benchmark /inference_and_execute.py
vlff李飞飞
update md
2319518
raw
history blame
No virus
9.87 kB
import argparse
import json
import logging
import os
from parser import ReActParser
import prettytable
import tqdm
from code_interpreter import code_interpreter
from config import (get_model, get_react_parser, get_react_prompt,
model_path_map)
from datasets import load_dataset
from metrics.code_execution import eval_code_execution_rate
from metrics.gsm8k import eval_gsm8k_acc, is_correct
from metrics.visualization import eval_visualization_acc
from utils.code_utils import replace_upload_fname
from utils.data_utils import load_jsonl
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
)
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
os.makedirs(WORK_DIR, exist_ok=True)
os.system(f'cp -r upload_file_clean {WORK_DIR}/upload_file')
os.system('cp -r upload_file_clean ./upload_file')
global_eval_result = {
'code_executability': {
'math': None,
'visualization': None,
'general': None,
},
'code_correctness': {
'math': None,
'visualization-hard': None,
'visualization-easy': None,
}
}
def llm_with_plugin(args, query, item=None, exec_limit=3):
exec_count = 0
# Build ReAct prompt
upload_fname_list = item[
'input_file_path'] if item and 'input_file_path' in item else []
lang = item['lang'] if item and 'lang' in item else 'en'
react_prompt_obj = get_react_prompt(args.model, query, lang,
upload_fname_list)
planning_prompt = react_prompt_obj.build_prompt()
# Execute the code when providing the first action in the query
if '<|im_start|>' in query:
_, prepend_code, __ = ReActParser().parse_latest_plugin_call(query)
prepend_code = replace_upload_fname(prepend_code, upload_fname_list)
call_plugin(_, [prepend_code], clear=(exec_count == 0))
exec_count += 1
exec_limit += 1
# Inference and execute
text = ''
while exec_count < exec_limit:
stop_words_list = react_prompt_obj.get_stop_words_list()
output = text_completion(args.llm,
planning_prompt + text,
stop_words=stop_words_list)
if args.gen_only:
text += output
break
react_parser = get_react_parser(args.model)
action, action_input, output = react_parser.parse_latest_plugin_call(
output)
if action:
action_input = replace_upload_fname(action_input,
upload_fname_list)
observation = call_plugin(action, [action_input],
clear=(exec_count == 0))
output += react_prompt_obj.build_observation(observation)
text += output
exec_count += 1
if 'error:' in observation or 'Traceback' in observation:
break
else:
text += output
break
return text
def text_completion(llm, input_text, stop_words=[]):
logging.info('Generating'.center(60, '='))
logging.info('Input'.center(60, '-'))
logging.info(input_text)
output = llm.generate(input_text, stop_words)
logging.info('Output'.center(60, '-'))
logging.info(output)
return output
def call_plugin(plugin_name, plugin_args_list, clear=False):
# Relax constraints on plugin name.
logging.info('Call code interpreter'.center(60, '='))
obs = code_interpreter(plugin_args_list, clear=clear)
logging.info(obs)
return obs
def process_code_interpreter(item, writer):
query = item['query']
exec_limit = 3 if 'visualization' in item['tags'] else 1
response = llm_with_plugin(args=args,
query=query,
item=item,
exec_limit=exec_limit)
item['gen'] = response
writer.write(json.dumps(item, ensure_ascii=False) + '\n')
writer.flush()
def process_gsm8k(doc, writer):
context = doc['question']
completion = llm_with_plugin(args=args, query=context)
acc = is_correct(completion, doc['answer'])
doc['completion'] = completion
doc['acc'] = acc
writer.write(json.dumps(doc, ensure_ascii=False) + '\n')
writer.flush()
def sequential_processing(args, data_list, process_func, writer):
for item in tqdm.tqdm(data_list):
process_func(item, writer)
process_func_map = {
'gsm8k': process_gsm8k,
'visualization': process_code_interpreter
}
def gather_eval_result(model_name):
for metric in global_eval_result:
logging.info(metric)
table = prettytable.PrettyTable()
table.field_names = ['model'] + list(global_eval_result[metric].keys())
row_data = [model_name]
for item in global_eval_result[metric].values():
item = str(item) if not item else str(round(item, 2))
row_data.append(item)
table.add_row(row_data)
logging.info('\n' + str(table))
def eval_metrics(args, test_set, full_output_fname):
# metrics
assert os.path.exists(
full_output_fname), f'Not Found File {full_output_fname}.'
inference_res = load_jsonl(full_output_fname)
assert len(inference_res) == len(
test_set
), f'There are still {len(test_set)-len(inference_res)} cases left.'
abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)),
full_output_fname)
if args.task == 'gsm8k':
math_code_correctness = eval_gsm8k_acc(abs_output_fname)
global_eval_result['code_correctness'].update(math_code_correctness)
else:
code_executability = eval_code_execution_rate(abs_output_fname,
args.task, args.model)
global_eval_result['code_executability'].update(code_executability)
if args.task in ['all_ci', 'visualization'
] and not args.eval_code_exec_only:
visualization_code_correctness = eval_visualization_acc(
abs_output_fname, args.model, args.vis_judger)
global_eval_result['code_correctness'].update(
visualization_code_correctness)
def main(args):
current_dir = os.getcwd()
os.makedirs(args.output_path, exist_ok=True)
full_output_fname = os.path.join(
args.output_path,
(args.output_fname or f'{args.task}_{args.model}_res.jsonl'))
if not os.path.exists(full_output_fname):
with open(full_output_fname, 'w'):
logging.info(f'Create file {full_output_fname} done.')
# build data
if args.task == 'gsm8k':
dataset = load_dataset('gsm8k', 'main')
test_set = dataset['test']
else:
eval_data_path = os.path.join(args.input_path, args.input_fname)
test_set = [
item for item in load_jsonl(eval_data_path)
if args.task in item['tags']
]
logging.info(f'Test set: {len(test_set)}')
if args.eval_only:
eval_metrics(args, test_set, full_output_fname)
else:
key = 'question' if args.task == 'gsm8k' else 'query'
cache_question = [item[key] for item in load_jsonl(full_output_fname)
] if not args.force else []
data_list = [
item for item in test_set if item[key] not in cache_question
]
logging.info(f'Left cases: {len(data_list)}')
# inference
writer_mode = 'w' if args.force else 'a'
f_output = open(full_output_fname, writer_mode, encoding='utf-8')
process_func = process_func_map.get(args.task,
process_code_interpreter)
sequential_processing(args, data_list, process_func, f_output)
f_output.close()
# evaluate
if not args.gen_exec_only:
eval_metrics(args, test_set, full_output_fname)
os.chdir(current_dir)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model',
type=str,
default='qwen-14b-chat',
choices=list(model_path_map.keys()))
parser.add_argument(
'--task',
type=str,
default='all',
choices=['all', 'gsm8k', 'visualization', 'general'])
parser.add_argument('--output-path', type=str, default='output_data')
parser.add_argument('--input-path', type=str, default='eval_data')
parser.add_argument('-o', '--output-fname', type=str, default='')
parser.add_argument('-i',
'--input-fname',
type=str,
default='eval_code_interpreter_v1.jsonl')
parser.add_argument('-f', '--force', action='store_true', default=False)
parser.add_argument('--eval-only', action='store_true', default=False)
parser.add_argument('--eval-code-exec-only',
action='store_true',
default=False)
parser.add_argument('--gen-exec-only', action='store_true', default=False)
parser.add_argument('--gen-only', action='store_true', default=False)
parser.add_argument('--vis-judger', type=str, default="'gpt-4-vision-preview'",
choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus'])
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if not args.eval_only:
args.llm = get_model(args.model)
logging.info(f'Init {args.model} done.')
if args.task == 'all':
for key in ['gsm8k', 'visualization', 'general']:
args.task = key
main(args)
else:
main(args)
gather_eval_result(args.model)