File size: 1,607 Bytes
9b4edaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201cfef
9b4edaf
 
 
 
 
201cfef
076bdf6
 
 
9b4edaf
076bdf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950e174
 
9b4edaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from prompt import TA_prompt
import re
from utils import generate_response, run_code


def post_process_code(code, question):
    func_name = code.split("(")[0].split("def")[-1].strip()
    parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",")
    if '' in parameters:
        parameters.remove('')
    values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)]
    values = [int(v) for v in values]
    arguments = list(zip(parameters, values))

    arg_string = ""
    for param, val in arguments:
        arg_string += f"{param}={val},"
    func_call = f"\nprint({func_name}({arg_string[:-1]}))"
    code += func_call
    return code


def solve_ta(question):
    question = question.strip()
    question = "Human: " + question
    query = TA_prompt + question
    query = query.strip()
    query += "\n"
    code = generate_response(query, 0.9)
    n = len(TA_prompt.strip())
    code = code[n:].strip().split("-----")[0]
    # print(code)
    splitting_string = "```" if "```python" not in code else "```python"
    if "```" in code:
        code = code.split(splitting_string)[1].split("```")[0].strip()
        # code preprocessing
        code = post_process_code(code, question)
        print(code)
        # code running
        if "input(" in code:
            return None, code
        pred = None
        try:
            pred = run_code(code)
        except Exception as ex:
            return None, code
        return pred, code
    else:
        res = re.findall(r"Assistant:(.*)", code, re.DOTALL)[0].split("Human:")[0]
        return res.strip(), ""