lvwerra HF staff commited on
Commit
5b832b6
1 Parent(s): f5a9755

Update python_interpreter_tool.py

Browse files
Files changed (1) hide show
  1. python_interpreter_tool.py +54 -82
python_interpreter_tool.py CHANGED
@@ -1,12 +1,9 @@
1
- import io
2
- import sys
3
- import platform
4
- import faulthandler
5
- from contextlib import contextmanager
6
- import signal
7
-
8
  from transformers import Tool
9
-
 
 
10
 
11
  class PythonInterpreter(Tool):
12
  name = "python_interpreter_tool"
@@ -14,75 +11,41 @@ class PythonInterpreter(Tool):
14
 
15
  inputs = ["text"]
16
  outputs = ["text"]
17
- timeout = 10.0
18
 
19
  def __call__(self, task: str):
20
- import os
21
- import shutil
22
 
23
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
24
  return "Can't execute code, set code evaluation flag for PythonInterpreter"
25
 
26
- rmtree = shutil.rmtree
27
- rmdir = os.rmdir
28
- chdir = os.chdir
29
-
30
- #reliability_guard()
31
- try:
32
- exec_globals = {}
33
- with capture_stdout() as output:
34
- #with time_limit(self.timeout):
35
- exec(task, exec_globals)
36
-
37
- captured_output = output.getvalue().strip()
38
- except Exception as e:
39
- captured_output = str(e)
40
-
41
- # Needed for cleaning up.
42
- shutil.rmtree = rmtree
43
- os.rmdir = rmdir
44
- os.chdir = chdir
45
-
46
- return captured_output
47
-
48
-
49
- @contextmanager
50
- def capture_stdout():
51
- output_buffer = io.StringIO()
52
- original_stdout = sys.stdout
53
- try:
54
- sys.stdout = output_buffer
55
- yield output_buffer
56
- finally:
57
- sys.stdout = original_stdout
58
 
 
 
 
 
59
 
60
- @contextmanager
61
- def time_limit(seconds):
62
- def signal_handler(signum, frame):
63
- raise TimeoutException("Timed out!")
64
-
65
- signal.setitimer(signal.ITIMER_REAL, seconds)
66
- signal.signal(signal.SIGALRM, signal_handler)
67
- try:
68
- yield
69
- finally:
70
- signal.setitimer(signal.ITIMER_REAL, 0)
71
-
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def reliability_guard(maximum_memory_bytes=None):
74
- """
75
- This disables various destructive functions and prevents the generated code
76
- from interfering with the test (e.g. fork bomb, killing other processes,
77
- removing filesystem files, etc.)
78
-
79
- WARNING
80
- This function is NOT a security sandbox. Untrusted code, including, model-
81
- generated code, should not be blindly executed outside of one. See the
82
- Codex paper for more information about OpenAI's code sandbox, and proceed
83
- with caution.
84
- """
85
-
86
  if maximum_memory_bytes is not None:
87
  import resource
88
 
@@ -91,6 +54,8 @@ def reliability_guard(maximum_memory_bytes=None):
91
  if not platform.uname().system == "Darwin":
92
  resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
93
 
 
 
94
  faulthandler.disable()
95
 
96
  import builtins
@@ -99,13 +64,12 @@ def reliability_guard(maximum_memory_bytes=None):
99
  builtins.quit = None
100
 
101
  import os
102
- try:
103
- os.environ["OMP_NUM_THREADS"] = "1"
104
- except:
105
- pass
106
  os.kill = None
107
  os.system = None
108
- os.putenv = None
109
  os.remove = None
110
  os.removedirs = None
111
  os.rmdir = None
@@ -141,7 +105,7 @@ def reliability_guard(maximum_memory_bytes=None):
141
 
142
  subprocess.Popen = None # type: ignore
143
 
144
- #__builtins__["help"] = None
145
 
146
  import sys
147
 
@@ -150,13 +114,21 @@ def reliability_guard(maximum_memory_bytes=None):
150
  sys.modules["resource"] = None
