databoysu commited on
Commit
985e10f
·
1 Parent(s): a1e4e94

protect sandbox

Browse files
__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/__pycache__/__init__.cpython-312.pyc and b/__pycache__/__init__.cpython-312.pyc differ
 
__pycache__/environment.cpython-312.pyc CHANGED
Binary files a/__pycache__/environment.cpython-312.pyc and b/__pycache__/environment.cpython-312.pyc differ
 
__pycache__/sandbox.cpython-312.pyc CHANGED
Binary files a/__pycache__/sandbox.cpython-312.pyc and b/__pycache__/sandbox.cpython-312.pyc differ
 
context.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  from typing import List, Optional
6
 
7
  WINDOW_LINES: int = 10
@@ -9,6 +10,37 @@ WINDOW_LINES: int = 10
9
  MAX_CONTEXT_CHARS: int = 2_000
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def get_localized_context(
13
  code_lines: List[str],
14
  anchor_line: Optional[int],
 
2
 
3
  from __future__ import annotations
4
 
5
+ import re
6
  from typing import List, Optional
7
 
8
  WINDOW_LINES: int = 10
 
10
  MAX_CONTEXT_CHARS: int = 2_000
11
 
12
 
13
+ _TRACEBACK_FILE_LINE_RE = re.compile(r'File "([^"]+)", line (\d+)')
14
+ _SYNTAX_LINE_RE = re.compile(r"SyntaxError at line (\d+)")
15
+
16
+
17
+ def extract_error_line(traceback_str: str) -> Optional[int]:
18
+ """
19
+ Extract the most relevant crashing line number from sandbox output.
20
+
21
+ Preference order:
22
+ 1) Last frame pointing to agent code pseudo-files (<agent_code>, <string>).
23
+ 2) Last traceback frame line number.
24
+ 3) "SyntaxError at line N" fallback.
25
+ """
26
+ if not traceback_str:
27
+ return None
28
+
29
+ matches = _TRACEBACK_FILE_LINE_RE.findall(traceback_str)
30
+ if matches:
31
+ preferred_files = {"<agent_code>", "<string>"}
32
+ for file_name, line_str in reversed(matches):
33
+ if file_name in preferred_files:
34
+ return int(line_str)
35
+ return int(matches[-1][1])
36
+
37
+ syntax_match = _SYNTAX_LINE_RE.search(traceback_str)
38
+ if syntax_match:
39
+ return int(syntax_match.group(1))
40
+
41
+ return None
42
+
43
+
44
  def get_localized_context(
45
  code_lines: List[str],
46
  anchor_line: Optional[int],
environment.py CHANGED
@@ -7,12 +7,12 @@ import uuid
7
  from typing import Any, Dict, List, Optional, Tuple
8
 
9
  try:
10
- from .context import get_localized_context
11
  from .models import CodeAction, CodeObservation, TestResult
12
  from .sandbox import check_syntax, run_code_with_tests
13
  from .tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
14
  except ImportError:
15
- from context import get_localized_context
16
  from models import CodeAction, CodeObservation, TestResult
17
  from sandbox import check_syntax, run_code_with_tests
18
  from tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
@@ -147,6 +147,7 @@ class TraceFixRLGym:
147
  self._original_code: List[str] = []
148
  self._edit_history: List[List[str]] = []
149
  self.training_step: int = 0
 
150
 
151
 
152
  def _sample_task(self, task_override=None) -> Dict[str, Any]:
@@ -217,6 +218,7 @@ class TraceFixRLGym:
217
  self._edit_history = []
218
  self._last_action: Optional[str] = None
219
  self._consecutive_count: int = 0
 
220
 
221
  obs = self._build_observation(reward=0.0)
222
 
@@ -336,11 +338,13 @@ class TraceFixRLGym:
336
 
337
  if syntax_err:
338
  reward += R_SYNTAX_ERROR
 
339
  else:
340
  current_pass = sum(1 for t in results if t.passed)
341
  new_passes = max(0, current_pass - self._prev_pass_count)
342
  reward += new_passes * R_PER_NEW_PASS
343
  self._prev_pass_count = current_pass
 
344
 
345
  return reward
346
 
@@ -483,10 +487,18 @@ class TraceFixRLGym:
483
  def _build_observation(self, reward: float) -> CodeObservation:
484
  syntax_valid, _ = check_syntax(self._source())
485
 
486
- localized = get_localized_context(self._code_lines, self._last_edited_line)
 
 
 
 
 
 
487
 
488
  return CodeObservation(
489
- code_lines = list(self._code_lines),
 
 
490
  localized_context = localized,
491
  last_execution_output = self._last_output,
492
  syntax_error = not syntax_valid,
 
7
  from typing import Any, Dict, List, Optional, Tuple
8
 
9
  try:
10
+ from .context import extract_error_line, get_localized_context
11
  from .models import CodeAction, CodeObservation, TestResult
12
  from .sandbox import check_syntax, run_code_with_tests
13
  from .tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
14
  except ImportError:
15
+ from context import extract_error_line, get_localized_context
16
  from models import CodeAction, CodeObservation, TestResult
17
  from sandbox import check_syntax, run_code_with_tests
18
  from tasks import ALL_TASKS, TASKS_BY_DIFFICULTY
 
147
  self._original_code: List[str] = []
148
  self._edit_history: List[List[str]] = []
149
  self.training_step: int = 0
150
+ self._last_run_all_passed: bool = False
151
 
152
 
153
  def _sample_task(self, task_override=None) -> Dict[str, Any]:
 
218
  self._edit_history = []
219
  self._last_action: Optional[str] = None
220
  self._consecutive_count: int = 0
221
+ self._last_run_all_passed = False
222
 
223
  obs = self._build_observation(reward=0.0)
224
 
 
338
 
339
  if syntax_err:
340
  reward += R_SYNTAX_ERROR
341
+ self._last_run_all_passed = False
342
  else:
343
  current_pass = sum(1 for t in results if t.passed)
344
  new_passes = max(0, current_pass - self._prev_pass_count)
345
  reward += new_passes * R_PER_NEW_PASS
346
  self._prev_pass_count = current_pass
347
+ self._last_run_all_passed = all(t.passed for t in results)
348
 
349
  return reward
350
 
 
487
  def _build_observation(self, reward: float) -> CodeObservation:
488
  syntax_valid, _ = check_syntax(self._source())
489
 
490
+ context_anchor = self._last_edited_line
491
+ if self._last_action == "RUN_TESTS" and not self._last_run_all_passed:
492
+ extracted_line = extract_error_line(self._last_output)
493
+ if extracted_line is not None:
494
+ context_anchor = extracted_line
495
+
496
+ localized = get_localized_context(self._code_lines, context_anchor)
497
 
498
  return CodeObservation(
499
+ code_dict = {
500
+ idx + 1: line for idx, line in enumerate(self._code_lines)
501
+ },
502
  localized_context = localized,
503
  last_execution_output = self._last_output,
504
  syntax_error = not syntax_valid,
inference.py CHANGED
@@ -134,7 +134,15 @@ def _extract_json(text: str) -> dict[str, Any]:
134
 
135
 
136
  def _build_observation_text(observation: Any) -> str:
137
- code_preview = "\n".join(observation.code_lines[:30]) if observation.code_lines else ""
 
 
 
 
 
 
 
 
138
  return (
139
  f"step_count={observation.step_count}\n"
140
  f"steps_remaining={observation.steps_remaining}\n"
 
134
 
135
 
136
  def _build_observation_text(observation: Any) -> str:
137
+ code_dict = getattr(observation, "code_dict", {}) or {}
138
+ sorted_items = sorted(
139
+ ((int(line_num), text) for line_num, text in code_dict.items()),
140
+ key=lambda x: x[0],
141
+ )
142
+ code_preview = "\n".join(
143
+ f"{line_num} | {text}"
144
+ for line_num, text in sorted_items[:30]
145
+ )
146
  return (
147
  f"step_count={observation.step_count}\n"
148
  f"steps_remaining={observation.steps_remaining}\n"
models.py CHANGED
@@ -95,7 +95,7 @@ class TestResult(BaseModel):
95
  class CodeObservation(Observation):
96
  """Full observation returned after each step."""
97
 
98
- code_lines: List[str] = Field(default_factory=list)
99
  localized_context: str = Field(default="")
100
  last_execution_output: str = Field(default="")
101
  syntax_error: bool = Field(default=False)
@@ -107,8 +107,9 @@ class CodeObservation(Observation):
107
 
108
  def render_code(self) -> str:
109
  """Render source with 1-indexed line numbers for prompts."""
110
- if not self.code_lines:
111
  return "<empty>"
112
  return "\n".join(
113
- f"{idx + 1:>3} | {line}" for idx, line in enumerate(self.code_lines)
 
114
  )
 
95
  class CodeObservation(Observation):
96
  """Full observation returned after each step."""
97
 
98
+ code_dict: Dict[int, str] = Field(default_factory=dict)
99
  localized_context: str = Field(default="")
100
  last_execution_output: str = Field(default="")
101
  syntax_error: bool = Field(default=False)
 
107
 
108
  def render_code(self) -> str:
109
  """Render source with 1-indexed line numbers for prompts."""
110
+ if not self.code_dict:
111
  return "<empty>"
112
  return "\n".join(
113
+ f"{line_num:>3} | {self.code_dict[line_num]}"
114
+ for line_num in sorted(self.code_dict.keys())
115
  )
sandbox.py CHANGED
@@ -32,11 +32,12 @@ import ast
32
  import io
33
  import inspect
34
  import multiprocessing
 
35
  import signal
36
  import sys
37
  import textwrap
38
  import traceback
39
- from typing import Any, Callable, Dict, List, Tuple
40
 
41
  try:
42
  from .models import TestResult
@@ -62,7 +63,20 @@ def _make_safe_stub(name: str) -> Callable:
62
  return _stub
63
 
64
 
65
- _SAFE_BUILTINS: Dict[str, Any] = {
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  "int": int, "float": float, "str": str, "bool": bool,
67
  "list": list, "dict": dict, "set": set, "tuple": tuple,
68
  "bytes": bytes, "bytearray": bytearray, "frozenset": frozenset,
@@ -104,6 +118,102 @@ _SAFE_BUILTINS: Dict[str, Any] = {
104
  }
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def _tail_truncate(s: str, limit: int = MAX_OUTPUT_CHARS) -> str:
109
  """
@@ -150,9 +260,15 @@ def _worker(
150
  result_queue.put((_tail_truncate(err), [], True))
151
  return
152
 
153
- namespace: Dict[str, Any] = {"__builtins__": __builtins__}
154
  try:
155
- exec(code_obj, namespace) # noqa: S102
 
 
 
 
 
 
 
156
  except Exception: # noqa: BLE001
157
  tb = traceback.format_exc()
158
  sys.stdout, sys.stderr = old_stdout, old_stderr
@@ -162,15 +278,24 @@ def _worker(
162
  for test_src in test_sources:
163
  fn_name = "<unknown>"
164
  try:
165
- exec(test_src, namespace) # noqa: S102
 
 
 
 
 
 
 
 
 
166
 
167
  fn_name = [
168
  ln.split("(")[0].replace("def ", "").strip()
169
- for ln in test_src.splitlines()
170
  if ln.startswith("def ")
171
  ][-1]
172
 
173
- namespace[fn_name](namespace)
174
  test_results.append({"test_name": fn_name, "passed": True})
175
 
176
  except AssertionError as exc:
 
32
  import io
33
  import inspect
34
  import multiprocessing
35
+ import importlib
36
  import signal
37
  import sys
38
  import textwrap
39
  import traceback
40
+ from typing import Any, Callable, Dict, List, Set, Tuple
41
 
42
  try:
43
  from .models import TestResult
 
63
  return _stub
64
 
65
 
66
+ TEST_SUITE_ALLOWED_MODULES: Set[str] = {
67
+ "bisect",
68
+ "collections",
69
+ "functools",
70
+ "heapq",
71
+ "itertools",
72
+ "math",
73
+ "re",
74
+ "string",
75
+ "typing",
76
+ }
77
+
78
+
79
+ SAFE_BUILTINS: Dict[str, Any] = {
80
  "int": int, "float": float, "str": str, "bool": bool,
81
  "list": list, "dict": dict, "set": set, "tuple": tuple,
82
  "bytes": bytes, "bytearray": bytearray, "frozenset": frozenset,
 
118
  }
119
 
120
 
121
+ def _sanitize_imports_and_prepare_bindings(
122
+ source: str,
123
+ allowed_modules: Set[str],
124
+ ) -> Tuple[str, List[Tuple[str, str, str]], List[Tuple[str, str]]]:
125
+ """
126
+ Parse source, validate imports against allowlist, and strip import statements.
127
+
128
+ Returns
129
+ -------
130
+ sanitized_source:
131
+ Source with all import statements removed (so code never calls __import__).
132
+ module_alias_bindings:
133
+ List[(local_name, module_name, attribute_name)].
134
+ `attribute_name == ""` means bind module object itself.
135
+ modules_to_preload:
136
+ List[(root_name, import_target)] pairs.
137
+ """
138
+ tree = ast.parse(source)
139
+ blocked_lines: Set[int] = set()
140
+ module_alias_bindings: List[Tuple[str, str, str]] = []
141
+ modules_to_preload: Set[Tuple[str, str]] = set()
142
+
143
+ for node in ast.walk(tree):
144
+ if isinstance(node, ast.Import):
145
+ for alias in node.names:
146
+ module_name = alias.name
147
+ root_name = module_name.split(".")[0]
148
+ if root_name not in allowed_modules:
149
+ raise ImportError(
150
+ f"Import of '{root_name}' is not allowed in this sandbox."
151
+ )
152
+ local_name = alias.asname or root_name
153
+ module_alias_bindings.append((local_name, module_name, ""))
154
+ modules_to_preload.add((root_name, module_name))
155
+ if hasattr(node, "lineno") and hasattr(node, "end_lineno"):
156
+ blocked_lines.update(range(node.lineno, node.end_lineno + 1))
157
+
158
+ if isinstance(node, ast.ImportFrom):
159
+ if node.level != 0 or not node.module:
160
+ raise ImportError(
161
+ "Relative imports are not allowed in this sandbox."
162
+ )
163
+ module_name = node.module
164
+ root_name = module_name.split(".")[0]
165
+ if root_name not in allowed_modules:
166
+ raise ImportError(
167
+ f"Import of '{root_name}' is not allowed in this sandbox."
168
+ )
169
+ for alias in node.names:
170
+ if alias.name == "*":
171
+ raise ImportError(
172
+ "Wildcard imports are not allowed in this sandbox."
173
+ )
174
+ local_name = alias.asname or alias.name
175
+ module_alias_bindings.append((local_name, module_name, alias.name))
176
+ modules_to_preload.add((root_name, module_name))
177
+ if hasattr(node, "lineno") and hasattr(node, "end_lineno"):
178
+ blocked_lines.update(range(node.lineno, node.end_lineno + 1))
179
+
180
+ sanitized_lines = [
181
+ line
182
+ for i, line in enumerate(source.splitlines(), start=1)
183
+ if i not in blocked_lines
184
+ ]
185
+ return "\n".join(sanitized_lines), module_alias_bindings, sorted(modules_to_preload)
186
+
187
+
188
+ def _build_local_env_for_source(
189
+ source: str,
190
+ allowed_modules: Set[str],
191
+ ) -> Tuple[str, Dict[str, Any]]:
192
+ """
193
+ Build a local env with preloaded authorized modules/symbols.
194
+ """
195
+ sanitized_source, bindings, modules_to_preload = _sanitize_imports_and_prepare_bindings(
196
+ source, allowed_modules
197
+ )
198
+ local_env: Dict[str, Any] = {}
199
+ loaded_modules: Dict[str, Any] = {}
200
+
201
+ for root_name, import_target in modules_to_preload:
202
+ if import_target not in loaded_modules:
203
+ loaded_modules[import_target] = importlib.import_module(import_target)
204
+ if root_name not in loaded_modules:
205
+ loaded_modules[root_name] = importlib.import_module(root_name)
206
+
207
+ for local_name, module_name, attribute_name in bindings:
208
+ module_obj = loaded_modules[module_name]
209
+ if attribute_name:
210
+ local_env[local_name] = getattr(module_obj, attribute_name)
211
+ else:
212
+ local_env[local_name] = module_obj
213
+
214
+ return sanitized_source, local_env
215
+
216
+
217
 
218
  def _tail_truncate(s: str, limit: int = MAX_OUTPUT_CHARS) -> str:
219
  """
 
260
  result_queue.put((_tail_truncate(err), [], True))
261
  return
262
 
 
263
  try:
264
+ sanitized_source, local_env = _build_local_env_for_source(
265
+ source,
266
+ TEST_SUITE_ALLOWED_MODULES,
267
+ )
268
+ exec_env: Dict[str, Any] = {"__builtins__": SAFE_BUILTINS}
269
+ exec_env.update(local_env)
270
+ code_obj = compile(sanitized_source, "<agent_code>", "exec")
271
+ exec(code_obj, exec_env, exec_env) # noqa: S102
272
  except Exception: # noqa: BLE001
273
  tb = traceback.format_exc()
274
  sys.stdout, sys.stderr = old_stdout, old_stderr
 
278
  for test_src in test_sources:
279
  fn_name = "<unknown>"
280
  try:
281
+ sanitized_test_src, test_env_injections = _build_local_env_for_source(
282
+ test_src,
283
+ TEST_SUITE_ALLOWED_MODULES,
284
+ )
285
+ exec_env.update(test_env_injections)
286
+ exec(
287
+ compile(sanitized_test_src, "<sandbox_test>", "exec"),
288
+ exec_env,
289
+ exec_env,
290
+ ) # noqa: S102
291
 
292
  fn_name = [
293
  ln.split("(")[0].replace("def ", "").strip()
294
+ for ln in sanitized_test_src.splitlines()
295
  if ln.startswith("def ")
296
  ][-1]
297
 
298
+ exec_env[fn_name](exec_env)
299
  test_results.append({"test_name": fn_name, "passed": True})
300
 
301
  except AssertionError as exc:
uv.lock CHANGED
@@ -1599,7 +1599,7 @@ core = [
1599
  ]
1600
 
1601
  [[package]]
1602
- name = "openenv-python-debugging-gym"
1603
  version = "0.1.0"
1604
  source = { editable = "." }
1605
  dependencies = [
 
1599
  ]
1600
 
1601
  [[package]]
1602
+ name = "openenv-tracefix-rl"
1603
  version = "0.1.0"
1604
  source = { editable = "." }
1605
  dependencies = [