Spaces:
Running
Running
| """ | |
| environment.py (Task 3 β Rule Checker) | |
| ----------------------------------------- | |
| OpenEnv-compliant RL environment. | |
| Episode setup | |
| βββββββββββββ | |
| - A Solidity contract is selected that contains at least one function | |
| violating a known property. | |
| - The agent sees: contract description + the property in natural English. | |
| - The agent must identify which function breaks that property. | |
| Observation at reset | |
| ββββββββββββββββββββ | |
| extra.property_english β the violated property in plain English | |
| extra.hint β instructions for the agent | |
| Actions & rewards | |
| βββββββββββββββββ | |
| list_functions -0.05 see all function names | |
| get_function_metadata -0.05 signature / visibility / modifiers / params | |
| get_function_code -0.10 full Solidity source of any function | |
| get_state_variables -0.05 list or inspect state variables | |
| get_call_graph -0.08 function call graph | |
| get_property_specification -0.03 formal pre/post-condition version of property | |
| submit_function terminal: +5.0 / +1.5 / -1.5 (ONE attempt) | |
| repeated_query -0.40 | |
| Difficulty: Easy | |
| The property text directly names the invariant broken; reading 2-3 functions | |
| should let most agents identify the culprit quickly. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from typing import Any, Dict, List, Optional, Set | |
| from data.data_loader import load_contracts, sample_task3_episode | |
| from env.base_env import BaseEnv | |
| from env.schemas import ( | |
| Action, | |
| ActionType, | |
| Observation, | |
| Reward, | |
| ResetResult, | |
| StateResult, | |
| StepResult, | |
| ) | |
| from .grader import Task3Grader | |
| from server.tasks.task3 import actions | |
| TASK_ID = "task3_rule_checker" | |
| AVAILABLE_ACTIONS = [ | |
| ActionType.LIST_FUNCTIONS, | |
| ActionType.GET_FUNCTION_METADATA, | |
| ActionType.GET_FUNCTION_CODE, | |
| ActionType.GET_STATE_VARIABLE, | |
| ActionType.GET_CALL_GRAPH, | |
| ActionType.GET_PROPERTY_SPECIFICATION, | |
| ActionType.SUBMIT_FUNCTION, | |
| ] | |
| class Task3Environment(BaseEnv): | |
| """Task 3: Rule Checker β identify the function that violates a given property.""" | |
| def __init__(self, contracts_path: Optional[str] = None) -> None: | |
| self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts() | |
| self._rng = random.Random() | |
| self._max_steps = 20 | |
| # Episode state β initialised by reset() | |
| self._contract: Dict[str, Any] = {} | |
| self._target_fn: Dict[str, Any] = {} | |
| self._grader: Optional[Task3Grader] = None | |
| self._step_count: int = 0 | |
| self._cum_reward: float = 0.0 | |
| self._done: bool = False | |
| self._query_hist: List[str] = [] | |
| self._seen: Set[str] = set() | |
| # ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, seed: Optional[int] = None) -> ResetResult: | |
| if seed is not None: | |
| self._rng.seed(seed) | |
| self._contract, self._target_fn = sample_task3_episode( | |
| self._contracts, self._rng | |
| ) | |
| self._grader = Task3Grader( | |
| target_function=self._target_fn, | |
| property_specification=self._target_fn.get("property_specification", ""), | |
| max_steps = self._max_steps | |
| ) | |
| self._step_count = 0 | |
| self._cum_reward = 0.0 | |
| self._done = False | |
| self._query_hist = [] | |
| self._seen = set() | |
| obs = self._build_obs( | |
| last_action=None, | |
| last_result=( | |
| f"New episode started.\n" | |
| f"Contract : {self._contract['contract_name']}\n\n" | |
| f"Property : {self._target_fn.get('property', '')}\n\n" | |
| f"Find the function in this contract that violates the property above.\n" | |
| f"Use list_functions then get_function_code to investigate.\n" | |
| f"Submit with submit_function, params={{\"function_name\": \"...\"}}.\n" | |
| f"Only ONE submission allowed." | |
| ), | |
| ) | |
| return ResetResult(observation=obs, info={"task_id": TASK_ID}) | |
| def step(self, action: Action) -> StepResult: | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| if self._step_count > self._max_steps: | |
| raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.") | |
| self._step_count += 1 | |
| result_text, reward = self._dispatch(action) | |
| self._cum_reward += reward.value | |
| self._query_hist.append(f"[{action.action_type}] β {result_text[:100]}") | |
| obs = self._build_obs( | |
| last_action=action.action_type, | |
| last_result=result_text, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=reward, | |
| done=self._done, | |
| info={"step": self._step_count, "cumulative_reward": self._cum_reward}, | |
| ) | |
| def state(self) -> StateResult: | |
| return StateResult( | |
| task_id=TASK_ID, | |
| contract_name=self._contract.get("contract_name", ""), | |
| target_function=self._target_fn.get("name", ""), | |
| step_count=self._step_count, | |
| cumulative_reward=self._cum_reward, | |
| done=self._done, | |
| query_history=list(self._query_hist), | |
| ) | |
| # ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation: | |
| return Observation( | |
| task_id=TASK_ID, | |
| contract_name=self._contract.get("contract_name", ""), | |
| last_action=last_action, | |
| last_action_result=last_result, | |
| done=self._done, | |
| extra={ | |
| "property_english": self._target_fn.get("property", ""), | |
| "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""), | |
| "hint": ( | |
| "Read the property, then inspect function code to find which one violates it. " | |
| "Submit with: submit_function, params={'function_name': '<name>'}. " | |
| "ONE submission per episode." | |
| ), | |
| }, | |
| ) | |
| def _qkey(self, at: str, params: Dict[str, Any]) -> str: | |
| return f"{at}:{sorted(params.items())}" | |
| def _is_repeated(self, key: str) -> bool: | |
| if key in self._seen: | |
| return True | |
| self._seen.add(key) | |
| return False | |
| def _dispatch(self, action: Action) -> tuple[str, Reward]: | |
| at = action.action_type | |
| params = action.params | |
| qkey = self._qkey(at, params) | |
| # Mapping from ActionType to handler function | |
| handlers = { | |
| ActionType.LIST_FUNCTIONS: actions.list_functions, | |
| ActionType.GET_FUNCTION_METADATA: actions.get_function_metadata, | |
| ActionType.GET_FUNCTION_CODE: actions.get_function_code, | |
| ActionType.GET_STATE_VARIABLE: actions.get_state_variable, | |
| ActionType.GET_CALL_GRAPH: actions.get_call_graph, | |
| ActionType.GET_PROPERTY_SPECIFICATION: actions.get_property_specification, | |
| ActionType.SUBMIT_FUNCTION: actions.submit_function, | |
| } | |
| handler = handlers.get(at) | |
| if handler is None: | |
| return actions.unknown_action(self, qkey, params, at) | |
| return handler(self, qkey, params) |