151
  sys.modules["psutil"] = None
152
  sys.modules["tkinter"] = None
153
-
154
- class TimeoutException(Exception):
155
- pass
156
-
157
  """
158
- import os
159
- tool = PythonInterpreter()
160
- print(tool("import os; os.getcwd()"))
161
- print(os.getcwd())
162
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import subprocess
 
 
 
 
 
3
  from transformers import Tool
4
+ import sys
5
+ import os
6
+ import re
7
 
8
  class PythonInterpreter(Tool):
9
  name = "python_interpreter_tool"
 
11
 
12
  inputs = ["text"]
13
  outputs = ["text"]
14
+ timeout = 1.0
15
 
16
  def __call__(self, task: str):
 
 
17
 
18
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
19
  return "Can't execute code, set code evaluation flag for PythonInterpreter"
20
 
21
+ with tempfile.TemporaryDirectory() as temp_dir:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ code = "from safeguard import reliability_guard\nreliability_guard()\n" + task
24
+ code_file = os.path.join(temp_dir, 'code.py')
25
+ with open(code_file, 'w') as f:
26
+ f.write(code)
27
 
28
+ safeguard_file = os.path.join(temp_dir, 'safeguard.py')
29
+ with open(safeguard_file, 'w') as f:
30
+ f.write(reliability_guard_code)
 
 
 
 
 
 
 
 
 
31
 
32
+ cmd = f"python {code_file}"
33
+
34
+ try:
35
+ output = subprocess.check_output(cmd.split(), stderr=subprocess.STDOUT, timeout=self.timeout)
36
+ output_text = output.decode(sys.stdout.encoding).strip()
37
+ except subprocess.TimeoutExpired:
38
+ output_text = "Code execution timed out."
39
+ except subprocess.CalledProcessError as e:
40
+ output_text = e.output.decode(sys.stdout.encoding).strip()
41
+ output_text = output_text.replace(temp_dir + "/", "/tmp/")
42
+ output_text = fix_stacktrace_linenumber(output_text)
43
+ output_text = output_text.replace(temp_dir + "/", "/tmp/")
44
+ output_text = fix_warning_linenumber(output_text)
45
+ return output_text
46
+
47
+ reliability_guard_code = """\
48
  def reliability_guard(maximum_memory_bytes=None):
 
 
 
 
 
 
 
 
 
 
 
 
49
  if maximum_memory_bytes is not None:
50
  import resource
51
 
 
54
  if not platform.uname().system == "Darwin":
55
  resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
56
 
57
+ import faulthandler
58
+
59
  faulthandler.disable()
60
 
61
  import builtins
 
64
  builtins.quit = None
65
 
66
  import os
67
+
68
+ os.environ["OMP_NUM_THREADS"] = "1"
69
+
 
70
  os.kill = None
71
  os.system = None
72
+ # os.putenv = None # breaks e.g. numpy
73
  os.remove = None
74
  os.removedirs = None
75
  os.rmdir = None
 
105
 
106
  subprocess.Popen = None # type: ignore
107
 
108
+ __builtins__["help"] = None
109
 
110
  import sys
111
 
 
114
  sys.modules["resource"] = None
115
  sys.modules["psutil"] = None
116
  sys.modules["tkinter"] = None
 
 
 
 
117
  """
118
+
119
+ def fix_stacktrace_linenumber(text):
120
+ start = ' File "/tmp/code.py", line '
121
+ lines = text.split("\n")
122
+ fixed_lines = []
123
+ for line in lines:
124
+ if line.startswith(start):
125
+ number, end = line.split(start)[1].split(",")
126
+ new_number = int(number)-2
127
+ line = start + str(new_number) + end
128
+ fixed_lines.append(line)
129
+ return "\n".join(fixed_lines)
130
+
131
+ def fix_warning_linenumber(string):
132
+ pattern = r'code\.py:(\d+):'
133
+ replaced_string = re.sub(pattern, lambda match: 'code.py:' + str(int(match.group(1))-2) + ':', string)
134
+ return replaced_string