Spaces:
Running
Running
ajaxwin commited on
Commit Β·
9c888b7
1
Parent(s): 8fccda7
Task 2 added
Browse files- README.md +177 -169
- app.py +78 -116
- data/data_loader.py +135 -30
- demo.py +74 -4
- env/schemas.py +37 -36
- eval.py +259 -220
- inference.py +217 -234
- openenv.yaml +67 -91
- tasks/task2/__init__.py +4 -26
- tasks/task2/environment.py +340 -0
- tasks/task2/grader.py +171 -0
- validate.py +189 -199
README.md
CHANGED
|
@@ -1,108 +1,120 @@
|
|
| 1 |
# Smart Contract Audit RL Environment
|
| 2 |
|
| 3 |
> **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
|
| 4 |
-
>
|
| 5 |
|
| 6 |
-
[](https://huggingface.co/spaces)
|
| 8 |
[](https://python.org)
|
|
|
|
| 9 |
|
| 10 |
---
|
| 11 |
|
| 12 |
## Motivation
|
| 13 |
|
| 14 |
-
Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws. This environment lets agents practice exactly that workflow β exploring contract code through targeted queries and submitting findings β providing a
|
| 15 |
|
| 16 |
-
Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits
|
| 17 |
|
| 18 |
---
|
| 19 |
|
| 20 |
-
##
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|------|------|------------|--------|
|
| 26 |
-
| 1 | Targeted Vulnerability Detection | Medium | β
Active |
|
| 27 |
-
| 2 | Property Discovery | Hard | β³ Placeholder |
|
| 28 |
-
| 3 | Rule Checker | Easy | β³ Placeholder |
|
| 29 |
|
| 30 |
-
##
|
| 31 |
|
| 32 |
-
**Setup:**
|
| 33 |
|
| 34 |
-
**Objective:** Identify the vulnerable function and describe
|
| 35 |
|
| 36 |
-
|
| 37 |
-
1. `reset()` β randomly selects one of 8 vulnerable (contract, function) pairs from the dataset
|
| 38 |
-
2. Agent receives the contract name and description
|
| 39 |
-
3. Agent explores using the action API (each action has a small cost)
|
| 40 |
-
4. Agent calls `submit(function_name, vulnerability_type)` to end the episode
|
| 41 |
-
5. Grader assigns 0.0β1.0 score
|
| 42 |
|
| 43 |
-
|
| 44 |
-
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
|
| 55 |
-
###
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
---
|
| 60 |
|
| 61 |
-
##
|
| 62 |
|
| 63 |
-
|
| 64 |
|
| 65 |
-
|
| 66 |
|
| 67 |
-
##
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
|
| 81 |
-
**
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
|
| 86 |
---
|
| 87 |
|
| 88 |
## Observation Space
|
| 89 |
|
| 90 |
-
Every `step()` and `reset()` returns
|
| 91 |
|
| 92 |
```json
|
| 93 |
{
|
| 94 |
-
"task_id": "
|
| 95 |
-
"contract_name": "
|
| 96 |
-
"contract_description": "
|
| 97 |
-
"available_actions": ["
|
| 98 |
-
"last_action": "
|
| 99 |
-
"last_action_result": "
|
| 100 |
-
"step_count":
|
| 101 |
-
"cumulative_reward": -0.
|
| 102 |
"done": false,
|
| 103 |
"extra": {
|
| 104 |
-
"
|
| 105 |
-
"
|
|
|
|
|
|
|
| 106 |
}
|
| 107 |
}
|
| 108 |
```
|
|
@@ -114,63 +126,88 @@ Every `step()` and `reset()` returns an `Observation` object:
|
|
| 114 |
```
|
| 115 |
smart-contract-env/
|
| 116 |
βββ data/
|
| 117 |
-
β βββ contracts.json # 4 contracts
|
| 118 |
-
β βββ data_loader.py # JSON
|
| 119 |
βββ env/
|
| 120 |
β βββ base_env.py # Abstract OpenEnv base class
|
| 121 |
-
β βββ schemas.py # Pydantic
|
| 122 |
βββ tasks/
|
| 123 |
β βββ task1/
|
| 124 |
β β βββ environment.py # Full Task 1 RL environment
|
| 125 |
-
β β βββ grader.py # Deterministic 0
|
| 126 |
-
β βββ task2/
|
| 127 |
-
β
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
βββ
|
| 131 |
-
βββ
|
| 132 |
-
βββ
|
| 133 |
-
βββ
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
```
|
| 136 |
|
| 137 |
---
|
| 138 |
|
| 139 |
## Setup & Usage
|
| 140 |
|
| 141 |
-
###
|
| 142 |
|
| 143 |
```bash
|
| 144 |
-
|
| 145 |
-
git clone <repo>
|
| 146 |
-
cd smart-contract-env
|
| 147 |
pip install -r requirements.txt
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
python app.py
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
```
|
| 153 |
|
| 154 |
-
###
|
| 155 |
|
| 156 |
```bash
|
| 157 |
docker build -t sc-audit-env .
|
| 158 |
docker run -p 7860:7860 sc-audit-env
|
| 159 |
```
|
| 160 |
|
| 161 |
-
###
|
| 162 |
|
| 163 |
```python
|
| 164 |
from tasks.task1.environment import Task1Environment
|
|
|
|
| 165 |
from env.schemas import Action, ActionType
|
| 166 |
|
|
|
|
| 167 |
env = Task1Environment()
|
| 168 |
-
|
| 169 |
-
print(
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
```
|
| 175 |
|
| 176 |
---
|
|
@@ -180,35 +217,28 @@ print(step.observation.last_action_result)
|
|
| 180 |
| Method | Endpoint | Description |
|
| 181 |
|--------|----------|-------------|
|
| 182 |
| `GET` | `/health` | Liveness probe |
|
| 183 |
-
| `GET` | `/tasks` |
|
| 184 |
-
| `POST` | `/reset` | Start
|
| 185 |
-
| `POST` | `/step` | Take
|
| 186 |
-
| `GET` | `/state` | Debug: internal state |
|
| 187 |
-
| `GET` | `/action_space` | Action
|
| 188 |
-
| `GET` | `/observation_space` | Observation
|
| 189 |
-
|
| 190 |
-
**Example session:**
|
| 191 |
|
| 192 |
```bash
|
| 193 |
-
#
|
| 194 |
-
curl -X POST
|
| 195 |
-
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
-d '{"action_type":
|
| 202 |
-
|
| 203 |
-
# Submit answer
|
| 204 |
-
curl -X POST "http://localhost:7860/step" \
|
| 205 |
-
-H "Content-Type: application/json" \
|
| 206 |
-
-d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
|
| 207 |
```
|
| 208 |
|
| 209 |
---
|
| 210 |
|
| 211 |
-
##
|
| 212 |
|
| 213 |
```bash
|
| 214 |
export API_BASE_URL="https://api.openai.com/v1"
|
|
@@ -216,86 +246,64 @@ export MODEL_NAME="gpt-4o-mini"
|
|
| 216 |
export HF_TOKEN="sk-..."
|
| 217 |
|
| 218 |
python inference.py
|
|
|
|
| 219 |
```
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
**Expected baseline scores (gpt-4o-mini, 3 episodes):**
|
| 224 |
|
| 225 |
| Task | Avg Grader Score | Notes |
|
| 226 |
|------|-----------------|-------|
|
| 227 |
-
| Task 1 | ~0.67 |
|
| 228 |
-
| Task 2 | 0.
|
| 229 |
| Task 3 | 0.00 | Placeholder |
|
| 230 |
|
| 231 |
---
|
| 232 |
|
| 233 |
-
##
|
| 234 |
-
|
| 235 |
-
```json
|
| 236 |
-
{
|
| 237 |
-
"model": "gpt-4o-mini",
|
| 238 |
-
"tasks": [
|
| 239 |
-
{
|
| 240 |
-
"task_id": "task1_vuln_detection",
|
| 241 |
-
"avg_grader_score": 0.667,
|
| 242 |
-
"avg_cumulative_reward": 2.14
|
| 243 |
-
},
|
| 244 |
-
{ "task_id": "task2_property_discovery", "avg_grader_score": 0.0 },
|
| 245 |
-
{ "task_id": "task3_rule_checker", "avg_grader_score": 0.0 }
|
| 246 |
-
],
|
| 247 |
-
"overall_avg_score": 0.667
|
| 248 |
-
}
|
| 249 |
-
```
|
| 250 |
-
|
| 251 |
-
---
|
| 252 |
-
|
| 253 |
-
## Grader Details
|
| 254 |
|
| 255 |
-
|
| 256 |
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
Scores map to terminal rewards: 1.0 β +5, 0.5 β +1, 0.0 β β1.5.
|
| 262 |
-
|
| 263 |
-
---
|
| 264 |
-
|
| 265 |
-
## OpenEnv Spec Compliance
|
| 266 |
-
|
| 267 |
-
- β
Typed `Observation`, `Action`, `Reward` Pydantic models
|
| 268 |
-
- β
`step(action) β StepResult(observation, reward, done, info)`
|
| 269 |
-
- β
`reset() β ResetResult(observation, info)`
|
| 270 |
-
- β
`state() β StateResult`
|
| 271 |
-
- β
`openenv.yaml` metadata
|
| 272 |
-
- β
3 tasks defined (1 active, 2 placeholders)
|
| 273 |
-
- β
Grader scores in [0.0, 1.0]
|
| 274 |
-
- β
Shaped rewards (not just binary)
|
| 275 |
-
- β
Dockerfile + HF Space deployment
|
| 276 |
-
- β
Baseline `inference.py` using OpenAI client
|
| 277 |
|
| 278 |
---
|
| 279 |
|
| 280 |
## Deploying to Hugging Face Spaces
|
| 281 |
|
| 282 |
-
1. Create a new **Docker** Space
|
| 283 |
-
2.
|
| 284 |
-
3.
|
|
|
|
| 285 |
|
| 286 |
```bash
|
| 287 |
-
git remote add hf https://huggingface.co/spaces/<
|
| 288 |
git push hf main
|
| 289 |
```
|
| 290 |
|
| 291 |
-
The Space will build the Docker image and serve the FastAPI app on port 7860.
|
| 292 |
-
|
| 293 |
---
|
| 294 |
|
| 295 |
-
##
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
Contract vulnerability
|
|
|
|
| 1 |
# Smart Contract Audit RL Environment
|
| 2 |
|
| 3 |
> **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
|
| 4 |
+
> Train and evaluate agents on real-world Solidity audit tasks β the same work professional auditors do every day.
|
| 5 |
|
| 6 |
+
[](openenv.yaml)
|
|
|
|
| 7 |
[](https://python.org)
|
| 8 |
+
[](LICENSE)
|
| 9 |
|
| 10 |
---
|
| 11 |
|
| 12 |
## Motivation
|
| 13 |
|
| 14 |
+
Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws and formally specify function properties. This environment lets agents practice exactly that workflow β exploring contract code through targeted queries and submitting findings β providing a rigorous, real-world benchmark for code-reasoning agents.
|
| 15 |
|
| 16 |
+
Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits.
|
| 17 |
|
| 18 |
---
|
| 19 |
|
| 20 |
+
## Tasks
|
| 21 |
|
| 22 |
+
| # | Name | Difficulty | Status | Description |
|
| 23 |
+
|---|------|------------|--------|-------------|
|
| 24 |
+
| 1 | Targeted Vulnerability Detection | Medium | β
Active | Find the vulnerable function and name the vulnerability type |
|
| 25 |
+
| 2 | Property Discovery | Hard | β
Active | Write the natural-language postcondition for a given function |
|
| 26 |
+
| 3 | Rule Checker | Easy | β³ Placeholder | Identify which function violates a given property |
|
| 27 |
|
| 28 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
## Task 1 β Targeted Vulnerability Detection *(Medium)*
|
| 31 |
|
| 32 |
+
**Setup:** Agent is shown a Solidity contract (4β6 functions). One function contains a critical vulnerability.
|
| 33 |
|
| 34 |
+
**Objective:** Identify the vulnerable function and describe its vulnerability type in 2β3 words.
|
| 35 |
|
| 36 |
+
### Actions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
| Action | Params | Reward |
|
| 39 |
+
|--------|--------|--------|
|
| 40 |
+
| `list_functions` | β | β0.05 |
|
| 41 |
+
| `get_function_code` | `function_name` | +0.05 (target) / β0.10 (other) |
|
| 42 |
+
| `get_function_summary` | `function_name` | +0.03 (target) / β0.05 (other) |
|
| 43 |
+
| `get_file_metadata` | β | β0.04 |
|
| 44 |
+
| `get_state_variable` | `variable_name` (opt.) | β0.05 |
|
| 45 |
+
| `get_call_graph` | β | β0.08 |
|
| 46 |
+
| `submit` | `function_name`, `vulnerability_type` | **+5.0** / +1.0 / β1.5 |
|
| 47 |
|
| 48 |
+
Repeated identical queries: **β0.40**
|
| 49 |
|
| 50 |
+
### Submit scoring (deterministic)
|
| 51 |
+
- **1.0** β correct function **+** correct vulnerability keyword β reward +5.0
|
| 52 |
+
- **0.5** β correct function, wrong/vague vulnerability type β reward +1.0
|
| 53 |
+
- **0.0** β wrong function β reward β1.5
|
| 54 |
|
| 55 |
+
### Vulnerability types in dataset
|
| 56 |
+
Reentrancy Β· Missing access control Β· Integer overflow Β· tx.origin authentication Β·
|
| 57 |
+
Front-running Β· Timestamp dependence Β· Denial of service (unbounded loop) Β· Unchecked return value
|
| 58 |
|
| 59 |
---
|
| 60 |
|
| 61 |
+
## Task 2 β Property Discovery *(Hard)*
|
| 62 |
|
| 63 |
+
**Setup:** Agent is shown a single Solidity function and must write its natural-language correctness property (postcondition / invariant).
|
| 64 |
|
| 65 |
+
**Objective:** Write a precise 2β4 sentence property describing what the function guarantees when it succeeds.
|
| 66 |
|
| 67 |
+
### Actions
|
| 68 |
|
| 69 |
+
| Action | Params | Reward |
|
| 70 |
+
|--------|--------|--------|
|
| 71 |
+
| `get_function_code` | β | β0.06 |
|
| 72 |
+
| `get_function_natspec` | β | β0.08 |
|
| 73 |
+
| `get_file_natspec` | β | β0.03 |
|
| 74 |
+
| `get_related_functions` | β | β0.06 |
|
| 75 |
+
| `get_io` | β | β0.04 |
|
| 76 |
+
| `get_similar_rule` | β | β0.20 |
|
| 77 |
+
| `submit_property` | `property` (string) | **0.0β5.0** (scored, ONE attempt) |
|
| 78 |
|
| 79 |
+
Repeated identical queries: **β0.40**
|
| 80 |
+
|
| 81 |
+
### Submit scoring (keyword-weighted)
|
| 82 |
+
```
|
| 83 |
+
score = 0.70 Γ (key_phrases_matched / total_key_phrases)
|
| 84 |
+
+ 0.30 Γ (bonus_phrases_matched / total_bonus_phrases)
|
| 85 |
+
|
| 86 |
+
reward = score Γ 5.0 β range: 0.0 β 5.0
|
| 87 |
+
```
|
| 88 |
|
| 89 |
+
Matching uses **word-set containment** with synonym expansion (e.g. "caller" matches "msg.sender", "sender", "user"). Phrases don't need to be adjacent β all constituent words just need to appear somewhere in the submitted text.
|
| 90 |
+
|
| 91 |
+
**One submission per episode** β choose carefully.
|
| 92 |
+
|
| 93 |
+
### Property coverage
|
| 94 |
+
11 functions across 4 contracts with ground-truth properties: SimpleVault (deposit, withdraw, emergencyDrain), TokenSale (buyTokens, setPrice, withdrawETH), DutchAuction (getPrice, bid, finalize), YieldFarm (stake, claimRewards).
|
| 95 |
|
| 96 |
---
|
| 97 |
|
| 98 |
## Observation Space
|
| 99 |
|
| 100 |
+
Every `step()` and `reset()` returns the same `Observation` structure:
|
| 101 |
|
| 102 |
```json
|
| 103 |
{
|
| 104 |
+
"task_id": "task2_property_discovery",
|
| 105 |
+
"contract_name": "YieldFarm",
|
| 106 |
+
"contract_description": "A simple yield farming contract...",
|
| 107 |
+
"available_actions": ["get_function_code", "get_function_natspec", ...],
|
| 108 |
+
"last_action": "get_function_natspec",
|
| 109 |
+
"last_action_result": "NatSpec for 'claimRewards':\n@notice Claim all accrued...",
|
| 110 |
+
"step_count": 2,
|
| 111 |
+
"cumulative_reward": -0.14,
|
| 112 |
"done": false,
|
| 113 |
"extra": {
|
| 114 |
+
"target_function": "claimRewards",
|
| 115 |
+
"target_signature": "claimRewards()",
|
| 116 |
+
"solidity_version": "0.8.10",
|
| 117 |
+
"hint": "Discover the property of the target function..."
|
| 118 |
}
|
| 119 |
}
|
| 120 |
```
|
|
|
|
| 126 |
```
|
| 127 |
smart-contract-env/
|
| 128 |
βββ data/
|
| 129 |
+
β βββ contracts.json # 4 contracts Β· 8 vulnerabilities Β· 11 properties
|
| 130 |
+
β βββ data_loader.py # JSON parser, episode samplers, T1 + T2 helpers
|
| 131 |
βββ env/
|
| 132 |
β βββ base_env.py # Abstract OpenEnv base class
|
| 133 |
+
β βββ schemas.py # Pydantic: Observation, Action, Reward, StepResultβ¦
|
| 134 |
βββ tasks/
|
| 135 |
β βββ task1/
|
| 136 |
β β βββ environment.py # Full Task 1 RL environment
|
| 137 |
+
β β βββ grader.py # Deterministic 0/0.5/1.0 rubric + longest-match keywords
|
| 138 |
+
β βββ task2/
|
| 139 |
+
β β βββ environment.py # Full Task 2 RL environment (one submit per episode)
|
| 140 |
+
β β βββ grader.py # Keyword-weighted 0.0β1.0 grader + synonym expansion
|
| 141 |
+
β βββ task3/ # TODO: Rule Checker (placeholder)
|
| 142 |
+
βββ app.py # FastAPI server β all OpenEnv HTTP endpoints
|
| 143 |
+
βββ inference.py # Baseline LLM agent (Task 1 + Task 2)
|
| 144 |
+
βββ eval.py # Oracle/partial/random evaluation harness
|
| 145 |
+
βββ demo.py # Colourised interactive + scripted demo
|
| 146 |
+
βββ validate.py # 19-check pre-submission validator
|
| 147 |
+
βββ openenv.yaml # Full OpenEnv spec metadata
|
| 148 |
+
βββ Dockerfile # Port 7860, uvicorn, healthcheck
|
| 149 |
+
βββ requirements.txt
|
| 150 |
```
|
| 151 |
|
| 152 |
---
|
| 153 |
|
| 154 |
## Setup & Usage
|
| 155 |
|
| 156 |
+
### Local Python
|
| 157 |
|
| 158 |
```bash
|
| 159 |
+
git clone <repo> && cd smart-contract-env
|
|
|
|
|
|
|
| 160 |
pip install -r requirements.txt
|
| 161 |
|
| 162 |
+
# Run the server
|
| 163 |
+
python app.py # β http://localhost:7860
|
| 164 |
+
|
| 165 |
+
# Run interactive demo
|
| 166 |
+
python demo.py # Task 1 interactive
|
| 167 |
+
python demo.py --auto # Task 1 scripted
|
| 168 |
+
python demo.py --auto --task 2 # Task 2 scripted (add --task flag)
|
| 169 |
+
|
| 170 |
+
# Run evaluation harness (no LLM needed)
|
| 171 |
+
python eval.py # Both tasks, 8 episodes each
|
| 172 |
+
python eval.py --task 2 # Task 2 only
|
| 173 |
+
python eval.py --episodes 16 --verbose
|
| 174 |
+
|
| 175 |
+
# Pre-submission validation
|
| 176 |
+
python validate.py # 19/19 checks
|
| 177 |
```
|
| 178 |
|
| 179 |
+
### Docker
|
| 180 |
|
| 181 |
```bash
|
| 182 |
docker build -t sc-audit-env .
|
| 183 |
docker run -p 7860:7860 sc-audit-env
|
| 184 |
```
|
| 185 |
|
| 186 |
+
### Direct Python API
|
| 187 |
|
| 188 |
```python
|
| 189 |
from tasks.task1.environment import Task1Environment
|
| 190 |
+
from tasks.task2.environment import Task2Environment
|
| 191 |
from env.schemas import Action, ActionType
|
| 192 |
|
| 193 |
+
# Task 1
|
| 194 |
env = Task1Environment()
|
| 195 |
+
r = env.reset(seed=42)
|
| 196 |
+
print(r.observation.contract_name) # SimpleVault
|
| 197 |
+
s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 198 |
+
s = env.step(Action(action_type=ActionType.SUBMIT,
|
| 199 |
+
params={"function_name": "emergencyDrain",
|
| 200 |
+
"vulnerability_type": "missing access control"}))
|
| 201 |
+
print(s.reward.value) # +5.0
|
| 202 |
+
|
| 203 |
+
# Task 2
|
| 204 |
+
env2 = Task2Environment()
|
| 205 |
+
r2 = env2.reset(seed=42)
|
| 206 |
+
print(r2.observation.extra["target_function"]) # claimRewards
|
| 207 |
+
s2 = env2.step(Action(action_type=ActionType.GET_FUNCTION_NATSPEC))
|
| 208 |
+
s2 = env2.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
|
| 209 |
+
params={"property": "After a successful claimRewards call, all accrued reward tokens are transferred to the caller and their rewards balance is zeroed. Reverts if no rewards."}))
|
| 210 |
+
print(s2.reward.value) # ~4.0
|
| 211 |
```
|
| 212 |
|
| 213 |
---
|
|
|
|
| 217 |
| Method | Endpoint | Description |
|
| 218 |
|--------|----------|-------------|
|
| 219 |
| `GET` | `/health` | Liveness probe |
|
| 220 |
+
| `GET` | `/tasks` | All tasks + status |
|
| 221 |
+
| `POST` | `/reset` | Start episode (`task_id`, `seed`) |
|
| 222 |
+
| `POST` | `/step` | Take action (`action_type`, `params`) |
|
| 223 |
+
| `GET` | `/state` | Debug: internal episode state |
|
| 224 |
+
| `GET` | `/action_space?task_id=...` | Action schema for a task |
|
| 225 |
+
| `GET` | `/observation_space` | Observation schema |
|
|
|
|
|
|
|
| 226 |
|
| 227 |
```bash
|
| 228 |
+
# Task 2 full episode
|
| 229 |
+
curl -X POST localhost:7860/reset \
|
| 230 |
+
-d '{"task_id":"task2_property_discovery","seed":42}'
|
| 231 |
+
|
| 232 |
+
curl -X POST localhost:7860/step \
|
| 233 |
+
-d '{"action_type":"get_function_natspec","params":{}}'
|
| 234 |
+
|
| 235 |
+
curl -X POST localhost:7860/step \
|
| 236 |
+
-d '{"action_type":"submit_property","params":{"property":"..."}}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
```
|
| 238 |
|
| 239 |
---
|
| 240 |
|
| 241 |
+
## Baseline Inference
|
| 242 |
|
| 243 |
```bash
|
| 244 |
export API_BASE_URL="https://api.openai.com/v1"
|
|
|
|
| 246 |
export HF_TOKEN="sk-..."
|
| 247 |
|
| 248 |
python inference.py
|
| 249 |
+
# β baseline_scores.json
|
| 250 |
```
|
| 251 |
|
| 252 |
+
### Expected baseline scores (gpt-4o-mini, 3 episodes per task)
|
|
|
|
|
|
|
| 253 |
|
| 254 |
| Task | Avg Grader Score | Notes |
|
| 255 |
|------|-----------------|-------|
|
| 256 |
+
| Task 1 | ~0.67 | Good at common vulns; misses subtle ones |
|
| 257 |
+
| Task 2 | ~0.55 | Reasonable properties but often misses specific variable names |
|
| 258 |
| Task 3 | 0.00 | Placeholder |
|
| 259 |
|
| 260 |
---
|
| 261 |
|
| 262 |
+
## Evaluation Scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
+
Deterministic oracle / partial / baseline tiers verified on 8 episodes (seeds 42β49):
|
| 265 |
|
| 266 |
+
| Task | Oracle | Partial | Floor |
|
| 267 |
+
|------|--------|---------|-------|
|
| 268 |
+
| Task 1 | **1.000** | 0.500 | 0.000 |
|
| 269 |
+
| Task 2 | **0.775** | 0.034 | 0.000 |
|
| 270 |
|
| 271 |
+
The clear separation confirms the grader provides **meaningful gradient signal** for RL training.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
---
|
| 274 |
|
| 275 |
## Deploying to Hugging Face Spaces
|
| 276 |
|
| 277 |
+
1. Create a new **Docker** Space at [huggingface.co/spaces](https://huggingface.co/spaces)
|
| 278 |
+
2. Add tag `openenv` in the Space settings
|
| 279 |
+
3. Copy the `SPACES_README.md` frontmatter into `README.md`
|
| 280 |
+
4. Push:
|
| 281 |
|
| 282 |
```bash
|
| 283 |
+
git remote add hf https://huggingface.co/spaces/<user>/<space>
|
| 284 |
git push hf main
|
| 285 |
```
|
| 286 |
|
|
|
|
|
|
|
| 287 |
---
|
| 288 |
|
| 289 |
+
## OpenEnv Spec Compliance
|
| 290 |
|
| 291 |
+
| Requirement | Status |
|
| 292 |
+
|-------------|--------|
|
| 293 |
+
| Typed `Observation`, `Action`, `Reward` Pydantic models | β
|
|
| 294 |
+
| `step(action) β StepResult(obs, reward, done, info)` | β
|
|
| 295 |
+
| `reset() β ResetResult` | β
|
|
| 296 |
+
| `state() β StateResult` | β
|
|
| 297 |
+
| `openenv.yaml` metadata | β
|
|
| 298 |
+
| 3+ tasks defined | β
(2 active, 1 placeholder) |
|
| 299 |
+
| Grader scores in [0.0, 1.0] | β
|
|
| 300 |
+
| Shaped rewards (non-binary) | β
|
|
| 301 |
+
| Dockerfile + port 7860 | β
|
|
| 302 |
+
| `inference.py` with OpenAI client | β
|
|
| 303 |
+
| `validate.py` β all 19 checks pass | β
|
|
| 304 |
|
| 305 |
+
---
|
| 306 |
+
|
| 307 |
+
## License
|
| 308 |
|
| 309 |
+
MIT. Contract vulnerability data adapted from Certora audits on production DeFi protocols.
|
app.py
CHANGED
|
@@ -4,31 +4,30 @@ app.py
|
|
| 4 |
FastAPI server exposing the OpenEnv HTTP interface.
|
| 5 |
|
| 6 |
Endpoints:
|
| 7 |
-
POST /reset
|
| 8 |
-
POST /step
|
| 9 |
-
GET /state
|
| 10 |
-
GET /tasks
|
| 11 |
-
GET /health
|
| 12 |
-
GET /action_space
|
| 13 |
-
GET /observation_space
|
| 14 |
-
|
| 15 |
-
Sessions are keyed by a UUID
|
| 16 |
-
If omitted,
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
import uuid
|
| 20 |
from typing import Dict, Optional
|
| 21 |
|
| 22 |
from fastapi import FastAPI, HTTPException, Query
|
| 23 |
-
from fastapi.responses import JSONResponse
|
| 24 |
from pydantic import BaseModel
|
| 25 |
|
| 26 |
from env.schemas import Action, ActionType, TaskInfo
|
| 27 |
from tasks.task1.environment import Task1Environment
|
|
|
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
# App
|
| 31 |
-
#
|
| 32 |
|
| 33 |
app = FastAPI(
|
| 34 |
title="Smart Contract Audit RL Environment",
|
|
@@ -36,38 +35,36 @@ app = FastAPI(
|
|
| 36 |
"OpenEnv-compliant reinforcement learning environment for smart contract "
|
| 37 |
"security analysis. Train and evaluate agents on real-world Solidity audit tasks."
|
| 38 |
),
|
| 39 |
-
version="1.
|
| 40 |
)
|
| 41 |
|
| 42 |
-
#
|
| 43 |
# Session management
|
| 44 |
-
#
|
| 45 |
|
| 46 |
-
_sessions: Dict[str,
|
| 47 |
DEFAULT_SESSION = "default"
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
def _get_or_create_session(session_id: str, task_id: str = "task1_vuln_detection") -> Task1Environment:
|
| 51 |
-
if session_id not in _sessions:
|
| 52 |
-
env = _create_env(task_id)
|
| 53 |
-
_sessions[session_id] = env
|
| 54 |
-
return _sessions[session_id]
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
if
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
detail=f"Unknown task_id '{task_id}'. Available: ['task1_vuln_detection']",
|
| 65 |
-
)
|
| 66 |
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
# Request
|
| 70 |
-
#
|
| 71 |
|
| 72 |
class ResetRequest(BaseModel):
|
| 73 |
task_id: str = "task1_vuln_detection"
|
|
@@ -79,48 +76,39 @@ class StepRequest(BaseModel):
|
|
| 79 |
params: dict = {}
|
| 80 |
|
| 81 |
|
| 82 |
-
#
|
| 83 |
# Routes
|
| 84 |
-
#
|
| 85 |
|
| 86 |
@app.get("/health")
|
| 87 |
def health():
|
| 88 |
-
"""Liveness probe
|
| 89 |
-
return {"status": "ok", "version": "1.
|
| 90 |
|
| 91 |
|
| 92 |
@app.get("/tasks")
|
| 93 |
def list_tasks():
|
| 94 |
-
"""List all
|
| 95 |
tasks = [
|
| 96 |
TaskInfo(
|
| 97 |
task_id="task1_vuln_detection",
|
| 98 |
name="Targeted Vulnerability Detection",
|
| 99 |
difficulty="medium",
|
| 100 |
-
description=
|
| 101 |
-
"Given a Solidity contract, identify the vulnerable function "
|
| 102 |
-
"and describe the vulnerability type in 2-3 words."
|
| 103 |
-
),
|
| 104 |
status="active",
|
| 105 |
),
|
| 106 |
TaskInfo(
|
| 107 |
task_id="task2_property_discovery",
|
| 108 |
name="Property Discovery",
|
| 109 |
difficulty="hard",
|
| 110 |
-
description=
|
| 111 |
-
|
| 112 |
-
"that describes its correct behaviour."
|
| 113 |
-
),
|
| 114 |
-
status="placeholder",
|
| 115 |
),
|
| 116 |
TaskInfo(
|
| 117 |
task_id="task3_rule_checker",
|
| 118 |
name="Rule Checker",
|
| 119 |
difficulty="easy",
|
| 120 |
-
description=
|
| 121 |
-
"Given a property in English, identify which function in the contract "
|
| 122 |
-
"violates that property."
|
| 123 |
-
),
|
| 124 |
status="placeholder",
|
| 125 |
),
|
| 126 |
]
|
|
@@ -144,7 +132,7 @@ def step(
|
|
| 144 |
body: StepRequest,
|
| 145 |
session_id: str = Query(default=DEFAULT_SESSION),
|
| 146 |
):
|
| 147 |
-
"""Apply
|
| 148 |
env = _sessions.get(session_id)
|
| 149 |
if env is None:
|
| 150 |
raise HTTPException(
|
|
@@ -156,8 +144,7 @@ def step(
|
|
| 156 |
except ValueError:
|
| 157 |
raise HTTPException(
|
| 158 |
status_code=400,
|
| 159 |
-
detail=f"Unknown action_type '{body.action_type}'. "
|
| 160 |
-
f"Valid: {[a.value for a in ActionType]}",
|
| 161 |
)
|
| 162 |
action = Action(action_type=action_type, params=body.params)
|
| 163 |
try:
|
|
@@ -169,7 +156,7 @@ def step(
|
|
| 169 |
|
| 170 |
@app.get("/state")
|
| 171 |
def state(session_id: str = Query(default=DEFAULT_SESSION)):
|
| 172 |
-
"""Return
|
| 173 |
env = _sessions.get(session_id)
|
| 174 |
if env is None:
|
| 175 |
raise HTTPException(
|
|
@@ -186,51 +173,26 @@ def action_space(task_id: str = "task1_vuln_detection"):
|
|
| 186 |
return {
|
| 187 |
"task_id": task_id,
|
| 188 |
"actions": [
|
| 189 |
-
{
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
},
|
| 195 |
-
{
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
},
|
| 207 |
-
{
|
| 208 |
-
|
| 209 |
-
"params": {},
|
| 210 |
-
"reward": -0.04,
|
| 211 |
-
"description": "Retrieve contract-level metadata (version, author, description)",
|
| 212 |
-
},
|
| 213 |
-
{
|
| 214 |
-
"type": "get_state_variable",
|
| 215 |
-
"params": {"variable_name": "string (optional)"},
|
| 216 |
-
"reward": -0.05,
|
| 217 |
-
"description": "Retrieve a state variable or list all variables",
|
| 218 |
-
},
|
| 219 |
-
{
|
| 220 |
-
"type": "get_call_graph",
|
| 221 |
-
"params": {},
|
| 222 |
-
"reward": -0.08,
|
| 223 |
-
"description": "Retrieve the function call graph",
|
| 224 |
-
},
|
| 225 |
-
{
|
| 226 |
-
"type": "submit",
|
| 227 |
-
"params": {
|
| 228 |
-
"function_name": "string",
|
| 229 |
-
"vulnerability_type": "string",
|
| 230 |
-
},
|
| 231 |
-
"reward": "+5.0 (correct) / +1.0 (right fn, wrong vuln) / -1.5 (wrong)",
|
| 232 |
-
"description": "Submit your final answer. Ends the episode.",
|
| 233 |
-
},
|
| 234 |
],
|
| 235 |
}
|
| 236 |
return {"error": f"No action space defined for task '{task_id}'"}
|
|
@@ -238,27 +200,27 @@ def action_space(task_id: str = "task1_vuln_detection"):
|
|
| 238 |
|
| 239 |
@app.get("/observation_space")
|
| 240 |
def observation_space():
|
| 241 |
-
"""Describe the observation space."""
|
| 242 |
return {
|
| 243 |
"type": "object",
|
| 244 |
"fields": {
|
| 245 |
-
"task_id":
|
| 246 |
-
"contract_name":
|
| 247 |
"contract_description": "string β what the contract does",
|
| 248 |
-
"available_actions":
|
| 249 |
-
"last_action":
|
| 250 |
-
"last_action_result":
|
| 251 |
-
"step_count":
|
| 252 |
-
"cumulative_reward":
|
| 253 |
-
"done":
|
| 254 |
-
"extra":
|
| 255 |
},
|
| 256 |
}
|
| 257 |
|
| 258 |
|
| 259 |
-
#
|
| 260 |
# Entry point
|
| 261 |
-
#
|
| 262 |
|
| 263 |
if __name__ == "__main__":
|
| 264 |
import uvicorn
|
|
|
|
| 4 |
FastAPI server exposing the OpenEnv HTTP interface.
|
| 5 |
|
| 6 |
Endpoints:
|
| 7 |
+
POST /reset β start a new episode
|
| 8 |
+
POST /step β take one action
|
| 9 |
+
GET /state β inspect internal state (debugging)
|
| 10 |
+
GET /tasks β list available tasks
|
| 11 |
+
GET /health β liveness probe
|
| 12 |
+
GET /action_space β action space description for a task
|
| 13 |
+
GET /observation_space β observation space description
|
| 14 |
+
|
| 15 |
+
Sessions are keyed by a UUID in the `session_id` query parameter.
|
| 16 |
+
If omitted, "default" is used (fine for sequential single-agent runs).
|
| 17 |
"""
|
| 18 |
|
|
|
|
| 19 |
from typing import Dict, Optional
|
| 20 |
|
| 21 |
from fastapi import FastAPI, HTTPException, Query
|
|
|
|
| 22 |
from pydantic import BaseModel
|
| 23 |
|
| 24 |
from env.schemas import Action, ActionType, TaskInfo
|
| 25 |
from tasks.task1.environment import Task1Environment
|
| 26 |
+
from tasks.task2.environment import Task2Environment
|
| 27 |
|
| 28 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
# App
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
|
| 32 |
app = FastAPI(
|
| 33 |
title="Smart Contract Audit RL Environment",
|
|
|
|
| 35 |
"OpenEnv-compliant reinforcement learning environment for smart contract "
|
| 36 |
"security analysis. Train and evaluate agents on real-world Solidity audit tasks."
|
| 37 |
),
|
| 38 |
+
version="1.1.0",
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
# Session management
|
| 43 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
|
| 45 |
+
_sessions: Dict[str, object] = {}
|
| 46 |
DEFAULT_SESSION = "default"
|
| 47 |
|
| 48 |
+
TASK_ENV_MAP = {
|
| 49 |
+
"task1_vuln_detection": Task1Environment,
|
| 50 |
+
"task2_property_discovery": Task2Environment,
|
| 51 |
+
# TODO: "task3_rule_checker": Task3Environment,
|
| 52 |
+
}
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
def _create_env(task_id: str):
|
| 56 |
+
cls = TASK_ENV_MAP.get(task_id)
|
| 57 |
+
if cls is None:
|
| 58 |
+
raise HTTPException(
|
| 59 |
+
status_code=400,
|
| 60 |
+
detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}",
|
| 61 |
+
)
|
| 62 |
+
return cls()
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
# Request bodies
|
| 67 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
|
| 69 |
class ResetRequest(BaseModel):
|
| 70 |
task_id: str = "task1_vuln_detection"
|
|
|
|
| 76 |
params: dict = {}
|
| 77 |
|
| 78 |
|
| 79 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
# Routes
|
| 81 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
|
| 83 |
@app.get("/health")
|
| 84 |
def health():
|
| 85 |
+
"""Liveness probe."""
|
| 86 |
+
return {"status": "ok", "version": "1.1.0"}
|
| 87 |
|
| 88 |
|
| 89 |
@app.get("/tasks")
|
| 90 |
def list_tasks():
|
| 91 |
+
"""List all tasks with their status."""
|
| 92 |
tasks = [
|
| 93 |
TaskInfo(
|
| 94 |
task_id="task1_vuln_detection",
|
| 95 |
name="Targeted Vulnerability Detection",
|
| 96 |
difficulty="medium",
|
| 97 |
+
description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.",
|
|
|
|
|
|
|
|
|
|
| 98 |
status="active",
|
| 99 |
),
|
| 100 |
TaskInfo(
|
| 101 |
task_id="task2_property_discovery",
|
| 102 |
name="Property Discovery",
|
| 103 |
difficulty="hard",
|
| 104 |
+
description="Given a Solidity function, write the natural-language property that describes its correct behaviour.",
|
| 105 |
+
status="active",
|
|
|
|
|
|
|
|
|
|
| 106 |
),
|
| 107 |
TaskInfo(
|
| 108 |
task_id="task3_rule_checker",
|
| 109 |
name="Rule Checker",
|
| 110 |
difficulty="easy",
|
| 111 |
+
description="Given a property in English and a Solidity contract, identify which function violates that property.",
|
|
|
|
|
|
|
|
|
|
| 112 |
status="placeholder",
|
| 113 |
),
|
| 114 |
]
|
|
|
|
| 132 |
body: StepRequest,
|
| 133 |
session_id: str = Query(default=DEFAULT_SESSION),
|
| 134 |
):
|
| 135 |
+
"""Apply one action and advance the episode."""
|
| 136 |
env = _sessions.get(session_id)
|
| 137 |
if env is None:
|
| 138 |
raise HTTPException(
|
|
|
|
| 144 |
except ValueError:
|
| 145 |
raise HTTPException(
|
| 146 |
status_code=400,
|
| 147 |
+
detail=f"Unknown action_type '{body.action_type}'. Valid: {[a.value for a in ActionType]}",
|
|
|
|
| 148 |
)
|
| 149 |
action = Action(action_type=action_type, params=body.params)
|
| 150 |
try:
|
|
|
|
| 156 |
|
| 157 |
@app.get("/state")
|
| 158 |
def state(session_id: str = Query(default=DEFAULT_SESSION)):
|
| 159 |
+
"""Return internal state for debugging (not for agents)."""
|
| 160 |
env = _sessions.get(session_id)
|
| 161 |
if env is None:
|
| 162 |
raise HTTPException(
|
|
|
|
| 173 |
return {
|
| 174 |
"task_id": task_id,
|
| 175 |
"actions": [
|
| 176 |
+
{"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
|
| 177 |
+
{"type": "get_function_code", "params": {"function_name": "string"}, "reward": "+0.05 (target) / -0.10 (other)", "description": "Get full Solidity source of a function"},
|
| 178 |
+
{"type": "get_function_summary", "params": {"function_name": "string"}, "reward": "+0.03 (target) / -0.05 (other)", "description": "Get NatSpec comment of a function"},
|
| 179 |
+
{"type": "get_file_metadata", "params": {}, "reward": -0.04, "description": "Get contract-level metadata"},
|
| 180 |
+
{"type": "get_state_variable", "params": {"variable_name": "string (optional)"}, "reward": -0.05, "description": "Get a state variable or list all"},
|
| 181 |
+
{"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"},
|
| 182 |
+
{"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."},
|
| 183 |
+
],
|
| 184 |
+
}
|
| 185 |
+
if task_id == "task2_property_discovery":
|
| 186 |
+
return {
|
| 187 |
+
"task_id": task_id,
|
| 188 |
+
"actions": [
|
| 189 |
+
{"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"},
|
| 190 |
+
{"type": "get_function_natspec", "params": {}, "reward": -0.08, "description": "Read NatSpec + expected behaviour"},
|
| 191 |
+
{"type": "get_file_natspec", "params": {}, "reward": -0.03, "description": "Read contract-level NatSpec"},
|
| 192 |
+
{"type": "get_related_functions", "params": {}, "reward": -0.06, "description": "List caller/callee functions with summaries"},
|
| 193 |
+
{"type": "get_io", "params": {}, "reward": -0.04, "description": "Get structured I/O + expected behaviour"},
|
| 194 |
+
{"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"},
|
| 195 |
+
{"type": "submit_property", "params": {"property": "string"}, "reward": "0.0β5.0 (scored)", "description": "Submit property. ONE attempt. Ends episode."},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
],
|
| 197 |
}
|
| 198 |
return {"error": f"No action space defined for task '{task_id}'"}
|
|
|
|
| 200 |
|
| 201 |
@app.get("/observation_space")
|
| 202 |
def observation_space():
|
| 203 |
+
"""Describe the observation space (same for all tasks)."""
|
| 204 |
return {
|
| 205 |
"type": "object",
|
| 206 |
"fields": {
|
| 207 |
+
"task_id": "string β active task identifier",
|
| 208 |
+
"contract_name": "string β Solidity contract name",
|
| 209 |
"contract_description": "string β what the contract does",
|
| 210 |
+
"available_actions": "list[string] β valid action types for this task",
|
| 211 |
+
"last_action": "string|null β previous action type",
|
| 212 |
+
"last_action_result": "string|null β human-readable result of last action",
|
| 213 |
+
"step_count": "int β steps taken in this episode",
|
| 214 |
+
"cumulative_reward": "float β running reward total",
|
| 215 |
+
"done": "bool β True when episode is over",
|
| 216 |
+
"extra": "object β task-specific hints (target_function, hint, etc.)",
|
| 217 |
},
|
| 218 |
}
|
| 219 |
|
| 220 |
|
| 221 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
# Entry point
|
| 223 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 224 |
|
| 225 |
if __name__ == "__main__":
|
| 226 |
import uvicorn
|
data/data_loader.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
data_loader.py
|
| 3 |
--------------
|
| 4 |
Loads and indexes smart contract data from JSON files.
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import json
|
|
@@ -16,25 +17,62 @@ DATA_DIR = os.path.join(os.path.dirname(__file__))
|
|
| 16 |
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
|
| 17 |
DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
|
| 20 |
"""Load and return all contracts from the JSON dataset."""
|
| 21 |
with open(path, "r") as f:
|
| 22 |
return json.load(f)
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]:
|
| 26 |
"""Load and return all vulnerability entries from the JSON dataset."""
|
| 27 |
with open(path, "r") as f:
|
| 28 |
return json.load(f)
|
| 29 |
|
| 30 |
-
|
| 31 |
def get_all_vulnerable_entries(
|
| 32 |
contracts: List[Dict[str, Any]],
|
| 33 |
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
| 34 |
"""
|
| 35 |
Returns a flat list of (contract, function) pairs where
|
| 36 |
function['vulnerable'] is True.
|
| 37 |
-
Used by Task 1 to populate the episode pool.
|
| 38 |
"""
|
| 39 |
entries = []
|
| 40 |
for contract in contracts:
|
|
@@ -48,10 +86,7 @@ def sample_episode(
|
|
| 48 |
contracts: List[Dict[str, Any]],
|
| 49 |
rng: Optional[random.Random] = None,
|
| 50 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 51 |
-
"""
|
| 52 |
-
Randomly selects one (contract, vulnerable_function) pair.
|
| 53 |
-
Returns the contract dict and the target function dict.
|
| 54 |
-
"""
|
| 55 |
if rng is None:
|
| 56 |
rng = random.Random()
|
| 57 |
entries = get_all_vulnerable_entries(contracts)
|
|
@@ -60,31 +95,101 @@ def sample_episode(
|
|
| 60 |
return rng.choice(entries)
|
| 61 |
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
"""Case-insensitive function lookup within a contract."""
|
| 67 |
-
for fn in contract.get("functions", []):
|
| 68 |
-
if fn["name"].lower() == name.lower():
|
| 69 |
-
return fn
|
| 70 |
-
return None
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
def get_state_variable_by_name(
|
| 74 |
-
contract: Dict[str, Any], name: str
|
| 75 |
-
) -> Optional[Dict[str, Any]]:
|
| 76 |
-
"""Case-insensitive state variable lookup."""
|
| 77 |
-
for sv in contract.get("state_variables", []):
|
| 78 |
-
if sv["name"].lower() == name.lower():
|
| 79 |
-
return sv
|
| 80 |
-
return None
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
def list_function_names(contract: Dict[str, Any]) -> List[str]:
|
| 84 |
-
"""Return all function names in the contract."""
|
| 85 |
-
return [fn["name"] for fn in contract.get("functions", [])]
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
data_loader.py
|
| 3 |
--------------
|
| 4 |
Loads and indexes smart contract data from JSON files.
|
| 5 |
+
|
| 6 |
+
Task 1 helpers β vulnerable function sampling
|
| 7 |
+
Task 2 helpers β property function sampling, natspec, similar-rule lookup
|
| 8 |
"""
|
| 9 |
|
| 10 |
import json
|
|
|
|
| 17 |
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
|
| 18 |
DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json")
|
| 19 |
|
| 20 |
+
|
| 21 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# Core loaders
|
| 23 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
|
| 26 |
"""Load and return all contracts from the JSON dataset."""
|
| 27 |
with open(path, "r") as f:
|
| 28 |
return json.load(f)
|
| 29 |
|
| 30 |
|
| 31 |
+
def get_function_by_name(
|
| 32 |
+
contract: Dict[str, Any], name: str
|
| 33 |
+
) -> Optional[Dict[str, Any]]:
|
| 34 |
+
"""Case-insensitive function lookup within a contract."""
|
| 35 |
+
for fn in contract.get("functions", []):
|
| 36 |
+
if fn["name"].lower() == name.lower():
|
| 37 |
+
return fn
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_state_variable_by_name(
|
| 42 |
+
contract: Dict[str, Any], name: str
|
| 43 |
+
) -> Optional[Dict[str, Any]]:
|
| 44 |
+
"""Case-insensitive state variable lookup."""
|
| 45 |
+
for sv in contract.get("state_variables", []):
|
| 46 |
+
if sv["name"].lower() == name.lower():
|
| 47 |
+
return sv
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def list_function_names(contract: Dict[str, Any]) -> List[str]:
|
| 52 |
+
"""Return all function names in the contract."""
|
| 53 |
+
return [fn["name"] for fn in contract.get("functions", [])]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
|
| 57 |
+
"""Return all state variable names."""
|
| 58 |
+
return [sv["name"] for sv in contract.get("state_variables", [])]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
# Task 1 helpers
|
| 63 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
|
| 65 |
def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]:
|
| 66 |
"""Load and return all vulnerability entries from the JSON dataset."""
|
| 67 |
with open(path, "r") as f:
|
| 68 |
return json.load(f)
|
| 69 |
|
|
|
|
| 70 |
def get_all_vulnerable_entries(
|
| 71 |
contracts: List[Dict[str, Any]],
|
| 72 |
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
| 73 |
"""
|
| 74 |
Returns a flat list of (contract, function) pairs where
|
| 75 |
function['vulnerable'] is True.
|
|
|
|
| 76 |
"""
|
| 77 |
entries = []
|
| 78 |
for contract in contracts:
|
|
|
|
| 86 |
contracts: List[Dict[str, Any]],
|
| 87 |
rng: Optional[random.Random] = None,
|
| 88 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 89 |
+
"""Randomly selects one (contract, vulnerable_function) pair for Task 1."""
|
|
|
|
|
|
|
|
|
|
| 90 |
if rng is None:
|
| 91 |
rng = random.Random()
|
| 92 |
entries = get_all_vulnerable_entries(contracts)
|
|
|
|
| 95 |
return rng.choice(entries)
|
| 96 |
|
| 97 |
|
| 98 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 99 |
+
# Task 2 helpers
|
| 100 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
def get_all_property_entries(
|
| 103 |
+
contracts: List[Dict[str, Any]],
|
| 104 |
+
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
| 105 |
+
"""
|
| 106 |
+
Returns a flat list of (contract, function) pairs where
|
| 107 |
+
function['property'] is not None.
|
| 108 |
+
Used by Task 2 to populate the episode pool.
|
| 109 |
+
"""
|
| 110 |
+
entries = []
|
| 111 |
+
for contract in contracts:
|
| 112 |
+
for fn in contract.get("functions", []):
|
| 113 |
+
if fn.get("property") is not None:
|
| 114 |
+
entries.append((contract, fn))
|
| 115 |
+
return entries
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
def sample_property_episode(
|
| 119 |
+
contracts: List[Dict[str, Any]],
|
| 120 |
+
rng: Optional[random.Random] = None,
|
| 121 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 122 |
+
"""Randomly selects one (contract, function-with-property) pair for Task 2."""
|
| 123 |
+
if rng is None:
|
| 124 |
+
rng = random.Random()
|
| 125 |
+
entries = get_all_property_entries(contracts)
|
| 126 |
+
if not entries:
|
| 127 |
+
raise ValueError("No functions with properties found in dataset.")
|
| 128 |
+
return rng.choice(entries)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
def get_related_functions(
|
| 132 |
+
contract: Dict[str, Any],
|
| 133 |
+
function_name: str,
|
| 134 |
+
) -> List[str]:
|
| 135 |
+
"""
|
| 136 |
+
Returns function names that are related to the given function:
|
| 137 |
+
- Functions that it calls (from call_graph)
|
| 138 |
+
- Functions that call it (reverse call_graph lookup)
|
| 139 |
+
"""
|
| 140 |
+
name_lower = function_name.lower()
|
| 141 |
+
cg: Dict[str, List[str]] = contract.get("call_graph", {})
|
| 142 |
+
related = set()
|
| 143 |
|
| 144 |
+
# Direct callees (functions called by this function)
|
| 145 |
+
for callee in cg.get(function_name, []):
|
| 146 |
+
# Only include callees that are also functions in this contract
|
| 147 |
+
if get_function_by_name(contract, callee) is not None:
|
| 148 |
+
related.add(callee)
|
| 149 |
+
|
| 150 |
+
# Reverse: functions that call this function
|
| 151 |
+
for caller_name, callees in cg.items():
|
| 152 |
+
if any(c.lower() == name_lower for c in callees):
|
| 153 |
+
if get_function_by_name(contract, caller_name) is not None:
|
| 154 |
+
related.add(caller_name)
|
| 155 |
+
|
| 156 |
+
return sorted(related)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_similar_rule(
|
| 160 |
+
contracts: List[Dict[str, Any]],
|
| 161 |
+
current_contract_name: str,
|
| 162 |
+
current_function_name: str,
|
| 163 |
+
) -> Optional[Dict[str, Any]]:
|
| 164 |
+
"""
|
| 165 |
+
Returns the similar_rule hint stored in the target function's property field,
|
| 166 |
+
enriched with the referenced function's natspec if available.
|
| 167 |
+
|
| 168 |
+
Returns a dict with keys: contract_name, function_name, property_hint, natspec.
|
| 169 |
+
Returns None if no similar_rule is defined.
|
| 170 |
+
"""
|
| 171 |
+
# Find target function
|
| 172 |
+
for contract in contracts:
|
| 173 |
+
if contract["contract_name"] == current_contract_name:
|
| 174 |
+
fn = get_function_by_name(contract, current_function_name)
|
| 175 |
+
if fn and fn.get("property") and fn["property"].get("similar_rule"):
|
| 176 |
+
sr = fn["property"]["similar_rule"]
|
| 177 |
+
# Look up the referenced function's natspec
|
| 178 |
+
for c2 in contracts:
|
| 179 |
+
if c2["contract_name"] == sr["contract_name"]:
|
| 180 |
+
ref_fn = get_function_by_name(c2, sr["function_name"])
|
| 181 |
+
if ref_fn:
|
| 182 |
+
return {
|
| 183 |
+
"contract_name": sr["contract_name"],
|
| 184 |
+
"function_name": sr["function_name"],
|
| 185 |
+
"property_hint": sr["property_hint"],
|
| 186 |
+
"natspec": ref_fn.get("natspec", ""),
|
| 187 |
+
}
|
| 188 |
+
# Referenced function not found β return hint only
|
| 189 |
+
return {
|
| 190 |
+
"contract_name": sr["contract_name"],
|
| 191 |
+
"function_name": sr["function_name"],
|
| 192 |
+
"property_hint": sr["property_hint"],
|
| 193 |
+
"natspec": "",
|
| 194 |
+
}
|
| 195 |
+
return None
|
demo.py
CHANGED
|
@@ -237,12 +237,18 @@ def _print_episode_summary(obs):
|
|
| 237 |
print(f" Steps taken : {obs.step_count}")
|
| 238 |
print(f" Total reward : {colour}{reward:+.2f}{RESET}")
|
| 239 |
last = obs.last_action_result or ""
|
| 240 |
-
if "β
" in last:
|
| 241 |
print(f" {GREEN}Perfect score β full marks!{RESET}")
|
| 242 |
-
elif "β οΈ" in last:
|
| 243 |
-
print(f" {YELLOW}Partial credit
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
print(f" {RED}Incorrect β better luck next episode.{RESET}")
|
|
|
|
|
|
|
| 246 |
print(f"{BOLD}{'β' * 64}{RESET}\n")
|
| 247 |
|
| 248 |
|
|
@@ -285,3 +291,67 @@ def main():
|
|
| 285 |
|
| 286 |
if __name__ == "__main__":
|
| 287 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
print(f" Steps taken : {obs.step_count}")
|
| 238 |
print(f" Total reward : {colour}{reward:+.2f}{RESET}")
|
| 239 |
last = obs.last_action_result or ""
|
| 240 |
+
if "β
CORRECT" in last or "EXCELLENT" in last:
|
| 241 |
print(f" {GREEN}Perfect score β full marks!{RESET}")
|
| 242 |
+
elif "β οΈ" in last or "PARTIAL" in last:
|
| 243 |
+
print(f" {YELLOW}Partial credit.{RESET}")
|
| 244 |
+
elif "π‘ GOOD" in last:
|
| 245 |
+
print(f" {YELLOW}Good β most key concepts matched!{RESET}")
|
| 246 |
+
elif "π " in last:
|
| 247 |
+
print(f" {YELLOW}Partial β some key concepts matched.{RESET}")
|
| 248 |
+
elif "β" in last:
|
| 249 |
print(f" {RED}Incorrect β better luck next episode.{RESET}")
|
| 250 |
+
else:
|
| 251 |
+
print(f" {'Good effort!' if reward > 0 else 'Keep exploring next time.'}")
|
| 252 |
print(f"{BOLD}{'β' * 64}{RESET}\n")
|
| 253 |
|
| 254 |
|
|
|
|
| 291 |
|
| 292 |
if __name__ == "__main__":
|
| 293 |
main()
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
# Task 2 demo
|
| 298 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 299 |
+
|
| 300 |
+
DEMO_SCRIPTS_T2 = {
|
| 301 |
+
42: [
|
| 302 |
+
(ActionType.GET_FUNCTION_NATSPEC, {}, "First, read the NatSpec to understand intent and expected outputs."),
|
| 303 |
+
(ActionType.GET_IO, {}, "Check parameters, return type and expected behaviour."),
|
| 304 |
+
(ActionType.GET_FUNCTION_CODE, {}, "Read the actual Solidity code to confirm the behaviour."),
|
| 305 |
+
(ActionType.SUBMIT_PROPERTY,
|
| 306 |
+
{"property": "After a successful claimRewards call, all of the caller's accrued reward tokens are transferred to the caller and their rewards balance is set to zero. Reverts if the caller has no accrued rewards."},
|
| 307 |
+
"Confident about the property. Submitting!"),
|
| 308 |
+
],
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def run_auto_demo_t2(seed: int = 42, delay: float = 0.9):
|
| 313 |
+
"""Run the scripted Task 2 demo."""
|
| 314 |
+
from tasks.task2.environment import Task2Environment
|
| 315 |
+
|
| 316 |
+
script = DEMO_SCRIPTS_T2.get(seed)
|
| 317 |
+
env = Task2Environment()
|
| 318 |
+
result = env.reset(seed=seed)
|
| 319 |
+
obs = result.observation
|
| 320 |
+
|
| 321 |
+
print()
|
| 322 |
+
print(f"{BOLD}{CYAN}ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 323 |
+
print(f"β Smart Contract Audit RL Env Β· Task 2 Demo β")
|
| 324 |
+
print(f"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ{RESET}")
|
| 325 |
+
print()
|
| 326 |
+
print(f"{BOLD}Mode:{RESET} Automated demo | {BOLD}Seed:{RESET} {seed}")
|
| 327 |
+
print(f"{BOLD}Task:{RESET} Property Discovery")
|
| 328 |
+
print()
|
| 329 |
+
|
| 330 |
+
fn_name = obs.extra.get("target_function", "?")
|
| 331 |
+
sig = obs.extra.get("target_signature", "")
|
| 332 |
+
print(f"{BOLD}Contract :{RESET} {obs.contract_name}")
|
| 333 |
+
print(f"{BOLD}Function :{RESET} {fn_name} ({sig})")
|
| 334 |
+
print(f"{BOLD}Goal :{RESET} Write the natural-language property for '{fn_name}'")
|
| 335 |
+
print(DIVIDER)
|
| 336 |
+
|
| 337 |
+
if not script:
|
| 338 |
+
print(f"{YELLOW}No pre-written script for seed {seed}.{RESET}")
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
for at, params, commentary in script:
|
| 342 |
+
time.sleep(delay)
|
| 343 |
+
print(f"\n{CYAN}βΆ Agent thinking:{RESET} {commentary}")
|
| 344 |
+
time.sleep(delay * 0.5)
|
| 345 |
+
step_result = env.step(Action(action_type=at, params=params))
|
| 346 |
+
sobs = step_result.observation
|
| 347 |
+
print(DIVIDER)
|
| 348 |
+
print(f"{BOLD}Step {sobs.step_count:2d}{RESET} [{at.value}] r={step_result.reward.value:+.2f} cum={sobs.cumulative_reward:+.2f}")
|
| 349 |
+
result_text = sobs.last_action_result or ""
|
| 350 |
+
colour = GREEN if step_result.reward.value > 0 else (YELLOW if step_result.reward.value > -0.15 else YELLOW)
|
| 351 |
+
for line in result_text.split("\n")[:8]:
|
| 352 |
+
print(f" {colour}{line[:90]}{RESET}")
|
| 353 |
+
print(DIVIDER)
|
| 354 |
+
|
| 355 |
+
if step_result.done:
|
| 356 |
+
_print_episode_summary(sobs)
|
| 357 |
+
return
|
env/schemas.py
CHANGED
|
@@ -3,12 +3,12 @@ schemas.py
|
|
| 3 |
----------
|
| 4 |
Typed Pydantic models implementing the OpenEnv interface spec.
|
| 5 |
|
| 6 |
-
Observation
|
| 7 |
-
Action
|
| 8 |
-
StepResult
|
| 9 |
-
ResetResult
|
| 10 |
-
StateResult
|
| 11 |
-
Reward
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
@@ -24,27 +24,28 @@ from pydantic import BaseModel, Field
|
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
|
| 26 |
class ActionType(str, Enum):
|
| 27 |
-
# Task 1 β Vulnerability Detection
|
| 28 |
-
LIST_FUNCTIONS
|
| 29 |
-
GET_FUNCTION_CODE
|
| 30 |
GET_FUNCTION_SUMMARY = "get_function_summary"
|
| 31 |
-
GET_FILE_METADATA
|
| 32 |
-
GET_STATE_VARIABLE
|
| 33 |
-
GET_CALL_GRAPH
|
| 34 |
-
SUBMIT
|
| 35 |
-
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
#
|
|
|
|
| 45 |
# GET_FORMALIZED_PROPERTY = "get_formalized_property"
|
| 46 |
-
# GET_FUNCTION_METADATA
|
| 47 |
-
# SUBMIT_FUNCTION
|
| 48 |
|
| 49 |
|
| 50 |
class Action(BaseModel):
|
|
@@ -54,7 +55,7 @@ class Action(BaseModel):
|
|
| 54 |
action_type : one of ActionType enum values
|
| 55 |
params : optional key/value arguments, e.g.
|
| 56 |
{"function_name": "withdraw"} for GET_FUNCTION_CODE
|
| 57 |
-
{"
|
| 58 |
"""
|
| 59 |
action_type: ActionType
|
| 60 |
params: Dict[str, Any] = Field(default_factory=dict)
|
|
@@ -71,16 +72,16 @@ class Observation(BaseModel):
|
|
| 71 |
"""
|
| 72 |
What the agent receives from the environment.
|
| 73 |
|
| 74 |
-
task_id
|
| 75 |
-
contract_name
|
| 76 |
contract_description : high-level description of what the contract does
|
| 77 |
-
available_actions
|
| 78 |
-
last_action
|
| 79 |
-
last_action_result: human-readable result of the last action
|
| 80 |
-
step_count
|
| 81 |
-
cumulative_reward
|
| 82 |
-
done
|
| 83 |
-
extra
|
| 84 |
"""
|
| 85 |
task_id: str
|
| 86 |
contract_name: str
|
|
@@ -147,4 +148,4 @@ class TaskInfo(BaseModel):
|
|
| 147 |
name: str
|
| 148 |
difficulty: str
|
| 149 |
description: str
|
| 150 |
-
status: str = "active" #
|
|
|
|
| 3 |
----------
|
| 4 |
Typed Pydantic models implementing the OpenEnv interface spec.
|
| 5 |
|
| 6 |
+
Observation β what the agent sees at each step
|
| 7 |
+
Action β what the agent can send
|
| 8 |
+
StepResult β returned by step()
|
| 9 |
+
ResetResult β returned by reset()
|
| 10 |
+
StateResult β returned by state()
|
| 11 |
+
Reward β structured reward info
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
|
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
|
| 26 |
class ActionType(str, Enum):
|
| 27 |
+
# ββ Task 1 β Vulnerability Detection βββββββββββββββββββββββββββββββββββ
|
| 28 |
+
LIST_FUNCTIONS = "list_functions"
|
| 29 |
+
GET_FUNCTION_CODE = "get_function_code"
|
| 30 |
GET_FUNCTION_SUMMARY = "get_function_summary"
|
| 31 |
+
GET_FILE_METADATA = "get_file_metadata"
|
| 32 |
+
GET_STATE_VARIABLE = "get_state_variable"
|
| 33 |
+
GET_CALL_GRAPH = "get_call_graph"
|
| 34 |
+
SUBMIT = "submit"
|
| 35 |
+
|
| 36 |
+
# ββ Task 2 β Property Discovery βββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
GET_SIMILAR_RULE = "get_similar_rule" # -0.20
|
| 38 |
+
GET_FILE_NATSPEC = "get_file_natspec" # -0.03
|
| 39 |
+
GET_FUNCTION_NATSPEC = "get_function_natspec" # -0.08
|
| 40 |
+
GET_RELATED_FUNCTIONS = "get_related_functions" # -0.06
|
| 41 |
+
GET_IO = "get_io" # -0.04
|
| 42 |
+
SUBMIT_PROPERTY = "submit_property" # scored 0β5, one attempt
|
| 43 |
+
|
| 44 |
+
# ββ Task 3 β Rule Checker ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
# TODO: Task 3
|
| 46 |
# GET_FORMALIZED_PROPERTY = "get_formalized_property"
|
| 47 |
+
# GET_FUNCTION_METADATA = "get_function_metadata"
|
| 48 |
+
# SUBMIT_FUNCTION = "submit_function"
|
| 49 |
|
| 50 |
|
| 51 |
class Action(BaseModel):
|
|
|
|
| 55 |
action_type : one of ActionType enum values
|
| 56 |
params : optional key/value arguments, e.g.
|
| 57 |
{"function_name": "withdraw"} for GET_FUNCTION_CODE
|
| 58 |
+
{"property": "..."} for SUBMIT_PROPERTY
|
| 59 |
"""
|
| 60 |
action_type: ActionType
|
| 61 |
params: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
| 72 |
"""
|
| 73 |
What the agent receives from the environment.
|
| 74 |
|
| 75 |
+
task_id : which task is active
|
| 76 |
+
contract_name : name of the Solidity contract
|
| 77 |
contract_description : high-level description of what the contract does
|
| 78 |
+
available_actions : list of valid ActionType strings
|
| 79 |
+
last_action : the action that produced this observation (None on reset)
|
| 80 |
+
last_action_result : human-readable result of the last action
|
| 81 |
+
step_count : number of steps taken so far
|
| 82 |
+
cumulative_reward : running reward total
|
| 83 |
+
done : whether the episode has ended
|
| 84 |
+
extra : any additional task-specific context
|
| 85 |
"""
|
| 86 |
task_id: str
|
| 87 |
contract_name: str
|
|
|
|
| 148 |
name: str
|
| 149 |
difficulty: str
|
| 150 |
description: str
|
| 151 |
+
status: str = "active" # "active" | "placeholder"
|
eval.py
CHANGED
|
@@ -3,264 +3,276 @@ eval.py
|
|
| 3 |
-------
|
| 4 |
Evaluation harness for the Smart Contract Audit RL Environment.
|
| 5 |
|
| 6 |
-
Runs
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
Unlike inference.py (which uses an external LLM), this evaluates the
|
| 10 |
-
*environment itself* using a built-in oracle agent β useful for:
|
| 11 |
-
- Verifying grader correctness
|
| 12 |
-
- Benchmarking reward shaping
|
| 13 |
-
- Checking score distribution across vulnerability types
|
| 14 |
|
| 15 |
Usage:
|
| 16 |
-
python eval.py
|
| 17 |
-
python eval.py --
|
| 18 |
-
python eval.py --
|
| 19 |
-
python eval.py --
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
import argparse
|
| 23 |
import json
|
| 24 |
import sys
|
| 25 |
-
import time
|
| 26 |
from typing import Any, Dict, List
|
| 27 |
|
| 28 |
from tasks.task1.environment import Task1Environment
|
|
|
|
| 29 |
from env.schemas import Action, ActionType
|
| 30 |
-
from data.data_loader import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
-
#
|
| 35 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
This gives an upper-bound score trajectory for the environment.
|
| 45 |
-
Always ends with grader_score = 1.0.
|
| 46 |
-
"""
|
| 47 |
-
reset_result = env.reset(seed=seed)
|
| 48 |
-
obs = reset_result.observation
|
| 49 |
-
|
| 50 |
-
steps_taken: List[Dict[str, Any]] = []
|
| 51 |
-
|
| 52 |
-
def _step(at: ActionType, params: dict = None) -> Any:
|
| 53 |
-
params = params or {}
|
| 54 |
-
action = Action(action_type=at, params=params)
|
| 55 |
-
result = env.step(action)
|
| 56 |
-
entry = {
|
| 57 |
-
"step": result.observation.step_count,
|
| 58 |
-
"action": at.value,
|
| 59 |
-
"params": params,
|
| 60 |
-
"reward": result.reward.value,
|
| 61 |
-
"reason": result.reward.reason,
|
| 62 |
-
"cumulative": result.observation.cumulative_reward,
|
| 63 |
-
"done": result.done,
|
| 64 |
-
}
|
| 65 |
-
steps_taken.append(entry)
|
| 66 |
-
if verbose:
|
| 67 |
-
done_flag = " [DONE]" if result.done else ""
|
| 68 |
-
print(
|
| 69 |
-
f" step {entry['step']:2d}: {at.value:25s} "
|
| 70 |
-
f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}"
|
| 71 |
-
f"{done_flag}"
|
| 72 |
-
)
|
| 73 |
-
return result
|
| 74 |
-
|
| 75 |
-
# Peek at ground truth (oracle only)
|
| 76 |
-
state = env.state()
|
| 77 |
-
target_fn = state.target_function
|
| 78 |
-
|
| 79 |
-
# Get ground-truth vulnerability from data
|
| 80 |
contracts = load_contracts()
|
| 81 |
-
vuln_issue =
|
| 82 |
-
for
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
vuln_issue = fn["vulnerability_details"]["issue"]
|
| 87 |
-
break
|
| 88 |
-
if vuln_issue:
|
| 89 |
break
|
| 90 |
|
| 91 |
if verbose:
|
| 92 |
-
print(f"
|
| 93 |
-
print(f" Target : {target_fn} ({vuln_issue})")
|
| 94 |
-
|
| 95 |
-
# Step 1: list functions (small cost, realistic)
|
| 96 |
-
_step(ActionType.LIST_FUNCTIONS)
|
| 97 |
-
# Step 2: read target function code (gets +0.05 shaping reward)
|
| 98 |
-
_step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn})
|
| 99 |
-
# Step 3: submit perfect answer
|
| 100 |
-
result = _step(ActionType.SUBMIT, {
|
| 101 |
-
"function_name": target_fn,
|
| 102 |
-
"vulnerability_type": vuln_issue,
|
| 103 |
-
})
|
| 104 |
-
|
| 105 |
-
final_reward = result.reward.value
|
| 106 |
-
if final_reward >= 4.9:
|
| 107 |
-
grader_score = 1.0
|
| 108 |
-
elif final_reward >= 0.9:
|
| 109 |
-
grader_score = 0.5
|
| 110 |
-
else:
|
| 111 |
-
grader_score = 0.0
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
return {
|
| 114 |
"seed": seed,
|
| 115 |
"contract": obs.contract_name,
|
| 116 |
-
"target_function":
|
| 117 |
"vulnerability": vuln_issue,
|
| 118 |
-
"grader_score":
|
| 119 |
"cumulative_reward": result.observation.cumulative_reward,
|
| 120 |
-
"steps": steps_taken,
|
| 121 |
-
"num_steps": len(steps_taken),
|
| 122 |
}
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
-
#
|
| 127 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
|
| 129 |
-
def
|
| 130 |
-
"""Submits
|
| 131 |
-
|
| 132 |
-
obs =
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
return {
|
| 142 |
"seed": seed,
|
| 143 |
-
"
|
|
|
|
|
|
|
| 144 |
"cumulative_reward": result.observation.cumulative_reward,
|
| 145 |
}
|
| 146 |
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
def
|
| 153 |
-
"""
|
| 154 |
env.reset(seed=seed)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
result = env.step(action)
|
| 160 |
-
return {
|
| 161 |
-
"seed": seed,
|
| 162 |
-
"grader_score": 0.0,
|
| 163 |
-
"cumulative_reward": result.observation.cumulative_reward,
|
| 164 |
-
}
|
| 165 |
|
| 166 |
|
| 167 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
-
# Evaluation
|
| 169 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
|
| 171 |
-
def
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
output_file: str = "eval_results.json",
|
| 176 |
-
) -> None:
|
| 177 |
-
env = Task1Environment()
|
| 178 |
contracts = load_contracts()
|
| 179 |
-
entries
|
| 180 |
-
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
print("
|
| 185 |
-
|
| 186 |
-
print(f" Seed range: {seed_offset} β {seed_offset + num_episodes - 1}")
|
| 187 |
-
print(f" Vulns in dataset: {len(entries)}")
|
| 188 |
-
print()
|
| 189 |
-
|
| 190 |
-
# ββ Oracle agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 191 |
-
print("βΆ Oracle agent (upper bound β always submits correct answer):")
|
| 192 |
-
oracle_episodes = []
|
| 193 |
-
for i in range(num_episodes):
|
| 194 |
-
seed = seed_offset + i
|
| 195 |
-
ep = oracle_agent(env, seed=seed, verbose=verbose)
|
| 196 |
-
oracle_episodes.append(ep)
|
| 197 |
-
icon = "β
" if ep["grader_score"] == 1.0 else "β οΈ "
|
| 198 |
-
print(
|
| 199 |
-
f" {icon} seed={seed:3d} {ep['contract']:12s} "
|
| 200 |
-
f"{ep['target_function']:15s} score={ep['grader_score']:.1f} "
|
| 201 |
-
f"reward={ep['cumulative_reward']:+.2f}"
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes
|
| 205 |
-
oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes
|
| 206 |
-
print(f"\n Oracle avg grader score : {oracle_avg:.3f}")
|
| 207 |
-
print(f" Oracle avg reward : {oracle_avg_r:+.2f}")
|
| 208 |
-
|
| 209 |
-
# ββ Partial agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
-
print("\nβΆ Partial agent (right function, wrong vuln type β 0.5 each):")
|
| 211 |
-
partial_episodes = []
|
| 212 |
-
for i in range(num_episodes):
|
| 213 |
-
ep = partial_agent(env, seed=seed_offset + i)
|
| 214 |
-
partial_episodes.append(ep)
|
| 215 |
-
partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes
|
| 216 |
-
print(f" Partial avg grader score: {partial_avg:.3f}")
|
| 217 |
-
|
| 218 |
-
# ββ Random agent ββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββ
|
| 219 |
-
print("\nβΆ Random agent (always wrong β 0.0 each):")
|
| 220 |
-
random_episodes = []
|
| 221 |
for i in range(num_episodes):
|
| 222 |
-
ep =
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
print("\n
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
v = ep.get("vulnerability", "unknown")
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
print("\n" + "=" * 64)
|
| 238 |
-
print("
|
| 239 |
print("=" * 64)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
"
|
| 253 |
-
"
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
}
|
| 261 |
-
with open(output_file, "w") as f:
|
| 262 |
-
json.dump(report, f, indent=2)
|
| 263 |
-
print(f"\n Results written to {output_file}")
|
| 264 |
|
| 265 |
|
| 266 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -268,23 +280,50 @@ def run_evaluation(
|
|
| 268 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
|
| 270 |
def main():
|
| 271 |
-
parser = argparse.ArgumentParser(
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
parser.add_argument("--
|
| 275 |
-
help="
|
| 276 |
-
parser.add_argument("--
|
| 277 |
-
help="
|
| 278 |
-
parser.add_argument("--
|
| 279 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
args = parser.parse_args()
|
| 281 |
|
| 282 |
-
|
| 283 |
-
num_episodes
|
| 284 |
-
seed_offset
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
|
|
|
| 3 |
-------
|
| 4 |
Evaluation harness for the Smart Contract Audit RL Environment.
|
| 5 |
|
| 6 |
+
Runs oracle / partial / baseline agents against Task 1 and Task 2,
|
| 7 |
+
verifying that grader scores form a clear ordering and that reward
|
| 8 |
+
shaping is meaningful.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
+
python eval.py # Task 1 + Task 2, 8 episodes each
|
| 12 |
+
python eval.py --task 1 # Task 1 only
|
| 13 |
+
python eval.py --task 2 # Task 2 only
|
| 14 |
+
python eval.py --episodes 16 # more episodes
|
| 15 |
+
python eval.py --seed 0 --verbose # detailed per-step trace
|
| 16 |
+
python eval.py --out results.json # custom output file
|
| 17 |
"""
|
| 18 |
|
| 19 |
import argparse
|
| 20 |
import json
|
| 21 |
import sys
|
|
|
|
| 22 |
from typing import Any, Dict, List
|
| 23 |
|
| 24 |
from tasks.task1.environment import Task1Environment
|
| 25 |
+
from tasks.task2.environment import Task2Environment
|
| 26 |
from env.schemas import Action, ActionType
|
| 27 |
+
from data.data_loader import (
|
| 28 |
+
load_contracts,
|
| 29 |
+
get_function_by_name,
|
| 30 |
+
get_all_vulnerable_entries,
|
| 31 |
+
)
|
| 32 |
|
| 33 |
|
| 34 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
# Task 1 agents
|
| 36 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
|
| 38 |
+
def oracle_t1(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
|
| 39 |
+
"""Always submits the exact ground-truth answer β score = 1.0."""
|
| 40 |
+
r = env.reset(seed=seed)
|
| 41 |
+
obs = r.observation
|
| 42 |
+
st = env.state()
|
| 43 |
+
fn_name = st.target_function
|
| 44 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
contracts = load_contracts()
|
| 46 |
+
vuln_issue = ""
|
| 47 |
+
for c in contracts:
|
| 48 |
+
fn = get_function_by_name(c, fn_name)
|
| 49 |
+
if fn and fn.get("vulnerable"):
|
| 50 |
+
vuln_issue = fn["vulnerability_details"]["issue"]
|
|
|
|
|
|
|
|
|
|
| 51 |
break
|
| 52 |
|
| 53 |
if verbose:
|
| 54 |
+
print(f" {obs.contract_name}.{fn_name}() [{vuln_issue}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 57 |
+
env.step(Action(action_type=ActionType.GET_FUNCTION_CODE,
|
| 58 |
+
params={"function_name": fn_name}))
|
| 59 |
+
result = env.step(Action(action_type=ActionType.SUBMIT,
|
| 60 |
+
params={"function_name": fn_name,
|
| 61 |
+
"vulnerability_type": vuln_issue}))
|
| 62 |
+
|
| 63 |
+
v = result.reward.value
|
| 64 |
+
score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
|
| 65 |
return {
|
| 66 |
"seed": seed,
|
| 67 |
"contract": obs.contract_name,
|
| 68 |
+
"target_function": fn_name,
|
| 69 |
"vulnerability": vuln_issue,
|
| 70 |
+
"grader_score": score,
|
| 71 |
"cumulative_reward": result.observation.cumulative_reward,
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
|
| 75 |
+
def partial_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
|
| 76 |
+
"""Right function, wrong vuln type β score = 0.5."""
|
| 77 |
+
env.reset(seed=seed)
|
| 78 |
+
fn_name = env.state().target_function
|
| 79 |
+
result = env.step(Action(action_type=ActionType.SUBMIT,
|
| 80 |
+
params={"function_name": fn_name,
|
| 81 |
+
"vulnerability_type": "unknown"}))
|
| 82 |
+
v = result.reward.value
|
| 83 |
+
return {"seed": seed, "grader_score": 0.5 if v >= 0.9 else 0.0,
|
| 84 |
+
"cumulative_reward": result.observation.cumulative_reward}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def random_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
|
| 88 |
+
"""Always submits 'constructor' β score = 0.0."""
|
| 89 |
+
env.reset(seed=seed)
|
| 90 |
+
result = env.step(Action(action_type=ActionType.SUBMIT,
|
| 91 |
+
params={"function_name": "constructor",
|
| 92 |
+
"vulnerability_type": "reentrancy"}))
|
| 93 |
+
return {"seed": seed, "grader_score": 0.0,
|
| 94 |
+
"cumulative_reward": result.observation.cumulative_reward}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
# Task 2 agents
|
| 99 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
|
| 101 |
+
def oracle_t2(env: Task2Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
|
| 102 |
+
"""Submits the exact ground-truth natural_language β score β₯ 0.70."""
|
| 103 |
+
r = env.reset(seed=seed)
|
| 104 |
+
obs = r.observation
|
| 105 |
+
fn_name = obs.extra["target_function"]
|
| 106 |
+
contract = obs.contract_name
|
| 107 |
+
|
| 108 |
+
contracts = load_contracts()
|
| 109 |
+
gt_text = ""
|
| 110 |
+
for c in contracts:
|
| 111 |
+
if c["contract_name"] == contract:
|
| 112 |
+
fn = get_function_by_name(c, fn_name)
|
| 113 |
+
if fn and fn.get("property"):
|
| 114 |
+
gt_text = fn["property"]["natural_language"]
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
if verbose:
|
| 118 |
+
print(f" {contract}.{fn_name}()")
|
| 119 |
+
|
| 120 |
+
# read code first (realistic browsing step)
|
| 121 |
+
env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
|
| 122 |
+
result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
|
| 123 |
+
params={"property": gt_text}))
|
| 124 |
+
|
| 125 |
+
r_val = result.reward.value
|
| 126 |
+
score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
|
| 127 |
return {
|
| 128 |
"seed": seed,
|
| 129 |
+
"contract": contract,
|
| 130 |
+
"function": fn_name,
|
| 131 |
+
"grader_score": score,
|
| 132 |
"cumulative_reward": result.observation.cumulative_reward,
|
| 133 |
}
|
| 134 |
|
| 135 |
|
| 136 |
+
def partial_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
|
| 137 |
+
"""Submits the function's NatSpec comment β partial credit."""
|
| 138 |
+
r = env.reset(seed=seed)
|
| 139 |
+
obs = r.observation
|
| 140 |
+
contracts = load_contracts()
|
| 141 |
+
comment = ""
|
| 142 |
+
for c in contracts:
|
| 143 |
+
if c["contract_name"] == obs.contract_name:
|
| 144 |
+
fn = get_function_by_name(c, obs.extra["target_function"])
|
| 145 |
+
if fn:
|
| 146 |
+
comment = fn.get("comment", "")
|
| 147 |
+
break
|
| 148 |
+
result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
|
| 149 |
+
params={"property": comment}))
|
| 150 |
+
r_val = result.reward.value
|
| 151 |
+
score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
|
| 152 |
+
return {"seed": seed, "grader_score": score,
|
| 153 |
+
"cumulative_reward": result.observation.cumulative_reward}
|
| 154 |
+
|
| 155 |
|
| 156 |
+
def empty_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
|
| 157 |
+
"""Submits empty string β score = 0.0."""
|
| 158 |
env.reset(seed=seed)
|
| 159 |
+
result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
|
| 160 |
+
params={"property": ""}))
|
| 161 |
+
return {"seed": seed, "grader_score": 0.0,
|
| 162 |
+
"cumulative_reward": result.observation.cumulative_reward}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
# Evaluation runners
|
| 167 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
|
| 169 |
+
def run_task1_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
|
| 170 |
+
print("\n" + "=" * 64)
|
| 171 |
+
print("TASK 1 β Targeted Vulnerability Detection")
|
| 172 |
+
print("=" * 64)
|
|
|
|
|
|
|
|
|
|
| 173 |
contracts = load_contracts()
|
| 174 |
+
entries = get_all_vulnerable_entries(contracts)
|
| 175 |
+
print(f" Dataset: {len(contracts)} contracts, {len(entries)} vulnerable functions\n")
|
| 176 |
|
| 177 |
+
env = Task1Environment()
|
| 178 |
+
|
| 179 |
+
print("βΆ Oracle agent (always submits correct answer):")
|
| 180 |
+
oracle_eps = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
for i in range(num_episodes):
|
| 182 |
+
ep = oracle_t1(env, seed_offset + i, verbose=verbose)
|
| 183 |
+
oracle_eps.append(ep)
|
| 184 |
+
print(f" seed={ep['seed']:3d} {ep['contract']:12s}.{ep['target_function']:18s}"
|
| 185 |
+
f" score={ep['grader_score']:.1f} reward={ep['cumulative_reward']:+.2f}")
|
| 186 |
+
oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
|
| 187 |
+
oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
|
| 188 |
+
print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
|
| 189 |
+
|
| 190 |
+
print("\nβΆ Partial agent (right function, wrong vuln type β 0.5):")
|
| 191 |
+
partial_eps = [partial_t1(env, seed_offset + i) for i in range(num_episodes)]
|
| 192 |
+
partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
|
| 193 |
+
print(f" Partial avg score: {partial_avg:.3f}")
|
| 194 |
+
|
| 195 |
+
print("\nβΆ Random agent (always wrong β 0.0):")
|
| 196 |
+
random_eps = [random_t1(env, seed_offset + i) for i in range(num_episodes)]
|
| 197 |
+
random_avg = sum(e["grader_score"] for e in random_eps) / num_episodes
|
| 198 |
+
print(f" Random avg score : {random_avg:.3f}")
|
| 199 |
+
|
| 200 |
+
vuln_seen: Dict[str, int] = {}
|
| 201 |
+
for ep in oracle_eps:
|
| 202 |
v = ep.get("vulnerability", "unknown")
|
| 203 |
+
vuln_seen[v] = vuln_seen.get(v, 0) + 1
|
| 204 |
+
print("\nβΆ Vulnerability type coverage:")
|
| 205 |
+
for v in sorted(vuln_seen):
|
| 206 |
+
print(f" {vuln_seen[v]:2d}Γ {v}")
|
| 207 |
+
|
| 208 |
+
assert oracle_avg == 1.0, f"Oracle should be 1.0, got {oracle_avg}"
|
| 209 |
+
assert partial_avg == 0.5, f"Partial should be 0.5, got {partial_avg}"
|
| 210 |
+
assert random_avg == 0.0, f"Random should be 0.0, got {random_avg}"
|
| 211 |
+
print("\n β
Task 1 score ordering: oracle(1.0) > partial(0.5) > random(0.0)")
|
| 212 |
|
| 213 |
+
return {
|
| 214 |
+
"task_id": "task1_vuln_detection",
|
| 215 |
+
"oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
|
| 216 |
+
"partial": {"avg_score": partial_avg, "episodes": partial_eps},
|
| 217 |
+
"random": {"avg_score": random_avg, "episodes": random_eps},
|
| 218 |
+
"vuln_coverage": vuln_seen,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def run_task2_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
|
| 223 |
print("\n" + "=" * 64)
|
| 224 |
+
print("TASK 2 β Property Discovery")
|
| 225 |
print("=" * 64)
|
| 226 |
+
from data.data_loader import get_all_property_entries
|
| 227 |
+
contracts = load_contracts()
|
| 228 |
+
entries = get_all_property_entries(contracts)
|
| 229 |
+
print(f" Dataset: {len(entries)} functions with properties\n")
|
| 230 |
+
|
| 231 |
+
env = Task2Environment()
|
| 232 |
+
|
| 233 |
+
print("βΆ Oracle agent (submits ground-truth natural language):")
|
| 234 |
+
oracle_eps = []
|
| 235 |
+
for i in range(num_episodes):
|
| 236 |
+
ep = oracle_t2(env, seed_offset + i, verbose=verbose)
|
| 237 |
+
oracle_eps.append(ep)
|
| 238 |
+
icon = "β
" if ep["grader_score"] >= 0.65 else "β οΈ "
|
| 239 |
+
print(f" {icon} seed={ep['seed']:3d} {ep['contract']:12s}.{ep['function']:18s}"
|
| 240 |
+
f" score={ep['grader_score']:.3f} reward={ep['cumulative_reward']:+.2f}")
|
| 241 |
+
oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
|
| 242 |
+
oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
|
| 243 |
+
print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
|
| 244 |
+
|
| 245 |
+
print("\nβΆ Partial agent (submits NatSpec comment β partial signal):")
|
| 246 |
+
partial_eps = [partial_t2(env, seed_offset + i) for i in range(num_episodes)]
|
| 247 |
+
partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
|
| 248 |
+
partial_avg_r = sum(e["cumulative_reward"] for e in partial_eps) / num_episodes
|
| 249 |
+
print(f" Partial avg score: {partial_avg:.3f} avg reward: {partial_avg_r:+.2f}")
|
| 250 |
+
|
| 251 |
+
print("\nβΆ Empty agent (submits nothing β 0.0):")
|
| 252 |
+
empty_eps = [empty_t2(env, seed_offset + i) for i in range(num_episodes)]
|
| 253 |
+
empty_avg = sum(e["grader_score"] for e in empty_eps) / num_episodes
|
| 254 |
+
print(f" Empty avg score : {empty_avg:.3f}")
|
| 255 |
+
|
| 256 |
+
fn_seen: Dict[str, int] = {}
|
| 257 |
+
for ep in oracle_eps:
|
| 258 |
+
fn_seen[ep["function"]] = fn_seen.get(ep["function"], 0) + 1
|
| 259 |
+
print("\nβΆ Function coverage:")
|
| 260 |
+
for fn in sorted(fn_seen):
|
| 261 |
+
print(f" {fn_seen[fn]:2d}Γ {fn}")
|
| 262 |
+
|
| 263 |
+
assert oracle_avg > 0.60, f"Oracle avg {oracle_avg:.3f} should be > 0.60"
|
| 264 |
+
assert oracle_avg > partial_avg, "Oracle should beat partial"
|
| 265 |
+
assert partial_avg >= empty_avg, "Partial should be >= empty"
|
| 266 |
+
assert empty_avg == 0.0, f"Empty should be 0.0, got {empty_avg}"
|
| 267 |
+
print(f"\n β
Task 2 score ordering: oracle({oracle_avg:.3f}) > partial({partial_avg:.3f}) > empty(0.0)")
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
"task_id": "task2_property_discovery",
|
| 271 |
+
"oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
|
| 272 |
+
"partial": {"avg_score": partial_avg, "avg_reward": partial_avg_r, "episodes": partial_eps},
|
| 273 |
+
"empty": {"avg_score": empty_avg, "episodes": empty_eps},
|
| 274 |
+
"fn_coverage": fn_seen,
|
| 275 |
}
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 280 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
|
| 282 |
def main():
|
| 283 |
+
parser = argparse.ArgumentParser(
|
| 284 |
+
description="Evaluate Task 1 and/or Task 2 of the SC Audit RL Environment"
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument("--episodes", type=int, default=8,
|
| 287 |
+
help="Episodes per agent tier (default: 8)")
|
| 288 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 289 |
+
help="Starting RNG seed (default: 42)")
|
| 290 |
+
parser.add_argument("--task", choices=["1", "2", "all"], default="all",
|
| 291 |
+
help="Which task(s) to evaluate (default: all)")
|
| 292 |
+
parser.add_argument("--verbose", action="store_true",
|
| 293 |
+
help="Print per-episode target details")
|
| 294 |
+
parser.add_argument("--out", default="eval_results.json",
|
| 295 |
+
help="Output file (default: eval_results.json)")
|
| 296 |
args = parser.parse_args()
|
| 297 |
|
| 298 |
+
report: Dict[str, Any] = {
|
| 299 |
+
"num_episodes": args.episodes,
|
| 300 |
+
"seed_offset": args.seed,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
if args.task in ("1", "all"):
|
| 304 |
+
report["task1"] = run_task1_eval(args.episodes, args.seed, args.verbose)
|
| 305 |
+
|
| 306 |
+
if args.task in ("2", "all"):
|
| 307 |
+
report["task2"] = run_task2_eval(args.episodes, args.seed, args.verbose)
|
| 308 |
+
|
| 309 |
+
# ββ Summary ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
+
print("\n" + "=" * 64)
|
| 311 |
+
print("EVALUATION COMPLETE")
|
| 312 |
+
print("=" * 64)
|
| 313 |
+
if "task1" in report:
|
| 314 |
+
t1 = report["task1"]
|
| 315 |
+
print(f" Task 1 oracle={t1['oracle']['avg_score']:.3f} "
|
| 316 |
+
f"partial={t1['partial']['avg_score']:.3f} "
|
| 317 |
+
f"random={t1['random']['avg_score']:.3f}")
|
| 318 |
+
if "task2" in report:
|
| 319 |
+
t2 = report["task2"]
|
| 320 |
+
print(f" Task 2 oracle={t2['oracle']['avg_score']:.3f} "
|
| 321 |
+
f"partial={t2['partial']['avg_score']:.3f} "
|
| 322 |
+
f"empty={t2['empty']['avg_score']:.3f}")
|
| 323 |
+
|
| 324 |
+
with open(args.out, "w") as f:
|
| 325 |
+
json.dump(report, f, indent=2)
|
| 326 |
+
print(f"\n Results written to {args.out}")
|
| 327 |
|
| 328 |
|
| 329 |
if __name__ == "__main__":
|
inference.py
CHANGED
|
@@ -2,14 +2,13 @@
|
|
| 2 |
inference.py
|
| 3 |
------------
|
| 4 |
Baseline inference script for the Smart Contract Audit RL Environment.
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
API_BASE_URL β LLM endpoint (e.g. https://api.openai.com/v1)
|
| 11 |
-
MODEL_NAME β model identifier (e.g. gpt-4o-mini)
|
| 12 |
-
HF_TOKEN β API key (passed as Authorization: Bearer <HF_TOKEN>)
|
| 13 |
|
| 14 |
Usage:
|
| 15 |
python inference.py
|
|
@@ -17,306 +16,290 @@ Usage:
|
|
| 17 |
Output:
|
| 18 |
Per-task scores printed to stdout.
|
| 19 |
Final baseline scores written to baseline_scores.json.
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
import json
|
| 23 |
import os
|
| 24 |
import sys
|
| 25 |
import time
|
| 26 |
-
from typing import Any, Dict, List
|
| 27 |
|
| 28 |
from openai import OpenAI
|
| 29 |
|
| 30 |
-
# ---------------------------------------------------------------------------
|
| 31 |
-
# Import the env directly (no HTTP overhead for baseline)
|
| 32 |
-
# ---------------------------------------------------------------------------
|
| 33 |
from tasks.task1.environment import Task1Environment
|
|
|
|
| 34 |
from env.schemas import Action, ActionType
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
#
|
| 38 |
-
#
|
| 39 |
|
| 40 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 41 |
-
MODEL_NAME
|
| 42 |
-
HF_TOKEN
|
| 43 |
|
| 44 |
if not HF_TOKEN:
|
| 45 |
-
print("WARNING: HF_TOKEN
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
#
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
1. list_functions
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
β {"action": "get_call_graph", "params": {}}
|
| 91 |
-
|
| 92 |
-
7. submit (ENDS THE EPISODE)
|
| 93 |
-
β {"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 word description>"}}
|
| 94 |
-
|
| 95 |
-
## Strategy
|
| 96 |
-
- Start with list_functions and get_file_metadata to understand the contract
|
| 97 |
-
- Inspect suspicious functions (withdraw, transfer, emergency*, stake, etc.)
|
| 98 |
-
- Submit when you are confident about the vulnerable function
|
| 99 |
-
|
| 100 |
-
## Output Format
|
| 101 |
-
Always respond with a single JSON object:
|
| 102 |
-
{"action": "<action_type>", "params": {...}}
|
| 103 |
-
Do NOT include any other text β only valid JSON.
|
| 104 |
-
"""
|
| 105 |
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
f"Description: {obs['contract_description']}",
|
| 112 |
-
f"Step: {obs['step_count']} | Cumulative reward: {obs['cumulative_reward']:.2f}",
|
| 113 |
-
"",
|
| 114 |
-
f"Last action: {obs['last_action'] or 'None'}",
|
| 115 |
-
f"Result: {obs['last_action_result'] or 'Episode just started'}",
|
| 116 |
-
"",
|
| 117 |
-
f"Available actions: {', '.join(obs['available_actions'])}",
|
| 118 |
-
]
|
| 119 |
-
if obs.get("extra", {}).get("hint"):
|
| 120 |
-
lines.append(f"Hint: {obs['extra']['hint']}")
|
| 121 |
-
return "\n".join(lines)
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
print(f" Contract: {obs['contract_name']}")
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
|
| 147 |
-
|
|
|
|
| 148 |
try:
|
| 149 |
-
|
| 150 |
-
model=MODEL_NAME,
|
| 151 |
-
|
| 152 |
-
max_tokens=256,
|
| 153 |
-
temperature=0.0,
|
| 154 |
)
|
| 155 |
-
raw =
|
| 156 |
except Exception as e:
|
| 157 |
-
print(f"
|
| 158 |
break
|
| 159 |
|
| 160 |
-
# Parse action
|
| 161 |
try:
|
| 162 |
parsed = json.loads(raw)
|
| 163 |
-
|
| 164 |
params = parsed.get("params", {})
|
| 165 |
-
except Exception
|
| 166 |
-
|
| 167 |
-
# Default safe action
|
| 168 |
-
action_type = ActionType.LIST_FUNCTIONS
|
| 169 |
-
params = {}
|
| 170 |
|
| 171 |
-
action = Action(action_type=action_type, params=params)
|
| 172 |
messages.append({"role": "assistant", "content": raw})
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
print(
|
| 182 |
-
f" Step {step_num+1}: {action_type.value} | "
|
| 183 |
-
f"reward={step_result.reward.value:+.2f} | "
|
| 184 |
-
f"cumulative={final_reward:.2f}"
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
if done:
|
| 188 |
-
# Determine grader score from reward
|
| 189 |
-
last_reward = step_result.reward.value
|
| 190 |
-
if last_reward >= 4.9:
|
| 191 |
-
final_score = 1.0
|
| 192 |
-
elif last_reward >= 0.9:
|
| 193 |
-
final_score = 0.5
|
| 194 |
-
else:
|
| 195 |
-
final_score = 0.0
|
| 196 |
-
print(f" β DONE | grader_score={final_score:.1f}")
|
| 197 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
if not done:
|
| 200 |
-
print(f" β MAX STEPS reached without submission. Score=0.0")
|
| 201 |
-
|
| 202 |
-
return {
|
| 203 |
-
"episode": episode_num,
|
| 204 |
-
"seed": seed,
|
| 205 |
-
"contract": obs["contract_name"],
|
| 206 |
-
"steps": steps,
|
| 207 |
-
"cumulative_reward": final_reward,
|
| 208 |
-
"grader_score": final_score,
|
| 209 |
-
"done": done,
|
| 210 |
-
}
|
| 211 |
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
def run_task1(
|
| 214 |
-
"""Run Task 1 and return aggregate scores."""
|
| 215 |
print("\n" + "="*60)
|
| 216 |
print("TASK 1: Targeted Vulnerability Detection")
|
| 217 |
print("="*60)
|
| 218 |
-
|
| 219 |
env = Task1Environment()
|
| 220 |
-
episodes = []
|
| 221 |
-
|
| 222 |
-
for
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
scores = [e["grader_score"] for e in episodes]
|
| 229 |
-
avg = sum(scores) / len(scores) if scores else 0.0
|
| 230 |
-
avg_reward = sum(e["cumulative_reward"] for e in episodes) / len(episodes)
|
| 231 |
-
|
| 232 |
-
print(f"\n Task 1 Results:")
|
| 233 |
-
print(f" Episodes: {num_episodes}")
|
| 234 |
-
print(f" Grader scores: {scores}")
|
| 235 |
-
print(f" Average grader score: {avg:.3f}")
|
| 236 |
-
print(f" Average cumulative reward: {avg_reward:.2f}")
|
| 237 |
-
|
| 238 |
-
return {
|
| 239 |
-
"task_id": "task1_vuln_detection",
|
| 240 |
-
"name": "Targeted Vulnerability Detection",
|
| 241 |
-
"status": "active",
|
| 242 |
-
"num_episodes": num_episodes,
|
| 243 |
-
"episodes": episodes,
|
| 244 |
-
"avg_grader_score": avg,
|
| 245 |
-
"avg_cumulative_reward": avg_reward,
|
| 246 |
-
}
|
| 247 |
|
| 248 |
|
| 249 |
-
def
|
| 250 |
-
"""Task 2 placeholder β returns 0.0 score."""
|
| 251 |
print("\n" + "="*60)
|
| 252 |
-
print("TASK 2: Property Discovery
|
| 253 |
print("="*60)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
}
|
| 264 |
|
| 265 |
|
| 266 |
def run_task3_placeholder() -> Dict[str, Any]:
|
| 267 |
-
"""Task 3 placeholder β returns 0.0 score."""
|
| 268 |
print("\n" + "="*60)
|
| 269 |
print("TASK 3: Rule Checker [PLACEHOLDER β not implemented]")
|
| 270 |
print("="*60)
|
| 271 |
print(" Skipping. Score: 0.0")
|
| 272 |
-
return {
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
"status": "placeholder",
|
| 276 |
-
"num_episodes": 0,
|
| 277 |
-
"episodes": [],
|
| 278 |
-
"avg_grader_score": 0.0,
|
| 279 |
-
"avg_cumulative_reward": 0.0,
|
| 280 |
-
}
|
| 281 |
|
| 282 |
|
| 283 |
-
#
|
| 284 |
# Main
|
| 285 |
-
#
|
| 286 |
|
| 287 |
def main():
|
| 288 |
print("Smart Contract Audit RL Environment β Baseline Inference")
|
| 289 |
print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
"base_url": API_BASE_URL,
|
| 294 |
-
"tasks": [],
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
t1 = run_task1(num_episodes=NUM_EPISODES)
|
| 298 |
-
t2 = run_task2_placeholder()
|
| 299 |
t3 = run_task3_placeholder()
|
| 300 |
|
| 301 |
-
results
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
overall = (
|
| 306 |
-
sum(t["avg_grader_score"] for t in active_tasks) / len(active_tasks)
|
| 307 |
-
if active_tasks else 0.0
|
| 308 |
-
)
|
| 309 |
results["overall_avg_score"] = overall
|
| 310 |
|
| 311 |
print("\n" + "="*60)
|
| 312 |
print("BASELINE SUMMARY")
|
| 313 |
print("="*60)
|
| 314 |
for t in results["tasks"]:
|
| 315 |
-
|
| 316 |
-
print(f" {
|
| 317 |
-
print(f" Overall (active tasks): {overall:.3f}")
|
| 318 |
|
| 319 |
-
# Write scores file
|
| 320 |
with open("baseline_scores.json", "w") as f:
|
| 321 |
json.dump(results, f, indent=2)
|
| 322 |
print("\n Scores written to baseline_scores.json")
|
|
|
|
| 2 |
inference.py
|
| 3 |
------------
|
| 4 |
Baseline inference script for the Smart Contract Audit RL Environment.
|
| 5 |
+
Implements Task 1 (Vulnerability Detection) and Task 2 (Property Discovery).
|
| 6 |
+
Task 3 is a placeholder that returns 0.0.
|
| 7 |
|
| 8 |
+
Environment variables:
|
| 9 |
+
API_BASE_URL β LLM API endpoint (e.g. https://api.openai.com/v1)
|
| 10 |
+
MODEL_NAME β model identifier (e.g. gpt-4o-mini)
|
| 11 |
+
HF_TOKEN β API key
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
Usage:
|
| 14 |
python inference.py
|
|
|
|
| 16 |
Output:
|
| 17 |
Per-task scores printed to stdout.
|
| 18 |
Final baseline scores written to baseline_scores.json.
|
| 19 |
+
|
| 20 |
+
Runtime: < 5 minutes on 3 episodes per task with gpt-4o-mini.
|
| 21 |
"""
|
| 22 |
|
| 23 |
import json
|
| 24 |
import os
|
| 25 |
import sys
|
| 26 |
import time
|
| 27 |
+
from typing import Any, Dict, List
|
| 28 |
|
| 29 |
from openai import OpenAI
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
from tasks.task1.environment import Task1Environment
|
| 32 |
+
from tasks.task2.environment import Task2Environment
|
| 33 |
from env.schemas import Action, ActionType
|
| 34 |
|
| 35 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
# Configuration
|
| 37 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
|
| 39 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 40 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 41 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 42 |
|
| 43 |
if not HF_TOKEN:
|
| 44 |
+
print("WARNING: HF_TOKEN not set. API calls may fail.", file=sys.stderr)
|
| 45 |
+
|
| 46 |
+
MAX_STEPS_T1 = 15
|
| 47 |
+
MAX_STEPS_T2 = 10
|
| 48 |
+
NUM_EPISODES = 3
|
| 49 |
+
SEED_BASE_T1 = 42
|
| 50 |
+
SEED_BASE_T2 = 10
|
| 51 |
+
|
| 52 |
+
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 53 |
+
|
| 54 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
# Task 1 agent
|
| 56 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
T1_SYSTEM = """You are an expert Solidity smart contract security auditor.
|
| 59 |
+
|
| 60 |
+
Given a contract, identify the ONE vulnerable function and its vulnerability type.
|
| 61 |
+
|
| 62 |
+
## Actions (choose ONE per turn, respond with JSON only):
|
| 63 |
+
{"action": "list_functions", "params": {}}
|
| 64 |
+
{"action": "get_function_code", "params": {"function_name": "<name>"}}
|
| 65 |
+
{"action": "get_function_summary", "params": {"function_name": "<name>"}}
|
| 66 |
+
{"action": "get_file_metadata", "params": {}}
|
| 67 |
+
{"action": "get_state_variable", "params": {"variable_name": "<name>"}}
|
| 68 |
+
{"action": "get_call_graph", "params": {}}
|
| 69 |
+
{"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 words>"}}
|
| 70 |
+
|
| 71 |
+
## Strategy:
|
| 72 |
+
1. list_functions first to see the attack surface
|
| 73 |
+
2. Inspect suspicious functions (withdraw, drain, buy, stake, claim, setPrice, bid, finalize)
|
| 74 |
+
3. Look for: reentrancy, missing access control, integer overflow, tx.origin, front-running,
|
| 75 |
+
timestamp dependence, denial of service, unchecked return value
|
| 76 |
+
4. Submit when confident
|
| 77 |
+
|
| 78 |
+
Respond ONLY with valid JSON. No explanation, no markdown."""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _t1_user_msg(obs: Dict[str, Any]) -> str:
|
| 82 |
+
return (
|
| 83 |
+
f"Contract: {obs['contract_name']}\n"
|
| 84 |
+
f"Description: {obs['contract_description']}\n"
|
| 85 |
+
f"Step: {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n\n"
|
| 86 |
+
f"Last action: {obs['last_action'] or 'None'}\n"
|
| 87 |
+
f"Result: {obs['last_action_result'] or 'Episode started.'}"
|
| 88 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
+
def run_t1_episode(env: Task1Environment, seed: int, ep: int) -> Dict[str, Any]:
|
| 92 |
+
r = env.reset(seed=seed)
|
| 93 |
+
obs = r.observation.model_dump()
|
| 94 |
+
print(f" ep={ep} seed={seed} contract={obs['contract_name']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
messages = [{"role": "system", "content": T1_SYSTEM}]
|
| 97 |
+
grader_score = 0.0
|
| 98 |
+
cum_reward = 0.0
|
| 99 |
|
| 100 |
+
for step in range(MAX_STEPS_T1):
|
| 101 |
+
messages.append({"role": "user", "content": _t1_user_msg(obs)})
|
| 102 |
+
try:
|
| 103 |
+
resp = client.chat.completions.create(
|
| 104 |
+
model=MODEL_NAME, messages=messages,
|
| 105 |
+
max_tokens=200, temperature=0.0,
|
| 106 |
+
)
|
| 107 |
+
raw = resp.choices[0].message.content.strip()
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f" LLM error: {e}", file=sys.stderr)
|
| 110 |
+
break
|
| 111 |
|
| 112 |
+
try:
|
| 113 |
+
parsed = json.loads(raw)
|
| 114 |
+
at = ActionType(parsed["action"])
|
| 115 |
+
params = parsed.get("params", {})
|
| 116 |
+
except Exception:
|
| 117 |
+
at, params = ActionType.LIST_FUNCTIONS, {}
|
| 118 |
|
| 119 |
+
messages.append({"role": "assistant", "content": raw})
|
| 120 |
+
result = env.step(Action(action_type=at, params=params))
|
| 121 |
+
obs = result.observation.model_dump()
|
| 122 |
+
print(f" step {step+1:2d}: {at.value:25s} r={result.reward.value:+.2f}")
|
| 123 |
+
|
| 124 |
+
if result.done:
|
| 125 |
+
v = result.reward.value
|
| 126 |
+
grader_score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
|
| 127 |
+
cum_reward = obs["cumulative_reward"]
|
| 128 |
+
break
|
| 129 |
+
time.sleep(0.3)
|
| 130 |
+
|
| 131 |
+
print(f" β grader_score={grader_score:.1f} cum_reward={cum_reward:.2f}")
|
| 132 |
+
return {"episode": ep, "seed": seed, "contract": obs["contract_name"],
|
| 133 |
+
"grader_score": grader_score, "cumulative_reward": cum_reward}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 137 |
+
# Task 2 agent
|
| 138 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
+
|
| 140 |
+
T2_SYSTEM = """You are a formal methods engineer specialising in Solidity smart contracts.
|
| 141 |
+
|
| 142 |
+
You will be shown a specific Solidity function. Your task is to write a precise
|
| 143 |
+
natural-language property (invariant / postcondition) that describes what the
|
| 144 |
+
function guarantees when it succeeds.
|
| 145 |
+
|
| 146 |
+
A good property covers:
|
| 147 |
+
- What state changes (balances, counters, flags)
|
| 148 |
+
- What assets are transferred (ETH, tokens, NFTs)
|
| 149 |
+
- What return value is produced (for view functions)
|
| 150 |
+
- Under what conditions it reverts
|
| 151 |
+
|
| 152 |
+
## Actions (respond with JSON only, ONE action per turn):
|
| 153 |
+
{"action": "get_function_code", "params": {}}
|
| 154 |
+
{"action": "get_function_natspec", "params": {}}
|
| 155 |
+
{"action": "get_file_natspec", "params": {}}
|
| 156 |
+
{"action": "get_related_functions", "params": {}}
|
| 157 |
+
{"action": "get_io", "params": {}}
|
| 158 |
+
{"action": "get_similar_rule", "params": {}}
|
| 159 |
+
{"action": "submit_property", "params": {"property": "<your full property text>"}}
|
| 160 |
+
|
| 161 |
+
## Rules:
|
| 162 |
+
- You have ONE submit_property attempt. Make it count.
|
| 163 |
+
- Use get_function_natspec and get_io first β they give the most signal.
|
| 164 |
+
- get_similar_rule costs more (-0.20) but shows a parallel property from another contract.
|
| 165 |
+
- Write 2β4 sentences. Be specific about variable names and amounts.
|
| 166 |
+
- Do NOT guess β read the code first.
|
| 167 |
+
|
| 168 |
+
Respond ONLY with valid JSON. No markdown, no explanation."""
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _t2_user_msg(obs: Dict[str, Any]) -> str:
|
| 172 |
+
extra = obs.get("extra", {})
|
| 173 |
+
return (
|
| 174 |
+
f"Contract : {obs['contract_name']}\n"
|
| 175 |
+
f"Function : {extra.get('target_function', '?')} "
|
| 176 |
+
f"({extra.get('target_signature', '')})\n"
|
| 177 |
+
f"Step: {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n\n"
|
| 178 |
+
f"Last action: {obs['last_action'] or 'None'}\n"
|
| 179 |
+
f"Result:\n{obs['last_action_result'] or 'Episode started β begin exploring.'}"
|
| 180 |
+
)
|
| 181 |
|
|
|
|
| 182 |
|
| 183 |
+
def run_t2_episode(env: Task2Environment, seed: int, ep: int) -> Dict[str, Any]:
|
| 184 |
+
r = env.reset(seed=seed)
|
| 185 |
+
obs = r.observation.model_dump()
|
| 186 |
+
fn = obs["extra"].get("target_function", "?")
|
| 187 |
+
print(f" ep={ep} seed={seed} {obs['contract_name']}.{fn}()")
|
| 188 |
|
| 189 |
+
messages = [{"role": "system", "content": T2_SYSTEM}]
|
| 190 |
+
grader_score = 0.0
|
| 191 |
+
cum_reward = 0.0
|
| 192 |
|
| 193 |
+
for step in range(MAX_STEPS_T2):
|
| 194 |
+
messages.append({"role": "user", "content": _t2_user_msg(obs)})
|
| 195 |
try:
|
| 196 |
+
resp = client.chat.completions.create(
|
| 197 |
+
model=MODEL_NAME, messages=messages,
|
| 198 |
+
max_tokens=400, temperature=0.0,
|
|
|
|
|
|
|
| 199 |
)
|
| 200 |
+
raw = resp.choices[0].message.content.strip()
|
| 201 |
except Exception as e:
|
| 202 |
+
print(f" LLM error: {e}", file=sys.stderr)
|
| 203 |
break
|
| 204 |
|
|
|
|
| 205 |
try:
|
| 206 |
parsed = json.loads(raw)
|
| 207 |
+
at = ActionType(parsed["action"])
|
| 208 |
params = parsed.get("params", {})
|
| 209 |
+
except Exception:
|
| 210 |
+
at, params = ActionType.GET_FUNCTION_CODE, {}
|
|
|
|
|
|
|
|
|
|
| 211 |
|
|
|
|
| 212 |
messages.append({"role": "assistant", "content": raw})
|
| 213 |
+
result = env.step(Action(action_type=at, params=params))
|
| 214 |
+
obs = result.observation.model_dump()
|
| 215 |
+
r_val = result.reward.value
|
| 216 |
+
print(f" step {step+1:2d}: {at.value:25s} r={r_val:+.2f}")
|
| 217 |
+
|
| 218 |
+
if result.done:
|
| 219 |
+
grader_score = round(r_val / 5.0, 3) if r_val > 0 else 0.0
|
| 220 |
+
cum_reward = obs["cumulative_reward"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
break
|
| 222 |
+
time.sleep(0.3)
|
| 223 |
+
|
| 224 |
+
print(f" β grader_score={grader_score:.3f} cum_reward={cum_reward:.2f}")
|
| 225 |
+
return {"episode": ep, "seed": seed,
|
| 226 |
+
"contract": obs["contract_name"], "function": fn,
|
| 227 |
+
"grader_score": grader_score, "cumulative_reward": cum_reward}
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 231 |
+
# Task runners
|
| 232 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
|
| 234 |
+
def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
|
|
|
| 235 |
print("\n" + "="*60)
|
| 236 |
print("TASK 1: Targeted Vulnerability Detection")
|
| 237 |
print("="*60)
|
|
|
|
| 238 |
env = Task1Environment()
|
| 239 |
+
episodes = [run_t1_episode(env, SEED_BASE_T1 + i, i+1) for i in range(n)]
|
| 240 |
+
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 241 |
+
avg_r = sum(e["cumulative_reward"] for e in episodes) / n
|
| 242 |
+
print(f"\n Avg grader score : {avg_s:.3f}")
|
| 243 |
+
print(f" Avg cum reward : {avg_r:.2f}")
|
| 244 |
+
return {"task_id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection",
|
| 245 |
+
"status": "active", "num_episodes": n, "episodes": episodes,
|
| 246 |
+
"avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
|
| 249 |
+
def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
|
|
|
| 250 |
print("\n" + "="*60)
|
| 251 |
+
print("TASK 2: Property Discovery")
|
| 252 |
print("="*60)
|
| 253 |
+
env = Task2Environment()
|
| 254 |
+
episodes = [run_t2_episode(env, SEED_BASE_T2 + i, i+1) for i in range(n)]
|
| 255 |
+
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 256 |
+
avg_r = sum(e["cumulative_reward"] for e in episodes) / n
|
| 257 |
+
print(f"\n Avg grader score : {avg_s:.3f}")
|
| 258 |
+
print(f" Avg cum reward : {avg_r:.2f}")
|
| 259 |
+
return {"task_id": "task2_property_discovery", "name": "Property Discovery",
|
| 260 |
+
"status": "active", "num_episodes": n, "episodes": episodes,
|
| 261 |
+
"avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
def run_task3_placeholder() -> Dict[str, Any]:
|
|
|
|
| 265 |
print("\n" + "="*60)
|
| 266 |
print("TASK 3: Rule Checker [PLACEHOLDER β not implemented]")
|
| 267 |
print("="*60)
|
| 268 |
print(" Skipping. Score: 0.0")
|
| 269 |
+
return {"task_id": "task3_rule_checker", "name": "Rule Checker",
|
| 270 |
+
"status": "placeholder", "num_episodes": 0, "episodes": [],
|
| 271 |
+
"avg_grader_score": 0.0, "avg_cumulative_reward": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
|
| 274 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 275 |
# Main
|
| 276 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 277 |
|
| 278 |
def main():
|
| 279 |
print("Smart Contract Audit RL Environment β Baseline Inference")
|
| 280 |
print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
|
| 281 |
|
| 282 |
+
t1 = run_task1(NUM_EPISODES)
|
| 283 |
+
t2 = run_task2(NUM_EPISODES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
t3 = run_task3_placeholder()
|
| 285 |
|
| 286 |
+
results = {
|
| 287 |
+
"model": MODEL_NAME, "base_url": API_BASE_URL,
|
| 288 |
+
"tasks": [t1, t2, t3],
|
| 289 |
+
}
|
| 290 |
|
| 291 |
+
active = [t for t in results["tasks"] if t["status"] == "active"]
|
| 292 |
+
overall = sum(t["avg_grader_score"] for t in active) / len(active) if active else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
results["overall_avg_score"] = overall
|
| 294 |
|
| 295 |
print("\n" + "="*60)
|
| 296 |
print("BASELINE SUMMARY")
|
| 297 |
print("="*60)
|
| 298 |
for t in results["tasks"]:
|
| 299 |
+
icon = "β
" if t["status"] == "active" else "β³"
|
| 300 |
+
print(f" {icon} {t['name']:40s}: {t['avg_grader_score']:.3f}")
|
| 301 |
+
print(f"\n Overall (active tasks): {overall:.3f}")
|
| 302 |
|
|
|
|
| 303 |
with open("baseline_scores.json", "w") as f:
|
| 304 |
json.dump(results, f, indent=2)
|
| 305 |
print("\n Scores written to baseline_scores.json")
|
openenv.yaml
CHANGED
|
@@ -1,25 +1,22 @@
|
|
| 1 |
name: smart-contract-audit-env
|
| 2 |
-
version: "1.
|
| 3 |
description: >
|
| 4 |
Reinforcement learning environment for smart contract security analysis.
|
| 5 |
Agents interact with real-world Solidity contract data from Certora-audited
|
| 6 |
-
projects, learning to detect vulnerabilities
|
| 7 |
-
|
| 8 |
|
| 9 |
author: "SmartAudit Team"
|
| 10 |
license: MIT
|
| 11 |
|
| 12 |
-
# ---------------------------------------------------------------------------
|
| 13 |
-
# Tasks
|
| 14 |
-
# ---------------------------------------------------------------------------
|
| 15 |
tasks:
|
| 16 |
- id: task1_vuln_detection
|
| 17 |
name: Targeted Vulnerability Detection
|
| 18 |
difficulty: medium
|
| 19 |
status: active
|
| 20 |
description: >
|
| 21 |
-
Given a Solidity contract (4
|
| 22 |
-
function and describe its vulnerability type in 2
|
| 23 |
max_steps: 20
|
| 24 |
reward_range: [-10.0, 10.0]
|
| 25 |
grader: tasks/task1/grader.py
|
|
@@ -28,13 +25,13 @@ tasks:
|
|
| 28 |
- id: task2_property_discovery
|
| 29 |
name: Property Discovery
|
| 30 |
difficulty: hard
|
| 31 |
-
status:
|
| 32 |
description: >
|
| 33 |
Given a single Solidity function with known properties, discover the
|
| 34 |
-
correct natural-language
|
| 35 |
max_steps: 15
|
| 36 |
reward_range: [-5.0, 5.0]
|
| 37 |
-
grader: tasks/task2/grader.py
|
| 38 |
grader_score_range: [0.0, 1.0]
|
| 39 |
|
| 40 |
- id: task3_rule_checker
|
|
@@ -46,81 +43,52 @@ tasks:
|
|
| 46 |
function that violates that property.
|
| 47 |
max_steps: 15
|
| 48 |
reward_range: [-5.0, 5.0]
|
| 49 |
-
grader: tasks/task3/grader.py
|
| 50 |
grader_score_range: [0.0, 1.0]
|
| 51 |
|
| 52 |
-
# ---------------------------------------------------------------------------
|
| 53 |
-
# Observation space
|
| 54 |
-
# ---------------------------------------------------------------------------
|
| 55 |
observation_space:
|
| 56 |
type: object
|
| 57 |
properties:
|
| 58 |
-
task_id:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
type: array
|
| 69 |
-
items:
|
| 70 |
-
type: string
|
| 71 |
-
description: List of valid action type strings
|
| 72 |
-
last_action:
|
| 73 |
-
type: string
|
| 74 |
-
nullable: true
|
| 75 |
-
description: The action type that produced this observation
|
| 76 |
-
last_action_result:
|
| 77 |
-
type: string
|
| 78 |
-
nullable: true
|
| 79 |
-
description: Human-readable result of the last action
|
| 80 |
-
step_count:
|
| 81 |
-
type: integer
|
| 82 |
-
description: Number of steps taken in this episode
|
| 83 |
-
cumulative_reward:
|
| 84 |
-
type: number
|
| 85 |
-
description: Running reward total for this episode
|
| 86 |
-
done:
|
| 87 |
-
type: boolean
|
| 88 |
-
description: True when the episode has ended
|
| 89 |
-
extra:
|
| 90 |
-
type: object
|
| 91 |
-
description: Task-specific hints and auxiliary data
|
| 92 |
|
| 93 |
-
# ---------------------------------------------------------------------------
|
| 94 |
-
# Action space (Task 1)
|
| 95 |
-
# ---------------------------------------------------------------------------
|
| 96 |
action_space:
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
# ---------------------------------------------------------------------------
|
| 115 |
-
# Reward function
|
| 116 |
-
# ---------------------------------------------------------------------------
|
| 117 |
reward:
|
| 118 |
type: shaped
|
| 119 |
description: >
|
| 120 |
-
Per-step costs encourage efficient exploration.
|
| 121 |
-
when the agent
|
| 122 |
-
|
| 123 |
-
|
| 124 |
list_functions: -0.05
|
| 125 |
get_function_code_wrong: -0.10
|
| 126 |
get_function_code_correct: +0.05
|
|
@@ -130,19 +98,28 @@ reward:
|
|
| 130 |
get_state_variable: -0.05
|
| 131 |
get_call_graph: -0.08
|
| 132 |
repeated_query: -0.40
|
| 133 |
-
|
| 134 |
correct_submission: +5.0
|
| 135 |
partial_submission: +1.0
|
| 136 |
wrong_submission: -1.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
# ---------------------------------------------------------------------------
|
| 139 |
-
# Data
|
| 140 |
-
# ---------------------------------------------------------------------------
|
| 141 |
data:
|
| 142 |
-
source: "Certora audited
|
| 143 |
format: JSON
|
| 144 |
num_contracts: 4
|
| 145 |
num_vulnerable_functions: 8
|
|
|
|
| 146 |
vulnerability_types:
|
| 147 |
- Reentrancy
|
| 148 |
- Missing access control
|
|
@@ -153,17 +130,16 @@ data:
|
|
| 153 |
- Denial of service (unbounded loop)
|
| 154 |
- Unchecked return value
|
| 155 |
|
| 156 |
-
# ---------------------------------------------------------------------------
|
| 157 |
-
# Interface
|
| 158 |
-
# ---------------------------------------------------------------------------
|
| 159 |
interface:
|
| 160 |
http:
|
| 161 |
-
reset:
|
| 162 |
-
step:
|
| 163 |
-
state:
|
| 164 |
-
tasks:
|
| 165 |
-
health:
|
|
|
|
|
|
|
| 166 |
python:
|
| 167 |
-
reset: env.reset(seed=None)
|
| 168 |
-
step:
|
| 169 |
-
state: env.state()
|
|
|
|
| 1 |
name: smart-contract-audit-env
|
| 2 |
+
version: "1.1.0"
|
| 3 |
description: >
|
| 4 |
Reinforcement learning environment for smart contract security analysis.
|
| 5 |
Agents interact with real-world Solidity contract data from Certora-audited
|
| 6 |
+
projects, learning to detect vulnerabilities and discover correctness
|
| 7 |
+
properties β tasks that professional auditors perform daily.
|
| 8 |
|
| 9 |
author: "SmartAudit Team"
|
| 10 |
license: MIT
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
tasks:
|
| 13 |
- id: task1_vuln_detection
|
| 14 |
name: Targeted Vulnerability Detection
|
| 15 |
difficulty: medium
|
| 16 |
status: active
|
| 17 |
description: >
|
| 18 |
+
Given a Solidity contract (4-6 functions), identify the single vulnerable
|
| 19 |
+
function and describe its vulnerability type in 2-3 words.
|
| 20 |
max_steps: 20
|
| 21 |
reward_range: [-10.0, 10.0]
|
| 22 |
grader: tasks/task1/grader.py
|
|
|
|
| 25 |
- id: task2_property_discovery
|
| 26 |
name: Property Discovery
|
| 27 |
difficulty: hard
|
| 28 |
+
status: active
|
| 29 |
description: >
|
| 30 |
Given a single Solidity function with known properties, discover the
|
| 31 |
+
correct natural-language postcondition describing its correct behaviour.
|
| 32 |
max_steps: 15
|
| 33 |
reward_range: [-5.0, 5.0]
|
| 34 |
+
grader: tasks/task2/grader.py
|
| 35 |
grader_score_range: [0.0, 1.0]
|
| 36 |
|
| 37 |
- id: task3_rule_checker
|
|
|
|
| 43 |
function that violates that property.
|
| 44 |
max_steps: 15
|
| 45 |
reward_range: [-5.0, 5.0]
|
| 46 |
+
grader: tasks/task3/grader.py
|
| 47 |
grader_score_range: [0.0, 1.0]
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
observation_space:
|
| 50 |
type: object
|
| 51 |
properties:
|
| 52 |
+
task_id: {type: string, description: Active task identifier}
|
| 53 |
+
contract_name: {type: string, description: Solidity contract name}
|
| 54 |
+
contract_description: {type: string, description: Human-readable contract description}
|
| 55 |
+
available_actions: {type: array, items: {type: string}, description: Valid action types}
|
| 56 |
+
last_action: {type: string, nullable: true}
|
| 57 |
+
last_action_result: {type: string, nullable: true}
|
| 58 |
+
step_count: {type: integer}
|
| 59 |
+
cumulative_reward: {type: number}
|
| 60 |
+
done: {type: boolean}
|
| 61 |
+
extra: {type: object, description: Task-specific hints}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
action_space:
|
| 64 |
+
task1:
|
| 65 |
+
type: object
|
| 66 |
+
actions:
|
| 67 |
+
list_functions: {params: {}, reward: -0.05}
|
| 68 |
+
get_function_code: {params: {function_name: string}, reward: "+0.05 / -0.10"}
|
| 69 |
+
get_function_summary: {params: {function_name: string}, reward: "+0.03 / -0.05"}
|
| 70 |
+
get_file_metadata: {params: {}, reward: -0.04}
|
| 71 |
+
get_state_variable: {params: {variable_name: "string (opt)"}, reward: -0.05}
|
| 72 |
+
get_call_graph: {params: {}, reward: -0.08}
|
| 73 |
+
submit: {params: {function_name: str, vulnerability_type: str}, reward: "+5.0 / +1.0 / -1.5"}
|
| 74 |
+
task2:
|
| 75 |
+
type: object
|
| 76 |
+
actions:
|
| 77 |
+
get_function_code: {params: {}, reward: -0.06}
|
| 78 |
+
get_function_natspec: {params: {}, reward: -0.08}
|
| 79 |
+
get_file_natspec: {params: {}, reward: -0.03}
|
| 80 |
+
get_related_functions: {params: {}, reward: -0.06}
|
| 81 |
+
get_io: {params: {}, reward: -0.04}
|
| 82 |
+
get_similar_rule: {params: {}, reward: -0.20}
|
| 83 |
+
submit_property: {params: {property: string}, reward: "0.0β5.0 (keyword-weighted)"}
|
| 84 |
|
|
|
|
|
|
|
|
|
|
| 85 |
reward:
|
| 86 |
type: shaped
|
| 87 |
description: >
|
| 88 |
+
Per-step costs encourage efficient exploration. Positive shaping rewards
|
| 89 |
+
fire when the agent inspects the actual target. Terminal rewards reflect
|
| 90 |
+
grader score accuracy.
|
| 91 |
+
task1_shaping:
|
| 92 |
list_functions: -0.05
|
| 93 |
get_function_code_wrong: -0.10
|
| 94 |
get_function_code_correct: +0.05
|
|
|
|
| 98 |
get_state_variable: -0.05
|
| 99 |
get_call_graph: -0.08
|
| 100 |
repeated_query: -0.40
|
| 101 |
+
task1_terminal:
|
| 102 |
correct_submission: +5.0
|
| 103 |
partial_submission: +1.0
|
| 104 |
wrong_submission: -1.5
|
| 105 |
+
task2_shaping:
|
| 106 |
+
get_function_code: -0.06
|
| 107 |
+
get_function_natspec: -0.08
|
| 108 |
+
get_file_natspec: -0.03
|
| 109 |
+
get_related_functions: -0.06
|
| 110 |
+
get_io: -0.04
|
| 111 |
+
get_similar_rule: -0.20
|
| 112 |
+
repeated_query: -0.40
|
| 113 |
+
task2_terminal:
|
| 114 |
+
score_range: [0.0, 5.0]
|
| 115 |
+
formula: "score * 5.0 where score = 0.70*(key_matches/total_key) + 0.30*(bonus_matches/total_bonus)"
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
data:
|
| 118 |
+
source: "Certora audited DeFi projects"
|
| 119 |
format: JSON
|
| 120 |
num_contracts: 4
|
| 121 |
num_vulnerable_functions: 8
|
| 122 |
+
num_property_functions: 11
|
| 123 |
vulnerability_types:
|
| 124 |
- Reentrancy
|
| 125 |
- Missing access control
|
|
|
|
| 130 |
- Denial of service (unbounded loop)
|
| 131 |
- Unchecked return value
|
| 132 |
|
|
|
|
|
|
|
|
|
|
| 133 |
interface:
|
| 134 |
http:
|
| 135 |
+
reset: POST /reset
|
| 136 |
+
step: POST /step
|
| 137 |
+
state: GET /state
|
| 138 |
+
tasks: GET /tasks
|
| 139 |
+
health: GET /health
|
| 140 |
+
action_space: GET /action_space?task_id=<id>
|
| 141 |
+
observation_space: GET /observation_space
|
| 142 |
python:
|
| 143 |
+
reset: env.reset(seed=None) -> ResetResult
|
| 144 |
+
step: env.step(action) -> StepResult
|
| 145 |
+
state: env.state() -> StateResult
|
tasks/task2/__init__.py
CHANGED
|
@@ -1,27 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
tasks
|
| 3 |
-
|
| 4 |
-
Task 2: Property Discovery (PLACEHOLDER)
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
Episode setup:
|
| 9 |
-
- One function from a Solidity file with known properties
|
| 10 |
-
- Agent must discover the natural-language property of the function
|
| 11 |
-
|
| 12 |
-
Actions (to implement):
|
| 13 |
-
- get_similar_rule : -0.20
|
| 14 |
-
- get_file_natspec : -0.03
|
| 15 |
-
- get_function_natspec : -0.08
|
| 16 |
-
- get_function_code : -0.06
|
| 17 |
-
- get_related_functions : -0.06
|
| 18 |
-
- get_io : -0.04
|
| 19 |
-
- submit_property : scored 0.0β5.0 by semantic similarity grader
|
| 20 |
-
|
| 21 |
-
See README.md for full task specification.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
# TODO: Task 2 β Property Discovery
|
| 25 |
-
# from tasks.task2.environment import Task2Environment
|
| 26 |
-
|
| 27 |
-
__all__: list = []
|
|
|
|
| 1 |
+
# Task 2: Property Discovery
|
| 2 |
+
from tasks.task2.environment import Task2Environment
|
| 3 |
+
from tasks.task2.grader import Task2Grader
|
|
|
|
| 4 |
|
| 5 |
+
__all__ = ["Task2Environment", "Task2Grader"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task2/environment.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
environment.py (Task 2 β Property Discovery)
|
| 3 |
+
----------------------------------------------
|
| 4 |
+
OpenEnv-compliant RL environment.
|
| 5 |
+
|
| 6 |
+
Episode setup:
|
| 7 |
+
- One function from a Solidity contract that has a known property.
|
| 8 |
+
- The agent sees: contract description + function name + function signature.
|
| 9 |
+
- The agent must discover the natural-language property of the function.
|
| 10 |
+
|
| 11 |
+
Actions & rewards:
|
| 12 |
+
get_function_code -0.06 (always positive topic context)
|
| 13 |
+
get_function_natspec -0.08 (strongest hint β natspec has param/return docs)
|
| 14 |
+
get_file_natspec -0.03 (broad contract-level context)
|
| 15 |
+
get_related_functions -0.06 (shows callers/callees)
|
| 16 |
+
get_io -0.04 (structured input/output description)
|
| 17 |
+
get_similar_rule -0.20 (shows a similar property from another contract)
|
| 18 |
+
submit_property scored 0β5 (ONE attempt, ends episode)
|
| 19 |
+
repeated_query -0.40
|
| 20 |
+
|
| 21 |
+
Episode ends when:
|
| 22 |
+
- submit_property is called (scored), OR
|
| 23 |
+
- max_steps is reached without submission (reward = -1.0)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import random
|
| 29 |
+
from typing import Any, Dict, List, Optional, Set
|
| 30 |
+
|
| 31 |
+
from data.data_loader import (
|
| 32 |
+
load_contracts,
|
| 33 |
+
sample_property_episode,
|
| 34 |
+
get_function_by_name,
|
| 35 |
+
get_related_functions,
|
| 36 |
+
get_similar_rule,
|
| 37 |
+
)
|
| 38 |
+
from env.base_env import BaseEnv
|
| 39 |
+
from env.schemas import (
|
| 40 |
+
Action,
|
| 41 |
+
ActionType,
|
| 42 |
+
Observation,
|
| 43 |
+
Reward,
|
| 44 |
+
ResetResult,
|
| 45 |
+
StateResult,
|
| 46 |
+
StepResult,
|
| 47 |
+
)
|
| 48 |
+
from tasks.task2.grader import Task2Grader
|
| 49 |
+
|
| 50 |
+
TASK_ID = "task2_property_discovery"
|
| 51 |
+
MAX_STEPS = 15
|
| 52 |
+
|
| 53 |
+
AVAILABLE_ACTIONS = [
|
| 54 |
+
ActionType.GET_FUNCTION_CODE,
|
| 55 |
+
ActionType.GET_FUNCTION_NATSPEC,
|
| 56 |
+
ActionType.GET_FILE_NATSPEC,
|
| 57 |
+
ActionType.GET_RELATED_FUNCTIONS,
|
| 58 |
+
ActionType.GET_IO,
|
| 59 |
+
ActionType.GET_SIMILAR_RULE,
|
| 60 |
+
ActionType.SUBMIT_PROPERTY,
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Task2Environment(BaseEnv):
|
| 65 |
+
"""Task 2: Property Discovery."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, contracts_path: Optional[str] = None) -> None:
|
| 68 |
+
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
|
| 69 |
+
self._rng = random.Random()
|
| 70 |
+
|
| 71 |
+
# Episode state β initialised by reset()
|
| 72 |
+
self._contract: Dict[str, Any] = {}
|
| 73 |
+
self._target_fn: Dict[str, Any] = {}
|
| 74 |
+
self._grader: Optional[Task2Grader] = None
|
| 75 |
+
self._step_count: int = 0
|
| 76 |
+
self._cum_reward: float = 0.0
|
| 77 |
+
self._done: bool = False
|
| 78 |
+
self._submitted: bool = False # only one submit_property allowed
|
| 79 |
+
self._query_hist: List[str] = []
|
| 80 |
+
self._seen: Set[str] = set()
|
| 81 |
+
|
| 82 |
+
# ββ OpenEnv interface ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
|
| 84 |
+
def reset(self, seed: Optional[int] = None) -> ResetResult:
|
| 85 |
+
if seed is not None:
|
| 86 |
+
self._rng.seed(seed)
|
| 87 |
+
|
| 88 |
+
self._contract, self._target_fn = sample_property_episode(
|
| 89 |
+
self._contracts, self._rng
|
| 90 |
+
)
|
| 91 |
+
self._grader = Task2Grader(
|
| 92 |
+
function_name=self._target_fn["name"],
|
| 93 |
+
property_data=self._target_fn["property"],
|
| 94 |
+
)
|
| 95 |
+
self._step_count = 0
|
| 96 |
+
self._cum_reward = 0.0
|
| 97 |
+
self._done = False
|
| 98 |
+
self._submitted = False
|
| 99 |
+
self._query_hist = []
|
| 100 |
+
self._seen = set()
|
| 101 |
+
|
| 102 |
+
obs = self._build_obs(
|
| 103 |
+
last_action=None,
|
| 104 |
+
last_result=(
|
| 105 |
+
f"New episode started.\n"
|
| 106 |
+
f"Contract : {self._contract['contract_name']}\n"
|
| 107 |
+
f"Function : {self._target_fn['name']} "
|
| 108 |
+
f"({self._target_fn.get('signature', '')})\n"
|
| 109 |
+
f"Your task : Discover the natural-language property of "
|
| 110 |
+
f"'{self._target_fn['name']}' and submit it with submit_property."
|
| 111 |
+
),
|
| 112 |
+
)
|
| 113 |
+
return ResetResult(observation=obs, info={"task_id": TASK_ID})
|
| 114 |
+
|
| 115 |
+
def step(self, action: Action) -> StepResult:
|
| 116 |
+
if self._done:
|
| 117 |
+
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 118 |
+
|
| 119 |
+
self._step_count += 1
|
| 120 |
+
result_text, reward = self._dispatch(action)
|
| 121 |
+
self._cum_reward += reward.value
|
| 122 |
+
self._query_hist.append(f"[{action.action_type}] β {result_text[:100]}")
|
| 123 |
+
|
| 124 |
+
obs = self._build_obs(
|
| 125 |
+
last_action=action.action_type,
|
| 126 |
+
last_result=result_text,
|
| 127 |
+
)
|
| 128 |
+
return StepResult(
|
| 129 |
+
observation=obs,
|
| 130 |
+
reward=reward,
|
| 131 |
+
done=self._done,
|
| 132 |
+
info={
|
| 133 |
+
"step": self._step_count,
|
| 134 |
+
"cumulative_reward": self._cum_reward,
|
| 135 |
+
},
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def state(self) -> StateResult:
|
| 139 |
+
return StateResult(
|
| 140 |
+
task_id=TASK_ID,
|
| 141 |
+
contract_name=self._contract.get("contract_name", ""),
|
| 142 |
+
target_function=self._target_fn.get("name"),
|
| 143 |
+
step_count=self._step_count,
|
| 144 |
+
cumulative_reward=self._cum_reward,
|
| 145 |
+
done=self._done,
|
| 146 |
+
query_history=list(self._query_hist),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# ββ Internal helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
|
| 152 |
+
return Observation(
|
| 153 |
+
task_id=TASK_ID,
|
| 154 |
+
contract_name=self._contract.get("contract_name", ""),
|
| 155 |
+
contract_description=self._contract.get("metadata", {}).get("description", ""),
|
| 156 |
+
available_actions=[a.value for a in AVAILABLE_ACTIONS],
|
| 157 |
+
last_action=last_action,
|
| 158 |
+
last_action_result=last_result,
|
| 159 |
+
step_count=self._step_count,
|
| 160 |
+
cumulative_reward=self._cum_reward,
|
| 161 |
+
done=self._done,
|
| 162 |
+
extra={
|
| 163 |
+
"target_function": self._target_fn.get("name", ""),
|
| 164 |
+
"target_signature": self._target_fn.get("signature", ""),
|
| 165 |
+
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
|
| 166 |
+
"hint": (
|
| 167 |
+
"Discover the property of the target function. "
|
| 168 |
+
"Use get_function_code, get_function_natspec, or get_similar_rule for hints. "
|
| 169 |
+
"Submit with submit_property, params={'property': '<your property text>'}. "
|
| 170 |
+
"ONE submission attempt only."
|
| 171 |
+
),
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def _qkey(self, at: str, params: Dict[str, Any]) -> str:
|
| 176 |
+
return f"{at}:{sorted(params.items())}"
|
| 177 |
+
|
| 178 |
+
def _is_repeated(self, key: str) -> bool:
|
| 179 |
+
if key in self._seen:
|
| 180 |
+
return True
|
| 181 |
+
self._seen.add(key)
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
def _dispatch(self, action: Action) -> tuple[str, Reward]:
|
| 185 |
+
at = action.action_type
|
| 186 |
+
params = action.params
|
| 187 |
+
qkey = self._qkey(at, params)
|
| 188 |
+
fn = self._target_fn
|
| 189 |
+
name = fn["name"]
|
| 190 |
+
|
| 191 |
+
# ββ get_function_code ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
if at == ActionType.GET_FUNCTION_CODE:
|
| 193 |
+
if self._is_repeated(qkey):
|
| 194 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 195 |
+
code = fn.get("code", "// no code available")
|
| 196 |
+
return (
|
| 197 |
+
f"// {name}\n{code}",
|
| 198 |
+
Reward(value=-0.06, reason="get_function_code cost"),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ββ get_function_natspec βββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
if at == ActionType.GET_FUNCTION_NATSPEC:
|
| 203 |
+
if self._is_repeated(qkey):
|
| 204 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 205 |
+
natspec = fn.get("natspec") or fn.get("comment") or "No NatSpec available."
|
| 206 |
+
# Also include output_property if present
|
| 207 |
+
out_prop = fn.get("output_property", "")
|
| 208 |
+
result = f"NatSpec for '{name}':\n{natspec}"
|
| 209 |
+
if out_prop:
|
| 210 |
+
result += f"\n\nExpected output: {out_prop}"
|
| 211 |
+
return result, Reward(value=-0.08, reason="get_function_natspec cost")
|
| 212 |
+
|
| 213 |
+
# ββ get_file_natspec βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
if at == ActionType.GET_FILE_NATSPEC:
|
| 215 |
+
if self._is_repeated(qkey):
|
| 216 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 217 |
+
meta = self._contract.get("metadata", {})
|
| 218 |
+
natspec = meta.get("natspec") or meta.get("description", "No file NatSpec available.")
|
| 219 |
+
return (
|
| 220 |
+
f"File NatSpec for {self._contract['contract_name']}:\n{natspec}",
|
| 221 |
+
Reward(value=-0.03, reason="get_file_natspec cost"),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# ββ get_related_functions ββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
if at == ActionType.GET_RELATED_FUNCTIONS:
|
| 226 |
+
if self._is_repeated(qkey):
|
| 227 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 228 |
+
related = get_related_functions(self._contract, name)
|
| 229 |
+
if not related:
|
| 230 |
+
text = f"No related functions found for '{name}'."
|
| 231 |
+
else:
|
| 232 |
+
summaries = []
|
| 233 |
+
for rn in related:
|
| 234 |
+
rfn = get_function_by_name(self._contract, rn)
|
| 235 |
+
if rfn:
|
| 236 |
+
sig = rfn.get("signature", rn)
|
| 237 |
+
comment = rfn.get("comment", "")
|
| 238 |
+
summaries.append(f" β’ {sig} β {comment}")
|
| 239 |
+
text = f"Related functions for '{name}':\n" + "\n".join(summaries)
|
| 240 |
+
return text, Reward(value=-0.06, reason="get_related_functions cost")
|
| 241 |
+
|
| 242 |
+
# ββ get_io βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
+
if at == ActionType.GET_IO:
|
| 244 |
+
if self._is_repeated(qkey):
|
| 245 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 246 |
+
params_list = fn.get("parameters", [])
|
| 247 |
+
returns = fn.get("returns", "") or "void"
|
| 248 |
+
out_prop = fn.get("output_property", "")
|
| 249 |
+
visibility = fn.get("visibility", "")
|
| 250 |
+
modifiers = fn.get("modifiers", [])
|
| 251 |
+
|
| 252 |
+
lines = [f"Function: {fn.get('signature', name)}"]
|
| 253 |
+
lines.append(f"Visibility: {visibility}" + (f" Modifiers: {', '.join(modifiers)}" if modifiers else ""))
|
| 254 |
+
if params_list:
|
| 255 |
+
lines.append("Parameters:")
|
| 256 |
+
for p in params_list:
|
| 257 |
+
lines.append(f" β’ {p['type']} {p['name']}: {p.get('description','')}")
|
| 258 |
+
else:
|
| 259 |
+
lines.append("Parameters: none (payable)" if "payable" in fn.get("code","") else "Parameters: none")
|
| 260 |
+
lines.append(f"Returns: {returns}")
|
| 261 |
+
if out_prop:
|
| 262 |
+
lines.append(f"Expected behaviour: {out_prop}")
|
| 263 |
+
return "\n".join(lines), Reward(value=-0.04, reason="get_io cost")
|
| 264 |
+
|
| 265 |
+
# ββ get_similar_rule βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
+
if at == ActionType.GET_SIMILAR_RULE:
|
| 267 |
+
if self._is_repeated(qkey):
|
| 268 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 269 |
+
sr = get_similar_rule(
|
| 270 |
+
self._contracts,
|
| 271 |
+
self._contract["contract_name"],
|
| 272 |
+
name,
|
| 273 |
+
)
|
| 274 |
+
if sr is None:
|
| 275 |
+
return (
|
| 276 |
+
"No similar rule available for this function.",
|
| 277 |
+
Reward(value=-0.20, reason="get_similar_rule cost (not found)"),
|
| 278 |
+
)
|
| 279 |
+
lines = [
|
| 280 |
+
f"Similar property from {sr['contract_name']}.{sr['function_name']}():",
|
| 281 |
+
f" {sr['property_hint']}",
|
| 282 |
+
]
|
| 283 |
+
if sr.get("natspec"):
|
| 284 |
+
lines.append(f"\nFunction NatSpec:\n {sr['natspec']}")
|
| 285 |
+
return "\n".join(lines), Reward(value=-0.20, reason="get_similar_rule cost")
|
| 286 |
+
|
| 287 |
+
# ββ submit_property ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
+
if at == ActionType.SUBMIT_PROPERTY:
|
| 289 |
+
if self._submitted:
|
| 290 |
+
return (
|
| 291 |
+
"β You have already submitted a property for this episode. "
|
| 292 |
+
"Only one submission is allowed.",
|
| 293 |
+
Reward(value=-1.0, reason="Second submit_property attempt", partial=False),
|
| 294 |
+
)
|
| 295 |
+
submitted_text = params.get("property", "").strip()
|
| 296 |
+
if not submitted_text:
|
| 297 |
+
return (
|
| 298 |
+
"Submit requires 'property' key in params with a non-empty string.",
|
| 299 |
+
Reward(value=-0.5, reason="Empty property submission"),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self._submitted = True
|
| 303 |
+
self._done = True
|
| 304 |
+
|
| 305 |
+
score = self._grader.grade(submitted_text)
|
| 306 |
+
reward = self._grader.reward_for_score(score)
|
| 307 |
+
bd = self._grader.breakdown(submitted_text)
|
| 308 |
+
|
| 309 |
+
pct = int(score * 100)
|
| 310 |
+
if score >= 0.85:
|
| 311 |
+
emoji = "β
"
|
| 312 |
+
label = "EXCELLENT"
|
| 313 |
+
elif score >= 0.60:
|
| 314 |
+
emoji = "π‘"
|
| 315 |
+
label = "GOOD"
|
| 316 |
+
elif score >= 0.35:
|
| 317 |
+
emoji = "π "
|
| 318 |
+
label = "PARTIAL"
|
| 319 |
+
else:
|
| 320 |
+
emoji = "β"
|
| 321 |
+
label = "POOR"
|
| 322 |
+
|
| 323 |
+
msg = (
|
| 324 |
+
f"{emoji} {label} β Score: {score:.2f}/1.00 β Reward: {reward:.2f}/5.00 ({pct}%)\n"
|
| 325 |
+
f"Key concepts matched : {len(bd['key_matched'])}/{len(bd['key_matched'])+len(bd['key_missed'])} "
|
| 326 |
+
f"{bd['key_matched']}\n"
|
| 327 |
+
f"Bonus concepts matched : {len(bd['bonus_matched'])}/{len(bd['bonus_matched'])+len(bd['bonus_missed'])} "
|
| 328 |
+
f"{bd['bonus_matched']}"
|
| 329 |
+
)
|
| 330 |
+
return msg, Reward(
|
| 331 |
+
value=reward,
|
| 332 |
+
reason=f"Property submission score={score:.3f}",
|
| 333 |
+
partial=False,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# ββ unknown action ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 337 |
+
return (
|
| 338 |
+
f"Unknown action type: '{at}'. Valid: {[a.value for a in AVAILABLE_ACTIONS]}",
|
| 339 |
+
Reward(value=-0.10, reason="Unknown action"),
|
| 340 |
+
)
|
tasks/task2/grader.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
grader.py (Task 2 β Property Discovery)
|
| 3 |
+
-----------------------------------------
|
| 4 |
+
Deterministic scorer for natural-language property submissions.
|
| 5 |
+
|
| 6 |
+
Score formula
|
| 7 |
+
βββββββββββββ
|
| 8 |
+
key_phrases weight = 0.70
|
| 9 |
+
bonus_phrases weight = 0.30
|
| 10 |
+
|
| 11 |
+
score = 0.70 * (matched_key / total_key)
|
| 12 |
+
+ 0.30 * (matched_bonus / total_bonus)
|
| 13 |
+
|
| 14 |
+
Phrase matching
|
| 15 |
+
βββββββββββββββ
|
| 16 |
+
A phrase is considered matched if ALL its words (after normalisation)
|
| 17 |
+
appear in the submitted text. This is intentionally lenient β it
|
| 18 |
+
doesn't require the words to be adjacent, so "balance increases by
|
| 19 |
+
msg.value" is matched by "the caller's vault balance increases by the
|
| 20 |
+
sent msg.value amount".
|
| 21 |
+
|
| 22 |
+
Synonym expansion allows common paraphrases to also match
|
| 23 |
+
(e.g. "caller" β "msg.sender", "sender", "user").
|
| 24 |
+
|
| 25 |
+
Terminal reward = score Γ 5.0 (range: 0.0 β 5.0)
|
| 26 |
+
One submission attempt per episode.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import string
|
| 32 |
+
from typing import Dict, List, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ Text normalisation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
_PUNCT = str.maketrans("", "", string.punctuation)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _norm(text: str) -> str:
|
| 41 |
+
"""Lowercase, strip punctuation, collapse whitespace β word string."""
|
| 42 |
+
return " ".join(text.lower().translate(_PUNCT).split())
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _word_set(text: str) -> set:
|
| 46 |
+
"""Return normalised set of words."""
|
| 47 |
+
return set(_norm(text).split())
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ββ Synonym table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
# Maps canonical word β list of accepted synonyms (all lowercase, no punct).
|
| 52 |
+
|
| 53 |
+
SYNONYMS: Dict[str, List[str]] = {
|
| 54 |
+
"caller": ["caller", "sender", "user", "msgsender", "msg sender"],
|
| 55 |
+
"balance": ["balance", "holdings", "amount held"],
|
| 56 |
+
"increases": ["increases", "incremented", "added", "grows", "rise"],
|
| 57 |
+
"decreases": ["decreases", "decremented", "reduced", "subtracted", "falls"],
|
| 58 |
+
"transfers": ["transfers", "sends", "moved", "forwarded", "sent"],
|
| 59 |
+
"reverts": ["reverts", "fails", "rejected", "throws", "require"],
|
| 60 |
+
"zero": ["zero", "0", "nothing", "none", "empty"],
|
| 61 |
+
"owner": ["owner", "admin", "authorized"],
|
| 62 |
+
"entire": ["entire", "full", "whole", "all", "total"],
|
| 63 |
+
"returned": ["returned", "sent back", "refunded", "transferred back"],
|
| 64 |
+
"reset": ["reset", "zeroed", "set to zero", "cleared"],
|
| 65 |
+
"only": ["only", "exclusively", "restricted"],
|
| 66 |
+
"price": ["price", "cost", "rate"],
|
| 67 |
+
"tokens": ["tokens", "token amount"],
|
| 68 |
+
"rewards": ["rewards", "reward tokens", "accrued"],
|
| 69 |
+
"staked": ["staked", "deposited", "locked"],
|
| 70 |
+
"winner": ["winner", "winning bidder", "successful bidder"],
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _expand_words(phrase_words: List[str]) -> List[List[str]]:
|
| 75 |
+
"""
|
| 76 |
+
For each word in the phrase, generate synonym variants.
|
| 77 |
+
Returns a list of word-list variants to try.
|
| 78 |
+
Only substitutes ONE word at a time to avoid combinatorial explosion.
|
| 79 |
+
"""
|
| 80 |
+
variants = [phrase_words] # original
|
| 81 |
+
for i, word in enumerate(phrase_words):
|
| 82 |
+
if word in SYNONYMS:
|
| 83 |
+
for syn in SYNONYMS[word]:
|
| 84 |
+
syn_words = _norm(syn).split()
|
| 85 |
+
new_variant = phrase_words[:i] + syn_words + phrase_words[i + 1:]
|
| 86 |
+
variants.append(new_variant)
|
| 87 |
+
return variants
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _phrase_matched(text_words: set, phrase: str) -> bool:
|
| 91 |
+
"""
|
| 92 |
+
True if ALL words in the phrase (or a synonym variant) appear in text_words.
|
| 93 |
+
Uses word-set containment, not substring adjacency.
|
| 94 |
+
"""
|
| 95 |
+
norm_words = _norm(phrase).split()
|
| 96 |
+
for variant_words in _expand_words(norm_words):
|
| 97 |
+
if all(w in text_words for w in variant_words):
|
| 98 |
+
return True
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ββ Grader ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
|
| 104 |
+
class Task2Grader:
|
| 105 |
+
"""
|
| 106 |
+
Grades a Task 2 property submission.
|
| 107 |
+
|
| 108 |
+
Parameters
|
| 109 |
+
----------
|
| 110 |
+
function_name : name of the target function
|
| 111 |
+
property_data : the 'property' dict from the dataset
|
| 112 |
+
Must have: natural_language, key_phrases, bonus_phrases
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
KEY_WEIGHT = 0.70
|
| 116 |
+
BONUS_WEIGHT = 0.30
|
| 117 |
+
|
| 118 |
+
def __init__(self, function_name: str, property_data: Dict) -> None:
|
| 119 |
+
self.function_name = function_name
|
| 120 |
+
self.natural_language = property_data.get("natural_language", "")
|
| 121 |
+
self.key_phrases = property_data.get("key_phrases", [])
|
| 122 |
+
self.bonus_phrases = property_data.get("bonus_phrases", [])
|
| 123 |
+
|
| 124 |
+
# ββ Public API ββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
|
| 126 |
+
def grade(self, submitted: str) -> float:
|
| 127 |
+
"""Deterministic score in [0.0, 1.0]."""
|
| 128 |
+
if not submitted or not submitted.strip():
|
| 129 |
+
return 0.0
|
| 130 |
+
tw = _word_set(submitted)
|
| 131 |
+
key_score = self._phrase_score(tw, self.key_phrases)
|
| 132 |
+
bonus_score = self._phrase_score(tw, self.bonus_phrases)
|
| 133 |
+
raw = self.KEY_WEIGHT * key_score + self.BONUS_WEIGHT * bonus_score
|
| 134 |
+
return round(min(max(raw, 0.0), 1.0), 4)
|
| 135 |
+
|
| 136 |
+
def reward_for_score(self, score: float) -> float:
|
| 137 |
+
"""Maps [0.0, 1.0] β [0.0, 5.0]."""
|
| 138 |
+
return round(score * 5.0, 4)
|
| 139 |
+
|
| 140 |
+
def breakdown(self, submitted: str) -> Dict:
|
| 141 |
+
"""Detailed scoring breakdown for debugging."""
|
| 142 |
+
tw = _word_set(submitted)
|
| 143 |
+
key_hits = [p for p in self.key_phrases if _phrase_matched(tw, p)]
|
| 144 |
+
bonus_hits = [p for p in self.bonus_phrases if _phrase_matched(tw, p)]
|
| 145 |
+
score = self.grade(submitted)
|
| 146 |
+
return {
|
| 147 |
+
"score": score,
|
| 148 |
+
"reward": self.reward_for_score(score),
|
| 149 |
+
"key_matched": key_hits,
|
| 150 |
+
"key_missed": [p for p in self.key_phrases if p not in key_hits],
|
| 151 |
+
"bonus_matched": bonus_hits,
|
| 152 |
+
"bonus_missed": [p for p in self.bonus_phrases if p not in bonus_hits],
|
| 153 |
+
"key_score": self._phrase_score(tw, self.key_phrases),
|
| 154 |
+
"bonus_score": self._phrase_score(tw, self.bonus_phrases),
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def get_canonical_answer(self) -> Dict:
|
| 158 |
+
"""For debugging / logging only."""
|
| 159 |
+
return {
|
| 160 |
+
"function": self.function_name,
|
| 161 |
+
"natural_language": self.natural_language,
|
| 162 |
+
"key_phrases": self.key_phrases,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
|
| 167 |
+
def _phrase_score(self, text_words: set, phrases: List[str]) -> float:
|
| 168 |
+
if not phrases:
|
| 169 |
+
return 1.0
|
| 170 |
+
matched = sum(1 for p in phrases if _phrase_matched(text_words, p))
|
| 171 |
+
return matched / len(phrases)
|
validate.py
CHANGED
|
@@ -1,290 +1,280 @@
|
|
| 1 |
"""
|
| 2 |
validate.py
|
| 3 |
-----------
|
| 4 |
-
Pre-submission validation
|
| 5 |
-
Checks all OpenEnv spec requirements locally before submitting.
|
| 6 |
|
| 7 |
-
Usage:
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
Exit code 0 = all checks pass.
|
| 11 |
-
Exit code 1 = one or more checks failed.
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import json
|
| 15 |
-
import sys
|
| 16 |
-
import traceback
|
| 17 |
from typing import Callable, List, Tuple
|
| 18 |
|
| 19 |
-
|
| 20 |
-
# Helpers
|
| 21 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
-
|
| 23 |
-
PASS = "β
"
|
| 24 |
-
FAIL = "β"
|
| 25 |
-
SKIP = "β "
|
| 26 |
results: List[Tuple[str, bool, str]] = []
|
| 27 |
|
| 28 |
-
|
| 29 |
def check(name: str, fn: Callable[[], None]) -> None:
|
| 30 |
try:
|
| 31 |
-
fn()
|
| 32 |
-
results.append((name, True, ""))
|
| 33 |
print(f" {PASS} {name}")
|
| 34 |
except Exception as e:
|
| 35 |
-
tb = traceback.format_exc(limit=3)
|
| 36 |
results.append((name, False, str(e)))
|
| 37 |
-
print(f" {FAIL} {name}")
|
| 38 |
-
print(f" {e}")
|
| 39 |
-
|
| 40 |
|
| 41 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
-
# Checks
|
| 43 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
|
| 45 |
def check_imports():
|
| 46 |
-
from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult
|
| 47 |
from tasks.task1.environment import Task1Environment
|
| 48 |
from tasks.task1.grader import Task1Grader
|
|
|
|
|
|
|
| 49 |
from data.data_loader import load_contracts
|
| 50 |
|
| 51 |
-
|
| 52 |
def check_openenv_yaml():
|
| 53 |
import yaml
|
| 54 |
-
with open("openenv.yaml") as f:
|
| 55 |
-
spec = yaml.safe_load(f)
|
| 56 |
assert "name" in spec
|
| 57 |
-
assert "tasks"
|
| 58 |
-
assert len(spec["tasks"]) >= 3, "Need at least 3 tasks defined"
|
| 59 |
assert "observation_space" in spec
|
| 60 |
assert "action_space" in spec
|
| 61 |
assert "reward" in spec
|
| 62 |
|
| 63 |
-
|
| 64 |
def check_pydantic_models():
|
| 65 |
from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
|
| 66 |
-
|
| 67 |
-
obs = Observation(
|
| 68 |
-
task_id="t1", contract_name="C", contract_description="D",
|
| 69 |
-
available_actions=["submit"]
|
| 70 |
-
)
|
| 71 |
assert obs.task_id == "t1"
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
assert
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
assert reward.value == 1.0
|
| 78 |
-
|
| 79 |
-
step = StepResult(observation=obs, reward=reward, done=False)
|
| 80 |
-
assert not step.done
|
| 81 |
-
|
| 82 |
-
reset = ResetResult(observation=obs)
|
| 83 |
-
assert reset.observation.task_id == "t1"
|
| 84 |
-
|
| 85 |
-
state = StateResult(task_id="t1", contract_name="C", step_count=0,
|
| 86 |
-
cumulative_reward=0.0, done=False)
|
| 87 |
-
assert state.step_count == 0
|
| 88 |
-
|
| 89 |
|
| 90 |
def check_data_loading():
|
| 91 |
-
from data.data_loader import load_contracts, get_all_vulnerable_entries
|
| 92 |
contracts = load_contracts()
|
| 93 |
-
assert len(contracts) >= 1
|
| 94 |
-
|
| 95 |
-
assert len(
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
assert result.observation is not None
|
| 107 |
-
assert result.observation.task_id == "task1_vuln_detection"
|
| 108 |
-
assert result.observation.contract_name != ""
|
| 109 |
-
assert not result.observation.done
|
| 110 |
-
assert result.observation.step_count == 0
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def check_env_step():
|
| 114 |
from tasks.task1.environment import Task1Environment
|
| 115 |
from env.schemas import Action, ActionType
|
| 116 |
env = Task1Environment()
|
| 117 |
-
env.reset(seed=42)
|
| 118 |
-
|
| 119 |
-
assert
|
| 120 |
-
assert
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
def
|
| 136 |
from tasks.task1.grader import Task1Grader
|
| 137 |
cases = [
|
| 138 |
-
("withdraw", "Reentrancy vulnerability",
|
| 139 |
-
("withdraw", "Reentrancy vulnerability",
|
| 140 |
-
("withdraw", "Reentrancy vulnerability",
|
| 141 |
]
|
| 142 |
for tf, issue, sf, sv, expected in cases:
|
| 143 |
g = Task1Grader(tf, issue)
|
| 144 |
score = g.grade_submission(sf, sv)
|
| 145 |
-
assert 0.0 <= score <= 1.0
|
| 146 |
assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
from
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def check_reward_shaping():
|
| 158 |
-
|
| 159 |
-
from tasks.task1.environment import Task1Environment
|
| 160 |
from env.schemas import Action, ActionType
|
| 161 |
-
env =
|
| 162 |
env.reset(seed=1)
|
| 163 |
-
rewards =
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
rewards.add(round(r.reward.value, 4))
|
| 167 |
-
# Should have at least 2 distinct shaping reward values
|
| 168 |
-
assert len(rewards) >= 2, f"Expected multiple reward values, got {rewards}"
|
| 169 |
-
|
| 170 |
|
| 171 |
-
def
|
| 172 |
-
"""Episode must end after submit and raise on subsequent step."""
|
| 173 |
from tasks.task1.environment import Task1Environment
|
| 174 |
from env.schemas import Action, ActionType
|
| 175 |
env = Task1Environment()
|
| 176 |
env.reset(seed=2)
|
| 177 |
-
env.step(Action(action_type=ActionType.SUBMIT,
|
| 178 |
-
|
| 179 |
-
}))
|
| 180 |
try:
|
| 181 |
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 182 |
-
raise AssertionError("Should
|
| 183 |
except RuntimeError:
|
| 184 |
-
pass
|
| 185 |
-
|
| 186 |
|
| 187 |
def check_repeated_query_penalty():
|
| 188 |
from tasks.task1.environment import Task1Environment
|
| 189 |
from env.schemas import Action, ActionType
|
| 190 |
-
env = Task1Environment()
|
| 191 |
-
env.reset(seed=3)
|
| 192 |
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 193 |
r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 194 |
-
assert r.reward.value == -0.40
|
| 195 |
-
|
| 196 |
|
| 197 |
-
def
|
| 198 |
-
|
| 199 |
-
from
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
| 201 |
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
def
|
| 204 |
import os
|
| 205 |
-
assert os.path.exists("Dockerfile")
|
| 206 |
-
with open("Dockerfile") as f:
|
| 207 |
-
|
| 208 |
-
assert "
|
| 209 |
-
assert "uvicorn" in content or "CMD" in content
|
| 210 |
-
|
| 211 |
|
| 212 |
def check_inference_script():
|
| 213 |
import os
|
| 214 |
-
assert os.path.exists("inference.py")
|
| 215 |
-
with open("inference.py") as f:
|
| 216 |
-
|
| 217 |
-
assert "
|
| 218 |
-
|
| 219 |
-
assert "
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def check_baseline_json_schema():
|
| 224 |
-
"""baseline_scores.json must have valid schema if it exists."""
|
| 225 |
import os
|
| 226 |
-
if not os.path.exists("baseline_scores.json"):
|
| 227 |
-
|
| 228 |
-
with open("baseline_scores.json") as f:
|
| 229 |
-
data = json.load(f)
|
| 230 |
assert "tasks" in data
|
| 231 |
-
for
|
| 232 |
-
|
| 233 |
-
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
def main():
|
| 241 |
-
print("=" *
|
| 242 |
-
print("OpenEnv Pre-Submission Validation")
|
| 243 |
-
print("=" *
|
| 244 |
-
|
| 245 |
-
all_checks = [
|
| 246 |
-
("Python imports", check_imports),
|
| 247 |
-
("openenv.yaml format", check_openenv_yaml),
|
| 248 |
-
("Pydantic model types", check_pydantic_models),
|
| 249 |
-
("Dataset loading (3+ vulns)", check_data_loading),
|
| 250 |
-
("env.reset() β ResetResult", check_env_reset),
|
| 251 |
-
("env.step() β StepResult", check_env_step),
|
| 252 |
-
("env.state() β StateResult", check_env_state),
|
| 253 |
-
("Grader scores in [0.0, 1.0]", check_grader_scores_in_range),
|
| 254 |
-
("Grader is deterministic", check_grader_deterministic),
|
| 255 |
-
("Reward shaping (non-binary)", check_reward_shaping),
|
| 256 |
-
("Episode boundary (done=True)",check_episode_boundary),
|
| 257 |
-
("Repeated query penalty", check_repeated_query_penalty),
|
| 258 |
-
("Task 2 & 3 placeholders", check_tasks_list),
|
| 259 |
-
("Dockerfile exists + port", check_dockerfile_exists),
|
| 260 |
-
("inference.py exists + vars", check_inference_script),
|
| 261 |
-
("baseline_scores.json schema", check_baseline_json_schema),
|
| 262 |
-
]
|
| 263 |
-
|
| 264 |
print()
|
| 265 |
-
for name, fn in
|
| 266 |
check(name, fn)
|
| 267 |
|
| 268 |
-
print()
|
| 269 |
passed = sum(1 for _, ok, _ in results if ok)
|
| 270 |
-
total
|
| 271 |
-
failed = [(n,
|
| 272 |
|
| 273 |
-
print(
|
|
|
|
| 274 |
print(f"Results: {passed}/{total} checks passed")
|
| 275 |
-
|
| 276 |
if failed:
|
| 277 |
print("\nFailed checks:")
|
| 278 |
-
for
|
| 279 |
-
print(f" {FAIL} {
|
| 280 |
-
print()
|
| 281 |
-
print("β VALIDATION FAILED β fix the issues above before submitting.")
|
| 282 |
sys.exit(1)
|
| 283 |
else:
|
| 284 |
-
print()
|
| 285 |
-
print("β
ALL CHECKS PASSED β ready to submit!")
|
| 286 |
sys.exit(0)
|
| 287 |
|
| 288 |
-
|
| 289 |
if __name__ == "__main__":
|
| 290 |
main()
|
|
|
|
| 1 |
"""
|
| 2 |
validate.py
|
| 3 |
-----------
|
| 4 |
+
Pre-submission validation. Checks all OpenEnv spec requirements.
|
|
|
|
| 5 |
|
| 6 |
+
Usage: python validate.py
|
| 7 |
+
Exit 0 = all checks pass. Exit 1 = one or more failures.
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
import json, sys, traceback
|
|
|
|
|
|
|
| 11 |
from typing import Callable, List, Tuple
|
| 12 |
|
| 13 |
+
PASS = "β
"; FAIL = "β"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
results: List[Tuple[str, bool, str]] = []
|
| 15 |
|
|
|
|
| 16 |
def check(name: str, fn: Callable[[], None]) -> None:
|
| 17 |
try:
|
| 18 |
+
fn(); results.append((name, True, ""))
|
|
|
|
| 19 |
print(f" {PASS} {name}")
|
| 20 |
except Exception as e:
|
|
|
|
| 21 |
results.append((name, False, str(e)))
|
| 22 |
+
print(f" {FAIL} {name}\n {e}")
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
# ββ Checks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def check_imports():
|
| 27 |
+
from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult, ActionType
|
| 28 |
from tasks.task1.environment import Task1Environment
|
| 29 |
from tasks.task1.grader import Task1Grader
|
| 30 |
+
from tasks.task2.environment import Task2Environment
|
| 31 |
+
from tasks.task2.grader import Task2Grader
|
| 32 |
from data.data_loader import load_contracts
|
| 33 |
|
|
|
|
| 34 |
def check_openenv_yaml():
|
| 35 |
import yaml
|
| 36 |
+
with open("openenv.yaml") as f: spec = yaml.safe_load(f)
|
|
|
|
| 37 |
assert "name" in spec
|
| 38 |
+
assert len(spec.get("tasks", [])) >= 3
|
|
|
|
| 39 |
assert "observation_space" in spec
|
| 40 |
assert "action_space" in spec
|
| 41 |
assert "reward" in spec
|
| 42 |
|
|
|
|
| 43 |
def check_pydantic_models():
|
| 44 |
from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
|
| 45 |
+
obs = Observation(task_id="t1", contract_name="C", contract_description="D", available_actions=["submit"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
assert obs.task_id == "t1"
|
| 47 |
+
action = Action(action_type=ActionType.LIST_FUNCTIONS); assert action.action_type == ActionType.LIST_FUNCTIONS
|
| 48 |
+
action2 = Action(action_type=ActionType.SUBMIT_PROPERTY); assert action2.action_type == ActionType.SUBMIT_PROPERTY
|
| 49 |
+
reward = Reward(value=1.0, reason="test"); assert reward.value == 1.0
|
| 50 |
+
step = StepResult(observation=obs, reward=reward, done=False); assert not step.done
|
| 51 |
+
reset = ResetResult(observation=obs); assert reset.observation.task_id == "t1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def check_data_loading():
|
| 54 |
+
from data.data_loader import load_contracts, get_all_vulnerable_entries, get_all_property_entries
|
| 55 |
contracts = load_contracts()
|
| 56 |
+
assert len(contracts) >= 1
|
| 57 |
+
vuln_entries = get_all_vulnerable_entries(contracts)
|
| 58 |
+
assert len(vuln_entries) >= 3, f"Need >=3 vulnerable fns, got {len(vuln_entries)}"
|
| 59 |
+
prop_entries = get_all_property_entries(contracts)
|
| 60 |
+
assert len(prop_entries) >= 3, f"Need >=3 property fns, got {len(prop_entries)}"
|
| 61 |
+
for _, fn in prop_entries:
|
| 62 |
+
p = fn["property"]
|
| 63 |
+
assert "natural_language" in p
|
| 64 |
+
assert "key_phrases" in p
|
| 65 |
+
assert "bonus_phrases" in p
|
| 66 |
+
assert len(p["key_phrases"]) >= 2
|
| 67 |
+
|
| 68 |
+
def check_t1_env():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
from tasks.task1.environment import Task1Environment
|
| 70 |
from env.schemas import Action, ActionType
|
| 71 |
env = Task1Environment()
|
| 72 |
+
r = env.reset(seed=42); assert r.observation.task_id == "task1_vuln_detection"
|
| 73 |
+
s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 74 |
+
assert isinstance(s.reward.value, float)
|
| 75 |
+
assert s.observation.step_count == 1
|
| 76 |
+
st = env.state(); assert st.target_function is not None
|
| 77 |
+
|
| 78 |
+
def check_t2_env():
|
| 79 |
+
from tasks.task2.environment import Task2Environment
|
| 80 |
+
from env.schemas import Action, ActionType
|
| 81 |
+
env = Task2Environment()
|
| 82 |
+
r = env.reset(seed=42)
|
| 83 |
+
assert r.observation.task_id == "task2_property_discovery"
|
| 84 |
+
assert "target_function" in r.observation.extra
|
| 85 |
+
# test each action type
|
| 86 |
+
for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC,
|
| 87 |
+
ActionType.GET_FILE_NATSPEC, ActionType.GET_IO, ActionType.GET_RELATED_FUNCTIONS]:
|
| 88 |
+
s = env.step(Action(action_type=at)); assert s.reward.value < 0
|
| 89 |
+
s = env.step(Action(action_type=ActionType.GET_SIMILAR_RULE))
|
| 90 |
+
assert s.reward.value == -0.20
|
| 91 |
+
|
| 92 |
+
def check_t2_env_submit():
|
| 93 |
+
from tasks.task2.environment import Task2Environment
|
| 94 |
+
from data.data_loader import load_contracts, get_function_by_name
|
| 95 |
+
from env.schemas import Action, ActionType
|
| 96 |
+
env = Task2Environment()
|
| 97 |
+
r = env.reset(seed=42)
|
| 98 |
+
fn_name = r.observation.extra["target_function"]
|
| 99 |
+
contract = r.observation.contract_name
|
| 100 |
+
contracts = load_contracts()
|
| 101 |
+
gt_text = ""
|
| 102 |
+
for c in contracts:
|
| 103 |
+
if c["contract_name"] == contract:
|
| 104 |
+
fn = get_function_by_name(c, fn_name)
|
| 105 |
+
if fn and fn.get("property"):
|
| 106 |
+
gt_text = fn["property"]["natural_language"]
|
| 107 |
+
result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": gt_text}))
|
| 108 |
+
assert result.done
|
| 109 |
+
assert result.reward.value > 0, f"GT text should score >0, got {result.reward.value}"
|
| 110 |
+
|
| 111 |
+
def check_t2_one_submit_only():
|
| 112 |
+
from tasks.task2.environment import Task2Environment
|
| 113 |
+
from env.schemas import Action, ActionType
|
| 114 |
+
env = Task2Environment()
|
| 115 |
+
env.reset(seed=5)
|
| 116 |
+
env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test"}))
|
| 117 |
+
# Second submit must either fail (episode done β RuntimeError) or return negative reward
|
| 118 |
+
try:
|
| 119 |
+
s2 = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test2"}))
|
| 120 |
+
# If it doesn't raise, the reward must be negative
|
| 121 |
+
assert s2.reward.value < 0, "Second submit should penalise"
|
| 122 |
+
except RuntimeError:
|
| 123 |
+
pass # expected
|
| 124 |
|
| 125 |
+
def check_t1_grader():
|
| 126 |
from tasks.task1.grader import Task1Grader
|
| 127 |
cases = [
|
| 128 |
+
("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
|
| 129 |
+
("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
|
| 130 |
+
("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
|
| 131 |
]
|
| 132 |
for tf, issue, sf, sv, expected in cases:
|
| 133 |
g = Task1Grader(tf, issue)
|
| 134 |
score = g.grade_submission(sf, sv)
|
| 135 |
+
assert 0.0 <= score <= 1.0
|
| 136 |
assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
|
| 137 |
|
| 138 |
+
def check_t2_grader():
|
| 139 |
+
from tasks.task2.grader import Task2Grader
|
| 140 |
+
from data.data_loader import load_contracts, get_all_property_entries
|
| 141 |
+
contracts = load_contracts()
|
| 142 |
+
entries = get_all_property_entries(contracts)
|
| 143 |
+
for contract, fn in entries:
|
| 144 |
+
g = Task2Grader(fn["name"], fn["property"])
|
| 145 |
+
# Ground truth must score β₯ 0.65
|
| 146 |
+
gt_score = g.grade(fn["property"]["natural_language"])
|
| 147 |
+
assert gt_score >= 0.65, f"{fn['name']}: gt_score={gt_score} < 0.65"
|
| 148 |
+
# Empty must be 0.0
|
| 149 |
+
assert g.grade("") == 0.0
|
| 150 |
+
# Deterministic
|
| 151 |
+
assert g.grade("test text") == g.grade("test text")
|
| 152 |
+
# Score in [0,1]
|
| 153 |
+
assert 0.0 <= gt_score <= 1.0
|
| 154 |
+
# Reward maps correctly
|
| 155 |
+
assert abs(g.reward_for_score(gt_score) - gt_score * 5.0) < 0.01
|
| 156 |
|
| 157 |
def check_reward_shaping():
|
| 158 |
+
from tasks.task2.environment import Task2Environment
|
|
|
|
| 159 |
from env.schemas import Action, ActionType
|
| 160 |
+
env = Task2Environment()
|
| 161 |
env.reset(seed=1)
|
| 162 |
+
rewards = {env.step(Action(action_type=at)).reward.value
|
| 163 |
+
for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FILE_NATSPEC, ActionType.GET_IO]}
|
| 164 |
+
assert len(rewards) >= 2, f"Need multiple reward values, got {rewards}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
def check_t1_episode_boundary():
|
|
|
|
| 167 |
from tasks.task1.environment import Task1Environment
|
| 168 |
from env.schemas import Action, ActionType
|
| 169 |
env = Task1Environment()
|
| 170 |
env.reset(seed=2)
|
| 171 |
+
env.step(Action(action_type=ActionType.SUBMIT,
|
| 172 |
+
params={"function_name": "withdraw", "vulnerability_type": "test"}))
|
|
|
|
| 173 |
try:
|
| 174 |
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 175 |
+
raise AssertionError("Should raise RuntimeError after done")
|
| 176 |
except RuntimeError:
|
| 177 |
+
pass
|
|
|
|
| 178 |
|
| 179 |
def check_repeated_query_penalty():
|
| 180 |
from tasks.task1.environment import Task1Environment
|
| 181 |
from env.schemas import Action, ActionType
|
| 182 |
+
env = Task1Environment(); env.reset(seed=3)
|
|
|
|
| 183 |
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 184 |
r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 185 |
+
assert r.reward.value == -0.40
|
|
|
|
| 186 |
|
| 187 |
+
def check_t2_repeated_penalty():
|
| 188 |
+
from tasks.task2.environment import Task2Environment
|
| 189 |
+
from env.schemas import Action, ActionType
|
| 190 |
+
env = Task2Environment(); env.reset(seed=3)
|
| 191 |
+
env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
|
| 192 |
+
r = env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
|
| 193 |
+
assert r.reward.value == -0.40
|
| 194 |
|
| 195 |
+
def check_task_placeholders():
|
| 196 |
+
from tasks.task3 import __all__ as t3
|
| 197 |
|
| 198 |
+
def check_dockerfile():
|
| 199 |
import os
|
| 200 |
+
assert os.path.exists("Dockerfile")
|
| 201 |
+
with open("Dockerfile") as f: c = f.read()
|
| 202 |
+
assert "7860" in c
|
| 203 |
+
assert "uvicorn" in c or "CMD" in c
|
|
|
|
|
|
|
| 204 |
|
| 205 |
def check_inference_script():
|
| 206 |
import os
|
| 207 |
+
assert os.path.exists("inference.py")
|
| 208 |
+
with open("inference.py") as f: c = f.read()
|
| 209 |
+
assert "HF_TOKEN" in c
|
| 210 |
+
assert "API_BASE_URL" in c
|
| 211 |
+
assert "MODEL_NAME" in c
|
| 212 |
+
assert "task2" in c.lower() or "Task2" in c or "TASK 2" in c
|
| 213 |
+
|
| 214 |
+
def check_baseline_json():
|
|
|
|
|
|
|
|
|
|
| 215 |
import os
|
| 216 |
+
if not os.path.exists("baseline_scores.json"): return
|
| 217 |
+
with open("baseline_scores.json") as f: data = json.load(f)
|
|
|
|
|
|
|
| 218 |
assert "tasks" in data
|
| 219 |
+
for t in data["tasks"]:
|
| 220 |
+
assert 0.0 <= t["avg_grader_score"] <= 1.0
|
|
|
|
| 221 |
|
| 222 |
+
def check_similar_rule_lookup():
|
| 223 |
+
from data.data_loader import load_contracts, get_similar_rule
|
| 224 |
+
contracts = load_contracts()
|
| 225 |
+
sr = get_similar_rule(contracts, "SimpleVault", "withdraw")
|
| 226 |
+
assert sr is not None, "similar_rule should exist for withdraw"
|
| 227 |
+
assert "property_hint" in sr
|
| 228 |
+
assert "contract_name" in sr
|
| 229 |
+
|
| 230 |
+
# ββ Runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 231 |
+
|
| 232 |
+
ALL_CHECKS = [
|
| 233 |
+
("Python imports (T1 + T2)", check_imports),
|
| 234 |
+
("openenv.yaml format", check_openenv_yaml),
|
| 235 |
+
("Pydantic models (incl T2 actions)", check_pydantic_models),
|
| 236 |
+
("Dataset: vuln + property entries", check_data_loading),
|
| 237 |
+
("Task 1: reset / step / state", check_t1_env),
|
| 238 |
+
("Task 2: reset + all 6 browse actions",check_t2_env),
|
| 239 |
+
("Task 2: submit_property scores > 0", check_t2_env_submit),
|
| 240 |
+
("Task 2: one submit only", check_t2_one_submit_only),
|
| 241 |
+
("Task 1 grader: 0/0.5/1.0 rubric", check_t1_grader),
|
| 242 |
+
("Task 2 grader: all 11 properties", check_t2_grader),
|
| 243 |
+
("Reward shaping (multi-value)", check_reward_shaping),
|
| 244 |
+
("T1 episode boundary", check_t1_episode_boundary),
|
| 245 |
+
("T1 repeated query penalty (-0.40)", check_repeated_query_penalty),
|
| 246 |
+
("T2 repeated query penalty (-0.40)", check_t2_repeated_penalty),
|
| 247 |
+
("Task 3 placeholder exists", check_task_placeholders),
|
| 248 |
+
("Dockerfile + port 7860", check_dockerfile),
|
| 249 |
+
("inference.py: creds + Task 2 code", check_inference_script),
|
| 250 |
+
("baseline_scores.json schema", check_baseline_json),
|
| 251 |
+
("similar_rule data lookup", check_similar_rule_lookup),
|
| 252 |
+
]
|
| 253 |
|
| 254 |
def main():
|
| 255 |
+
print("=" * 64)
|
| 256 |
+
print("OpenEnv Pre-Submission Validation (Task 1 + Task 2)")
|
| 257 |
+
print("=" * 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
print()
|
| 259 |
+
for name, fn in ALL_CHECKS:
|
| 260 |
check(name, fn)
|
| 261 |
|
|
|
|
| 262 |
passed = sum(1 for _, ok, _ in results if ok)
|
| 263 |
+
total = len(results)
|
| 264 |
+
failed = [(n, m) for n, ok, m in results if not ok]
|
| 265 |
|
| 266 |
+
print()
|
| 267 |
+
print("=" * 64)
|
| 268 |
print(f"Results: {passed}/{total} checks passed")
|
|
|
|
| 269 |
if failed:
|
| 270 |
print("\nFailed checks:")
|
| 271 |
+
for n, m in failed:
|
| 272 |
+
print(f" {FAIL} {n}: {m}")
|
| 273 |
+
print("\nβ VALIDATION FAILED β fix the issues above before submitting.")
|
|
|
|
| 274 |
sys.exit(1)
|
| 275 |
else:
|
| 276 |
+
print("\nβ
ALL CHECKS PASSED β ready to submit!")
|
|
|
|
| 277 |
sys.exit(0)
|
| 278 |
|
|
|
|
| 279 |
if __name__ == "__main__":
|
| 280 |
main()
|