Spaces:
Sleeping
Sleeping
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade | """Ephemeral generated app sandbox operations.""" | |
| from __future__ import annotations | |
| import difflib | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| try: | |
| from ..models import CyberSecurityOWASPState | |
| from ..safety import is_local_route | |
| from ..validators import is_path_allowed, simulate_request | |
| except ImportError: # pragma: no cover | |
| from models import CyberSecurityOWASPState | |
| from safety import is_local_route | |
| from validators import is_path_allowed, simulate_request | |
| class AppSandbox: | |
| """Encapsulates all generated workspace reads, patches, and local requests.""" | |
| def __init__(self, state: CyberSecurityOWASPState): | |
| self.state = state | |
| def workspace(self) -> Path: | |
| return Path(str(self.state.hidden_facts["workspace"])) | |
| def read_file(self, path: str) -> str: | |
| return self._resolve_path(path).read_text(encoding="utf-8") | |
| def search_code(self, query: str) -> str: | |
| if not query: | |
| raise ValueError("query is required") | |
| results: list[str] = [] | |
| for rel in self.state.hidden_facts.get("editable_files", []): | |
| path = self.workspace / rel | |
| text = path.read_text(encoding="utf-8") | |
| for idx, line in enumerate(text.splitlines(), start=1): | |
| if query.lower() in line.lower(): | |
| results.append(f"{rel}:{idx}: {line}") | |
| return "\n".join(results) or "No matches." | |
| def patch_file(self, path: str, *, content: str | None = None, diff: str | None = None) -> dict[str, str]: | |
| target = self._resolve_path(path, write=True) | |
| before = target.read_text(encoding="utf-8") | |
| if content is not None: | |
| target.write_text(content, encoding="utf-8") | |
| else: | |
| self._apply_unified_diff(target, diff or "") | |
| after = target.read_text(encoding="utf-8") | |
| patch_diff = "".join( | |
| difflib.unified_diff( | |
| before.splitlines(True), | |
| after.splitlines(True), | |
| fromfile=path, | |
| tofile=path, | |
| ) | |
| ) | |
| self.state.patch_diff = patch_diff | |
| self.state.patch_attempt_count += 1 | |
| files_touched = self.state.metrics.setdefault("files_touched", []) | |
| if path not in files_touched: | |
| files_touched.append(path) | |
| return {"path": path, "diff": patch_diff} | |
| def read_openapi(self) -> str: | |
| routes = self.state.visible_facts.get("workspace_summary", {}).get("routes", []) | |
| paths: dict[str, Any] = {} | |
| for route in routes: | |
| paths.setdefault(route["path"], {})[route["method"].lower()] = { | |
| "x-public": bool(route.get("public", False)) | |
| } | |
| return json.dumps( | |
| { | |
| "openapi": "3.1.0", | |
| "info": {"title": "Generated invoices app", "version": "0.1.0"}, | |
| "paths": paths, | |
| }, | |
| indent=2, | |
| sort_keys=True, | |
| ) | |
| def send_local_request(self, method: str, path: str, user_id: str | None = None) -> dict[str, Any]: | |
| if not is_local_route(path): | |
| raise ValueError("send_local_request only accepts local route paths") | |
| response = simulate_request(self.state, method, path, user_id) | |
| trace_id = self._record_request_trace( | |
| method=method, | |
| path=path, | |
| user_id=user_id, | |
| status=int(response.get("status", 0) or 0), | |
| ) | |
| return {"trace_id": trace_id, **response} | |
| def compare_identities( | |
| self, | |
| method: str, | |
| path: str, | |
| first_user_id: str, | |
| second_user_id: str, | |
| ) -> dict[str, Any]: | |
| if not is_local_route(path): | |
| raise ValueError("compare_identities only accepts local route paths") | |
| first = simulate_request(self.state, method, path, first_user_id) | |
| second = simulate_request(self.state, method, path, second_user_id) | |
| trace_id = self._record_request_trace( | |
| method=method, | |
| path=path, | |
| user_id=first_user_id, | |
| status=int(first.get("status", 0) or 0), | |
| comparison_user_id=second_user_id, | |
| comparison_status=int(second.get("status", 0) or 0), | |
| ) | |
| return { | |
| "trace_id": trace_id, | |
| "first": first, | |
| "second": second, | |
| } | |
| def _record_request_trace( | |
| self, | |
| *, | |
| method: str, | |
| path: str, | |
| user_id: str | None, | |
| status: int, | |
| comparison_user_id: str | None = None, | |
| comparison_status: int | None = None, | |
| ) -> str: | |
| trace_id = f"req_{len(self.state.request_trace) + 1:03d}" | |
| hidden = self.state.hidden_facts | |
| unauthorized_success = ( | |
| str(hidden.get("other_invoice_id", "")) in path | |
| and user_id == hidden.get("owner_user_id") | |
| and status == 200 | |
| ) | |
| if comparison_user_id is not None and comparison_status is not None: | |
| unauthorized_success = unauthorized_success or ( | |
| str(hidden.get("other_invoice_id", "")) in path | |
| and comparison_user_id == hidden.get("owner_user_id") | |
| and comparison_status == 200 | |
| ) | |
| self.state.request_trace.append( | |
| { | |
| "trace_id": trace_id, | |
| "method": method.upper(), | |
| "path": path, | |
| "user_id": user_id, | |
| "status": status, | |
| "comparison_user_id": comparison_user_id, | |
| "comparison_status": comparison_status, | |
| "unauthorized_success": unauthorized_success, | |
| } | |
| ) | |
| return trace_id | |
| def _resolve_path(self, path: str, *, write: bool = False) -> Path: | |
| allowed, normalized_or_error = is_path_allowed(self.state, path, write=write) | |
| if not allowed: | |
| raise ValueError(normalized_or_error) | |
| return self.workspace / normalized_or_error | |
| def _apply_unified_diff(self, path: Path, diff: str) -> None: | |
| if not diff.strip(): | |
| raise ValueError("diff or content is required") | |
| original = path.read_text(encoding="utf-8").splitlines(True) | |
| output: list[str] = [] | |
| old_index = 0 | |
| lines = diff.splitlines(True) | |
| i = 0 | |
| while i < len(lines): | |
| line = lines[i] | |
| if not line.startswith("@@"): | |
| i += 1 | |
| continue | |
| old_start = int(line.split()[1].split(",")[0][1:]) | |
| output.extend(original[old_index : old_start - 1]) | |
| old_index = old_start - 1 | |
| i += 1 | |
| while i < len(lines) and not lines[i].startswith("@@"): | |
| hunk_line = lines[i] | |
| if hunk_line.startswith(" "): | |
| output.append(original[old_index]) | |
| old_index += 1 | |
| elif hunk_line.startswith("-"): | |
| old_index += 1 | |
| elif hunk_line.startswith("+"): | |
| output.append(hunk_line[1:]) | |
| elif hunk_line.startswith("\\"): | |
| pass | |
| i += 1 | |
| output.extend(original[old_index:]) | |
| path.write_text("".join(output), encoding="utf-8") | |