SmartContractAudit / server /tasks /task3 /environment.py
ajaxwin
refactor: Reward clamping in graders
41a051f
"""
environment.py (Task 3 – Rule Checker)
-----------------------------------------
OpenEnv-compliant RL environment.
Episode setup
─────────────
- A Solidity contract is selected that contains at least one function
violating a known property.
- The agent sees: contract description + the property in natural English.
- The agent must identify which function breaks that property.
Observation at reset
────────────────────
extra.property_english – the violated property in plain English
extra.hint – instructions for the agent
Actions & rewards
─────────────────
list_functions -0.05 see all function names
get_function_metadata -0.05 signature / visibility / modifiers / params
get_function_code -0.10 full Solidity source of any function
get_state_variables -0.05 list or inspect state variables
get_call_graph -0.08 function call graph
get_property_specification -0.03 formal pre/post-condition version of property
submit_function terminal: +5.0 / +1.5 / -1.5 (ONE attempt)
repeated_query -0.40
Difficulty: Easy
The property text directly names the invariant broken; reading 2-3 functions
should let most agents identify the culprit quickly.
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Set
from data.data_loader import load_contracts, sample_task3_episode
from env.base_env import BaseEnv
from env.schemas import (
Action,
ActionType,
Observation,
Reward,
ResetResult,
StateResult,
StepResult,
)
from .grader import Task3Grader
from server.tasks.task3 import actions
TASK_ID = "task3_rule_checker"
AVAILABLE_ACTIONS = [
ActionType.LIST_FUNCTIONS,
ActionType.GET_FUNCTION_METADATA,
ActionType.GET_FUNCTION_CODE,
ActionType.GET_STATE_VARIABLE,
ActionType.GET_CALL_GRAPH,
ActionType.GET_PROPERTY_SPECIFICATION,
ActionType.SUBMIT_FUNCTION,
]
class Task3Environment(BaseEnv):
"""Task 3: Rule Checker β€” identify the function that violates a given property."""
def __init__(self, contracts_path: Optional[str] = None) -> None:
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
self._rng = random.Random()
self._max_steps = 20
# Episode state β€” initialised by reset()
self._contract: Dict[str, Any] = {}
self._target_fn: Dict[str, Any] = {}
self._grader: Optional[Task3Grader] = None
self._step_count: int = 0
self._cum_reward: float = 0.0
self._done: bool = False
self._query_hist: List[str] = []
self._seen: Set[str] = set()
# ── OpenEnv interface ─────────────────────────────────────────────────────
def reset(self, seed: Optional[int] = None) -> ResetResult:
if seed is not None:
self._rng.seed(seed)
self._contract, self._target_fn = sample_task3_episode(
self._contracts, self._rng
)
self._grader = Task3Grader(
target_function=self._target_fn,
property_specification=self._target_fn.get("property_specification", ""),
max_steps = self._max_steps
)
self._step_count = 0
self._cum_reward = 0.0
self._done = False
self._query_hist = []
self._seen = set()
obs = self._build_obs(
last_action=None,
last_result=(
f"New episode started.\n"
f"Contract : {self._contract['contract_name']}\n\n"
f"Property : {self._target_fn.get('property', '')}\n\n"
f"Find the function in this contract that violates the property above.\n"
f"Use list_functions then get_function_code to investigate.\n"
f"Submit with submit_function, params={{\"function_name\": \"...\"}}.\n"
f"Only ONE submission allowed."
),
)
return ResetResult(observation=obs, info={"task_id": TASK_ID})
def step(self, action: Action) -> StepResult:
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if self._step_count > self._max_steps:
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
self._step_count += 1
result_text, reward = self._dispatch(action)
self._cum_reward += reward.value
self._query_hist.append(f"[{action.action_type}] β†’ {result_text[:100]}")
obs = self._build_obs(
last_action=action.action_type,
last_result=result_text,
)
return StepResult(
observation=obs,
reward=reward,
done=self._done,
info={"step": self._step_count, "cumulative_reward": self._cum_reward},
)
def state(self) -> StateResult:
return StateResult(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
target_function=self._target_fn.get("name", ""),
step_count=self._step_count,
cumulative_reward=self._cum_reward,
done=self._done,
query_history=list(self._query_hist),
)
# ── Internal helpers ──────────────────────────────────────────────────────
def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
return Observation(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
last_action=last_action,
last_action_result=last_result,
done=self._done,
extra={
"property_english": self._target_fn.get("property", ""),
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
"hint": (
"Read the property, then inspect function code to find which one violates it. "
"Submit with: submit_function, params={'function_name': '<name>'}. "
"ONE submission per episode."
),
},
)
def _qkey(self, at: str, params: Dict[str, Any]) -> str:
return f"{at}:{sorted(params.items())}"
def _is_repeated(self, key: str) -> bool:
if key in self._seen:
return True
self._seen.add(key)
return False
def _dispatch(self, action: Action) -> tuple[str, Reward]:
at = action.action_type
params = action.params
qkey = self._qkey(at, params)
# Mapping from ActionType to handler function
handlers = {
ActionType.LIST_FUNCTIONS: actions.list_functions,
ActionType.GET_FUNCTION_METADATA: actions.get_function_metadata,
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
ActionType.GET_STATE_VARIABLE: actions.get_state_variable,
ActionType.GET_CALL_GRAPH: actions.get_call_graph,
ActionType.GET_PROPERTY_SPECIFICATION: actions.get_property_specification,
ActionType.SUBMIT_FUNCTION: actions.submit_function,
}
handler = handlers.get(at)
if handler is None:
return actions.unknown_action(self, qkey, params, at)
return handler(self, qkey, params)