Makima57 commited on
Commit
945bc2c
1 Parent(s): 444d1cb

Upload codeexecutor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. codeexecutor.py +125 -0
codeexecutor.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import tempfile
6
+ import multiprocessing
7
+ from collections import Counter
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
10
+
11
+
12
+ class PythonREPL:
13
+ def __init__(self, timeout=5):
14
+ self.timeout = timeout
15
+
16
+ @staticmethod
17
+ def _run_code(temp_file_path):
18
+ result = subprocess.run(
19
+ ["python3", temp_file_path],
20
+ capture_output=True,
21
+ check=False,
22
+ text=True
23
+ )
24
+ if result.returncode == 0:
25
+ return True, result.stdout.strip()
26
+ else:
27
+ error_msg = result.stderr.strip()
28
+ msgs = error_msg.split("
29
+ ")
30
+ new_msgs = []
31
+ want_next = False
32
+ for m in msgs:
33
+ if "Traceback" in m:
34
+ new_msgs.append(m)
35
+ elif m == msgs[-1]:
36
+ new_msgs.append(m)
37
+ elif temp_file_path in m:
38
+ st = m.index('"/') + 1 if '"/' in m else 0
39
+ ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
40
+ clr = m[st:ed] if not ed else m[st:]
41
+ m = m.replace(clr, "")
42
+ new_msgs.append(m)
43
+ want_next = True
44
+ elif want_next:
45
+ new_msgs.append(m)
46
+ want_next = False
47
+ return False, "
48
+ ".join(new_msgs).strip()
49
+
50
+ def __call__(self, query):
51
+ query = "import math
52
+ import numpy as np
53
+ import sympy as sp
54
+ " + query
55
+ query = query.strip().split("
56
+ ")
57
+ if "print(" not in query[-1]:
58
+ if "#" in query[-1]:
59
+ query[-1] = query[-1].split("#")[0]
60
+ query[-1] = "print(" + query[-1] + ")"
61
+ query = "
62
+ ".join(query)
63
+
64
+ with tempfile.TemporaryDirectory() as temp_dir:
65
+ temp_file_path = os.path.join(temp_dir, "tmp.py")
66
+ with open(temp_file_path, "w", encoding="utf-8") as f:
67
+ f.write(query)
68
+
69
+ with multiprocessing.Pool(1) as pool:
70
+ result = pool.apply_async(self._run_code, (temp_file_path,))
71
+ try:
72
+ success, output = result.get(self.timeout)
73
+ except multiprocessing.TimeoutError:
74
+ pool.terminate()
75
+ return False, f"Timed out after {self.timeout} seconds."
76
+ return success, output
77
+
78
+
79
+ def execute_completion(executor, completion, return_status, last_code_block):
80
+ executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
81
+ if len(executions) == 0:
82
+ return completion, False if return_status else completion
83
+ if last_code_block:
84
+ executions = [executions[-1]]
85
+ outputs = []
86
+ successes = []
87
+ for code in executions:
88
+ success = False
89
+ for lib in ("subprocess", "venv"):
90
+ if lib in code:
91
+ output = f"{lib} is not allowed"
92
+ outputs.append(output)
93
+ successes.append(success)
94
+ continue
95
+ try:
96
+ success, output = executor(code)
97
+ except TimeoutError as e:
98
+ print("Code timed out")
99
+ output = e
100
+ if not success and not return_status:
101
+ output = ""
102
+ outputs.append(output)
103
+ successes.append(success)
104
+ output = str(outputs[-1]).strip()
105
+ success = successes[-1]
106
+ if return_status:
107
+ return output, success
108
+ return output
109
+
110
+
111
+ def postprocess_completion(text, return_status, last_code_block):
112
+ executor = PythonREPL()
113
+ result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
114
+ del executor
115
+ return result
116
+
117
+
118
+ def get_majority_vote(answers):
119
+ if not len(answers):
120
+ return 0
121
+ c = Counter(answers)
122
+ value, _ = c.most_common()[0]
123
+ return value
124
+
125
+