Spaces:
Sleeping
Sleeping
File size: 4,010 Bytes
753cb33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import re
import signal
import subprocess
import tempfile
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
class PythonREPL:
def __init__(self, timeout=5):
self.timeout = timeout
@contextmanager
def time_limit(self, seconds):
def signal_handler(*_):
raise TimeoutError(f"Timed out after {seconds} seconds.")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
def __call__(self, query):
query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
query = query.strip().split("\n")
if "print(" not in query[-1]:
if "#" in query[-1]:
query[-1] = query[-1].split("#")[0]
query[-1] = "print(" + query[-1] + ")"
query = "\n".join(query)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "tmp.py")
with open(temp_file_path, "w", encoding="utf-8") as f:
f.write(query)
with self.time_limit(self.timeout):
result = subprocess.run(
["python3", temp_file_path],
capture_output=True,
check=False,
text=True,
timeout=self.timeout,
)
if result.returncode == 0:
output = result.stdout
return True, output.strip()
error_msg = result.stderr.strip()
msgs = error_msg.split("\n")
new_msgs = []
want_next = False
for m in msgs:
if "Traceback" in m:
new_msgs.append(m)
elif m == msgs[-1]:
new_msgs.append(m)
elif temp_file_path in m:
st = m.index('"/') + 1 if '"/' in m else 0
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
clr = m[st:ed] if not ed else m[st:]
m = m.replace(clr, "")
new_msgs.append(m)
want_next = True
elif want_next:
new_msgs.append(m)
want_next = False
error_msg = "\n".join(new_msgs)
return False, error_msg.strip()
def execute_completion(executor, completion, return_status, last_code_block):
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
if len(executions) == 0:
return completion, False if return_status else completion
if last_code_block:
executions = [executions[-1]]
outputs = []
successes = []
for code in executions:
success = False
for lib in ("subprocess", "venv"):
if lib in code:
output = f"{lib} is not allowed"
outputs.append(output)
successes.append(success)
continue
try:
success, output = executor(code)
except TimeoutError as e:
print("Code timed out")
output = e
if not success and not return_status:
output = ""
outputs.append(output)
successes.append(success)
output = str(outputs[-1]).strip()
success = successes[-1]
if return_status:
return output, success
return output
def postprocess_completion(text, return_status, last_code_block):
executor = PythonREPL()
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
del executor
return result
def get_majority_vote(answers):
if not len(answers):
return 0
c = Counter(answers)
value, _ = c.most_common()[0]
return value |