from prompt import TA_prompt import re from utils import generate_response, run_code def post_process_code(code, question): func_name = code.split("(")[0].split("def")[-1].strip() parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",") if '' in parameters: parameters.remove('') values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] values = [int(v) for v in values] arguments = list(zip(parameters, values)) arg_string = "" for param, val in arguments: arg_string += f"{param}={val}," func_call = f"\nprint({func_name}({arg_string[:-1]}))" code += func_call return code def solve_ta(question): question = question.strip() question = "Human: " + question query = TA_prompt + question query = query.strip() query += "\n" code = generate_response(query, 0.9) n = len(TA_prompt.strip()) code = code[n:].strip().split("-----")[0] # print(code) splitting_string = "```" if "```python" not in code else "```python" if "```" in code: code = code.split(splitting_string)[1].split("```")[0].strip() # code preprocessing code = post_process_code(code, question) print(code) # code running if "input(" in code: return None, code pred = None try: pred = run_code(code) except Exception as ex: return None, code return pred, code else: res = re.findall(r"Assistant:(.*)", code, re.DOTALL)[0].split("Human:")[0] return res.strip(), ""