|
import os, sys |
|
import traceback |
|
|
|
HUMAN_EVAL_PATH = os.path.join( |
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), |
|
"human-eval", |
|
) |
|
|
|
sys.path.append(HUMAN_EVAL_PATH) |
|
from human_eval.data import write_jsonl, read_problems |
|
from finetuning.conversation_template import msg_to_code_result_tok_temp |
|
from code_interpreter.llama_hf import build_model_from_hf_path |
|
from code_interpreter.LlamaCodeInterpreter import LlamaCodeInterpreter |
|
from code_interpreter.GPTCodeInterpreter import GPTCodeInterpreter |
|
from code_interpreter.RetrospectiveGPTCodeInterpreter import ( |
|
RetrospectiveGPTCodeInterpreter, |
|
) |
|
|
|
import re |
|
|
|
from rich import print |
|
from rich.panel import Panel |
|
from rich.syntax import Syntax |
|
from rich.text import Text |
|
|
|
from timeout_decorator import timeout |
|
|
|
wrong = 0 |
|
|
|
|
|
def extract_text(prompt, remove_lines=True): |
|
token = '"""' |
|
start = token |
|
end = ">>>" |
|
|
|
|
|
start_idx = prompt.find(start) + len(start) |
|
end_idx = prompt.find(end) |
|
|
|
output = prompt[start_idx:end_idx] |
|
if remove_lines: |
|
output = output.replace("\n", " ") |
|
output = re.sub(r"\s+", " ", output).strip() |
|
|
|
return output |
|
|
|
|
|
def extract_all_code_block(input_str: str) -> str: |
|
pattern = r"\[CODE_START_TOK\](.*?)\[/CODE_END_TOK\]" |
|
matches = re.findall(pattern, input_str, re.DOTALL) |
|
return "\n".join([match.strip() for match in matches]) if matches else None |
|
|
|
|
|
def extract_all_code_block_gpt(input_str: str) -> str: |
|
pattern = r"```python(.*?)```" |
|
matches = re.findall(pattern, input_str, re.DOTALL) |
|
|
|
return "\n".join([match.strip() for match in matches]) if matches else None |
|
|
|
|
|
def delete_print_asser(code_text: str): |
|
lines = code_text.split("\n") |
|
new_lines = list() |
|
for i in lines: |
|
if i.strip().startswith("print("): |
|
continue |
|
new_lines.append(i) |
|
|
|
new_code_text = "\n".join(new_lines) |
|
return new_code_text |
|
|
|
|
|
def extract_function_from_code_block(code_block: str) -> str: |
|
lines = code_block.split("\n") |
|
function_lines = [] |
|
|
|
inside_function = False |
|
for line in lines: |
|
|
|
if line.startswith("def "): |
|
inside_function = True |
|
|
|
|
|
if inside_function: |
|
function_lines.append(line) |
|
|
|
|
|
if ( |
|
not line.startswith(" ") |
|
and not line.startswith("#") |
|
and not line.startswith("def ") |
|
): |
|
break |
|
|
|
|
|
while function_lines and ( |
|
function_lines[-1].strip() == "" |
|
or function_lines[-1].strip().startswith("#") |
|
or not function_lines[-1].startswith(" ") |
|
): |
|
function_lines.pop() |
|
|
|
return "\n".join(function_lines) |
|
|
|
|
|
def get_last_outermost_function_name(function_str): |
|
matches = re.findall(r"^def (\w+)", function_str, re.MULTILINE) |
|
if matches: |
|
return matches[-1] |
|
return "" |
|
|
|
|
|
def get_last_function_name(function_str): |
|
|
|
matches = re.findall(r"def (\w+)", function_str) |
|
if matches: |
|
return matches[-1] |
|
return "" |
|
|
|
|
|
def get_outermost_function_name(function_str): |
|
matches = re.findall(r"^def (\w+)", function_str, re.MULTILINE) |
|
if matches: |
|
return matches[0] |
|
return "" |
|
|
|
|
|
def get_function_name(function_str): |
|
|
|
match = re.search(r"def (\w+)", function_str) |
|
if match: |
|
return match.group(0) |
|
return "" |
|
|
|
|
|
def extract_test_assertion(test_func: str): |
|
test_cases = list() |
|
for i in test_func.split("\n"): |
|
if "assert" in i: |
|
test_cases.append(i.strip()) |
|
|
|
return ("\n".join(test_cases)).strip() |
|
|
|
|
|
import_str = """ |
|
import re |
|
import math |
|
from typing import List, Tuple, Optional |
|
""" |
|
|
|
|
|
@timeout(100, timeout_exception=TimeoutError) |
|
def exec_with_timeout(import_str, full_test_code): |
|
env = {**locals()} |
|
code_to_exec = f"{import_str}\n{full_test_code}" |
|
try: |
|
exec(code_to_exec, env) |
|
except Exception as e: |
|
print(f"Error Type: {type(e).__name__}, Error Message: {e}") |
|
return False |
|
return True |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Process path for LLAMA2_FINETUNEED.") |
|
parser.add_argument( |
|
"--path", |
|
type=str, |
|
required=True, |
|
help="Path to the finetuned LLAMA2 model.", |
|
default='"./output/llama-2-7b-chat-ci"', |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
required=False, |
|
help="Path to the finetuned LLAMA2 model.", |
|
default='"./output/llama-2-7b-chat-ci"', |
|
) |
|
parser.add_argument( |
|
"--max-retry", |
|
type=int, |
|
required=False, |
|
help="Maximum number of retries.", |
|
default=5, |
|
) |
|
args = parser.parse_args() |
|
PROGRAMMING_PUZZLE_Q = True |
|
|
|
problems = read_problems() |
|
correct_total = 0 |
|
total_problems = len(problems) |
|
|
|
for idx, task_id in enumerate(problems): |
|
if "gpt" not in args.model.lower(): |
|
LLAMA2_FINETUNEED_PATH = args.path |
|
interpreter = LlamaCodeInterpreter( |
|
model_path=LLAMA2_FINETUNEED_PATH, |
|
|
|
) |
|
else: |
|
interpreter = RetrospectiveGPTCodeInterpreter( |
|
model=args.model, |
|
) |
|
|
|
|
|
programming_puzzle = problems[task_id]["prompt"].replace(" ", "\t") |
|
text_only_problem = extract_text(programming_puzzle) |
|
|
|
interpreter.dialog = [ |
|
{ |
|
"role": "system", |
|
"content": "You are helpful robot that can generate code , excute it and debug then answer", |
|
} |
|
] |
|
|
|
if PROGRAMMING_PUZZLE_Q: |
|
|
|
output_str = interpreter.chat( |
|
user_message=f"Write a Python script to solve the following problem:\n{programming_puzzle}\nEnsure the solution is verified by printing the expected output.", |
|
MAX_TRY=args.max_retry, |
|
VERBOSE=True, |
|
code_exec_prefix=f"\nfrom typing import List,Tuple\nimport math\n", |
|
feedback_prompt="Ensure the output matches the expected result, taking into account any corner cases. If discrepancies arise, pinpoint where you went wrong. Then, refine the code to achieve the desired outcome.", |
|
append_result=True, |
|
)["content"] |
|
|
|
else: |
|
output_str = interpreter.chat( |
|
user_message=f"Write a Python script for this problem:\n{text_only_problem}", |
|
MAX_TRY=args.max_retry, |
|
VERBOSE=True, |
|
code_exec_prefix=f"\nfrom typing import List,Tuple\nimport math\n", |
|
feedback_prompt="Ensure the output matches the expected result. If not tell where you got wrong, then refine the code to achieve the desired outcome.", |
|
append_result=True, |
|
)["content"] |
|
|
|
function_str = "" |
|
if "gpt" not in args.model.lower(): |
|
code_block = extract_all_code_block(output_str) |
|
else: |
|
code_block = extract_all_code_block_gpt(output_str) |
|
if (code_block is not None) and ("def" in code_block): |
|
function_str = code_block |
|
|
|
|
|
function_str = delete_print_asser(function_str) |
|
function_name = get_last_outermost_function_name(function_str) |
|
full_test_code = f"{function_str}\n#-----------\n{problems[task_id]['test']}\ncheck({function_name})" |
|
|
|
|
|
syntax = Syntax( |
|
|
|
f"{full_test_code}", |
|
"python", |
|
theme="monokai", |
|
line_numbers=True, |
|
) |
|
print(syntax) |
|
|
|
is_correct = False |
|
timeout_flag = False |
|
try: |
|
is_correct = exec_with_timeout(import_str, full_test_code) |
|
except TimeoutError as e: |
|
timeout_flag = True |
|
print(f"Timeout with error msg : {e}") |
|
|
|
if is_correct: |
|
correct_total += 1 |
|
|
|
acc = (correct_total) / (idx + 1) |
|
|
|
interpreter.save_dialog( |
|
path=f"./eval/gpt_humaneval_output/{task_id.replace('/','_')}_{is_correct}.json" |
|
) |
|
interpreter.close() |
|
del interpreter |
|
|
|
|
|
accuracy_text = Text( |
|
f"Accuracy: {correct_total}/{idx+1}[{total_problems}] = {acc:.2%} [{is_correct}]", |
|
style="bold blue", |
|
) |
|
panel = Panel(accuracy_text, title="Results", border_style="green") |
|
print(panel) |
|
|