SmartContractAudit / env /schemas.py
ajaxwin
refactor: Reward clamping in graders
41a051f
"""
schemas.py
----------
Typed Pydantic models implementing the OpenEnv interface spec.
Observation – what the agent sees at each step
Action – what the agent can send
StepResult – returned by step()
ResetResult – returned by reset()
StateResult – returned by state()
Reward – structured reward info
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Action types
# ---------------------------------------------------------------------------
class ActionType(str, Enum):
"""(Action type, cost)"""
# Attribute to store the cost of each action
cost: float
# ── Task 1 – Vulnerability Detection ───────────────────────────────────
LIST_FUNCTIONS = ("list_functions", -0.04)
GET_FUNCTION_CODE = ("get_function_code", -0.14)
GET_FUNCTION_SUMMARY = ("get_function_summary", -0.07)
GET_FILE_METADATA = ("get_file_metadata", -0.02)
GET_STATE_VARIABLE = ("get_state_variable", -0.06)
GET_CALL_GRAPH = ("get_call_graph", -0.08)
SUBMIT = ("submit", 0.0)
# ── Task 2 – Property Discovery ─────────────────────────────────────────
GET_SIMILAR_RULE = ("get_similar_rule", 0.15)
GET_FILE_NATSPEC = ("get_file_natspec", 0.05)
GET_FUNCTION_NATSPEC = ("get_function_natspec", -0.08)
GET_RELATED_FUNCTIONS = ("get_related_functions", 0.07)
GET_SIGNATURE = ("get_signature", 0.04)
SUBMIT_PROPERTY = ("submit_property", 0.0)
# ── Task 3 – Rule Checker ────────────────────────────────────────────────
GET_PROPERTY_SPECIFICATION = ("get_property_specification", 0.02)
GET_FUNCTION_METADATA = ("get_function_metadata", 0.04)
SUBMIT_FUNCTION = ("submit_function", 0.0)
# ─────── General Actions ─────────────────────────────────────────────────
UNKNOWN = ("unknown", 0.0)
REPEATED = ("repeated", -0.22)
RESUBMIT = ("resubmit", 0.0)
def __new__(cls, str_value: str, cost: float):
obj = str.__new__(cls, str_value)
obj._value_ = str_value
obj.cost = cost
return obj
class Action(BaseModel):
"""
Agent action.
action_type : one of ActionType enum values
params : optional key/value arguments, e.g.
{"function_name": "withdraw"} for GET_FUNCTION_CODE
{"property": "..."} for SUBMIT_PROPERTY
"""
action_type: ActionType
params: Dict[str, Any] = Field(default_factory=dict)
class Config:
use_enum_values = True
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class Observation(BaseModel):
"""
What the agent receives from the environment.
task_id : which task is active
contract_name : name of the Solidity contract
available_actions : list of valid ActionType strings
last_action : the action that produced this observation (None on reset)
last_action_result : human-readable result of the last action
done : whether the episode has ended
extra : any additional task-specific context
"""
task_id: str
contract_name: str
# available_actions: List[str] # May need it, may not depends on the agent
last_action: Optional[str] = None
last_action_result: Optional[str] = None
done: bool = False
extra: Dict[str, Any] = Field(default_factory=dict)
# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------
class Reward(BaseModel):
"""
Structured reward info returned with each step.
value : float reward for this step (can be negative)
reason : human-readable explanation
partial : True if this is a shaping reward, False if terminal
"""
value: float
reason: str
partial: bool = True
# ---------------------------------------------------------------------------
# Step / Reset / State results
# ---------------------------------------------------------------------------
class StepResult(BaseModel):
observation: Observation
reward: Reward
done: bool
info: Dict[str, Any] = Field(default_factory=dict)
class ResetResult(BaseModel):
observation: Observation
info: Dict[str, Any] = Field(default_factory=dict)
class StateResult(BaseModel):
task_id: str
contract_name: str
target_function: str # hidden in real eval, exposed here for debugging
step_count: int
cumulative_reward: float
done: bool
query_history: List[str] = Field(default_factory=list)
session_id: Optional[str] = None
# ---------------------------------------------------------------------------
# Task registry entry
# ---------------------------------------------------------------------------
class TaskInfo(BaseModel):
task_id: str
name: str
difficulty: str
description: str
status: str = "active" # "active" | "placeholder"