import logging import os import func_timeout from config import get_react_parser from func_timeout import func_set_timeout from utils.code_utils import extract_code, replace_upload_fname from utils.data_utils import load_jsonl, save_jsonl pre_load = """ import os if 'upload_file' not in os.getcwd(): os.chdir("./upload_file/") import seaborn as sns import matplotlib # matplotlib.use('Agg') import matplotlib.pyplot as plt plt.ion() import numpy as np import pandas as pd from sympy import Eq, symbols, solve import re import json import math """ tags_config = { 'visualization': { 'timelimit': True, 'extract_first_code': True, }, 'math': { 'timelimit': True, 'extract_first_code': False, }, 'general': { 'timelimit': False, 'extract_first_code': True, } } code_executability = {'math': None, 'visualization': None, 'general': None} @func_set_timeout(10) def exec_limit_time(text): exec(text, locals()) def exec_code(text, timelimit=False): if timelimit: exec_limit_time(text) else: exec(text, locals()) def postprocess_code(gen_code, line): if '<|im_start|>' in line['query']: first_action_code = get_action_input_code(line['query']) gen_code = first_action_code + gen_code upload_fname_list = line[ 'input_file_path'] if line and 'input_file_path' in line else [] gen_code = replace_upload_fname(gen_code, upload_fname_list) if 'def solution()' in gen_code: gen_code += '\nsolution()\n' if 'plt.show()' in gen_code: gen_code += "\nplt.pause(1)\nplt.close('all')\n" if 'sns.' in gen_code and 'plot' in gen_code: gen_code += "\nplt.close('all')\n" gen_code = pre_load + gen_code return gen_code def get_action_input_code(text, model_name='qwen-14b-chat', extract_first_code=False): action_input_list = [] tmp = text react_parser = get_react_parser(model_name) while True: action_input = react_parser.get_first_action_input(tmp) if not action_input: break action_input_list.append(action_input) tmp = tmp.split(action_input)[1] if not tmp or extract_first_code: break code = '' for action_input in action_input_list: code = code + '# concat\n' + extract_code(action_input) + '\n' return code def eval_code_execution_rate(output_fname, tag='all_ci', model_name='qwen-14b-chat', timelimit=False, extract_first_code=False): data_list = load_jsonl(output_fname) pip_package = [] for line_id, line in enumerate(data_list): line['idx'] = line_id tags_list = line['tags'].split(',') if tag not in tags_list: continue # update args for cur_tag in tags_list: if cur_tag != 'all_ci': timelimit = tags_config[cur_tag]['timelimit'] extract_first_code = tags_config[cur_tag]['extract_first_code'] line['executable_code'] = False line['missing_code'] = False line['code_error_info'] = '' # get Action Input code from response gen_code = get_action_input_code(line['gen'], model_name=model_name, extract_first_code=extract_first_code) if not gen_code: line['missing_code'] = True line['code'] = '' line['code_error_info'] = 'missing code' continue line['code'] = gen_code gen_code = postprocess_code(gen_code, line) while True: try: exec_code(gen_code, timelimit=timelimit) line['executable_code'] = True break except func_timeout.exceptions.FunctionTimedOut as ex: line['code_error_info'] = str(ex) break except (ImportError, ModuleNotFoundError) as ex: try: packege = str(ex).split("'")[1].strip() except Exception: packege = '' if packege and packege not in pip_package: # install package pip_package.append(packege) os.system('pip install ' + packege) logging.info(f'Automatic installation: {packege}') else: line['code_error_info'] = str(ex) break except Exception as ex: line['code_error_info'] = str(ex) break # double check observation = get_react_parser(model_name).get_first_observation( line['gen']) if line['executable_code'] and ('error:' in observation): logging.warning( 'The code executes correctly, but it has an error in IPython!') logging.warning(f'Code:\n{gen_code}') logging.warning(f'IPython error info:\n{observation}') logging.info('=' * 60) elif not line['executable_code'] and not ('error:' in observation): logging.warning( 'The code has an execution error, but it runs correctly in IPython!' ) logging.warning(f'Code:\n{gen_code}') logging.warning(f"Exec error info:\n{line['code_error_info']}") logging.warning(f'IPython observation:\n{observation}') logging.info('=' * 60) # save error data error_data_list = [ item for item in data_list if not item['executable_code'] or item['missing_code'] ] error_data_output_fname = os.path.splitext( output_fname)[0] + '_exec_error.jsonl' save_jsonl(error_data_list, error_data_output_fname) log_result(data_list) return code_executability def log_result(data_list, verbose=True): if verbose: logging.info('*' * 60) logging.info('{:^60}'.format('Detail')) logging.info('*' * 60) for line_id, line in enumerate(data_list): logging.info(f'Question {line_id}'.center(60, '=')) logging.info(line['query']) logging.info(f'Generated {line_id}'.center(60, '-')) logging.info('\n' + line['gen']) logging.info(f'Code {line_id}'.center(60, '-')) logging.info('\n' + line['code']) logging.info(f'Exec Result {line_id}'.center(60, '-')) prefix_info = 'Exec Success' if line[ 'executable_code'] else 'Exec Error: ' exec_info = prefix_info + line['code_error_info'] logging.info(exec_info) logging.info('=' * 60) logging.info('{:^60}'.format('Code Execuation Rate')) logging.info('=' * 60) involved_tags = [] for line in data_list: involved_tags += line['tags'].split(',') involved_tags = list(set(involved_tags)) for key in involved_tags: logging.info(f'task: {key}'.center(60, '=')) key_item_list = [item for item in data_list if key in item['tags']] all_count = len(key_item_list) missing_code_count = len( [item for item in key_item_list if item['missing_code']]) executable_code_count = len( [item for item in key_item_list if item['executable_code']]) logging.info(f'All Test: {all_count}') logging.info(f'Missing Code: {missing_code_count}') logging.info(f'Predict Exec Success: {executable_code_count}') logging.info('Codes available && Execution Rate: {:.2f}'.format( executable_code_count / (all_count - missing_code_count) * 100)) logging.info('Execution Rate: {:.2f}'.format(executable_code_count / all_count * 100)) logging.info('Non-executable rate: {:.2f}'.format( (all_count - missing_code_count - executable_code_count) / all_count * 100)) logging.info('Missing code rate: {:.2f}'.format(missing_code_count / all_count * 100)) if key != 'all_ci': code_executability[key] = executable_code_count / all_count * 100 if verbose: logging.info('Error List: ') error_list = [(item['idx'], item['code_error_info']) for item in key_item_list if item['code_error_info']] error_list.sort(key=lambda x: x[1]) for x in error_list: logging.info(x)