MathSolver / codeexecutor.py
Makima57's picture
Upload codeexecutor.py with huggingface_hub
945bc2c verified
raw
history blame
3.87 kB
import os
import re
import subprocess
import tempfile
import multiprocessing
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
class PythonREPL:
def __init__(self, timeout=5):
self.timeout = timeout
@staticmethod
def _run_code(temp_file_path):
result = subprocess.run(
["python3", temp_file_path],
capture_output=True,
check=False,
text=True
)
if result.returncode == 0:
return True, result.stdout.strip()
else:
error_msg = result.stderr.strip()
msgs = error_msg.split("
")
new_msgs = []
want_next = False
for m in msgs:
if "Traceback" in m:
new_msgs.append(m)
elif m == msgs[-1]:
new_msgs.append(m)
elif temp_file_path in m:
st = m.index('"/') + 1 if '"/' in m else 0
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
clr = m[st:ed] if not ed else m[st:]
m = m.replace(clr, "")
new_msgs.append(m)
want_next = True
elif want_next:
new_msgs.append(m)
want_next = False
return False, "
".join(new_msgs).strip()
def __call__(self, query):
query = "import math
import numpy as np
import sympy as sp
" + query
query = query.strip().split("
")
if "print(" not in query[-1]:
if "#" in query[-1]:
query[-1] = query[-1].split("#")[0]
query[-1] = "print(" + query[-1] + ")"
query = "
".join(query)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "tmp.py")
with open(temp_file_path, "w", encoding="utf-8") as f:
f.write(query)
with multiprocessing.Pool(1) as pool:
result = pool.apply_async(self._run_code, (temp_file_path,))
try:
success, output = result.get(self.timeout)
except multiprocessing.TimeoutError:
pool.terminate()
return False, f"Timed out after {self.timeout} seconds."
return success, output
def execute_completion(executor, completion, return_status, last_code_block):
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
if len(executions) == 0:
return completion, False if return_status else completion
if last_code_block:
executions = [executions[-1]]
outputs = []
successes = []
for code in executions:
success = False
for lib in ("subprocess", "venv"):
if lib in code:
output = f"{lib} is not allowed"
outputs.append(output)
successes.append(success)
continue
try:
success, output = executor(code)
except TimeoutError as e:
print("Code timed out")
output = e
if not success and not return_status:
output = ""
outputs.append(output)
successes.append(success)
output = str(outputs[-1]).strip()
success = successes[-1]
if return_status:
return output, success
return output
def postprocess_completion(text, return_status, last_code_block):
executor = PythonREPL()
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
del executor
return result
def get_majority_vote(answers):
if not len(answers):
return 0
c = Counter(answers)
value, _ = c.most_common()[0]
return value