File size: 3,867 Bytes
945bc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

import os
import re
import subprocess
import tempfile
import multiprocessing
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass


class PythonREPL:
    def __init__(self, timeout=5):
        self.timeout = timeout

    @staticmethod
    def _run_code(temp_file_path):
        result = subprocess.run(
            ["python3", temp_file_path],
            capture_output=True,
            check=False,
            text=True
        )
        if result.returncode == 0:
            return True, result.stdout.strip()
        else:
            error_msg = result.stderr.strip()
            msgs = error_msg.split("
")
            new_msgs = []
            want_next = False
            for m in msgs:
                if "Traceback" in m:
                    new_msgs.append(m)
                elif m == msgs[-1]:
                    new_msgs.append(m)
                elif temp_file_path in m:
                    st = m.index('"/') + 1 if '"/' in m else 0
                    ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
                    clr = m[st:ed] if not ed else m[st:]
                    m = m.replace(clr, "")
                    new_msgs.append(m)
                    want_next = True
                elif want_next:
                    new_msgs.append(m)
                    want_next = False
            return False, "
".join(new_msgs).strip()

    def __call__(self, query):
        query = "import math
import numpy as np
import sympy as sp
" + query
        query = query.strip().split("
")
        if "print(" not in query[-1]:
            if "#" in query[-1]:
                query[-1] = query[-1].split("#")[0]
            query[-1] = "print(" + query[-1] + ")"
        query = "
".join(query)
        
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file_path = os.path.join(temp_dir, "tmp.py")
            with open(temp_file_path, "w", encoding="utf-8") as f:
                f.write(query)

            with multiprocessing.Pool(1) as pool:
                result = pool.apply_async(self._run_code, (temp_file_path,))
                try:
                    success, output = result.get(self.timeout)
                except multiprocessing.TimeoutError:
                    pool.terminate()
                    return False, f"Timed out after {self.timeout} seconds."
        return success, output


def execute_completion(executor, completion, return_status, last_code_block):
    executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
    if len(executions) == 0:
        return completion, False if return_status else completion
    if last_code_block:
        executions = [executions[-1]]
    outputs = []
    successes = []
    for code in executions:
        success = False
        for lib in ("subprocess", "venv"):
            if lib in code:
                output = f"{lib} is not allowed"
                outputs.append(output)
                successes.append(success)
                continue
        try:
            success, output = executor(code)
        except TimeoutError as e:
            print("Code timed out")
            output = e
        if not success and not return_status:
            output = ""
        outputs.append(output)
        successes.append(success)
    output = str(outputs[-1]).strip()
    success = successes[-1]
    if return_status:
        return output, success
    return output


def postprocess_completion(text, return_status, last_code_block):
    executor = PythonREPL()
    result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
    del executor
    return result


def get_majority_vote(answers):
    if not len(answers):
        return 0
    c = Counter(answers)
    value, _ = c.most_common()[0]
    return value