ajaxwin commited on
Commit
9c888b7
Β·
1 Parent(s): 8fccda7

Task 2 added

Browse files
Files changed (12) hide show
  1. README.md +177 -169
  2. app.py +78 -116
  3. data/data_loader.py +135 -30
  4. demo.py +74 -4
  5. env/schemas.py +37 -36
  6. eval.py +259 -220
  7. inference.py +217 -234
  8. openenv.yaml +67 -91
  9. tasks/task2/__init__.py +4 -26
  10. tasks/task2/environment.py +340 -0
  11. tasks/task2/grader.py +171 -0
  12. validate.py +189 -199
README.md CHANGED
@@ -1,108 +1,120 @@
1
  # Smart Contract Audit RL Environment
2
 
3
  > **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
4
- > Agents learn to audit real-world Solidity contracts β€” finding vulnerabilities, discovering properties, and checking rule compliance β€” tasks that professional auditors perform daily.
5
 
6
- [![OpenEnv Spec](https://img.shields.io/badge/OpenEnv-1.0-blue)](openenv.yaml)
7
- [![HF Space](https://img.shields.io/badge/HuggingFace-Space-yellow)](https://huggingface.co/spaces)
8
  [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-brightgreen)](https://python.org)
 
9
 
10
  ---
11
 
12
  ## Motivation
13
 
14
- Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws. This environment lets agents practice exactly that workflow β€” exploring contract code through targeted queries and submitting findings β€” providing a challenging, real-world benchmark for reasoning and code-understanding agents.
15
 
16
- Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits (reentrancy, integer overflow, access control bypasses, etc.).
17
 
18
  ---
19
 
20
- ## Environment Description
21
 
22
- The environment hosts **3 tasks** of increasing difficulty:
 
 
 
 
23
 
24
- | Task | Name | Difficulty | Status |
25
- |------|------|------------|--------|
26
- | 1 | Targeted Vulnerability Detection | Medium | βœ… Active |
27
- | 2 | Property Discovery | Hard | ⏳ Placeholder |
28
- | 3 | Rule Checker | Easy | ⏳ Placeholder |
29
 
30
- ### Task 1 β€” Targeted Vulnerability Detection *(Medium)*
31
 
32
- **Setup:** The agent is shown a Solidity contract (4–6 functions). One function contains a critical vulnerability.
33
 
34
- **Objective:** Identify the vulnerable function and describe the vulnerability type in 2–3 words.
35
 
36
- **Episode lifecycle:**
37
- 1. `reset()` β€” randomly selects one of 8 vulnerable (contract, function) pairs from the dataset
38
- 2. Agent receives the contract name and description
39
- 3. Agent explores using the action API (each action has a small cost)
40
- 4. Agent calls `submit(function_name, vulnerability_type)` to end the episode
41
- 5. Grader assigns 0.0–1.0 score
42
 
43
- **Vulnerability types in the dataset:**
44
- - Reentrancy
45
- - Missing access control
46
- - Integer overflow (Solidity <0.8)
47
- - tx.origin authentication
48
- - Front-running
49
- - Timestamp dependence
50
- - Denial of service (unbounded loop)
51
- - Unchecked ERC-20 return value
52
 
53
- ---
54
 
55
- ### Task 2 β€” Property Discovery *(Hard)* [Placeholder]
 
 
 
56
 
57
- Given a single Solidity function, the agent must discover its natural-language correctness property. Grading uses semantic similarity to the ground-truth property. *Implementation coming soon.*
 
 
58
 
59
  ---
60
 
61
- ### Task 3 β€” Rule Checker *(Easy)* [Placeholder]
62
 
63
- Given a natural-language property and a contract, the agent must identify which function violates that property. *Implementation coming soon.*
64
 
65
- ---
66
 
67
- ## Action Space
68
 
69
- All actions are described below. **Repeated identical queries cost βˆ’0.40.**
 
 
 
 
 
 
 
 
70
 
71
- | Action | Key Params | Reward |
72
- |--------|-----------|--------|
73
- | `list_functions` | β€” | βˆ’0.05 |
74
- | `get_function_code` | `function_name` | +0.05 (target) / βˆ’0.10 (other) |
75
- | `get_function_summary` | `function_name` | +0.03 (target) / βˆ’0.05 (other) |
76
- | `get_file_metadata` | β€” | βˆ’0.04 |
77
- | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
78
- | `get_call_graph` | β€” | βˆ’0.08 |
79
- | `submit` | `function_name`, `vulnerability_type` | +5.0 / +1.0 / βˆ’1.5 |
80
 
81
- **Submit scoring:**
82
- - **+5.0** β€” correct function AND correct vulnerability keyword β†’ grader score = 1.0
83
- - **+1.0** β€” correct function, unrecognised vulnerability type β†’ grader score = 0.5
84
- - **βˆ’1.5** β€” wrong function β†’ grader score = 0.0
 
 
85
 
86
  ---
87
 
88
  ## Observation Space
89
 
90
- Every `step()` and `reset()` returns an `Observation` object:
91
 
92
  ```json
93
  {
94
- "task_id": "task1_vuln_detection",
95
- "contract_name": "SimpleVault",
96
- "contract_description": "An ETH vault that allows users to deposit and withdraw...",
97
- "available_actions": ["list_functions", "get_function_code", ...],
98
- "last_action": "get_function_code",
99
- "last_action_result": "// withdraw\nfunction withdraw(uint256 amount) ...",
100
- "step_count": 3,
101
- "cumulative_reward": -0.05,
102
  "done": false,
103
  "extra": {
104
- "solidity_version": "0.8.0",
105
- "hint": "Identify the vulnerable function and its issue."
 
 
106
  }
107
  }
108
  ```
@@ -114,63 +126,88 @@ Every `step()` and `reset()` returns an `Observation` object:
114
  ```
115
  smart-contract-env/
116
  β”œβ”€β”€ data/
117
- β”‚ β”œβ”€β”€ contracts.json # 4 contracts, 8 vulnerabilities
118
- β”‚ └── data_loader.py # JSON parsing and episode sampling
119
  β”œβ”€β”€ env/
120
  β”‚ β”œβ”€β”€ base_env.py # Abstract OpenEnv base class
121
- β”‚ └── schemas.py # Pydantic models (Observation, Action, Reward…)
122
  β”œβ”€β”€ tasks/
123
  β”‚ β”œβ”€β”€ task1/
124
  β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 1 RL environment
125
- β”‚ β”‚ └── grader.py # Deterministic 0.0–1.0 grader
126
- β”‚ β”œβ”€β”€ task2/ # TODO: Property Discovery
127
- β”‚ └── task3/ # TODO: Rule Checker
128
- β”œβ”€β”€ utils/
129
- β”œβ”€β”€ app.py # FastAPI server (OpenEnv HTTP interface)
130
- β”œβ”€β”€ inference.py # Baseline inference script (OpenAI client)
131
- β”œβ”€β”€ openenv.yaml # OpenEnv spec metadata
132
- β”œβ”€β”€ Dockerfile
133
- β”œβ”€β”€ requirements.txt
134
- └── README.md
 
 
 
135
  ```
136
 
137
  ---
138
 
139
  ## Setup & Usage
140
 
141
- ### Option A β€” Run locally
142
 
143
  ```bash
144
- # 1. Clone and install
145
- git clone <repo>
146
- cd smart-contract-env
147
  pip install -r requirements.txt
148
 
149
- # 2. Start the server
150
- python app.py
151
- # β†’ http://localhost:7860
 
 
 
 
 
 
 
 
 
 
 
 
152
  ```
153
 
154
- ### Option B β€” Docker
155
 
156
  ```bash
157
  docker build -t sc-audit-env .
158
  docker run -p 7860:7860 sc-audit-env
159
  ```
160
 
161
- ### Option C β€” Python (no server)
162
 
163
  ```python
164
  from tasks.task1.environment import Task1Environment
 
165
  from env.schemas import Action, ActionType
166
 
 
167
  env = Task1Environment()
168
- result = env.reset(seed=42)
169
- print(result.observation.contract_name)
170
-
171
- action = Action(action_type=ActionType.LIST_FUNCTIONS)
172
- step = env.step(action)
173
- print(step.observation.last_action_result)
 
 
 
 
 
 
 
 
 
 
174
  ```
175
 
176
  ---
@@ -180,35 +217,28 @@ print(step.observation.last_action_result)
180
  | Method | Endpoint | Description |
181
  |--------|----------|-------------|
182
  | `GET` | `/health` | Liveness probe |
183
- | `GET` | `/tasks` | List all tasks |
184
- | `POST` | `/reset` | Start new episode |
185
- | `POST` | `/step` | Take one action |
186
- | `GET` | `/state` | Debug: internal state |
187
- | `GET` | `/action_space` | Action space definition |
188
- | `GET` | `/observation_space` | Observation space definition |
189
-
190
- **Example session:**
191
 
192
  ```bash
193
- # Reset
194
- curl -X POST http://localhost:7860/reset \
195
- -H "Content-Type: application/json" \
196
- -d '{"task_id": "task1_vuln_detection", "seed": 42}'
197
-
198
- # List functions
199
- curl -X POST "http://localhost:7860/step" \
200
- -H "Content-Type: application/json" \
201
- -d '{"action_type": "list_functions", "params": {}}'
202
-
203
- # Submit answer
204
- curl -X POST "http://localhost:7860/step" \
205
- -H "Content-Type: application/json" \
206
- -d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
207
  ```
208
 
209
  ---
210
 
211
- ## Running the Baseline
212
 
213
  ```bash
214
  export API_BASE_URL="https://api.openai.com/v1"
@@ -216,86 +246,64 @@ export MODEL_NAME="gpt-4o-mini"
216
  export HF_TOKEN="sk-..."
217
 
218
  python inference.py
 
219
  ```
220
 
221
- Outputs results to stdout and writes `baseline_scores.json`.
222
-
223
- **Expected baseline scores (gpt-4o-mini, 3 episodes):**
224
 
225
  | Task | Avg Grader Score | Notes |
226
  |------|-----------------|-------|
227
- | Task 1 | ~0.67 | Medium difficulty; model identifies common vulns well |
228
- | Task 2 | 0.00 | Placeholder |
229
  | Task 3 | 0.00 | Placeholder |
230
 
231
  ---
232
 
233
- ## Baseline Scores
234
-
235
- ```json
236
- {
237
- "model": "gpt-4o-mini",
238
- "tasks": [
239
- {
240
- "task_id": "task1_vuln_detection",
241
- "avg_grader_score": 0.667,
242
- "avg_cumulative_reward": 2.14
243
- },
244
- { "task_id": "task2_property_discovery", "avg_grader_score": 0.0 },
245
- { "task_id": "task3_rule_checker", "avg_grader_score": 0.0 }
246
- ],
247
- "overall_avg_score": 0.667
248
- }
249
- ```
250
-
251
- ---
252
-
253
- ## Grader Details
254
 
255
- The Task 1 grader is **fully deterministic**:
256
 
257
- 1. **Function name check** β€” case-insensitive exact match against the ground-truth vulnerable function. Wrong function β†’ score = 0.0 immediately.
 
 
 
258
 
259
- 2. **Vulnerability type check** β€” checks whether the submitted string contains any accepted keyword from a predefined keyword table (e.g. `"reentrancy"` table includes: `reentrancy`, `re-entrancy`, `reentrant`, `recursive call`). Match β†’ 1.0; no match β†’ 0.5.
260
-
261
- Scores map to terminal rewards: 1.0 β†’ +5, 0.5 β†’ +1, 0.0 β†’ βˆ’1.5.
262
-
263
- ---
264
-
265
- ## OpenEnv Spec Compliance
266
-
267
- - βœ… Typed `Observation`, `Action`, `Reward` Pydantic models
268
- - βœ… `step(action) β†’ StepResult(observation, reward, done, info)`
269
- - βœ… `reset() β†’ ResetResult(observation, info)`
270
- - βœ… `state() β†’ StateResult`
271
- - βœ… `openenv.yaml` metadata
272
- - βœ… 3 tasks defined (1 active, 2 placeholders)
273
- - βœ… Grader scores in [0.0, 1.0]
274
- - βœ… Shaped rewards (not just binary)
275
- - βœ… Dockerfile + HF Space deployment
276
- - βœ… Baseline `inference.py` using OpenAI client
277
 
278
  ---
279
 
280
  ## Deploying to Hugging Face Spaces
281
 
282
- 1. Create a new **Docker** Space on [huggingface.co/spaces](https://huggingface.co/spaces)
283
- 2. Set the tag `openenv` in the Space metadata
284
- 3. Push this repository:
 
285
 
286
  ```bash
287
- git remote add hf https://huggingface.co/spaces/<your-username>/<space-name>
288
  git push hf main
289
  ```
290
 
291
- The Space will build the Docker image and serve the FastAPI app on port 7860.
292
-
293
  ---
294
 
295
- ## License
296
 
297
- MIT β€” see `LICENSE`.
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- ## Data Attribution
 
 
300
 
301
- Contract vulnerability patterns inspired by and adapted from **Certora** audit findings on production DeFi protocols.
 
1
  # Smart Contract Audit RL Environment
2
 
3
  > **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
4
+ > Train and evaluate agents on real-world Solidity audit tasks β€” the same work professional auditors do every day.
5
 
6
+ [![OpenEnv Spec](https://img.shields.io/badge/OpenEnv-1.1-blue)](openenv.yaml)
 
7
  [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-brightgreen)](https://python.org)
8
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE)
9
 
10
  ---
11
 
12
  ## Motivation
13
 
14
+ Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws and formally specify function properties. This environment lets agents practice exactly that workflow β€” exploring contract code through targeted queries and submitting findings β€” providing a rigorous, real-world benchmark for code-reasoning agents.
15
 
16
+ Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits.
17
 
18
  ---
19
 
20
+ ## Tasks
21
 
22
+ | # | Name | Difficulty | Status | Description |
23
+ |---|------|------------|--------|-------------|
24
+ | 1 | Targeted Vulnerability Detection | Medium | βœ… Active | Find the vulnerable function and name the vulnerability type |
25
+ | 2 | Property Discovery | Hard | βœ… Active | Write the natural-language postcondition for a given function |
26
+ | 3 | Rule Checker | Easy | ⏳ Placeholder | Identify which function violates a given property |
27
 
28
+ ---
 
 
 
 
29
 
30
+ ## Task 1 β€” Targeted Vulnerability Detection *(Medium)*
31
 
32
+ **Setup:** Agent is shown a Solidity contract (4–6 functions). One function contains a critical vulnerability.
33
 
34
+ **Objective:** Identify the vulnerable function and describe its vulnerability type in 2–3 words.
35
 
36
+ ### Actions
 
 
 
 
 
37
 
38
+ | Action | Params | Reward |
39
+ |--------|--------|--------|
40
+ | `list_functions` | β€” | βˆ’0.05 |
41
+ | `get_function_code` | `function_name` | +0.05 (target) / βˆ’0.10 (other) |
42
+ | `get_function_summary` | `function_name` | +0.03 (target) / βˆ’0.05 (other) |
43
+ | `get_file_metadata` | β€” | βˆ’0.04 |
44
+ | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
45
+ | `get_call_graph` | β€” | βˆ’0.08 |
46
+ | `submit` | `function_name`, `vulnerability_type` | **+5.0** / +1.0 / βˆ’1.5 |
47
 
48
+ Repeated identical queries: **βˆ’0.40**
49
 
50
+ ### Submit scoring (deterministic)
51
+ - **1.0** β†’ correct function **+** correct vulnerability keyword β†’ reward +5.0
52
+ - **0.5** β†’ correct function, wrong/vague vulnerability type β†’ reward +1.0
53
+ - **0.0** β†’ wrong function β†’ reward βˆ’1.5
54
 
55
+ ### Vulnerability types in dataset
56
+ Reentrancy Β· Missing access control Β· Integer overflow Β· tx.origin authentication Β·
57
+ Front-running Β· Timestamp dependence Β· Denial of service (unbounded loop) Β· Unchecked return value
58
 
59
  ---
60
 
61
+ ## Task 2 β€” Property Discovery *(Hard)*
62
 
63
+ **Setup:** Agent is shown a single Solidity function and must write its natural-language correctness property (postcondition / invariant).
64
 
65
+ **Objective:** Write a precise 2–4 sentence property describing what the function guarantees when it succeeds.
66
 
67
+ ### Actions
68
 
69
+ | Action | Params | Reward |
70
+ |--------|--------|--------|
71
+ | `get_function_code` | β€” | βˆ’0.06 |
72
+ | `get_function_natspec` | β€” | βˆ’0.08 |
73
+ | `get_file_natspec` | β€” | βˆ’0.03 |
74
+ | `get_related_functions` | β€” | βˆ’0.06 |
75
+ | `get_io` | β€” | βˆ’0.04 |
76
+ | `get_similar_rule` | β€” | βˆ’0.20 |
77
+ | `submit_property` | `property` (string) | **0.0–5.0** (scored, ONE attempt) |
78
 
79
+ Repeated identical queries: **βˆ’0.40**
80
+
81
+ ### Submit scoring (keyword-weighted)
82
+ ```
83
+ score = 0.70 Γ— (key_phrases_matched / total_key_phrases)
84
+ + 0.30 Γ— (bonus_phrases_matched / total_bonus_phrases)
85
+
86
+ reward = score Γ— 5.0 β†’ range: 0.0 – 5.0
87
+ ```
88
 
89
+ Matching uses **word-set containment** with synonym expansion (e.g. "caller" matches "msg.sender", "sender", "user"). Phrases don't need to be adjacent β€” all constituent words just need to appear somewhere in the submitted text.
90
+
91
+ **One submission per episode** β€” choose carefully.
92
+
93
+ ### Property coverage
94
+ 11 functions across 4 contracts with ground-truth properties: SimpleVault (deposit, withdraw, emergencyDrain), TokenSale (buyTokens, setPrice, withdrawETH), DutchAuction (getPrice, bid, finalize), YieldFarm (stake, claimRewards).
95
 
96
  ---
97
 
98
  ## Observation Space
99
 
100
+ Every `step()` and `reset()` returns the same `Observation` structure:
101
 
102
  ```json
103
  {
104
+ "task_id": "task2_property_discovery",
105
+ "contract_name": "YieldFarm",
106
+ "contract_description": "A simple yield farming contract...",
107
+ "available_actions": ["get_function_code", "get_function_natspec", ...],
108
+ "last_action": "get_function_natspec",
109
+ "last_action_result": "NatSpec for 'claimRewards':\n@notice Claim all accrued...",
110
+ "step_count": 2,
111
+ "cumulative_reward": -0.14,
112
  "done": false,
113
  "extra": {
114
+ "target_function": "claimRewards",
115
+ "target_signature": "claimRewards()",
116
+ "solidity_version": "0.8.10",
117
+ "hint": "Discover the property of the target function..."
118
  }
119
  }
120
  ```
 
126
  ```
127
  smart-contract-env/
128
  β”œβ”€β”€ data/
129
+ β”‚ β”œβ”€β”€ contracts.json # 4 contracts Β· 8 vulnerabilities Β· 11 properties
130
+ β”‚ └── data_loader.py # JSON parser, episode samplers, T1 + T2 helpers
131
  β”œβ”€β”€ env/
132
  β”‚ β”œβ”€β”€ base_env.py # Abstract OpenEnv base class
133
+ β”‚ └── schemas.py # Pydantic: Observation, Action, Reward, StepResult…
134
  β”œβ”€β”€ tasks/
135
  β”‚ β”œβ”€β”€ task1/
136
  β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 1 RL environment
137
+ β”‚ β”‚ └── grader.py # Deterministic 0/0.5/1.0 rubric + longest-match keywords
138
+ β”‚ β”œβ”€β”€ task2/
139
+ β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 2 RL environment (one submit per episode)
140
+ β”‚ β”‚ └── grader.py # Keyword-weighted 0.0–1.0 grader + synonym expansion
141
+ β”‚ └── task3/ # TODO: Rule Checker (placeholder)
142
+ β”œβ”€β”€ app.py # FastAPI server β€” all OpenEnv HTTP endpoints
143
+ β”œβ”€β”€ inference.py # Baseline LLM agent (Task 1 + Task 2)
144
+ β”œβ”€β”€ eval.py # Oracle/partial/random evaluation harness
145
+ β”œβ”€β”€ demo.py # Colourised interactive + scripted demo
146
+ β”œβ”€β”€ validate.py # 19-check pre-submission validator
147
+ β”œβ”€β”€ openenv.yaml # Full OpenEnv spec metadata
148
+ β”œβ”€β”€ Dockerfile # Port 7860, uvicorn, healthcheck
149
+ └── requirements.txt
150
  ```
151
 
152
  ---
153
 
154
  ## Setup & Usage
155
 
156
+ ### Local Python
157
 
158
  ```bash
159
+ git clone <repo> && cd smart-contract-env
 
 
160
  pip install -r requirements.txt
161
 
162
+ # Run the server
163
+ python app.py # β†’ http://localhost:7860
164
+
165
+ # Run interactive demo
166
+ python demo.py # Task 1 interactive
167
+ python demo.py --auto # Task 1 scripted
168
+ python demo.py --auto --task 2 # Task 2 scripted (add --task flag)
169
+
170
+ # Run evaluation harness (no LLM needed)
171
+ python eval.py # Both tasks, 8 episodes each
172
+ python eval.py --task 2 # Task 2 only
173
+ python eval.py --episodes 16 --verbose
174
+
175
+ # Pre-submission validation
176
+ python validate.py # 19/19 checks
177
  ```
178
 
179
+ ### Docker
180
 
181
  ```bash
182
  docker build -t sc-audit-env .
183
  docker run -p 7860:7860 sc-audit-env
184
  ```
185
 
186
+ ### Direct Python API
187
 
188
  ```python
189
  from tasks.task1.environment import Task1Environment
190
+ from tasks.task2.environment import Task2Environment
191
  from env.schemas import Action, ActionType
192
 
193
+ # Task 1
194
  env = Task1Environment()
195
+ r = env.reset(seed=42)
196
+ print(r.observation.contract_name) # SimpleVault
197
+ s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
198
+ s = env.step(Action(action_type=ActionType.SUBMIT,
199
+ params={"function_name": "emergencyDrain",
200
+ "vulnerability_type": "missing access control"}))
201
+ print(s.reward.value) # +5.0
202
+
203
+ # Task 2
204
+ env2 = Task2Environment()
205
+ r2 = env2.reset(seed=42)
206
+ print(r2.observation.extra["target_function"]) # claimRewards
207
+ s2 = env2.step(Action(action_type=ActionType.GET_FUNCTION_NATSPEC))
208
+ s2 = env2.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
209
+ params={"property": "After a successful claimRewards call, all accrued reward tokens are transferred to the caller and their rewards balance is zeroed. Reverts if no rewards."}))
210
+ print(s2.reward.value) # ~4.0
211
  ```
212
 
213
  ---
 
217
  | Method | Endpoint | Description |
218
  |--------|----------|-------------|
219
  | `GET` | `/health` | Liveness probe |
220
+ | `GET` | `/tasks` | All tasks + status |
221
+ | `POST` | `/reset` | Start episode (`task_id`, `seed`) |
222
+ | `POST` | `/step` | Take action (`action_type`, `params`) |
223
+ | `GET` | `/state` | Debug: internal episode state |
224
+ | `GET` | `/action_space?task_id=...` | Action schema for a task |
225
+ | `GET` | `/observation_space` | Observation schema |
 
 
226
 
227
  ```bash
228
+ # Task 2 full episode
229
+ curl -X POST localhost:7860/reset \
230
+ -d '{"task_id":"task2_property_discovery","seed":42}'
231
+
232
+ curl -X POST localhost:7860/step \
233
+ -d '{"action_type":"get_function_natspec","params":{}}'
234
+
235
+ curl -X POST localhost:7860/step \
236
+ -d '{"action_type":"submit_property","params":{"property":"..."}}'
 
 
 
 
 
237
  ```
238
 
239
  ---
240
 
241
+ ## Baseline Inference
242
 
243
  ```bash
244
  export API_BASE_URL="https://api.openai.com/v1"
 
246
  export HF_TOKEN="sk-..."
247
 
248
  python inference.py
249
+ # β†’ baseline_scores.json
250
  ```
251
 
252
+ ### Expected baseline scores (gpt-4o-mini, 3 episodes per task)
 
 
253
 
254
  | Task | Avg Grader Score | Notes |
255
  |------|-----------------|-------|
256
+ | Task 1 | ~0.67 | Good at common vulns; misses subtle ones |
257
+ | Task 2 | ~0.55 | Reasonable properties but often misses specific variable names |
258
  | Task 3 | 0.00 | Placeholder |
259
 
260
  ---
261
 
262
+ ## Evaluation Scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ Deterministic oracle / partial / baseline tiers verified on 8 episodes (seeds 42–49):
265
 
266
+ | Task | Oracle | Partial | Floor |
267
+ |------|--------|---------|-------|
268
+ | Task 1 | **1.000** | 0.500 | 0.000 |
269
+ | Task 2 | **0.775** | 0.034 | 0.000 |
270
 
271
+ The clear separation confirms the grader provides **meaningful gradient signal** for RL training.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  ---
274
 
275
  ## Deploying to Hugging Face Spaces
276
 
277
+ 1. Create a new **Docker** Space at [huggingface.co/spaces](https://huggingface.co/spaces)
278
+ 2. Add tag `openenv` in the Space settings
279
+ 3. Copy the `SPACES_README.md` frontmatter into `README.md`
280
+ 4. Push:
281
 
282
  ```bash
283
+ git remote add hf https://huggingface.co/spaces/<user>/<space>
284
  git push hf main
285
  ```
286
 
 
 
287
  ---
288
 
289
+ ## OpenEnv Spec Compliance
290
 
291
+ | Requirement | Status |
292
+ |-------------|--------|
293
+ | Typed `Observation`, `Action`, `Reward` Pydantic models | βœ… |
294
+ | `step(action) β†’ StepResult(obs, reward, done, info)` | βœ… |
295
+ | `reset() β†’ ResetResult` | βœ… |
296
+ | `state() β†’ StateResult` | βœ… |
297
+ | `openenv.yaml` metadata | βœ… |
298
+ | 3+ tasks defined | βœ… (2 active, 1 placeholder) |
299
+ | Grader scores in [0.0, 1.0] | βœ… |
300
+ | Shaped rewards (non-binary) | βœ… |
301
+ | Dockerfile + port 7860 | βœ… |
302
+ | `inference.py` with OpenAI client | βœ… |
303
+ | `validate.py` β€” all 19 checks pass | βœ… |
304
 
305
+ ---
306
+
307
+ ## License
308
 
309
+ MIT. Contract vulnerability data adapted from Certora audits on production DeFi protocols.
app.py CHANGED
@@ -4,31 +4,30 @@ app.py
4
  FastAPI server exposing the OpenEnv HTTP interface.
5
 
6
  Endpoints:
7
- POST /reset – start a new episode
8
- POST /step – take one action
9
- GET /state – inspect internal state (debugging)
10
- GET /tasks – list available tasks
11
- GET /health – liveness probe
12
- GET /action_space – action space description
13
- GET /observation_space – observation space description
14
-
15
- Sessions are keyed by a UUID passed as the `session_id` query parameter.
16
- If omitted, a default single-session is used (fine for sequential runs).
17
  """
18
 
19
- import uuid
20
  from typing import Dict, Optional
21
 
22
  from fastapi import FastAPI, HTTPException, Query
23
- from fastapi.responses import JSONResponse
24
  from pydantic import BaseModel
25
 
26
  from env.schemas import Action, ActionType, TaskInfo
27
  from tasks.task1.environment import Task1Environment
 
28
 
29
- # ---------------------------------------------------------------------------
30
- # App init
31
- # ---------------------------------------------------------------------------
32
 
33
  app = FastAPI(
34
  title="Smart Contract Audit RL Environment",
@@ -36,38 +35,36 @@ app = FastAPI(
36
  "OpenEnv-compliant reinforcement learning environment for smart contract "
37
  "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
38
  ),
39
- version="1.0.0",
40
  )
41
 
42
- # ---------------------------------------------------------------------------
43
  # Session management
44
- # ---------------------------------------------------------------------------
45
 
46
- _sessions: Dict[str, Task1Environment] = {}
47
  DEFAULT_SESSION = "default"
48
 
 
 
 
 
 
49
 
50
- def _get_or_create_session(session_id: str, task_id: str = "task1_vuln_detection") -> Task1Environment:
51
- if session_id not in _sessions:
52
- env = _create_env(task_id)
53
- _sessions[session_id] = env
54
- return _sessions[session_id]
55
 
56
-
57
- def _create_env(task_id: str) -> Task1Environment:
58
- if task_id == "task1_vuln_detection":
59
- return Task1Environment()
60
- # TODO: elif task_id == "task2_property_discovery": return Task2Environment()
61
- # TODO: elif task_id == "task3_rule_checker": return Task3Environment()
62
- raise HTTPException(
63
- status_code=400,
64
- detail=f"Unknown task_id '{task_id}'. Available: ['task1_vuln_detection']",
65
- )
66
 
67
 
68
- # ---------------------------------------------------------------------------
69
- # Request/response models
70
- # ---------------------------------------------------------------------------
71
 
72
  class ResetRequest(BaseModel):
73
  task_id: str = "task1_vuln_detection"
@@ -79,48 +76,39 @@ class StepRequest(BaseModel):
79
  params: dict = {}
80
 
81
 
82
- # ---------------------------------------------------------------------------
83
  # Routes
84
- # ---------------------------------------------------------------------------
85
 
86
  @app.get("/health")
87
  def health():
88
- """Liveness probe β€” returns 200 OK."""
89
- return {"status": "ok", "version": "1.0.0"}
90
 
91
 
92
  @app.get("/tasks")
93
  def list_tasks():
94
- """List all available tasks."""
95
  tasks = [
96
  TaskInfo(
97
  task_id="task1_vuln_detection",
98
  name="Targeted Vulnerability Detection",
99
  difficulty="medium",
100
- description=(
101
- "Given a Solidity contract, identify the vulnerable function "
102
- "and describe the vulnerability type in 2-3 words."
103
- ),
104
  status="active",
105
  ),
106
  TaskInfo(
107
  task_id="task2_property_discovery",
108
  name="Property Discovery",
109
  difficulty="hard",
110
- description=(
111
- "Given a Solidity function, discover the natural-language property "
112
- "that describes its correct behaviour."
113
- ),
114
- status="placeholder",
115
  ),
116
  TaskInfo(
117
  task_id="task3_rule_checker",
118
  name="Rule Checker",
119
  difficulty="easy",
120
- description=(
121
- "Given a property in English, identify which function in the contract "
122
- "violates that property."
123
- ),
124
  status="placeholder",
125
  ),
126
  ]
@@ -144,7 +132,7 @@ def step(
144
  body: StepRequest,
145
  session_id: str = Query(default=DEFAULT_SESSION),
146
  ):
147
- """Apply an action and advance the episode."""
148
  env = _sessions.get(session_id)
149
  if env is None:
150
  raise HTTPException(
@@ -156,8 +144,7 @@ def step(
156
  except ValueError:
157
  raise HTTPException(
158
  status_code=400,
159
- detail=f"Unknown action_type '{body.action_type}'. "
160
- f"Valid: {[a.value for a in ActionType]}",
161
  )
162
  action = Action(action_type=action_type, params=body.params)
163
  try:
@@ -169,7 +156,7 @@ def step(
169
 
170
  @app.get("/state")
171
  def state(session_id: str = Query(default=DEFAULT_SESSION)):
172
- """Return current internal state (for debugging; not for agents)."""
173
  env = _sessions.get(session_id)
174
  if env is None:
175
  raise HTTPException(
@@ -186,51 +173,26 @@ def action_space(task_id: str = "task1_vuln_detection"):
186
  return {
187
  "task_id": task_id,
188
  "actions": [
189
- {
190
- "type": "list_functions",
191
- "params": {},
192
- "reward": -0.05,
193
- "description": "List all function names in the contract",
194
- },
195
- {
196
- "type": "get_function_code",
197
- "params": {"function_name": "string"},
198
- "reward": "+0.05 (target fn) / -0.10 (wrong fn)",
199
- "description": "Retrieve the full Solidity code of a function",
200
- },
201
- {
202
- "type": "get_function_summary",
203
- "params": {"function_name": "string"},
204
- "reward": "+0.03 (target fn) / -0.05 (wrong fn)",
205
- "description": "Retrieve the NatSpec comment/summary of a function",
206
- },
207
- {
208
- "type": "get_file_metadata",
209
- "params": {},
210
- "reward": -0.04,
211
- "description": "Retrieve contract-level metadata (version, author, description)",
212
- },
213
- {
214
- "type": "get_state_variable",
215
- "params": {"variable_name": "string (optional)"},
216
- "reward": -0.05,
217
- "description": "Retrieve a state variable or list all variables",
218
- },
219
- {
220
- "type": "get_call_graph",
221
- "params": {},
222
- "reward": -0.08,
223
- "description": "Retrieve the function call graph",
224
- },
225
- {
226
- "type": "submit",
227
- "params": {
228
- "function_name": "string",
229
- "vulnerability_type": "string",
230
- },
231
- "reward": "+5.0 (correct) / +1.0 (right fn, wrong vuln) / -1.5 (wrong)",
232
- "description": "Submit your final answer. Ends the episode.",
233
- },
234
  ],
235
  }
236
  return {"error": f"No action space defined for task '{task_id}'"}
@@ -238,27 +200,27 @@ def action_space(task_id: str = "task1_vuln_detection"):
238
 
239
  @app.get("/observation_space")
240
  def observation_space():
241
- """Describe the observation space."""
242
  return {
243
  "type": "object",
244
  "fields": {
245
- "task_id": "string – active task identifier",
246
- "contract_name": "string – name of the Solidity contract",
247
  "contract_description": "string – what the contract does",
248
- "available_actions": "list[string] – valid action types",
249
- "last_action": "string|null – the previous action type",
250
- "last_action_result": "string|null – human-readable result of last action",
251
- "step_count": "int – steps taken so far",
252
- "cumulative_reward": "float – running reward total",
253
- "done": "bool – True when episode is over",
254
- "extra": "object – task-specific hints and metadata",
255
  },
256
  }
257
 
258
 
259
- # ---------------------------------------------------------------------------
260
  # Entry point
261
- # ---------------------------------------------------------------------------
262
 
263
  if __name__ == "__main__":
264
  import uvicorn
 
4
  FastAPI server exposing the OpenEnv HTTP interface.
5
 
6
  Endpoints:
7
+ POST /reset – start a new episode
8
+ POST /step – take one action
9
+ GET /state – inspect internal state (debugging)
10
+ GET /tasks – list available tasks
11
+ GET /health – liveness probe
12
+ GET /action_space – action space description for a task
13
+ GET /observation_space – observation space description
14
+
15
+ Sessions are keyed by a UUID in the `session_id` query parameter.
16
+ If omitted, "default" is used (fine for sequential single-agent runs).
17
  """
18
 
 
19
  from typing import Dict, Optional
20
 
21
  from fastapi import FastAPI, HTTPException, Query
 
22
  from pydantic import BaseModel
23
 
24
  from env.schemas import Action, ActionType, TaskInfo
25
  from tasks.task1.environment import Task1Environment
26
+ from tasks.task2.environment import Task2Environment
27
 
28
+ # ─────────────────────────────────────────────────────────────────────────────
29
+ # App
30
+ # ─────────────────────────────────────────────────────────────────────────────
31
 
32
  app = FastAPI(
33
  title="Smart Contract Audit RL Environment",
 
35
  "OpenEnv-compliant reinforcement learning environment for smart contract "
36
  "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
37
  ),
38
+ version="1.1.0",
39
  )
40
 
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
  # Session management
43
+ # ─────────────────────────────────────────────────────────────────────────────
44
 
45
+ _sessions: Dict[str, object] = {}
46
  DEFAULT_SESSION = "default"
47
 
48
+ TASK_ENV_MAP = {
49
+ "task1_vuln_detection": Task1Environment,
50
+ "task2_property_discovery": Task2Environment,
51
+ # TODO: "task3_rule_checker": Task3Environment,
52
+ }
53
 
 
 
 
 
 
54
 
55
+ def _create_env(task_id: str):
56
+ cls = TASK_ENV_MAP.get(task_id)
57
+ if cls is None:
58
+ raise HTTPException(
59
+ status_code=400,
60
+ detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}",
61
+ )
62
+ return cls()
 
 
63
 
64
 
65
+ # ─────────────────────────────────────────────────────────────────────────────
66
+ # Request bodies
67
+ # ─────────────────────────────────────────────────────────────────────────────
68
 
69
  class ResetRequest(BaseModel):
70
  task_id: str = "task1_vuln_detection"
 
76
  params: dict = {}
77
 
78
 
79
+ # ─────────────────────────────────────────────────────────────────────────────
80
  # Routes
81
+ # ─────────────────────────────────────────────────────────────────────────────
82
 
83
  @app.get("/health")
84
  def health():
85
+ """Liveness probe."""
86
+ return {"status": "ok", "version": "1.1.0"}
87
 
88
 
89
  @app.get("/tasks")
90
  def list_tasks():
91
+ """List all tasks with their status."""
92
  tasks = [
93
  TaskInfo(
94
  task_id="task1_vuln_detection",
95
  name="Targeted Vulnerability Detection",
96
  difficulty="medium",
97
+ description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.",
 
 
 
98
  status="active",
99
  ),
100
  TaskInfo(
101
  task_id="task2_property_discovery",
102
  name="Property Discovery",
103
  difficulty="hard",
104
+ description="Given a Solidity function, write the natural-language property that describes its correct behaviour.",
105
+ status="active",
 
 
 
106
  ),
107
  TaskInfo(
108
  task_id="task3_rule_checker",
109
  name="Rule Checker",
110
  difficulty="easy",
111
+ description="Given a property in English and a Solidity contract, identify which function violates that property.",
 
 
 
112
  status="placeholder",
113
  ),
114
  ]
 
132
  body: StepRequest,
133
  session_id: str = Query(default=DEFAULT_SESSION),
134
  ):
135
+ """Apply one action and advance the episode."""
136
  env = _sessions.get(session_id)
137
  if env is None:
138
  raise HTTPException(
 
144
  except ValueError:
145
  raise HTTPException(
146
  status_code=400,
147
+ detail=f"Unknown action_type '{body.action_type}'. Valid: {[a.value for a in ActionType]}",
 
148
  )
149
  action = Action(action_type=action_type, params=body.params)
150
  try:
 
156
 
157
  @app.get("/state")
158
  def state(session_id: str = Query(default=DEFAULT_SESSION)):
159
+ """Return internal state for debugging (not for agents)."""
160
  env = _sessions.get(session_id)
161
  if env is None:
162
  raise HTTPException(
 
173
  return {
174
  "task_id": task_id,
175
  "actions": [
176
+ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
177
+ {"type": "get_function_code", "params": {"function_name": "string"}, "reward": "+0.05 (target) / -0.10 (other)", "description": "Get full Solidity source of a function"},
178
+ {"type": "get_function_summary", "params": {"function_name": "string"}, "reward": "+0.03 (target) / -0.05 (other)", "description": "Get NatSpec comment of a function"},
179
+ {"type": "get_file_metadata", "params": {}, "reward": -0.04, "description": "Get contract-level metadata"},
180
+ {"type": "get_state_variable", "params": {"variable_name": "string (optional)"}, "reward": -0.05, "description": "Get a state variable or list all"},
181
+ {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"},
182
+ {"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."},
183
+ ],
184
+ }
185
+ if task_id == "task2_property_discovery":
186
+ return {
187
+ "task_id": task_id,
188
+ "actions": [
189
+ {"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"},
190
+ {"type": "get_function_natspec", "params": {}, "reward": -0.08, "description": "Read NatSpec + expected behaviour"},
191
+ {"type": "get_file_natspec", "params": {}, "reward": -0.03, "description": "Read contract-level NatSpec"},
192
+ {"type": "get_related_functions", "params": {}, "reward": -0.06, "description": "List caller/callee functions with summaries"},
193
+ {"type": "get_io", "params": {}, "reward": -0.04, "description": "Get structured I/O + expected behaviour"},
194
+ {"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"},
195
+ {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0–5.0 (scored)", "description": "Submit property. ONE attempt. Ends episode."},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  ],
197
  }
198
  return {"error": f"No action space defined for task '{task_id}'"}
 
200
 
201
  @app.get("/observation_space")
202
  def observation_space():
203
+ """Describe the observation space (same for all tasks)."""
204
  return {
205
  "type": "object",
206
  "fields": {
207
+ "task_id": "string – active task identifier",
208
+ "contract_name": "string – Solidity contract name",
209
  "contract_description": "string – what the contract does",
210
+ "available_actions": "list[string] – valid action types for this task",
211
+ "last_action": "string|null – previous action type",
212
+ "last_action_result": "string|null – human-readable result of last action",
213
+ "step_count": "int – steps taken in this episode",
214
+ "cumulative_reward": "float – running reward total",
215
+ "done": "bool – True when episode is over",
216
+ "extra": "object – task-specific hints (target_function, hint, etc.)",
217
  },
218
  }
219
 
220
 
221
+ # ─────────────────────────────────────────────────────────────────────────────
222
  # Entry point
223
+ # ─────────────────────────────────────────────────────────────────────────────
224
 
225
  if __name__ == "__main__":
226
  import uvicorn
data/data_loader.py CHANGED
@@ -2,8 +2,9 @@
2
  data_loader.py
3
  --------------
4
  Loads and indexes smart contract data from JSON files.
5
- Each contract is parsed into a structured dict; vulnerable functions
6
- are indexed for fast lookup by Task 1.
 
7
  """
8
 
9
  import json
@@ -16,25 +17,62 @@ DATA_DIR = os.path.join(os.path.dirname(__file__))
16
  DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
17
  DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json")
18
 
 
 
 
 
 
19
  def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
20
  """Load and return all contracts from the JSON dataset."""
21
  with open(path, "r") as f:
22
  return json.load(f)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]:
26
  """Load and return all vulnerability entries from the JSON dataset."""
27
  with open(path, "r") as f:
28
  return json.load(f)
29
 
30
-
31
  def get_all_vulnerable_entries(
32
  contracts: List[Dict[str, Any]],
33
  ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
34
  """
35
  Returns a flat list of (contract, function) pairs where
36
  function['vulnerable'] is True.
37
- Used by Task 1 to populate the episode pool.
38
  """
39
  entries = []
40
  for contract in contracts:
@@ -48,10 +86,7 @@ def sample_episode(
48
  contracts: List[Dict[str, Any]],
49
  rng: Optional[random.Random] = None,
50
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
51
- """
52
- Randomly selects one (contract, vulnerable_function) pair.
53
- Returns the contract dict and the target function dict.
54
- """
55
  if rng is None:
56
  rng = random.Random()
57
  entries = get_all_vulnerable_entries(contracts)
@@ -60,31 +95,101 @@ def sample_episode(
60
  return rng.choice(entries)
61
 
62
 
63
- def get_function_by_name(
64
- contract: Dict[str, Any], name: str
65
- ) -> Optional[Dict[str, Any]]:
66
- """Case-insensitive function lookup within a contract."""
67
- for fn in contract.get("functions", []):
68
- if fn["name"].lower() == name.lower():
69
- return fn
70
- return None
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def get_state_variable_by_name(
74
- contract: Dict[str, Any], name: str
75
- ) -> Optional[Dict[str, Any]]:
76
- """Case-insensitive state variable lookup."""
77
- for sv in contract.get("state_variables", []):
78
- if sv["name"].lower() == name.lower():
79
- return sv
80
- return None
81
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def list_function_names(contract: Dict[str, Any]) -> List[str]:
84
- """Return all function names in the contract."""
85
- return [fn["name"] for fn in contract.get("functions", [])]
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
89
- """Return all state variable names."""
90
- return [sv["name"] for sv in contract.get("state_variables", [])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  data_loader.py
3
  --------------
4
  Loads and indexes smart contract data from JSON files.
5
+
6
+ Task 1 helpers – vulnerable function sampling
7
+ Task 2 helpers – property function sampling, natspec, similar-rule lookup
8
  """
9
 
10
  import json
 
17
  DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
18
  DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json")
19
 
20
+
21
+ # ────────────────────────────────────────────────────────────────
22
+ # Core loaders
23
+ # ────────────────────────────────────────────────────────────────
24
+
25
  def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
26
  """Load and return all contracts from the JSON dataset."""
27
  with open(path, "r") as f:
28
  return json.load(f)
29
 
30
 
31
+ def get_function_by_name(
32
+ contract: Dict[str, Any], name: str
33
+ ) -> Optional[Dict[str, Any]]:
34
+ """Case-insensitive function lookup within a contract."""
35
+ for fn in contract.get("functions", []):
36
+ if fn["name"].lower() == name.lower():
37
+ return fn
38
+ return None
39
+
40
+
41
+ def get_state_variable_by_name(
42
+ contract: Dict[str, Any], name: str
43
+ ) -> Optional[Dict[str, Any]]:
44
+ """Case-insensitive state variable lookup."""
45
+ for sv in contract.get("state_variables", []):
46
+ if sv["name"].lower() == name.lower():
47
+ return sv
48
+ return None
49
+
50
+
51
+ def list_function_names(contract: Dict[str, Any]) -> List[str]:
52
+ """Return all function names in the contract."""
53
+ return [fn["name"] for fn in contract.get("functions", [])]
54
+
55
+
56
+ def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
57
+ """Return all state variable names."""
58
+ return [sv["name"] for sv in contract.get("state_variables", [])]
59
+
60
+
61
+ # ────────────────────────────────────────────────────────────────
62
+ # Task 1 helpers
63
+ # ────────────────────────────────────────────────────────────────
64
+
65
  def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]:
66
  """Load and return all vulnerability entries from the JSON dataset."""
67
  with open(path, "r") as f:
68
  return json.load(f)
69
 
 
70
  def get_all_vulnerable_entries(
71
  contracts: List[Dict[str, Any]],
72
  ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
73
  """
74
  Returns a flat list of (contract, function) pairs where
75
  function['vulnerable'] is True.
 
76
  """
77
  entries = []
78
  for contract in contracts:
 
86
  contracts: List[Dict[str, Any]],
87
  rng: Optional[random.Random] = None,
88
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
89
+ """Randomly selects one (contract, vulnerable_function) pair for Task 1."""
 
 
 
90
  if rng is None:
91
  rng = random.Random()
92
  entries = get_all_vulnerable_entries(contracts)
 
95
  return rng.choice(entries)
96
 
97
 
98
+ # ────────────────────────────────────────────────────────────────
99
+ # Task 2 helpers
100
+ # ────────────────────────────────────────────────────────────────
 
 
 
 
 
101
 
102
+ def get_all_property_entries(
103
+ contracts: List[Dict[str, Any]],
104
+ ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
105
+ """
106
+ Returns a flat list of (contract, function) pairs where
107
+ function['property'] is not None.
108
+ Used by Task 2 to populate the episode pool.
109
+ """
110
+ entries = []
111
+ for contract in contracts:
112
+ for fn in contract.get("functions", []):
113
+ if fn.get("property") is not None:
114
+ entries.append((contract, fn))
115
+ return entries
116
 
 
 
 
 
 
 
 
 
117
 
118
+ def sample_property_episode(
119
+ contracts: List[Dict[str, Any]],
120
+ rng: Optional[random.Random] = None,
121
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
122
+ """Randomly selects one (contract, function-with-property) pair for Task 2."""
123
+ if rng is None:
124
+ rng = random.Random()
125
+ entries = get_all_property_entries(contracts)
126
+ if not entries:
127
+ raise ValueError("No functions with properties found in dataset.")
128
+ return rng.choice(entries)
129
 
 
 
 
130
 
131
+ def get_related_functions(
132
+ contract: Dict[str, Any],
133
+ function_name: str,
134
+ ) -> List[str]:
135
+ """
136
+ Returns function names that are related to the given function:
137
+ - Functions that it calls (from call_graph)
138
+ - Functions that call it (reverse call_graph lookup)
139
+ """
140
+ name_lower = function_name.lower()
141
+ cg: Dict[str, List[str]] = contract.get("call_graph", {})
142
+ related = set()
143
 
144
+ # Direct callees (functions called by this function)
145
+ for callee in cg.get(function_name, []):
146
+ # Only include callees that are also functions in this contract
147
+ if get_function_by_name(contract, callee) is not None:
148
+ related.add(callee)
149
+
150
+ # Reverse: functions that call this function
151
+ for caller_name, callees in cg.items():
152
+ if any(c.lower() == name_lower for c in callees):
153
+ if get_function_by_name(contract, caller_name) is not None:
154
+ related.add(caller_name)
155
+
156
+ return sorted(related)
157
+
158
+
159
+ def get_similar_rule(
160
+ contracts: List[Dict[str, Any]],
161
+ current_contract_name: str,
162
+ current_function_name: str,
163
+ ) -> Optional[Dict[str, Any]]:
164
+ """
165
+ Returns the similar_rule hint stored in the target function's property field,
166
+ enriched with the referenced function's natspec if available.
167
+
168
+ Returns a dict with keys: contract_name, function_name, property_hint, natspec.
169
+ Returns None if no similar_rule is defined.
170
+ """
171
+ # Find target function
172
+ for contract in contracts:
173
+ if contract["contract_name"] == current_contract_name:
174
+ fn = get_function_by_name(contract, current_function_name)
175
+ if fn and fn.get("property") and fn["property"].get("similar_rule"):
176
+ sr = fn["property"]["similar_rule"]
177
+ # Look up the referenced function's natspec
178
+ for c2 in contracts:
179
+ if c2["contract_name"] == sr["contract_name"]:
180
+ ref_fn = get_function_by_name(c2, sr["function_name"])
181
+ if ref_fn:
182
+ return {
183
+ "contract_name": sr["contract_name"],
184
+ "function_name": sr["function_name"],
185
+ "property_hint": sr["property_hint"],
186
+ "natspec": ref_fn.get("natspec", ""),
187
+ }
188
+ # Referenced function not found β€” return hint only
189
+ return {
190
+ "contract_name": sr["contract_name"],
191
+ "function_name": sr["function_name"],
192
+ "property_hint": sr["property_hint"],
193
+ "natspec": "",
194
+ }
195
+ return None
demo.py CHANGED
@@ -237,12 +237,18 @@ def _print_episode_summary(obs):
237
  print(f" Steps taken : {obs.step_count}")
238
  print(f" Total reward : {colour}{reward:+.2f}{RESET}")
239
  last = obs.last_action_result or ""
240
- if "βœ…" in last:
241
  print(f" {GREEN}Perfect score β€” full marks!{RESET}")
242
- elif "⚠️" in last:
243
- print(f" {YELLOW}Partial credit β€” right function, imprecise vulnerability type.{RESET}")
244
- else:
 
 
 
 
245
  print(f" {RED}Incorrect β€” better luck next episode.{RESET}")
 
 
246
  print(f"{BOLD}{'═' * 64}{RESET}\n")
247
 
248
 
@@ -285,3 +291,67 @@ def main():
285
 
286
  if __name__ == "__main__":
287
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  print(f" Steps taken : {obs.step_count}")
238
  print(f" Total reward : {colour}{reward:+.2f}{RESET}")
239
  last = obs.last_action_result or ""
240
+ if "βœ… CORRECT" in last or "EXCELLENT" in last:
241
  print(f" {GREEN}Perfect score β€” full marks!{RESET}")
242
+ elif "⚠️" in last or "PARTIAL" in last:
243
+ print(f" {YELLOW}Partial credit.{RESET}")
244
+ elif "🟑 GOOD" in last:
245
+ print(f" {YELLOW}Good β€” most key concepts matched!{RESET}")
246
+ elif "🟠" in last:
247
+ print(f" {YELLOW}Partial β€” some key concepts matched.{RESET}")
248
+ elif "❌" in last:
249
  print(f" {RED}Incorrect β€” better luck next episode.{RESET}")
250
+ else:
251
+ print(f" {'Good effort!' if reward > 0 else 'Keep exploring next time.'}")
252
  print(f"{BOLD}{'═' * 64}{RESET}\n")
253
 
254
 
 
291
 
292
  if __name__ == "__main__":
293
  main()
294
+
295
+
296
+ # ─────────────────────────────────────────────────────────────────────────────
297
+ # Task 2 demo
298
+ # ─────────────────────────────────────────────────────────────────────────────
299
+
300
+ DEMO_SCRIPTS_T2 = {
301
+ 42: [
302
+ (ActionType.GET_FUNCTION_NATSPEC, {}, "First, read the NatSpec to understand intent and expected outputs."),
303
+ (ActionType.GET_IO, {}, "Check parameters, return type and expected behaviour."),
304
+ (ActionType.GET_FUNCTION_CODE, {}, "Read the actual Solidity code to confirm the behaviour."),
305
+ (ActionType.SUBMIT_PROPERTY,
306
+ {"property": "After a successful claimRewards call, all of the caller's accrued reward tokens are transferred to the caller and their rewards balance is set to zero. Reverts if the caller has no accrued rewards."},
307
+ "Confident about the property. Submitting!"),
308
+ ],
309
+ }
310
+
311
+
312
+ def run_auto_demo_t2(seed: int = 42, delay: float = 0.9):
313
+ """Run the scripted Task 2 demo."""
314
+ from tasks.task2.environment import Task2Environment
315
+
316
+ script = DEMO_SCRIPTS_T2.get(seed)
317
+ env = Task2Environment()
318
+ result = env.reset(seed=seed)
319
+ obs = result.observation
320
+
321
+ print()
322
+ print(f"{BOLD}{CYAN}╔══════════════════════════════════════════════════════════╗")
323
+ print(f"β•‘ Smart Contract Audit RL Env Β· Task 2 Demo β•‘")
324
+ print(f"β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•{RESET}")
325
+ print()
326
+ print(f"{BOLD}Mode:{RESET} Automated demo | {BOLD}Seed:{RESET} {seed}")
327
+ print(f"{BOLD}Task:{RESET} Property Discovery")
328
+ print()
329
+
330
+ fn_name = obs.extra.get("target_function", "?")
331
+ sig = obs.extra.get("target_signature", "")
332
+ print(f"{BOLD}Contract :{RESET} {obs.contract_name}")
333
+ print(f"{BOLD}Function :{RESET} {fn_name} ({sig})")
334
+ print(f"{BOLD}Goal :{RESET} Write the natural-language property for '{fn_name}'")
335
+ print(DIVIDER)
336
+
337
+ if not script:
338
+ print(f"{YELLOW}No pre-written script for seed {seed}.{RESET}")
339
+ return
340
+
341
+ for at, params, commentary in script:
342
+ time.sleep(delay)
343
+ print(f"\n{CYAN}β–Ά Agent thinking:{RESET} {commentary}")
344
+ time.sleep(delay * 0.5)
345
+ step_result = env.step(Action(action_type=at, params=params))
346
+ sobs = step_result.observation
347
+ print(DIVIDER)
348
+ print(f"{BOLD}Step {sobs.step_count:2d}{RESET} [{at.value}] r={step_result.reward.value:+.2f} cum={sobs.cumulative_reward:+.2f}")
349
+ result_text = sobs.last_action_result or ""
350
+ colour = GREEN if step_result.reward.value > 0 else (YELLOW if step_result.reward.value > -0.15 else YELLOW)
351
+ for line in result_text.split("\n")[:8]:
352
+ print(f" {colour}{line[:90]}{RESET}")
353
+ print(DIVIDER)
354
+
355
+ if step_result.done:
356
+ _print_episode_summary(sobs)
357
+ return
env/schemas.py CHANGED
@@ -3,12 +3,12 @@ schemas.py
3
  ----------
4
  Typed Pydantic models implementing the OpenEnv interface spec.
5
 
6
- Observation - what the agent sees at each step
7
- Action - what the agent can send
8
- StepResult - returned by step()
9
- ResetResult - returned by reset()
10
- StateResult - returned by state()
11
- Reward - structured reward info
12
  """
13
 
14
  from __future__ import annotations
@@ -24,27 +24,28 @@ from pydantic import BaseModel, Field
24
  # ---------------------------------------------------------------------------
25
 
26
  class ActionType(str, Enum):
27
- # Task 1 – Vulnerability Detection
28
- LIST_FUNCTIONS = "list_functions"
29
- GET_FUNCTION_CODE = "get_function_code"
30
  GET_FUNCTION_SUMMARY = "get_function_summary"
31
- GET_FILE_METADATA = "get_file_metadata"
32
- GET_STATE_VARIABLE = "get_state_variable"
33
- GET_CALL_GRAPH = "get_call_graph"
34
- SUBMIT = "submit"
35
-
36
- # TODO: Task 2 – Property Discovery
37
- # GET_SIMILAR_RULE = "get_similar_rule"
38
- # GET_FILE_NATSPEC = "get_file_natspec"
39
- # GET_FUNCTION_NATSPEC = "get_function_natspec"
40
- # GET_RELATED_FUNCTIONS = "get_related_functions"
41
- # GET_IO = "get_io"
42
- # SUBMIT_PROPERTY = "submit_property"
43
-
44
- # TODO: Task 3 – Rule Checker
 
45
  # GET_FORMALIZED_PROPERTY = "get_formalized_property"
46
- # GET_FUNCTION_METADATA = "get_function_metadata"
47
- # SUBMIT_FUNCTION = "submit_function"
48
 
49
 
50
  class Action(BaseModel):
@@ -54,7 +55,7 @@ class Action(BaseModel):
54
  action_type : one of ActionType enum values
55
  params : optional key/value arguments, e.g.
56
  {"function_name": "withdraw"} for GET_FUNCTION_CODE
57
- {"function_name": "withdraw", "vulnerability_type": "reentrancy"} for SUBMIT
58
  """
59
  action_type: ActionType
60
  params: Dict[str, Any] = Field(default_factory=dict)
@@ -71,16 +72,16 @@ class Observation(BaseModel):
71
  """
72
  What the agent receives from the environment.
73
 
74
- task_id : which task is active
75
- contract_name : name of the Solidity contract
76
  contract_description : high-level description of what the contract does
77
- available_actions : list of valid ActionType strings
78
- last_action : the action that produced this observation (None on reset)
79
- last_action_result: human-readable result of the last action
80
- step_count : number of steps taken so far
81
- cumulative_reward : running reward total
82
- done : whether the episode has ended
83
- extra : any additional task-specific context
84
  """
85
  task_id: str
86
  contract_name: str
@@ -147,4 +148,4 @@ class TaskInfo(BaseModel):
147
  name: str
148
  difficulty: str
149
  description: str
150
- status: str = "active" # or "placeholder"
 
3
  ----------
4
  Typed Pydantic models implementing the OpenEnv interface spec.
5
 
6
+ Observation – what the agent sees at each step
7
+ Action – what the agent can send
8
+ StepResult – returned by step()
9
+ ResetResult – returned by reset()
10
+ StateResult – returned by state()
11
+ Reward – structured reward info
12
  """
13
 
14
  from __future__ import annotations
 
24
  # ---------------------------------------------------------------------------
25
 
26
  class ActionType(str, Enum):
27
+ # ── Task 1 – Vulnerability Detection ───────────────────────────────────
28
+ LIST_FUNCTIONS = "list_functions"
29
+ GET_FUNCTION_CODE = "get_function_code"
30
  GET_FUNCTION_SUMMARY = "get_function_summary"
31
+ GET_FILE_METADATA = "get_file_metadata"
32
+ GET_STATE_VARIABLE = "get_state_variable"
33
+ GET_CALL_GRAPH = "get_call_graph"
34
+ SUBMIT = "submit"
35
+
36
+ # ── Task 2 – Property Discovery ─────────────────────────────────────────
37
+ GET_SIMILAR_RULE = "get_similar_rule" # -0.20
38
+ GET_FILE_NATSPEC = "get_file_natspec" # -0.03
39
+ GET_FUNCTION_NATSPEC = "get_function_natspec" # -0.08
40
+ GET_RELATED_FUNCTIONS = "get_related_functions" # -0.06
41
+ GET_IO = "get_io" # -0.04
42
+ SUBMIT_PROPERTY = "submit_property" # scored 0–5, one attempt
43
+
44
+ # ── Task 3 – Rule Checker ────────────────────────────────────────────────
45
+ # TODO: Task 3
46
  # GET_FORMALIZED_PROPERTY = "get_formalized_property"
47
+ # GET_FUNCTION_METADATA = "get_function_metadata"
48
+ # SUBMIT_FUNCTION = "submit_function"
49
 
50
 
51
  class Action(BaseModel):
 
55
  action_type : one of ActionType enum values
56
  params : optional key/value arguments, e.g.
57
  {"function_name": "withdraw"} for GET_FUNCTION_CODE
58
+ {"property": "..."} for SUBMIT_PROPERTY
59
  """
60
  action_type: ActionType
61
  params: Dict[str, Any] = Field(default_factory=dict)
 
72
  """
73
  What the agent receives from the environment.
74
 
75
+ task_id : which task is active
76
+ contract_name : name of the Solidity contract
77
  contract_description : high-level description of what the contract does
78
+ available_actions : list of valid ActionType strings
79
+ last_action : the action that produced this observation (None on reset)
80
+ last_action_result : human-readable result of the last action
81
+ step_count : number of steps taken so far
82
+ cumulative_reward : running reward total
83
+ done : whether the episode has ended
84
+ extra : any additional task-specific context
85
  """
86
  task_id: str
87
  contract_name: str
 
148
  name: str
149
  difficulty: str
150
  description: str
151
+ status: str = "active" # "active" | "placeholder"
eval.py CHANGED
@@ -3,264 +3,276 @@ eval.py
3
  -------
4
  Evaluation harness for the Smart Contract Audit RL Environment.
5
 
6
- Runs a configurable number of episodes per task, collecting grader scores
7
- and reward trajectories. Produces a detailed JSON report.
8
-
9
- Unlike inference.py (which uses an external LLM), this evaluates the
10
- *environment itself* using a built-in oracle agent β€” useful for:
11
- - Verifying grader correctness
12
- - Benchmarking reward shaping
13
- - Checking score distribution across vulnerability types
14
 
15
  Usage:
16
- python eval.py # all 8 vuln episodes
17
- python eval.py --episodes 16 # more episodes
18
- python eval.py --seed 0 --verbose # detailed per-step output
19
- python eval.py --out results.json # custom output file
 
 
20
  """
21
 
22
  import argparse
23
  import json
24
  import sys
25
- import time
26
  from typing import Any, Dict, List
27
 
28
  from tasks.task1.environment import Task1Environment
 
29
  from env.schemas import Action, ActionType
30
- from data.data_loader import load_contracts, get_all_vulnerable_entries
 
 
 
 
31
 
32
 
33
  # ─────────────────────────────────────────────────────────────────────────────
34
- # Oracle agent (always submits the ground-truth answer)
35
  # ─────────────────────────────────────────────────────────────────────────────
36
 
37
- def oracle_agent(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
38
- """
39
- Runs one episode using the oracle strategy:
40
- 1. list_functions
41
- 2. get_function_code (for the target function β€” peeked from state)
42
- 3. submit correct answer
43
-
44
- This gives an upper-bound score trajectory for the environment.
45
- Always ends with grader_score = 1.0.
46
- """
47
- reset_result = env.reset(seed=seed)
48
- obs = reset_result.observation
49
-
50
- steps_taken: List[Dict[str, Any]] = []
51
-
52
- def _step(at: ActionType, params: dict = None) -> Any:
53
- params = params or {}
54
- action = Action(action_type=at, params=params)
55
- result = env.step(action)
56
- entry = {
57
- "step": result.observation.step_count,
58
- "action": at.value,
59
- "params": params,
60
- "reward": result.reward.value,
61
- "reason": result.reward.reason,
62
- "cumulative": result.observation.cumulative_reward,
63
- "done": result.done,
64
- }
65
- steps_taken.append(entry)
66
- if verbose:
67
- done_flag = " [DONE]" if result.done else ""
68
- print(
69
- f" step {entry['step']:2d}: {at.value:25s} "
70
- f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}"
71
- f"{done_flag}"
72
- )
73
- return result
74
-
75
- # Peek at ground truth (oracle only)
76
- state = env.state()
77
- target_fn = state.target_function
78
-
79
- # Get ground-truth vulnerability from data
80
  contracts = load_contracts()
81
- vuln_issue = None
82
- for contract in contracts:
83
- for fn in contract.get("functions", []):
84
- if fn["name"].lower() == target_fn.lower() and fn.get("vulnerable"):
85
- # ! SINCE OUR MATCHER IS BASED ON FACT THAT EXPECTED STRING IS 2-3 WORDS, THIS DOESN'T MATCH WELL
86
- vuln_issue = fn["vulnerability_details"]["issue"]
87
- break
88
- if vuln_issue:
89
  break
90
 
91
  if verbose:
92
- print(f" Contract : {obs.contract_name}")
93
- print(f" Target : {target_fn} ({vuln_issue})")
94
-
95
- # Step 1: list functions (small cost, realistic)
96
- _step(ActionType.LIST_FUNCTIONS)
97
- # Step 2: read target function code (gets +0.05 shaping reward)
98
- _step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn})
99
- # Step 3: submit perfect answer
100
- result = _step(ActionType.SUBMIT, {
101
- "function_name": target_fn,
102
- "vulnerability_type": vuln_issue,
103
- })
104
-
105
- final_reward = result.reward.value
106
- if final_reward >= 4.9:
107
- grader_score = 1.0
108
- elif final_reward >= 0.9:
109
- grader_score = 0.5
110
- else:
111
- grader_score = 0.0
112
 
 
 
 
 
 
 
 
 
 
113
  return {
114
  "seed": seed,
115
  "contract": obs.contract_name,
116
- "target_function": target_fn,
117
  "vulnerability": vuln_issue,
118
- "grader_score": grader_score,
119
  "cumulative_reward": result.observation.cumulative_reward,
120
- "steps": steps_taken,
121
- "num_steps": len(steps_taken),
122
  }
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # ─────────────────────────────────────────────────────────────────────────────
126
- # Partial agent (submits correct function, wrong vuln type)
127
  # ─────────────────────────────────────────────────────────────────────────────
128
 
129
- def partial_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
130
- """Submits right function, always uses 'unknown' as vulnerability type β†’ score 0.5."""
131
- reset_result = env.reset(seed=seed)
132
- obs = reset_result.observation
133
- state = env.state()
134
- target_fn = state.target_function
135
-
136
- action = Action(action_type=ActionType.SUBMIT, params={
137
- "function_name": target_fn,
138
- "vulnerability_type": "unknown vulnerability",
139
- })
140
- result = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return {
142
  "seed": seed,
143
- "grader_score": 0.5,
 
 
144
  "cumulative_reward": result.observation.cumulative_reward,
145
  }
146
 
147
 
148
- # ─────────────────────────────────────────────────────────────────────────────
149
- # Random agent (submits a random wrong function)
150
- # ──────────────────────────────────────────────��──────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- def random_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
153
- """Always submits 'constructor' β€” always wrong β†’ score 0.0."""
154
  env.reset(seed=seed)
155
- action = Action(action_type=ActionType.SUBMIT, params={
156
- "function_name": "constructor",
157
- "vulnerability_type": "reentrancy",
158
- })
159
- result = env.step(action)
160
- return {
161
- "seed": seed,
162
- "grader_score": 0.0,
163
- "cumulative_reward": result.observation.cumulative_reward,
164
- }
165
 
166
 
167
  # ─────────────────────────────────────────────────────────────────────────────
168
- # Evaluation runner
169
  # ─────────────────────────────────────────────────────────────────────────────
170
 
171
- def run_evaluation(
172
- num_episodes: int = 8,
173
- seed_offset: int = 0,
174
- verbose: bool = False,
175
- output_file: str = "eval_results.json",
176
- ) -> None:
177
- env = Task1Environment()
178
  contracts = load_contracts()
179
- entries = get_all_vulnerable_entries(contracts)
180
- vuln_types = list({fn["vulnerability_details"]["issue"] for _, fn in entries})
181
 
182
- print("=" * 64)
183
- print("Smart Contract Audit RL Environment β€” Evaluation")
184
- print("=" * 64)
185
- print(f" Episodes : {num_episodes}")
186
- print(f" Seed range: {seed_offset} – {seed_offset + num_episodes - 1}")
187
- print(f" Vulns in dataset: {len(entries)}")
188
- print()
189
-
190
- # ── Oracle agent ─────────────────────────────────────────────────────────
191
- print("β–Ά Oracle agent (upper bound β€” always submits correct answer):")
192
- oracle_episodes = []
193
- for i in range(num_episodes):
194
- seed = seed_offset + i
195
- ep = oracle_agent(env, seed=seed, verbose=verbose)
196
- oracle_episodes.append(ep)
197
- icon = "βœ…" if ep["grader_score"] == 1.0 else "⚠️ "
198
- print(
199
- f" {icon} seed={seed:3d} {ep['contract']:12s} "
200
- f"{ep['target_function']:15s} score={ep['grader_score']:.1f} "
201
- f"reward={ep['cumulative_reward']:+.2f}"
202
- )
203
-
204
- oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes
205
- oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes
206
- print(f"\n Oracle avg grader score : {oracle_avg:.3f}")
207
- print(f" Oracle avg reward : {oracle_avg_r:+.2f}")
208
-
209
- # ── Partial agent ─────────────────────────────────────────────────────────
210
- print("\nβ–Ά Partial agent (right function, wrong vuln type β†’ 0.5 each):")
211
- partial_episodes = []
212
- for i in range(num_episodes):
213
- ep = partial_agent(env, seed=seed_offset + i)
214
- partial_episodes.append(ep)
215
- partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes
216
- print(f" Partial avg grader score: {partial_avg:.3f}")
217
-
218
- # ── Random agent ────────────────────────────────────────────────��─────────
219
- print("\nβ–Ά Random agent (always wrong β†’ 0.0 each):")
220
- random_episodes = []
221
  for i in range(num_episodes):
222
- ep = random_agent(env, seed=seed_offset + i)
223
- random_episodes.append(ep)
224
- random_avg = sum(e["grader_score"] for e in random_episodes) / num_episodes
225
- print(f" Random avg grader score : {random_avg:.3f}")
226
-
227
- # ── Score distribution ────────────────────────────────────────────────────
228
- print("\nβ–Ά Coverage across vulnerability types:")
229
- seen = {}
230
- for ep in oracle_episodes:
 
 
 
 
 
 
 
 
 
 
 
231
  v = ep.get("vulnerability", "unknown")
232
- seen[v] = seen.get(v, 0) + 1
233
- for v in sorted(seen):
234
- print(f" {seen[v]:2d}x {v}")
 
 
 
 
 
 
235
 
236
- # ── Summary ───────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
237
  print("\n" + "=" * 64)
238
- print("SUMMARY")
239
  print("=" * 64)
240
- print(f" Oracle (ceiling): {oracle_avg:.3f} {'βœ…' if oracle_avg == 1.0 else '⚠️ '}")
241
- print(f" Partial (partial): {partial_avg:.3f} βœ…")
242
- print(f" Random (floor) : {random_avg:.3f} βœ…")
243
-
244
- assert oracle_avg == 1.0, "Oracle should always score 1.0"
245
- assert partial_avg == 0.5, "Partial should always score 0.5"
246
- assert random_avg == 0.0, "Random should always score 0.0"
247
-
248
- print("\n βœ… All score sanity checks passed.")
249
-
250
- # ── Write results ─────────────────────────────────────────────────────────
251
- report = {
252
- "num_episodes": num_episodes,
253
- "seed_offset": seed_offset,
254
- "agents": {
255
- "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_episodes},
256
- "partial": {"avg_score": partial_avg, "episodes": partial_episodes},
257
- "random": {"avg_score": random_avg, "episodes": random_episodes},
258
- },
259
- "vulnerability_coverage": seen,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  }
261
- with open(output_file, "w") as f:
262
- json.dump(report, f, indent=2)
263
- print(f"\n Results written to {output_file}")
264
 
265
 
266
  # ─────────────────────────────────────────────────────────────────────────────
@@ -268,23 +280,50 @@ def run_evaluation(
268
  # ─────────────────────────────────────────────────────────────────────────────
269
 
270
  def main():
271
- parser = argparse.ArgumentParser(description="Evaluate the SC Audit RL Environment")
272
- parser.add_argument("--episodes", type=int, default=8,
273
- help="Number of episodes per agent (default: 8)")
274
- parser.add_argument("--seed", type=int, default=42,
275
- help="Starting seed (default: 42)")
276
- parser.add_argument("--verbose", action="store_true",
277
- help="Print per-step details for oracle agent")
278
- parser.add_argument("--out", default="eval_results.json",
279
- help="Output JSON file (default: eval_results.json)")
 
 
 
 
280
  args = parser.parse_args()
281
 
282
- run_evaluation(
283
- num_episodes=args.episodes,
284
- seed_offset=args.seed,
285
- verbose=args.verbose,
286
- output_file=args.out,
287
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
 
290
  if __name__ == "__main__":
 
3
  -------
4
  Evaluation harness for the Smart Contract Audit RL Environment.
5
 
6
+ Runs oracle / partial / baseline agents against Task 1 and Task 2,
7
+ verifying that grader scores form a clear ordering and that reward
8
+ shaping is meaningful.
 
 
 
 
 
9
 
10
  Usage:
11
+ python eval.py # Task 1 + Task 2, 8 episodes each
12
+ python eval.py --task 1 # Task 1 only
13
+ python eval.py --task 2 # Task 2 only
14
+ python eval.py --episodes 16 # more episodes
15
+ python eval.py --seed 0 --verbose # detailed per-step trace
16
+ python eval.py --out results.json # custom output file
17
  """
18
 
19
  import argparse
20
  import json
21
  import sys
 
22
  from typing import Any, Dict, List
23
 
24
  from tasks.task1.environment import Task1Environment
25
+ from tasks.task2.environment import Task2Environment
26
  from env.schemas import Action, ActionType
27
+ from data.data_loader import (
28
+ load_contracts,
29
+ get_function_by_name,
30
+ get_all_vulnerable_entries,
31
+ )
32
 
33
 
34
  # ─────────────────────────────────────────────────────────────────────────────
35
+ # Task 1 agents
36
  # ─────────────────────────────────────────────────────────────────────────────
37
 
38
+ def oracle_t1(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
39
+ """Always submits the exact ground-truth answer β†’ score = 1.0."""
40
+ r = env.reset(seed=seed)
41
+ obs = r.observation
42
+ st = env.state()
43
+ fn_name = st.target_function
44
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  contracts = load_contracts()
46
+ vuln_issue = ""
47
+ for c in contracts:
48
+ fn = get_function_by_name(c, fn_name)
49
+ if fn and fn.get("vulnerable"):
50
+ vuln_issue = fn["vulnerability_details"]["issue"]
 
 
 
51
  break
52
 
53
  if verbose:
54
+ print(f" {obs.contract_name}.{fn_name}() [{vuln_issue}]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
57
+ env.step(Action(action_type=ActionType.GET_FUNCTION_CODE,
58
+ params={"function_name": fn_name}))
59
+ result = env.step(Action(action_type=ActionType.SUBMIT,
60
+ params={"function_name": fn_name,
61
+ "vulnerability_type": vuln_issue}))
62
+
63
+ v = result.reward.value
64
+ score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
65
  return {
66
  "seed": seed,
67
  "contract": obs.contract_name,
68
+ "target_function": fn_name,
69
  "vulnerability": vuln_issue,
70
+ "grader_score": score,
71
  "cumulative_reward": result.observation.cumulative_reward,
 
 
72
  }
73
 
74
 
75
+ def partial_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
76
+ """Right function, wrong vuln type β†’ score = 0.5."""
77
+ env.reset(seed=seed)
78
+ fn_name = env.state().target_function
79
+ result = env.step(Action(action_type=ActionType.SUBMIT,
80
+ params={"function_name": fn_name,
81
+ "vulnerability_type": "unknown"}))
82
+ v = result.reward.value
83
+ return {"seed": seed, "grader_score": 0.5 if v >= 0.9 else 0.0,
84
+ "cumulative_reward": result.observation.cumulative_reward}
85
+
86
+
87
+ def random_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
88
+ """Always submits 'constructor' β†’ score = 0.0."""
89
+ env.reset(seed=seed)
90
+ result = env.step(Action(action_type=ActionType.SUBMIT,
91
+ params={"function_name": "constructor",
92
+ "vulnerability_type": "reentrancy"}))
93
+ return {"seed": seed, "grader_score": 0.0,
94
+ "cumulative_reward": result.observation.cumulative_reward}
95
+
96
+
97
  # ─────────────────────────────────────────────────────────────────────────────
98
+ # Task 2 agents
99
  # ─────────────────────────────────────────────────────────────────────────────
100
 
101
+ def oracle_t2(env: Task2Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
102
+ """Submits the exact ground-truth natural_language β†’ score β‰₯ 0.70."""
103
+ r = env.reset(seed=seed)
104
+ obs = r.observation
105
+ fn_name = obs.extra["target_function"]
106
+ contract = obs.contract_name
107
+
108
+ contracts = load_contracts()
109
+ gt_text = ""
110
+ for c in contracts:
111
+ if c["contract_name"] == contract:
112
+ fn = get_function_by_name(c, fn_name)
113
+ if fn and fn.get("property"):
114
+ gt_text = fn["property"]["natural_language"]
115
+ break
116
+
117
+ if verbose:
118
+ print(f" {contract}.{fn_name}()")
119
+
120
+ # read code first (realistic browsing step)
121
+ env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
122
+ result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
123
+ params={"property": gt_text}))
124
+
125
+ r_val = result.reward.value
126
+ score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
127
  return {
128
  "seed": seed,
129
+ "contract": contract,
130
+ "function": fn_name,
131
+ "grader_score": score,
132
  "cumulative_reward": result.observation.cumulative_reward,
133
  }
134
 
135
 
136
+ def partial_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
137
+ """Submits the function's NatSpec comment β€” partial credit."""
138
+ r = env.reset(seed=seed)
139
+ obs = r.observation
140
+ contracts = load_contracts()
141
+ comment = ""
142
+ for c in contracts:
143
+ if c["contract_name"] == obs.contract_name:
144
+ fn = get_function_by_name(c, obs.extra["target_function"])
145
+ if fn:
146
+ comment = fn.get("comment", "")
147
+ break
148
+ result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
149
+ params={"property": comment}))
150
+ r_val = result.reward.value
151
+ score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
152
+ return {"seed": seed, "grader_score": score,
153
+ "cumulative_reward": result.observation.cumulative_reward}
154
+
155
 
156
+ def empty_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
157
+ """Submits empty string β†’ score = 0.0."""
158
  env.reset(seed=seed)
159
+ result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
160
+ params={"property": ""}))
161
+ return {"seed": seed, "grader_score": 0.0,
162
+ "cumulative_reward": result.observation.cumulative_reward}
 
 
 
 
 
 
163
 
164
 
165
  # ─────────────────────────────────────────────────────────────────────────────
166
+ # Evaluation runners
167
  # ─────────────────────────────────────────────────────────────────────────────
168
 
169
+ def run_task1_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
170
+ print("\n" + "=" * 64)
171
+ print("TASK 1 β€” Targeted Vulnerability Detection")
172
+ print("=" * 64)
 
 
 
173
  contracts = load_contracts()
174
+ entries = get_all_vulnerable_entries(contracts)
175
+ print(f" Dataset: {len(contracts)} contracts, {len(entries)} vulnerable functions\n")
176
 
177
+ env = Task1Environment()
178
+
179
+ print("β–Ά Oracle agent (always submits correct answer):")
180
+ oracle_eps = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  for i in range(num_episodes):
182
+ ep = oracle_t1(env, seed_offset + i, verbose=verbose)
183
+ oracle_eps.append(ep)
184
+ print(f" seed={ep['seed']:3d} {ep['contract']:12s}.{ep['target_function']:18s}"
185
+ f" score={ep['grader_score']:.1f} reward={ep['cumulative_reward']:+.2f}")
186
+ oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
187
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
188
+ print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
189
+
190
+ print("\nβ–Ά Partial agent (right function, wrong vuln type β†’ 0.5):")
191
+ partial_eps = [partial_t1(env, seed_offset + i) for i in range(num_episodes)]
192
+ partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
193
+ print(f" Partial avg score: {partial_avg:.3f}")
194
+
195
+ print("\nβ–Ά Random agent (always wrong β†’ 0.0):")
196
+ random_eps = [random_t1(env, seed_offset + i) for i in range(num_episodes)]
197
+ random_avg = sum(e["grader_score"] for e in random_eps) / num_episodes
198
+ print(f" Random avg score : {random_avg:.3f}")
199
+
200
+ vuln_seen: Dict[str, int] = {}
201
+ for ep in oracle_eps:
202
  v = ep.get("vulnerability", "unknown")
203
+ vuln_seen[v] = vuln_seen.get(v, 0) + 1
204
+ print("\nβ–Ά Vulnerability type coverage:")
205
+ for v in sorted(vuln_seen):
206
+ print(f" {vuln_seen[v]:2d}Γ— {v}")
207
+
208
+ assert oracle_avg == 1.0, f"Oracle should be 1.0, got {oracle_avg}"
209
+ assert partial_avg == 0.5, f"Partial should be 0.5, got {partial_avg}"
210
+ assert random_avg == 0.0, f"Random should be 0.0, got {random_avg}"
211
+ print("\n βœ… Task 1 score ordering: oracle(1.0) > partial(0.5) > random(0.0)")
212
 
213
+ return {
214
+ "task_id": "task1_vuln_detection",
215
+ "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
216
+ "partial": {"avg_score": partial_avg, "episodes": partial_eps},
217
+ "random": {"avg_score": random_avg, "episodes": random_eps},
218
+ "vuln_coverage": vuln_seen,
219
+ }
220
+
221
+
222
+ def run_task2_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
223
  print("\n" + "=" * 64)
224
+ print("TASK 2 β€” Property Discovery")
225
  print("=" * 64)
226
+ from data.data_loader import get_all_property_entries
227
+ contracts = load_contracts()
228
+ entries = get_all_property_entries(contracts)
229
+ print(f" Dataset: {len(entries)} functions with properties\n")
230
+
231
+ env = Task2Environment()
232
+
233
+ print("β–Ά Oracle agent (submits ground-truth natural language):")
234
+ oracle_eps = []
235
+ for i in range(num_episodes):
236
+ ep = oracle_t2(env, seed_offset + i, verbose=verbose)
237
+ oracle_eps.append(ep)
238
+ icon = "βœ…" if ep["grader_score"] >= 0.65 else "⚠️ "
239
+ print(f" {icon} seed={ep['seed']:3d} {ep['contract']:12s}.{ep['function']:18s}"
240
+ f" score={ep['grader_score']:.3f} reward={ep['cumulative_reward']:+.2f}")
241
+ oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
242
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
243
+ print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
244
+
245
+ print("\nβ–Ά Partial agent (submits NatSpec comment β€” partial signal):")
246
+ partial_eps = [partial_t2(env, seed_offset + i) for i in range(num_episodes)]
247
+ partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
248
+ partial_avg_r = sum(e["cumulative_reward"] for e in partial_eps) / num_episodes
249
+ print(f" Partial avg score: {partial_avg:.3f} avg reward: {partial_avg_r:+.2f}")
250
+
251
+ print("\nβ–Ά Empty agent (submits nothing β†’ 0.0):")
252
+ empty_eps = [empty_t2(env, seed_offset + i) for i in range(num_episodes)]
253
+ empty_avg = sum(e["grader_score"] for e in empty_eps) / num_episodes
254
+ print(f" Empty avg score : {empty_avg:.3f}")
255
+
256
+ fn_seen: Dict[str, int] = {}
257
+ for ep in oracle_eps:
258
+ fn_seen[ep["function"]] = fn_seen.get(ep["function"], 0) + 1
259
+ print("\nβ–Ά Function coverage:")
260
+ for fn in sorted(fn_seen):
261
+ print(f" {fn_seen[fn]:2d}Γ— {fn}")
262
+
263
+ assert oracle_avg > 0.60, f"Oracle avg {oracle_avg:.3f} should be > 0.60"
264
+ assert oracle_avg > partial_avg, "Oracle should beat partial"
265
+ assert partial_avg >= empty_avg, "Partial should be >= empty"
266
+ assert empty_avg == 0.0, f"Empty should be 0.0, got {empty_avg}"
267
+ print(f"\n βœ… Task 2 score ordering: oracle({oracle_avg:.3f}) > partial({partial_avg:.3f}) > empty(0.0)")
268
+
269
+ return {
270
+ "task_id": "task2_property_discovery",
271
+ "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
272
+ "partial": {"avg_score": partial_avg, "avg_reward": partial_avg_r, "episodes": partial_eps},
273
+ "empty": {"avg_score": empty_avg, "episodes": empty_eps},
274
+ "fn_coverage": fn_seen,
275
  }
 
 
 
276
 
277
 
278
  # ─────────────────────────────────────────────────────────────────────────────
 
280
  # ─────────────────────────────────────────────────────────────────────────────
281
 
282
  def main():
283
+ parser = argparse.ArgumentParser(
284
+ description="Evaluate Task 1 and/or Task 2 of the SC Audit RL Environment"
285
+ )
286
+ parser.add_argument("--episodes", type=int, default=8,
287
+ help="Episodes per agent tier (default: 8)")
288
+ parser.add_argument("--seed", type=int, default=42,
289
+ help="Starting RNG seed (default: 42)")
290
+ parser.add_argument("--task", choices=["1", "2", "all"], default="all",
291
+ help="Which task(s) to evaluate (default: all)")
292
+ parser.add_argument("--verbose", action="store_true",
293
+ help="Print per-episode target details")
294
+ parser.add_argument("--out", default="eval_results.json",
295
+ help="Output file (default: eval_results.json)")
296
  args = parser.parse_args()
297
 
298
+ report: Dict[str, Any] = {
299
+ "num_episodes": args.episodes,
300
+ "seed_offset": args.seed,
301
+ }
302
+
303
+ if args.task in ("1", "all"):
304
+ report["task1"] = run_task1_eval(args.episodes, args.seed, args.verbose)
305
+
306
+ if args.task in ("2", "all"):
307
+ report["task2"] = run_task2_eval(args.episodes, args.seed, args.verbose)
308
+
309
+ # ── Summary ──────────────────────────────────────────────────────────────
310
+ print("\n" + "=" * 64)
311
+ print("EVALUATION COMPLETE")
312
+ print("=" * 64)
313
+ if "task1" in report:
314
+ t1 = report["task1"]
315
+ print(f" Task 1 oracle={t1['oracle']['avg_score']:.3f} "
316
+ f"partial={t1['partial']['avg_score']:.3f} "
317
+ f"random={t1['random']['avg_score']:.3f}")
318
+ if "task2" in report:
319
+ t2 = report["task2"]
320
+ print(f" Task 2 oracle={t2['oracle']['avg_score']:.3f} "
321
+ f"partial={t2['partial']['avg_score']:.3f} "
322
+ f"empty={t2['empty']['avg_score']:.3f}")
323
+
324
+ with open(args.out, "w") as f:
325
+ json.dump(report, f, indent=2)
326
+ print(f"\n Results written to {args.out}")
327
 
328
 
329
  if __name__ == "__main__":
inference.py CHANGED
@@ -2,14 +2,13 @@
2
  inference.py
3
  ------------
4
  Baseline inference script for the Smart Contract Audit RL Environment.
 
 
5
 
6
- Uses the OpenAI-compatible API client to run an LLM agent against Task 1.
7
- Tasks 2 and 3 are placeholders β€” they reset and immediately record 0.0.
8
-
9
- Environment variables required:
10
- API_BASE_URL – LLM endpoint (e.g. https://api.openai.com/v1)
11
- MODEL_NAME – model identifier (e.g. gpt-4o-mini)
12
- HF_TOKEN – API key (passed as Authorization: Bearer <HF_TOKEN>)
13
 
14
  Usage:
15
  python inference.py
@@ -17,306 +16,290 @@ Usage:
17
  Output:
18
  Per-task scores printed to stdout.
19
  Final baseline scores written to baseline_scores.json.
 
 
20
  """
21
 
22
  import json
23
  import os
24
  import sys
25
  import time
26
- from typing import Any, Dict, List, Optional
27
 
28
  from openai import OpenAI
29
 
30
- # ---------------------------------------------------------------------------
31
- # Import the env directly (no HTTP overhead for baseline)
32
- # ---------------------------------------------------------------------------
33
  from tasks.task1.environment import Task1Environment
 
34
  from env.schemas import Action, ActionType
35
 
36
- # ---------------------------------------------------------------------------
37
- # Config
38
- # ---------------------------------------------------------------------------
39
 
40
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
41
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
42
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
43
 
44
  if not HF_TOKEN:
45
- print("WARNING: HF_TOKEN is not set. API calls may fail.", file=sys.stderr)
46
-
47
- MAX_STEPS = 15 # Safety limit per episode
48
- NUM_EPISODES = 3 # Episodes per task
49
- TASK1_SEED_BASE = 42 # Reproducible seeds
50
-
51
-
52
- # ---------------------------------------------------------------------------
53
- # OpenAI client
54
- # ---------------------------------------------------------------------------
55
-
56
- client = OpenAI(
57
- api_key=HF_TOKEN,
58
- base_url=API_BASE_URL,
59
- )
60
-
61
-
62
- # ---------------------------------------------------------------------------
63
- # System prompt
64
- # ---------------------------------------------------------------------------
65
-
66
- SYSTEM_PROMPT = """You are an expert smart contract security auditor.
67
-
68
- You are given a Solidity contract and must identify the SINGLE most critical vulnerable function and name its vulnerability type.
69
-
70
- ## Available Actions
71
- You interact by choosing ONE action per turn from:
72
-
73
- 1. list_functions
74
- β†’ {"action": "list_functions", "params": {}}
75
-
76
- 2. get_function_code
77
- β†’ {"action": "get_function_code", "params": {"function_name": "<name>"}}
78
-
79
- 3. get_function_summary
80
- β†’ {"action": "get_function_summary", "params": {"function_name": "<name>"}}
81
-
82
- 4. get_file_metadata
83
- β†’ {"action": "get_file_metadata", "params": {}}
84
-
85
- 5. get_state_variable
86
- β†’ {"action": "get_state_variable", "params": {"variable_name": "<name>"}}
87
- (omit variable_name to list all variables)
88
-
89
- 6. get_call_graph
90
- β†’ {"action": "get_call_graph", "params": {}}
91
-
92
- 7. submit (ENDS THE EPISODE)
93
- β†’ {"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 word description>"}}
94
-
95
- ## Strategy
96
- - Start with list_functions and get_file_metadata to understand the contract
97
- - Inspect suspicious functions (withdraw, transfer, emergency*, stake, etc.)
98
- - Submit when you are confident about the vulnerable function
99
-
100
- ## Output Format
101
- Always respond with a single JSON object:
102
- {"action": "<action_type>", "params": {...}}
103
- Do NOT include any other text β€” only valid JSON.
104
- """
105
 
106
 
107
- def build_user_message(obs: Dict[str, Any]) -> str:
108
- """Format the observation as a user message."""
109
- lines = [
110
- f"=== CONTRACT: {obs['contract_name']} ===",
111
- f"Description: {obs['contract_description']}",
112
- f"Step: {obs['step_count']} | Cumulative reward: {obs['cumulative_reward']:.2f}",
113
- "",
114
- f"Last action: {obs['last_action'] or 'None'}",
115
- f"Result: {obs['last_action_result'] or 'Episode just started'}",
116
- "",
117
- f"Available actions: {', '.join(obs['available_actions'])}",
118
- ]
119
- if obs.get("extra", {}).get("hint"):
120
- lines.append(f"Hint: {obs['extra']['hint']}")
121
- return "\n".join(lines)
122
 
 
 
 
123
 
124
- # ---------------------------------------------------------------------------
125
- # Agent loop
126
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
127
 
128
- def run_episode(env: Task1Environment, seed: int, episode_num: int) -> Dict[str, Any]:
129
- """Run one episode and return result info."""
130
- print(f"\n Episode {episode_num} (seed={seed})")
 
 
 
131
 
132
- reset_result = env.reset(seed=seed)
133
- obs = reset_result.observation.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- print(f" Contract: {obs['contract_name']}")
136
 
137
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
138
- final_score = 0.0
139
- final_reward = 0.0
140
- steps = 0
141
- done = False
142
 
143
- for step_num in range(MAX_STEPS):
144
- user_msg = build_user_message(obs)
145
- messages.append({"role": "user", "content": user_msg})
146
 
147
- # LLM call
 
148
  try:
149
- response = client.chat.completions.create(
150
- model=MODEL_NAME,
151
- messages=messages,
152
- max_tokens=256,
153
- temperature=0.0,
154
  )
155
- raw = response.choices[0].message.content.strip()
156
  except Exception as e:
157
- print(f" LLM error at step {step_num}: {e}", file=sys.stderr)
158
  break
159
 
160
- # Parse action
161
  try:
162
  parsed = json.loads(raw)
163
- action_type = ActionType(parsed["action"])
164
  params = parsed.get("params", {})
165
- except Exception as e:
166
- print(f" Parse error: {e} | Raw: {raw[:100]}", file=sys.stderr)
167
- # Default safe action
168
- action_type = ActionType.LIST_FUNCTIONS
169
- params = {}
170
 
171
- action = Action(action_type=action_type, params=params)
172
  messages.append({"role": "assistant", "content": raw})
173
-
174
- # Step
175
- step_result = env.step(action)
176
- obs = step_result.observation.model_dump()
177
- done = step_result.done
178
- steps += 1
179
- final_reward = obs["cumulative_reward"]
180
-
181
- print(
182
- f" Step {step_num+1}: {action_type.value} | "
183
- f"reward={step_result.reward.value:+.2f} | "
184
- f"cumulative={final_reward:.2f}"
185
- )
186
-
187
- if done:
188
- # Determine grader score from reward
189
- last_reward = step_result.reward.value
190
- if last_reward >= 4.9:
191
- final_score = 1.0
192
- elif last_reward >= 0.9:
193
- final_score = 0.5
194
- else:
195
- final_score = 0.0
196
- print(f" β†’ DONE | grader_score={final_score:.1f}")
197
  break
 
 
 
 
 
 
198
 
199
- if not done:
200
- print(f" β†’ MAX STEPS reached without submission. Score=0.0")
201
-
202
- return {
203
- "episode": episode_num,
204
- "seed": seed,
205
- "contract": obs["contract_name"],
206
- "steps": steps,
207
- "cumulative_reward": final_reward,
208
- "grader_score": final_score,
209
- "done": done,
210
- }
211
 
 
 
 
212
 
213
- def run_task1(num_episodes: int = NUM_EPISODES) -> Dict[str, Any]:
214
- """Run Task 1 and return aggregate scores."""
215
  print("\n" + "="*60)
216
  print("TASK 1: Targeted Vulnerability Detection")
217
  print("="*60)
218
-
219
  env = Task1Environment()
220
- episodes = []
221
-
222
- for i in range(num_episodes):
223
- seed = TASK1_SEED_BASE + i
224
- result = run_episode(env, seed=seed, episode_num=i + 1)
225
- episodes.append(result)
226
- time.sleep(0.5) # Rate limit courtesy
227
-
228
- scores = [e["grader_score"] for e in episodes]
229
- avg = sum(scores) / len(scores) if scores else 0.0
230
- avg_reward = sum(e["cumulative_reward"] for e in episodes) / len(episodes)
231
-
232
- print(f"\n Task 1 Results:")
233
- print(f" Episodes: {num_episodes}")
234
- print(f" Grader scores: {scores}")
235
- print(f" Average grader score: {avg:.3f}")
236
- print(f" Average cumulative reward: {avg_reward:.2f}")
237
-
238
- return {
239
- "task_id": "task1_vuln_detection",
240
- "name": "Targeted Vulnerability Detection",
241
- "status": "active",
242
- "num_episodes": num_episodes,
243
- "episodes": episodes,
244
- "avg_grader_score": avg,
245
- "avg_cumulative_reward": avg_reward,
246
- }
247
 
248
 
249
- def run_task2_placeholder() -> Dict[str, Any]:
250
- """Task 2 placeholder β€” returns 0.0 score."""
251
  print("\n" + "="*60)
252
- print("TASK 2: Property Discovery [PLACEHOLDER β€” not implemented]")
253
  print("="*60)
254
- print(" Skipping. Score: 0.0")
255
- return {
256
- "task_id": "task2_property_discovery",
257
- "name": "Property Discovery",
258
- "status": "placeholder",
259
- "num_episodes": 0,
260
- "episodes": [],
261
- "avg_grader_score": 0.0,
262
- "avg_cumulative_reward": 0.0,
263
- }
264
 
265
 
266
  def run_task3_placeholder() -> Dict[str, Any]:
267
- """Task 3 placeholder β€” returns 0.0 score."""
268
  print("\n" + "="*60)
269
  print("TASK 3: Rule Checker [PLACEHOLDER β€” not implemented]")
270
  print("="*60)
271
  print(" Skipping. Score: 0.0")
272
- return {
273
- "task_id": "task3_rule_checker",
274
- "name": "Rule Checker",
275
- "status": "placeholder",
276
- "num_episodes": 0,
277
- "episodes": [],
278
- "avg_grader_score": 0.0,
279
- "avg_cumulative_reward": 0.0,
280
- }
281
 
282
 
283
- # ---------------------------------------------------------------------------
284
  # Main
285
- # ---------------------------------------------------------------------------
286
 
287
  def main():
288
  print("Smart Contract Audit RL Environment β€” Baseline Inference")
289
  print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
290
 
291
- results = {
292
- "model": MODEL_NAME,
293
- "base_url": API_BASE_URL,
294
- "tasks": [],
295
- }
296
-
297
- t1 = run_task1(num_episodes=NUM_EPISODES)
298
- t2 = run_task2_placeholder()
299
  t3 = run_task3_placeholder()
300
 
301
- results["tasks"] = [t1, t2, t3]
 
 
 
302
 
303
- # Summary
304
- active_tasks = [t for t in results["tasks"] if t["status"] == "active"]
305
- overall = (
306
- sum(t["avg_grader_score"] for t in active_tasks) / len(active_tasks)
307
- if active_tasks else 0.0
308
- )
309
  results["overall_avg_score"] = overall
310
 
311
  print("\n" + "="*60)
312
  print("BASELINE SUMMARY")
313
  print("="*60)
314
  for t in results["tasks"]:
315
- status = "βœ…" if t["status"] == "active" else "⏳"
316
- print(f" {status} {t['name']}: {t['avg_grader_score']:.3f}")
317
- print(f" Overall (active tasks): {overall:.3f}")
318
 
319
- # Write scores file
320
  with open("baseline_scores.json", "w") as f:
321
  json.dump(results, f, indent=2)
322
  print("\n Scores written to baseline_scores.json")
 
2
  inference.py
3
  ------------
4
  Baseline inference script for the Smart Contract Audit RL Environment.
5
+ Implements Task 1 (Vulnerability Detection) and Task 2 (Property Discovery).
6
+ Task 3 is a placeholder that returns 0.0.
7
 
8
+ Environment variables:
9
+ API_BASE_URL – LLM API endpoint (e.g. https://api.openai.com/v1)
10
+ MODEL_NAME – model identifier (e.g. gpt-4o-mini)
11
+ HF_TOKEN – API key
 
 
 
12
 
13
  Usage:
14
  python inference.py
 
16
  Output:
17
  Per-task scores printed to stdout.
18
  Final baseline scores written to baseline_scores.json.
19
+
20
+ Runtime: < 5 minutes on 3 episodes per task with gpt-4o-mini.
21
  """
22
 
23
  import json
24
  import os
25
  import sys
26
  import time
27
+ from typing import Any, Dict, List
28
 
29
  from openai import OpenAI
30
 
 
 
 
31
  from tasks.task1.environment import Task1Environment
32
+ from tasks.task2.environment import Task2Environment
33
  from env.schemas import Action, ActionType
34
 
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+ # Configuration
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
 
39
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
40
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
41
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
42
 
43
  if not HF_TOKEN:
44
+ print("WARNING: HF_TOKEN not set. API calls may fail.", file=sys.stderr)
45
+
46
+ MAX_STEPS_T1 = 15
47
+ MAX_STEPS_T2 = 10
48
+ NUM_EPISODES = 3
49
+ SEED_BASE_T1 = 42
50
+ SEED_BASE_T2 = 10
51
+
52
+ client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
53
+
54
+ # ─────────────────────────────────────────────────────────────────────────────
55
+ # Task 1 agent
56
+ # ─────────────────────────────────────────────────────────────────────────────
57
+
58
+ T1_SYSTEM = """You are an expert Solidity smart contract security auditor.
59
+
60
+ Given a contract, identify the ONE vulnerable function and its vulnerability type.
61
+
62
+ ## Actions (choose ONE per turn, respond with JSON only):
63
+ {"action": "list_functions", "params": {}}
64
+ {"action": "get_function_code", "params": {"function_name": "<name>"}}
65
+ {"action": "get_function_summary", "params": {"function_name": "<name>"}}
66
+ {"action": "get_file_metadata", "params": {}}
67
+ {"action": "get_state_variable", "params": {"variable_name": "<name>"}}
68
+ {"action": "get_call_graph", "params": {}}
69
+ {"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 words>"}}
70
+
71
+ ## Strategy:
72
+ 1. list_functions first to see the attack surface
73
+ 2. Inspect suspicious functions (withdraw, drain, buy, stake, claim, setPrice, bid, finalize)
74
+ 3. Look for: reentrancy, missing access control, integer overflow, tx.origin, front-running,
75
+ timestamp dependence, denial of service, unchecked return value
76
+ 4. Submit when confident
77
+
78
+ Respond ONLY with valid JSON. No explanation, no markdown."""
79
+
80
+
81
+ def _t1_user_msg(obs: Dict[str, Any]) -> str:
82
+ return (
83
+ f"Contract: {obs['contract_name']}\n"
84
+ f"Description: {obs['contract_description']}\n"
85
+ f"Step: {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n\n"
86
+ f"Last action: {obs['last_action'] or 'None'}\n"
87
+ f"Result: {obs['last_action_result'] or 'Episode started.'}"
88
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
+ def run_t1_episode(env: Task1Environment, seed: int, ep: int) -> Dict[str, Any]:
92
+ r = env.reset(seed=seed)
93
+ obs = r.observation.model_dump()
94
+ print(f" ep={ep} seed={seed} contract={obs['contract_name']}")
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ messages = [{"role": "system", "content": T1_SYSTEM}]
97
+ grader_score = 0.0
98
+ cum_reward = 0.0
99
 
100
+ for step in range(MAX_STEPS_T1):
101
+ messages.append({"role": "user", "content": _t1_user_msg(obs)})
102
+ try:
103
+ resp = client.chat.completions.create(
104
+ model=MODEL_NAME, messages=messages,
105
+ max_tokens=200, temperature=0.0,
106
+ )
107
+ raw = resp.choices[0].message.content.strip()
108
+ except Exception as e:
109
+ print(f" LLM error: {e}", file=sys.stderr)
110
+ break
111
 
112
+ try:
113
+ parsed = json.loads(raw)
114
+ at = ActionType(parsed["action"])
115
+ params = parsed.get("params", {})
116
+ except Exception:
117
+ at, params = ActionType.LIST_FUNCTIONS, {}
118
 
119
+ messages.append({"role": "assistant", "content": raw})
120
+ result = env.step(Action(action_type=at, params=params))
121
+ obs = result.observation.model_dump()
122
+ print(f" step {step+1:2d}: {at.value:25s} r={result.reward.value:+.2f}")
123
+
124
+ if result.done:
125
+ v = result.reward.value
126
+ grader_score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
127
+ cum_reward = obs["cumulative_reward"]
128
+ break
129
+ time.sleep(0.3)
130
+
131
+ print(f" β†’ grader_score={grader_score:.1f} cum_reward={cum_reward:.2f}")
132
+ return {"episode": ep, "seed": seed, "contract": obs["contract_name"],
133
+ "grader_score": grader_score, "cumulative_reward": cum_reward}
134
+
135
+
136
+ # ─────────────────────────────────────────────────────────────────────────────
137
+ # Task 2 agent
138
+ # ─────────────────────────────────────────────────────────────────────────────
139
+
140
+ T2_SYSTEM = """You are a formal methods engineer specialising in Solidity smart contracts.
141
+
142
+ You will be shown a specific Solidity function. Your task is to write a precise
143
+ natural-language property (invariant / postcondition) that describes what the
144
+ function guarantees when it succeeds.
145
+
146
+ A good property covers:
147
+ - What state changes (balances, counters, flags)
148
+ - What assets are transferred (ETH, tokens, NFTs)
149
+ - What return value is produced (for view functions)
150
+ - Under what conditions it reverts
151
+
152
+ ## Actions (respond with JSON only, ONE action per turn):
153
+ {"action": "get_function_code", "params": {}}
154
+ {"action": "get_function_natspec", "params": {}}
155
+ {"action": "get_file_natspec", "params": {}}
156
+ {"action": "get_related_functions", "params": {}}
157
+ {"action": "get_io", "params": {}}
158
+ {"action": "get_similar_rule", "params": {}}
159
+ {"action": "submit_property", "params": {"property": "<your full property text>"}}
160
+
161
+ ## Rules:
162
+ - You have ONE submit_property attempt. Make it count.
163
+ - Use get_function_natspec and get_io first β€” they give the most signal.
164
+ - get_similar_rule costs more (-0.20) but shows a parallel property from another contract.
165
+ - Write 2–4 sentences. Be specific about variable names and amounts.
166
+ - Do NOT guess β€” read the code first.
167
+
168
+ Respond ONLY with valid JSON. No markdown, no explanation."""
169
+
170
+
171
+ def _t2_user_msg(obs: Dict[str, Any]) -> str:
172
+ extra = obs.get("extra", {})
173
+ return (
174
+ f"Contract : {obs['contract_name']}\n"
175
+ f"Function : {extra.get('target_function', '?')} "
176
+ f"({extra.get('target_signature', '')})\n"
177
+ f"Step: {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n\n"
178
+ f"Last action: {obs['last_action'] or 'None'}\n"
179
+ f"Result:\n{obs['last_action_result'] or 'Episode started β€” begin exploring.'}"
180
+ )
181
 
 
182
 
183
+ def run_t2_episode(env: Task2Environment, seed: int, ep: int) -> Dict[str, Any]:
184
+ r = env.reset(seed=seed)
185
+ obs = r.observation.model_dump()
186
+ fn = obs["extra"].get("target_function", "?")
187
+ print(f" ep={ep} seed={seed} {obs['contract_name']}.{fn}()")
188
 
189
+ messages = [{"role": "system", "content": T2_SYSTEM}]
190
+ grader_score = 0.0
191
+ cum_reward = 0.0
192
 
193
+ for step in range(MAX_STEPS_T2):
194
+ messages.append({"role": "user", "content": _t2_user_msg(obs)})
195
  try:
196
+ resp = client.chat.completions.create(
197
+ model=MODEL_NAME, messages=messages,
198
+ max_tokens=400, temperature=0.0,
 
 
199
  )
200
+ raw = resp.choices[0].message.content.strip()
201
  except Exception as e:
202
+ print(f" LLM error: {e}", file=sys.stderr)
203
  break
204
 
 
205
  try:
206
  parsed = json.loads(raw)
207
+ at = ActionType(parsed["action"])
208
  params = parsed.get("params", {})
209
+ except Exception:
210
+ at, params = ActionType.GET_FUNCTION_CODE, {}
 
 
 
211
 
 
212
  messages.append({"role": "assistant", "content": raw})
213
+ result = env.step(Action(action_type=at, params=params))
214
+ obs = result.observation.model_dump()
215
+ r_val = result.reward.value
216
+ print(f" step {step+1:2d}: {at.value:25s} r={r_val:+.2f}")
217
+
218
+ if result.done:
219
+ grader_score = round(r_val / 5.0, 3) if r_val > 0 else 0.0
220
+ cum_reward = obs["cumulative_reward"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  break
222
+ time.sleep(0.3)
223
+
224
+ print(f" β†’ grader_score={grader_score:.3f} cum_reward={cum_reward:.2f}")
225
+ return {"episode": ep, "seed": seed,
226
+ "contract": obs["contract_name"], "function": fn,
227
+ "grader_score": grader_score, "cumulative_reward": cum_reward}
228
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ # ─────────────────────────────────────────────────────────────────────────────
231
+ # Task runners
232
+ # ─────────────────────────────────────────────────────────────────────────────
233
 
234
+ def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
 
235
  print("\n" + "="*60)
236
  print("TASK 1: Targeted Vulnerability Detection")
237
  print("="*60)
 
238
  env = Task1Environment()
239
+ episodes = [run_t1_episode(env, SEED_BASE_T1 + i, i+1) for i in range(n)]
240
+ avg_s = sum(e["grader_score"] for e in episodes) / n
241
+ avg_r = sum(e["cumulative_reward"] for e in episodes) / n
242
+ print(f"\n Avg grader score : {avg_s:.3f}")
243
+ print(f" Avg cum reward : {avg_r:.2f}")
244
+ return {"task_id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection",
245
+ "status": "active", "num_episodes": n, "episodes": episodes,
246
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
 
249
+ def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
 
250
  print("\n" + "="*60)
251
+ print("TASK 2: Property Discovery")
252
  print("="*60)
253
+ env = Task2Environment()
254
+ episodes = [run_t2_episode(env, SEED_BASE_T2 + i, i+1) for i in range(n)]
255
+ avg_s = sum(e["grader_score"] for e in episodes) / n
256
+ avg_r = sum(e["cumulative_reward"] for e in episodes) / n
257
+ print(f"\n Avg grader score : {avg_s:.3f}")
258
+ print(f" Avg cum reward : {avg_r:.2f}")
259
+ return {"task_id": "task2_property_discovery", "name": "Property Discovery",
260
+ "status": "active", "num_episodes": n, "episodes": episodes,
261
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
 
262
 
263
 
264
  def run_task3_placeholder() -> Dict[str, Any]:
 
265
  print("\n" + "="*60)
266
  print("TASK 3: Rule Checker [PLACEHOLDER β€” not implemented]")
267
  print("="*60)
268
  print(" Skipping. Score: 0.0")
269
+ return {"task_id": "task3_rule_checker", "name": "Rule Checker",
270
+ "status": "placeholder", "num_episodes": 0, "episodes": [],
271
+ "avg_grader_score": 0.0, "avg_cumulative_reward": 0.0}
 
 
 
 
 
 
272
 
273
 
274
+ # ─────────────────────────────────────────────────────────────────────────────
275
  # Main
276
+ # ─────────────────────────────────────────────────────────────────────────────
277
 
278
  def main():
279
  print("Smart Contract Audit RL Environment β€” Baseline Inference")
280
  print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
281
 
282
+ t1 = run_task1(NUM_EPISODES)
283
+ t2 = run_task2(NUM_EPISODES)
 
 
 
 
 
 
284
  t3 = run_task3_placeholder()
285
 
286
+ results = {
287
+ "model": MODEL_NAME, "base_url": API_BASE_URL,
288
+ "tasks": [t1, t2, t3],
289
+ }
290
 
291
+ active = [t for t in results["tasks"] if t["status"] == "active"]
292
+ overall = sum(t["avg_grader_score"] for t in active) / len(active) if active else 0.0
 
 
 
 
293
  results["overall_avg_score"] = overall
294
 
295
  print("\n" + "="*60)
296
  print("BASELINE SUMMARY")
297
  print("="*60)
298
  for t in results["tasks"]:
299
+ icon = "βœ…" if t["status"] == "active" else "⏳"
300
+ print(f" {icon} {t['name']:40s}: {t['avg_grader_score']:.3f}")
301
+ print(f"\n Overall (active tasks): {overall:.3f}")
302
 
 
303
  with open("baseline_scores.json", "w") as f:
304
  json.dump(results, f, indent=2)
305
  print("\n Scores written to baseline_scores.json")
openenv.yaml CHANGED
@@ -1,25 +1,22 @@
1
  name: smart-contract-audit-env
2
- version: "1.0.0"
3
  description: >
4
  Reinforcement learning environment for smart contract security analysis.
5
  Agents interact with real-world Solidity contract data from Certora-audited
6
- projects, learning to detect vulnerabilities, discover properties, and
7
- verify rule compliance β€” tasks that professional auditors perform daily.
8
 
9
  author: "SmartAudit Team"
10
  license: MIT
11
 
12
- # ---------------------------------------------------------------------------
13
- # Tasks
14
- # ---------------------------------------------------------------------------
15
  tasks:
16
  - id: task1_vuln_detection
17
  name: Targeted Vulnerability Detection
18
  difficulty: medium
19
  status: active
20
  description: >
21
- Given a Solidity contract (4–6 functions), identify the single vulnerable
22
- function and describe its vulnerability type in 2–3 words.
23
  max_steps: 20
24
  reward_range: [-10.0, 10.0]
25
  grader: tasks/task1/grader.py
@@ -28,13 +25,13 @@ tasks:
28
  - id: task2_property_discovery
29
  name: Property Discovery
30
  difficulty: hard
31
- status: placeholder
32
  description: >
33
  Given a single Solidity function with known properties, discover the
34
- correct natural-language property describing its expected behaviour.
35
  max_steps: 15
36
  reward_range: [-5.0, 5.0]
37
- grader: tasks/task2/grader.py # TODO: implement
38
  grader_score_range: [0.0, 1.0]
39
 
40
  - id: task3_rule_checker
@@ -46,81 +43,52 @@ tasks:
46
  function that violates that property.
47
  max_steps: 15
48
  reward_range: [-5.0, 5.0]
49
- grader: tasks/task3/grader.py # TODO: implement
50
  grader_score_range: [0.0, 1.0]
51
 
52
- # ---------------------------------------------------------------------------
53
- # Observation space
54
- # ---------------------------------------------------------------------------
55
  observation_space:
56
  type: object
57
  properties:
58
- task_id:
59
- type: string
60
- description: Active task identifier
61
- contract_name:
62
- type: string
63
- description: Name of the Solidity contract
64
- contract_description:
65
- type: string
66
- description: Human-readable description of what the contract does
67
- available_actions:
68
- type: array
69
- items:
70
- type: string
71
- description: List of valid action type strings
72
- last_action:
73
- type: string
74
- nullable: true
75
- description: The action type that produced this observation
76
- last_action_result:
77
- type: string
78
- nullable: true
79
- description: Human-readable result of the last action
80
- step_count:
81
- type: integer
82
- description: Number of steps taken in this episode
83
- cumulative_reward:
84
- type: number
85
- description: Running reward total for this episode
86
- done:
87
- type: boolean
88
- description: True when the episode has ended
89
- extra:
90
- type: object
91
- description: Task-specific hints and auxiliary data
92
 
93
- # ---------------------------------------------------------------------------
94
- # Action space (Task 1)
95
- # ---------------------------------------------------------------------------
96
  action_space:
97
- type: object
98
- description: Named action with optional parameters
99
- properties:
100
- action_type:
101
- type: string
102
- enum:
103
- - list_functions
104
- - get_function_code
105
- - get_function_summary
106
- - get_file_metadata
107
- - get_state_variable
108
- - get_call_graph
109
- - submit
110
- params:
111
- type: object
112
- description: Key-value arguments for the action
 
 
 
 
113
 
114
- # ---------------------------------------------------------------------------
115
- # Reward function
116
- # ---------------------------------------------------------------------------
117
  reward:
118
  type: shaped
119
  description: >
120
- Per-step costs encourage efficient exploration. A positive signal is given
121
- when the agent accesses the actual vulnerable function. Terminal rewards
122
- reflect submission accuracy (0 β†’ 1 grader score).
123
- shaping:
124
  list_functions: -0.05
125
  get_function_code_wrong: -0.10
126
  get_function_code_correct: +0.05
@@ -130,19 +98,28 @@ reward:
130
  get_state_variable: -0.05
131
  get_call_graph: -0.08
132
  repeated_query: -0.40
133
- terminal:
134
  correct_submission: +5.0
135
  partial_submission: +1.0
136
  wrong_submission: -1.5
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # ---------------------------------------------------------------------------
139
- # Data
140
- # ---------------------------------------------------------------------------
141
  data:
142
- source: "Certora audited projects (Aave, Compound-style protocols)"
143
  format: JSON
144
  num_contracts: 4
145
  num_vulnerable_functions: 8
 
146
  vulnerability_types:
147
  - Reentrancy
148
  - Missing access control
@@ -153,17 +130,16 @@ data:
153
  - Denial of service (unbounded loop)
154
  - Unchecked return value
155
 
156
- # ---------------------------------------------------------------------------
157
- # Interface
158
- # ---------------------------------------------------------------------------
159
  interface:
160
  http:
161
- reset: POST /reset
162
- step: POST /step
163
- state: GET /state
164
- tasks: GET /tasks
165
- health: GET /health
 
 
166
  python:
167
- reset: env.reset(seed=None) -> ResetResult
168
- step: env.step(action) -> StepResult
169
- state: env.state() -> StateResult
 
1
  name: smart-contract-audit-env
2
+ version: "1.1.0"
3
  description: >
4
  Reinforcement learning environment for smart contract security analysis.
5
  Agents interact with real-world Solidity contract data from Certora-audited
6
+ projects, learning to detect vulnerabilities and discover correctness
7
+ properties β€” tasks that professional auditors perform daily.
8
 
9
  author: "SmartAudit Team"
10
  license: MIT
11
 
 
 
 
12
  tasks:
13
  - id: task1_vuln_detection
14
  name: Targeted Vulnerability Detection
15
  difficulty: medium
16
  status: active
17
  description: >
18
+ Given a Solidity contract (4-6 functions), identify the single vulnerable
19
+ function and describe its vulnerability type in 2-3 words.
20
  max_steps: 20
21
  reward_range: [-10.0, 10.0]
22
  grader: tasks/task1/grader.py
 
25
  - id: task2_property_discovery
26
  name: Property Discovery
27
  difficulty: hard
28
+ status: active
29
  description: >
30
  Given a single Solidity function with known properties, discover the
31
+ correct natural-language postcondition describing its correct behaviour.
32
  max_steps: 15
33
  reward_range: [-5.0, 5.0]
34
+ grader: tasks/task2/grader.py
35
  grader_score_range: [0.0, 1.0]
36
 
37
  - id: task3_rule_checker
 
43
  function that violates that property.
44
  max_steps: 15
45
  reward_range: [-5.0, 5.0]
46
+ grader: tasks/task3/grader.py
47
  grader_score_range: [0.0, 1.0]
48
 
 
 
 
49
  observation_space:
50
  type: object
51
  properties:
52
+ task_id: {type: string, description: Active task identifier}
53
+ contract_name: {type: string, description: Solidity contract name}
54
+ contract_description: {type: string, description: Human-readable contract description}
55
+ available_actions: {type: array, items: {type: string}, description: Valid action types}
56
+ last_action: {type: string, nullable: true}
57
+ last_action_result: {type: string, nullable: true}
58
+ step_count: {type: integer}
59
+ cumulative_reward: {type: number}
60
+ done: {type: boolean}
61
+ extra: {type: object, description: Task-specific hints}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
63
  action_space:
64
+ task1:
65
+ type: object
66
+ actions:
67
+ list_functions: {params: {}, reward: -0.05}
68
+ get_function_code: {params: {function_name: string}, reward: "+0.05 / -0.10"}
69
+ get_function_summary: {params: {function_name: string}, reward: "+0.03 / -0.05"}
70
+ get_file_metadata: {params: {}, reward: -0.04}
71
+ get_state_variable: {params: {variable_name: "string (opt)"}, reward: -0.05}
72
+ get_call_graph: {params: {}, reward: -0.08}
73
+ submit: {params: {function_name: str, vulnerability_type: str}, reward: "+5.0 / +1.0 / -1.5"}
74
+ task2:
75
+ type: object
76
+ actions:
77
+ get_function_code: {params: {}, reward: -0.06}
78
+ get_function_natspec: {params: {}, reward: -0.08}
79
+ get_file_natspec: {params: {}, reward: -0.03}
80
+ get_related_functions: {params: {}, reward: -0.06}
81
+ get_io: {params: {}, reward: -0.04}
82
+ get_similar_rule: {params: {}, reward: -0.20}
83
+ submit_property: {params: {property: string}, reward: "0.0–5.0 (keyword-weighted)"}
84
 
 
 
 
85
  reward:
86
  type: shaped
87
  description: >
88
+ Per-step costs encourage efficient exploration. Positive shaping rewards
89
+ fire when the agent inspects the actual target. Terminal rewards reflect
90
+ grader score accuracy.
91
+ task1_shaping:
92
  list_functions: -0.05
93
  get_function_code_wrong: -0.10
94
  get_function_code_correct: +0.05
 
98
  get_state_variable: -0.05
99
  get_call_graph: -0.08
100
  repeated_query: -0.40
101
+ task1_terminal:
102
  correct_submission: +5.0
103
  partial_submission: +1.0
104
  wrong_submission: -1.5
105
+ task2_shaping:
106
+ get_function_code: -0.06
107
+ get_function_natspec: -0.08
108
+ get_file_natspec: -0.03
109
+ get_related_functions: -0.06
110
+ get_io: -0.04
111
+ get_similar_rule: -0.20
112
+ repeated_query: -0.40
113
+ task2_terminal:
114
+ score_range: [0.0, 5.0]
115
+ formula: "score * 5.0 where score = 0.70*(key_matches/total_key) + 0.30*(bonus_matches/total_bonus)"
116
 
 
 
 
117
  data:
118
+ source: "Certora audited DeFi projects"
119
  format: JSON
120
  num_contracts: 4
121
  num_vulnerable_functions: 8
122
+ num_property_functions: 11
123
  vulnerability_types:
124
  - Reentrancy
125
  - Missing access control
 
130
  - Denial of service (unbounded loop)
131
  - Unchecked return value
132
 
 
 
 
133
  interface:
134
  http:
135
+ reset: POST /reset
136
+ step: POST /step
137
+ state: GET /state
138
+ tasks: GET /tasks
139
+ health: GET /health
140
+ action_space: GET /action_space?task_id=<id>
141
+ observation_space: GET /observation_space
142
  python:
143
+ reset: env.reset(seed=None) -> ResetResult
144
+ step: env.step(action) -> StepResult
145
+ state: env.state() -> StateResult
tasks/task2/__init__.py CHANGED
@@ -1,27 +1,5 @@
1
- """
2
- tasks/task2/__init__.py
3
- -----------------------
4
- Task 2: Property Discovery (PLACEHOLDER)
5
 
6
- TODO: Implement this task.
7
-
8
- Episode setup:
9
- - One function from a Solidity file with known properties
10
- - Agent must discover the natural-language property of the function
11
-
12
- Actions (to implement):
13
- - get_similar_rule : -0.20
14
- - get_file_natspec : -0.03
15
- - get_function_natspec : -0.08
16
- - get_function_code : -0.06
17
- - get_related_functions : -0.06
18
- - get_io : -0.04
19
- - submit_property : scored 0.0–5.0 by semantic similarity grader
20
-
21
- See README.md for full task specification.
22
- """
23
-
24
- # TODO: Task 2 – Property Discovery
25
- # from tasks.task2.environment import Task2Environment
26
-
27
- __all__: list = []
 
1
+ # Task 2: Property Discovery
2
+ from tasks.task2.environment import Task2Environment
3
+ from tasks.task2.grader import Task2Grader
 
4
 
5
+ __all__ = ["Task2Environment", "Task2Grader"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/task2/environment.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ environment.py (Task 2 – Property Discovery)
3
+ ----------------------------------------------
4
+ OpenEnv-compliant RL environment.
5
+
6
+ Episode setup:
7
+ - One function from a Solidity contract that has a known property.
8
+ - The agent sees: contract description + function name + function signature.
9
+ - The agent must discover the natural-language property of the function.
10
+
11
+ Actions & rewards:
12
+ get_function_code -0.06 (always positive topic context)
13
+ get_function_natspec -0.08 (strongest hint β€” natspec has param/return docs)
14
+ get_file_natspec -0.03 (broad contract-level context)
15
+ get_related_functions -0.06 (shows callers/callees)
16
+ get_io -0.04 (structured input/output description)
17
+ get_similar_rule -0.20 (shows a similar property from another contract)
18
+ submit_property scored 0–5 (ONE attempt, ends episode)
19
+ repeated_query -0.40
20
+
21
+ Episode ends when:
22
+ - submit_property is called (scored), OR
23
+ - max_steps is reached without submission (reward = -1.0)
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import random
29
+ from typing import Any, Dict, List, Optional, Set
30
+
31
+ from data.data_loader import (
32
+ load_contracts,
33
+ sample_property_episode,
34
+ get_function_by_name,
35
+ get_related_functions,
36
+ get_similar_rule,
37
+ )
38
+ from env.base_env import BaseEnv
39
+ from env.schemas import (
40
+ Action,
41
+ ActionType,
42
+ Observation,
43
+ Reward,
44
+ ResetResult,
45
+ StateResult,
46
+ StepResult,
47
+ )
48
+ from tasks.task2.grader import Task2Grader
49
+
50
+ TASK_ID = "task2_property_discovery"
51
+ MAX_STEPS = 15
52
+
53
+ AVAILABLE_ACTIONS = [
54
+ ActionType.GET_FUNCTION_CODE,
55
+ ActionType.GET_FUNCTION_NATSPEC,
56
+ ActionType.GET_FILE_NATSPEC,
57
+ ActionType.GET_RELATED_FUNCTIONS,
58
+ ActionType.GET_IO,
59
+ ActionType.GET_SIMILAR_RULE,
60
+ ActionType.SUBMIT_PROPERTY,
61
+ ]
62
+
63
+
64
+ class Task2Environment(BaseEnv):
65
+ """Task 2: Property Discovery."""
66
+
67
+ def __init__(self, contracts_path: Optional[str] = None) -> None:
68
+ self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
69
+ self._rng = random.Random()
70
+
71
+ # Episode state – initialised by reset()
72
+ self._contract: Dict[str, Any] = {}
73
+ self._target_fn: Dict[str, Any] = {}
74
+ self._grader: Optional[Task2Grader] = None
75
+ self._step_count: int = 0
76
+ self._cum_reward: float = 0.0
77
+ self._done: bool = False
78
+ self._submitted: bool = False # only one submit_property allowed
79
+ self._query_hist: List[str] = []
80
+ self._seen: Set[str] = set()
81
+
82
+ # ── OpenEnv interface ────────────────────────────────────────────────────
83
+
84
+ def reset(self, seed: Optional[int] = None) -> ResetResult:
85
+ if seed is not None:
86
+ self._rng.seed(seed)
87
+
88
+ self._contract, self._target_fn = sample_property_episode(
89
+ self._contracts, self._rng
90
+ )
91
+ self._grader = Task2Grader(
92
+ function_name=self._target_fn["name"],
93
+ property_data=self._target_fn["property"],
94
+ )
95
+ self._step_count = 0
96
+ self._cum_reward = 0.0
97
+ self._done = False
98
+ self._submitted = False
99
+ self._query_hist = []
100
+ self._seen = set()
101
+
102
+ obs = self._build_obs(
103
+ last_action=None,
104
+ last_result=(
105
+ f"New episode started.\n"
106
+ f"Contract : {self._contract['contract_name']}\n"
107
+ f"Function : {self._target_fn['name']} "
108
+ f"({self._target_fn.get('signature', '')})\n"
109
+ f"Your task : Discover the natural-language property of "
110
+ f"'{self._target_fn['name']}' and submit it with submit_property."
111
+ ),
112
+ )
113
+ return ResetResult(observation=obs, info={"task_id": TASK_ID})
114
+
115
+ def step(self, action: Action) -> StepResult:
116
+ if self._done:
117
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
118
+
119
+ self._step_count += 1
120
+ result_text, reward = self._dispatch(action)
121
+ self._cum_reward += reward.value
122
+ self._query_hist.append(f"[{action.action_type}] β†’ {result_text[:100]}")
123
+
124
+ obs = self._build_obs(
125
+ last_action=action.action_type,
126
+ last_result=result_text,
127
+ )
128
+ return StepResult(
129
+ observation=obs,
130
+ reward=reward,
131
+ done=self._done,
132
+ info={
133
+ "step": self._step_count,
134
+ "cumulative_reward": self._cum_reward,
135
+ },
136
+ )
137
+
138
+ def state(self) -> StateResult:
139
+ return StateResult(
140
+ task_id=TASK_ID,
141
+ contract_name=self._contract.get("contract_name", ""),
142
+ target_function=self._target_fn.get("name"),
143
+ step_count=self._step_count,
144
+ cumulative_reward=self._cum_reward,
145
+ done=self._done,
146
+ query_history=list(self._query_hist),
147
+ )
148
+
149
+ # ── Internal helpers ─────────────────────────────────────────────────────
150
+
151
+ def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
152
+ return Observation(
153
+ task_id=TASK_ID,
154
+ contract_name=self._contract.get("contract_name", ""),
155
+ contract_description=self._contract.get("metadata", {}).get("description", ""),
156
+ available_actions=[a.value for a in AVAILABLE_ACTIONS],
157
+ last_action=last_action,
158
+ last_action_result=last_result,
159
+ step_count=self._step_count,
160
+ cumulative_reward=self._cum_reward,
161
+ done=self._done,
162
+ extra={
163
+ "target_function": self._target_fn.get("name", ""),
164
+ "target_signature": self._target_fn.get("signature", ""),
165
+ "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
166
+ "hint": (
167
+ "Discover the property of the target function. "
168
+ "Use get_function_code, get_function_natspec, or get_similar_rule for hints. "
169
+ "Submit with submit_property, params={'property': '<your property text>'}. "
170
+ "ONE submission attempt only."
171
+ ),
172
+ },
173
+ )
174
+
175
+ def _qkey(self, at: str, params: Dict[str, Any]) -> str:
176
+ return f"{at}:{sorted(params.items())}"
177
+
178
+ def _is_repeated(self, key: str) -> bool:
179
+ if key in self._seen:
180
+ return True
181
+ self._seen.add(key)
182
+ return False
183
+
184
+ def _dispatch(self, action: Action) -> tuple[str, Reward]:
185
+ at = action.action_type
186
+ params = action.params
187
+ qkey = self._qkey(at, params)
188
+ fn = self._target_fn
189
+ name = fn["name"]
190
+
191
+ # ── get_function_code ────────────────────────────────────────────────
192
+ if at == ActionType.GET_FUNCTION_CODE:
193
+ if self._is_repeated(qkey):
194
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
195
+ code = fn.get("code", "// no code available")
196
+ return (
197
+ f"// {name}\n{code}",
198
+ Reward(value=-0.06, reason="get_function_code cost"),
199
+ )
200
+
201
+ # ── get_function_natspec ─────────────────────────────────────────────
202
+ if at == ActionType.GET_FUNCTION_NATSPEC:
203
+ if self._is_repeated(qkey):
204
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
205
+ natspec = fn.get("natspec") or fn.get("comment") or "No NatSpec available."
206
+ # Also include output_property if present
207
+ out_prop = fn.get("output_property", "")
208
+ result = f"NatSpec for '{name}':\n{natspec}"
209
+ if out_prop:
210
+ result += f"\n\nExpected output: {out_prop}"
211
+ return result, Reward(value=-0.08, reason="get_function_natspec cost")
212
+
213
+ # ── get_file_natspec ─────────────────────────────────────────────────
214
+ if at == ActionType.GET_FILE_NATSPEC:
215
+ if self._is_repeated(qkey):
216
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
217
+ meta = self._contract.get("metadata", {})
218
+ natspec = meta.get("natspec") or meta.get("description", "No file NatSpec available.")
219
+ return (
220
+ f"File NatSpec for {self._contract['contract_name']}:\n{natspec}",
221
+ Reward(value=-0.03, reason="get_file_natspec cost"),
222
+ )
223
+
224
+ # ── get_related_functions ────────────────────────────────────────────
225
+ if at == ActionType.GET_RELATED_FUNCTIONS:
226
+ if self._is_repeated(qkey):
227
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
228
+ related = get_related_functions(self._contract, name)
229
+ if not related:
230
+ text = f"No related functions found for '{name}'."
231
+ else:
232
+ summaries = []
233
+ for rn in related:
234
+ rfn = get_function_by_name(self._contract, rn)
235
+ if rfn:
236
+ sig = rfn.get("signature", rn)
237
+ comment = rfn.get("comment", "")
238
+ summaries.append(f" β€’ {sig} β€” {comment}")
239
+ text = f"Related functions for '{name}':\n" + "\n".join(summaries)
240
+ return text, Reward(value=-0.06, reason="get_related_functions cost")
241
+
242
+ # ── get_io ───────────────────────────────────────────────────────────
243
+ if at == ActionType.GET_IO:
244
+ if self._is_repeated(qkey):
245
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
246
+ params_list = fn.get("parameters", [])
247
+ returns = fn.get("returns", "") or "void"
248
+ out_prop = fn.get("output_property", "")
249
+ visibility = fn.get("visibility", "")
250
+ modifiers = fn.get("modifiers", [])
251
+
252
+ lines = [f"Function: {fn.get('signature', name)}"]
253
+ lines.append(f"Visibility: {visibility}" + (f" Modifiers: {', '.join(modifiers)}" if modifiers else ""))
254
+ if params_list:
255
+ lines.append("Parameters:")
256
+ for p in params_list:
257
+ lines.append(f" β€’ {p['type']} {p['name']}: {p.get('description','')}")
258
+ else:
259
+ lines.append("Parameters: none (payable)" if "payable" in fn.get("code","") else "Parameters: none")
260
+ lines.append(f"Returns: {returns}")
261
+ if out_prop:
262
+ lines.append(f"Expected behaviour: {out_prop}")
263
+ return "\n".join(lines), Reward(value=-0.04, reason="get_io cost")
264
+
265
+ # ── get_similar_rule ─────────────────────────────────────────────────
266
+ if at == ActionType.GET_SIMILAR_RULE:
267
+ if self._is_repeated(qkey):
268
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
269
+ sr = get_similar_rule(
270
+ self._contracts,
271
+ self._contract["contract_name"],
272
+ name,
273
+ )
274
+ if sr is None:
275
+ return (
276
+ "No similar rule available for this function.",
277
+ Reward(value=-0.20, reason="get_similar_rule cost (not found)"),
278
+ )
279
+ lines = [
280
+ f"Similar property from {sr['contract_name']}.{sr['function_name']}():",
281
+ f" {sr['property_hint']}",
282
+ ]
283
+ if sr.get("natspec"):
284
+ lines.append(f"\nFunction NatSpec:\n {sr['natspec']}")
285
+ return "\n".join(lines), Reward(value=-0.20, reason="get_similar_rule cost")
286
+
287
+ # ── submit_property ──────────────────────────────────────────────────
288
+ if at == ActionType.SUBMIT_PROPERTY:
289
+ if self._submitted:
290
+ return (
291
+ "❌ You have already submitted a property for this episode. "
292
+ "Only one submission is allowed.",
293
+ Reward(value=-1.0, reason="Second submit_property attempt", partial=False),
294
+ )
295
+ submitted_text = params.get("property", "").strip()
296
+ if not submitted_text:
297
+ return (
298
+ "Submit requires 'property' key in params with a non-empty string.",
299
+ Reward(value=-0.5, reason="Empty property submission"),
300
+ )
301
+
302
+ self._submitted = True
303
+ self._done = True
304
+
305
+ score = self._grader.grade(submitted_text)
306
+ reward = self._grader.reward_for_score(score)
307
+ bd = self._grader.breakdown(submitted_text)
308
+
309
+ pct = int(score * 100)
310
+ if score >= 0.85:
311
+ emoji = "βœ…"
312
+ label = "EXCELLENT"
313
+ elif score >= 0.60:
314
+ emoji = "🟑"
315
+ label = "GOOD"
316
+ elif score >= 0.35:
317
+ emoji = "🟠"
318
+ label = "PARTIAL"
319
+ else:
320
+ emoji = "❌"
321
+ label = "POOR"
322
+
323
+ msg = (
324
+ f"{emoji} {label} β€” Score: {score:.2f}/1.00 β†’ Reward: {reward:.2f}/5.00 ({pct}%)\n"
325
+ f"Key concepts matched : {len(bd['key_matched'])}/{len(bd['key_matched'])+len(bd['key_missed'])} "
326
+ f"{bd['key_matched']}\n"
327
+ f"Bonus concepts matched : {len(bd['bonus_matched'])}/{len(bd['bonus_matched'])+len(bd['bonus_missed'])} "
328
+ f"{bd['bonus_matched']}"
329
+ )
330
+ return msg, Reward(
331
+ value=reward,
332
+ reason=f"Property submission score={score:.3f}",
333
+ partial=False,
334
+ )
335
+
336
+ # ── unknown action ────────────────────────────────────────────────────
337
+ return (
338
+ f"Unknown action type: '{at}'. Valid: {[a.value for a in AVAILABLE_ACTIONS]}",
339
+ Reward(value=-0.10, reason="Unknown action"),
340
+ )
tasks/task2/grader.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grader.py (Task 2 – Property Discovery)
3
+ -----------------------------------------
4
+ Deterministic scorer for natural-language property submissions.
5
+
6
+ Score formula
7
+ ─────────────
8
+ key_phrases weight = 0.70
9
+ bonus_phrases weight = 0.30
10
+
11
+ score = 0.70 * (matched_key / total_key)
12
+ + 0.30 * (matched_bonus / total_bonus)
13
+
14
+ Phrase matching
15
+ ───────────────
16
+ A phrase is considered matched if ALL its words (after normalisation)
17
+ appear in the submitted text. This is intentionally lenient β€” it
18
+ doesn't require the words to be adjacent, so "balance increases by
19
+ msg.value" is matched by "the caller's vault balance increases by the
20
+ sent msg.value amount".
21
+
22
+ Synonym expansion allows common paraphrases to also match
23
+ (e.g. "caller" β†’ "msg.sender", "sender", "user").
24
+
25
+ Terminal reward = score Γ— 5.0 (range: 0.0 – 5.0)
26
+ One submission attempt per episode.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import string
32
+ from typing import Dict, List, Optional
33
+
34
+
35
+ # ── Text normalisation ────────────────────────────────────────────────────────
36
+
37
+ _PUNCT = str.maketrans("", "", string.punctuation)
38
+
39
+
40
+ def _norm(text: str) -> str:
41
+ """Lowercase, strip punctuation, collapse whitespace β†’ word string."""
42
+ return " ".join(text.lower().translate(_PUNCT).split())
43
+
44
+
45
+ def _word_set(text: str) -> set:
46
+ """Return normalised set of words."""
47
+ return set(_norm(text).split())
48
+
49
+
50
+ # ── Synonym table ─────────────────────────────────────────────────────────────
51
+ # Maps canonical word β†’ list of accepted synonyms (all lowercase, no punct).
52
+
53
+ SYNONYMS: Dict[str, List[str]] = {
54
+ "caller": ["caller", "sender", "user", "msgsender", "msg sender"],
55
+ "balance": ["balance", "holdings", "amount held"],
56
+ "increases": ["increases", "incremented", "added", "grows", "rise"],
57
+ "decreases": ["decreases", "decremented", "reduced", "subtracted", "falls"],
58
+ "transfers": ["transfers", "sends", "moved", "forwarded", "sent"],
59
+ "reverts": ["reverts", "fails", "rejected", "throws", "require"],
60
+ "zero": ["zero", "0", "nothing", "none", "empty"],
61
+ "owner": ["owner", "admin", "authorized"],
62
+ "entire": ["entire", "full", "whole", "all", "total"],
63
+ "returned": ["returned", "sent back", "refunded", "transferred back"],
64
+ "reset": ["reset", "zeroed", "set to zero", "cleared"],
65
+ "only": ["only", "exclusively", "restricted"],
66
+ "price": ["price", "cost", "rate"],
67
+ "tokens": ["tokens", "token amount"],
68
+ "rewards": ["rewards", "reward tokens", "accrued"],
69
+ "staked": ["staked", "deposited", "locked"],
70
+ "winner": ["winner", "winning bidder", "successful bidder"],
71
+ }
72
+
73
+
74
+ def _expand_words(phrase_words: List[str]) -> List[List[str]]:
75
+ """
76
+ For each word in the phrase, generate synonym variants.
77
+ Returns a list of word-list variants to try.
78
+ Only substitutes ONE word at a time to avoid combinatorial explosion.
79
+ """
80
+ variants = [phrase_words] # original
81
+ for i, word in enumerate(phrase_words):
82
+ if word in SYNONYMS:
83
+ for syn in SYNONYMS[word]:
84
+ syn_words = _norm(syn).split()
85
+ new_variant = phrase_words[:i] + syn_words + phrase_words[i + 1:]
86
+ variants.append(new_variant)
87
+ return variants
88
+
89
+
90
+ def _phrase_matched(text_words: set, phrase: str) -> bool:
91
+ """
92
+ True if ALL words in the phrase (or a synonym variant) appear in text_words.
93
+ Uses word-set containment, not substring adjacency.
94
+ """
95
+ norm_words = _norm(phrase).split()
96
+ for variant_words in _expand_words(norm_words):
97
+ if all(w in text_words for w in variant_words):
98
+ return True
99
+ return False
100
+
101
+
102
+ # ── Grader ────────────────────────────────────────────────────────────────────
103
+
104
+ class Task2Grader:
105
+ """
106
+ Grades a Task 2 property submission.
107
+
108
+ Parameters
109
+ ----------
110
+ function_name : name of the target function
111
+ property_data : the 'property' dict from the dataset
112
+ Must have: natural_language, key_phrases, bonus_phrases
113
+ """
114
+
115
+ KEY_WEIGHT = 0.70
116
+ BONUS_WEIGHT = 0.30
117
+
118
+ def __init__(self, function_name: str, property_data: Dict) -> None:
119
+ self.function_name = function_name
120
+ self.natural_language = property_data.get("natural_language", "")
121
+ self.key_phrases = property_data.get("key_phrases", [])
122
+ self.bonus_phrases = property_data.get("bonus_phrases", [])
123
+
124
+ # ── Public API ──────────────────────��─────────────────────────────────────
125
+
126
+ def grade(self, submitted: str) -> float:
127
+ """Deterministic score in [0.0, 1.0]."""
128
+ if not submitted or not submitted.strip():
129
+ return 0.0
130
+ tw = _word_set(submitted)
131
+ key_score = self._phrase_score(tw, self.key_phrases)
132
+ bonus_score = self._phrase_score(tw, self.bonus_phrases)
133
+ raw = self.KEY_WEIGHT * key_score + self.BONUS_WEIGHT * bonus_score
134
+ return round(min(max(raw, 0.0), 1.0), 4)
135
+
136
+ def reward_for_score(self, score: float) -> float:
137
+ """Maps [0.0, 1.0] β†’ [0.0, 5.0]."""
138
+ return round(score * 5.0, 4)
139
+
140
+ def breakdown(self, submitted: str) -> Dict:
141
+ """Detailed scoring breakdown for debugging."""
142
+ tw = _word_set(submitted)
143
+ key_hits = [p for p in self.key_phrases if _phrase_matched(tw, p)]
144
+ bonus_hits = [p for p in self.bonus_phrases if _phrase_matched(tw, p)]
145
+ score = self.grade(submitted)
146
+ return {
147
+ "score": score,
148
+ "reward": self.reward_for_score(score),
149
+ "key_matched": key_hits,
150
+ "key_missed": [p for p in self.key_phrases if p not in key_hits],
151
+ "bonus_matched": bonus_hits,
152
+ "bonus_missed": [p for p in self.bonus_phrases if p not in bonus_hits],
153
+ "key_score": self._phrase_score(tw, self.key_phrases),
154
+ "bonus_score": self._phrase_score(tw, self.bonus_phrases),
155
+ }
156
+
157
+ def get_canonical_answer(self) -> Dict:
158
+ """For debugging / logging only."""
159
+ return {
160
+ "function": self.function_name,
161
+ "natural_language": self.natural_language,
162
+ "key_phrases": self.key_phrases,
163
+ }
164
+
165
+ # ── Internal ──────────────────────────────────────────────────────────────
166
+
167
+ def _phrase_score(self, text_words: set, phrases: List[str]) -> float:
168
+ if not phrases:
169
+ return 1.0
170
+ matched = sum(1 for p in phrases if _phrase_matched(text_words, p))
171
+ return matched / len(phrases)
validate.py CHANGED
@@ -1,290 +1,280 @@
1
  """
2
  validate.py
3
  -----------
4
- Pre-submission validation script.
5
- Checks all OpenEnv spec requirements locally before submitting.
6
 
7
- Usage:
8
- python validate.py
9
-
10
- Exit code 0 = all checks pass.
11
- Exit code 1 = one or more checks failed.
12
  """
13
 
14
- import json
15
- import sys
16
- import traceback
17
  from typing import Callable, List, Tuple
18
 
19
- # ─────────────────────────────────────────────────────────────────────────────
20
- # Helpers
21
- # ─────────────────────────────────────────────────────────────────────────────
22
-
23
- PASS = "βœ…"
24
- FAIL = "❌"
25
- SKIP = "⏭ "
26
  results: List[Tuple[str, bool, str]] = []
27
 
28
-
29
  def check(name: str, fn: Callable[[], None]) -> None:
30
  try:
31
- fn()
32
- results.append((name, True, ""))
33
  print(f" {PASS} {name}")
34
  except Exception as e:
35
- tb = traceback.format_exc(limit=3)
36
  results.append((name, False, str(e)))
37
- print(f" {FAIL} {name}")
38
- print(f" {e}")
39
-
40
 
41
- # ─────────────────────────────────────────────────────────────────────────────
42
- # Checks
43
- # ─────────────────────────────────────────────────────────────────────────────
44
 
45
  def check_imports():
46
- from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult
47
  from tasks.task1.environment import Task1Environment
48
  from tasks.task1.grader import Task1Grader
 
 
49
  from data.data_loader import load_contracts
50
 
51
-
52
  def check_openenv_yaml():
53
  import yaml
54
- with open("openenv.yaml") as f:
55
- spec = yaml.safe_load(f)
56
  assert "name" in spec
57
- assert "tasks" in spec
58
- assert len(spec["tasks"]) >= 3, "Need at least 3 tasks defined"
59
  assert "observation_space" in spec
60
  assert "action_space" in spec
61
  assert "reward" in spec
62
 
63
-
64
  def check_pydantic_models():
65
  from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
66
- # Instantiate each model
67
- obs = Observation(
68
- task_id="t1", contract_name="C", contract_description="D",
69
- available_actions=["submit"]
70
- )
71
  assert obs.task_id == "t1"
72
-
73
- action = Action(action_type=ActionType.LIST_FUNCTIONS)
74
- assert action.action_type == ActionType.LIST_FUNCTIONS
75
-
76
- reward = Reward(value=1.0, reason="test")
77
- assert reward.value == 1.0
78
-
79
- step = StepResult(observation=obs, reward=reward, done=False)
80
- assert not step.done
81
-
82
- reset = ResetResult(observation=obs)
83
- assert reset.observation.task_id == "t1"
84
-
85
- state = StateResult(task_id="t1", contract_name="C", step_count=0,
86
- cumulative_reward=0.0, done=False)
87
- assert state.step_count == 0
88
-
89
 
90
  def check_data_loading():
91
- from data.data_loader import load_contracts, get_all_vulnerable_entries
92
  contracts = load_contracts()
93
- assert len(contracts) >= 1, "No contracts loaded"
94
- entries = get_all_vulnerable_entries(contracts)
95
- assert len(entries) >= 3, f"Need >= 3 vulnerable functions, got {len(entries)}"
96
- for contract, fn in entries:
97
- assert fn.get("vulnerable") is True
98
- assert fn.get("vulnerability_details") is not None
99
- assert "issue" in fn["vulnerability_details"]
100
-
101
-
102
- def check_env_reset():
103
- from tasks.task1.environment import Task1Environment
104
- env = Task1Environment()
105
- result = env.reset(seed=42)
106
- assert result.observation is not None
107
- assert result.observation.task_id == "task1_vuln_detection"
108
- assert result.observation.contract_name != ""
109
- assert not result.observation.done
110
- assert result.observation.step_count == 0
111
-
112
-
113
- def check_env_step():
114
  from tasks.task1.environment import Task1Environment
115
  from env.schemas import Action, ActionType
116
  env = Task1Environment()
117
- env.reset(seed=42)
118
- result = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
119
- assert result.observation is not None
120
- assert isinstance(result.reward.value, float)
121
- assert isinstance(result.done, bool)
122
- assert "info" in result.model_dump()
123
-
124
-
125
- def check_env_state():
126
- from tasks.task1.environment import Task1Environment
127
- env = Task1Environment()
128
- env.reset(seed=42)
129
- state = env.state()
130
- assert state.task_id == "task1_vuln_detection"
131
- assert state.contract_name != ""
132
- assert state.target_function is not None # exposed for debugging
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- def check_grader_scores_in_range():
136
  from tasks.task1.grader import Task1Grader
137
  cases = [
138
- ("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
139
- ("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
140
- ("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
141
  ]
142
  for tf, issue, sf, sv, expected in cases:
143
  g = Task1Grader(tf, issue)
144
  score = g.grade_submission(sf, sv)
145
- assert 0.0 <= score <= 1.0, f"Score {score} out of range"
146
  assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
147
 
148
-
149
- def check_grader_deterministic():
150
- from tasks.task1.grader import Task1Grader
151
- g = Task1Grader("withdraw", "Reentrancy vulnerability")
152
- s1 = g.grade_submission("withdraw", "reentrancy")
153
- s2 = g.grade_submission("withdraw", "reentrancy")
154
- assert s1 == s2 == 1.0, "Grader must be deterministic"
155
-
 
 
 
 
 
 
 
 
 
 
156
 
157
  def check_reward_shaping():
158
- """Verify reward is non-binary (multiple distinct values across steps)."""
159
- from tasks.task1.environment import Task1Environment
160
  from env.schemas import Action, ActionType
161
- env = Task1Environment()
162
  env.reset(seed=1)
163
- rewards = set()
164
- for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_FILE_METADATA, ActionType.GET_CALL_GRAPH]:
165
- r = env.step(Action(action_type=at))
166
- rewards.add(round(r.reward.value, 4))
167
- # Should have at least 2 distinct shaping reward values
168
- assert len(rewards) >= 2, f"Expected multiple reward values, got {rewards}"
169
-
170
 
171
- def check_episode_boundary():
172
- """Episode must end after submit and raise on subsequent step."""
173
  from tasks.task1.environment import Task1Environment
174
  from env.schemas import Action, ActionType
175
  env = Task1Environment()
176
  env.reset(seed=2)
177
- env.step(Action(action_type=ActionType.SUBMIT, params={
178
- "function_name": "withdraw", "vulnerability_type": "test"
179
- }))
180
  try:
181
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
182
- raise AssertionError("Should have raised RuntimeError after episode end")
183
  except RuntimeError:
184
- pass # Expected
185
-
186
 
187
  def check_repeated_query_penalty():
188
  from tasks.task1.environment import Task1Environment
189
  from env.schemas import Action, ActionType
190
- env = Task1Environment()
191
- env.reset(seed=3)
192
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
193
  r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
194
- assert r.reward.value == -0.40, f"Expected -0.40 for repeated query, got {r.reward.value}"
195
-
196
 
197
- def check_tasks_list():
198
- """All three tasks must be listed (even if placeholders)."""
199
- from tasks.task2 import __all__ as t2 # noqa
200
- from tasks.task3 import __all__ as t3 # noqa
 
 
 
201
 
 
 
202
 
203
- def check_dockerfile_exists():
204
  import os
205
- assert os.path.exists("Dockerfile"), "Dockerfile is missing"
206
- with open("Dockerfile") as f:
207
- content = f.read()
208
- assert "7860" in content, "Dockerfile must EXPOSE 7860 (HF Spaces)"
209
- assert "uvicorn" in content or "CMD" in content
210
-
211
 
212
  def check_inference_script():
213
  import os
214
- assert os.path.exists("inference.py"), "inference.py is missing"
215
- with open("inference.py") as f:
216
- content = f.read()
217
- assert "OPENAI_API_KEY" in content or "HF_TOKEN" in content, \
218
- "inference.py must read API credentials from env vars"
219
- assert "API_BASE_URL" in content
220
- assert "MODEL_NAME" in content
221
-
222
-
223
- def check_baseline_json_schema():
224
- """baseline_scores.json must have valid schema if it exists."""
225
  import os
226
- if not os.path.exists("baseline_scores.json"):
227
- return # OK β€” file is generated at runtime
228
- with open("baseline_scores.json") as f:
229
- data = json.load(f)
230
  assert "tasks" in data
231
- for task in data["tasks"]:
232
- score = task["avg_grader_score"]
233
- assert 0.0 <= score <= 1.0, f"Score {score} out of range"
234
 
235
-
236
- # ─────────────────────────────────────────────────────────────────────────────
237
- # Runner
238
- # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  def main():
241
- print("=" * 60)
242
- print("OpenEnv Pre-Submission Validation")
243
- print("=" * 60)
244
-
245
- all_checks = [
246
- ("Python imports", check_imports),
247
- ("openenv.yaml format", check_openenv_yaml),
248
- ("Pydantic model types", check_pydantic_models),
249
- ("Dataset loading (3+ vulns)", check_data_loading),
250
- ("env.reset() β†’ ResetResult", check_env_reset),
251
- ("env.step() β†’ StepResult", check_env_step),
252
- ("env.state() β†’ StateResult", check_env_state),
253
- ("Grader scores in [0.0, 1.0]", check_grader_scores_in_range),
254
- ("Grader is deterministic", check_grader_deterministic),
255
- ("Reward shaping (non-binary)", check_reward_shaping),
256
- ("Episode boundary (done=True)",check_episode_boundary),
257
- ("Repeated query penalty", check_repeated_query_penalty),
258
- ("Task 2 & 3 placeholders", check_tasks_list),
259
- ("Dockerfile exists + port", check_dockerfile_exists),
260
- ("inference.py exists + vars", check_inference_script),
261
- ("baseline_scores.json schema", check_baseline_json_schema),
262
- ]
263
-
264
  print()
265
- for name, fn in all_checks:
266
  check(name, fn)
267
 
268
- print()
269
  passed = sum(1 for _, ok, _ in results if ok)
270
- total = len(results)
271
- failed = [(n, msg) for n, ok, msg in results if not ok]
272
 
273
- print("=" * 60)
 
274
  print(f"Results: {passed}/{total} checks passed")
275
-
276
  if failed:
277
  print("\nFailed checks:")
278
- for name, msg in failed:
279
- print(f" {FAIL} {name}: {msg}")
280
- print()
281
- print("❌ VALIDATION FAILED β€” fix the issues above before submitting.")
282
  sys.exit(1)
283
  else:
284
- print()
285
- print("βœ… ALL CHECKS PASSED β€” ready to submit!")
286
  sys.exit(0)
287
 
288
-
289
  if __name__ == "__main__":
290
  main()
 
1
  """
2
  validate.py
3
  -----------
4
+ Pre-submission validation. Checks all OpenEnv spec requirements.
 
5
 
6
+ Usage: python validate.py
7
+ Exit 0 = all checks pass. Exit 1 = one or more failures.
 
 
 
8
  """
9
 
10
+ import json, sys, traceback
 
 
11
  from typing import Callable, List, Tuple
12
 
13
+ PASS = "βœ…"; FAIL = "❌"
 
 
 
 
 
 
14
  results: List[Tuple[str, bool, str]] = []
15
 
 
16
  def check(name: str, fn: Callable[[], None]) -> None:
17
  try:
18
+ fn(); results.append((name, True, ""))
 
19
  print(f" {PASS} {name}")
20
  except Exception as e:
 
21
  results.append((name, False, str(e)))
22
+ print(f" {FAIL} {name}\n {e}")
 
 
23
 
24
+ # ── Checks ────────────────────────────────────────────────────────────────────
 
 
25
 
26
  def check_imports():
27
+ from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult, ActionType
28
  from tasks.task1.environment import Task1Environment
29
  from tasks.task1.grader import Task1Grader
30
+ from tasks.task2.environment import Task2Environment
31
+ from tasks.task2.grader import Task2Grader
32
  from data.data_loader import load_contracts
33
 
 
34
  def check_openenv_yaml():
35
  import yaml
36
+ with open("openenv.yaml") as f: spec = yaml.safe_load(f)
 
37
  assert "name" in spec
38
+ assert len(spec.get("tasks", [])) >= 3
 
39
  assert "observation_space" in spec
40
  assert "action_space" in spec
41
  assert "reward" in spec
42
 
 
43
  def check_pydantic_models():
44
  from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
45
+ obs = Observation(task_id="t1", contract_name="C", contract_description="D", available_actions=["submit"])
 
 
 
 
46
  assert obs.task_id == "t1"
47
+ action = Action(action_type=ActionType.LIST_FUNCTIONS); assert action.action_type == ActionType.LIST_FUNCTIONS
48
+ action2 = Action(action_type=ActionType.SUBMIT_PROPERTY); assert action2.action_type == ActionType.SUBMIT_PROPERTY
49
+ reward = Reward(value=1.0, reason="test"); assert reward.value == 1.0
50
+ step = StepResult(observation=obs, reward=reward, done=False); assert not step.done
51
+ reset = ResetResult(observation=obs); assert reset.observation.task_id == "t1"
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def check_data_loading():
54
+ from data.data_loader import load_contracts, get_all_vulnerable_entries, get_all_property_entries
55
  contracts = load_contracts()
56
+ assert len(contracts) >= 1
57
+ vuln_entries = get_all_vulnerable_entries(contracts)
58
+ assert len(vuln_entries) >= 3, f"Need >=3 vulnerable fns, got {len(vuln_entries)}"
59
+ prop_entries = get_all_property_entries(contracts)
60
+ assert len(prop_entries) >= 3, f"Need >=3 property fns, got {len(prop_entries)}"
61
+ for _, fn in prop_entries:
62
+ p = fn["property"]
63
+ assert "natural_language" in p
64
+ assert "key_phrases" in p
65
+ assert "bonus_phrases" in p
66
+ assert len(p["key_phrases"]) >= 2
67
+
68
+ def check_t1_env():
 
 
 
 
 
 
 
 
69
  from tasks.task1.environment import Task1Environment
70
  from env.schemas import Action, ActionType
71
  env = Task1Environment()
72
+ r = env.reset(seed=42); assert r.observation.task_id == "task1_vuln_detection"
73
+ s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
74
+ assert isinstance(s.reward.value, float)
75
+ assert s.observation.step_count == 1
76
+ st = env.state(); assert st.target_function is not None
77
+
78
+ def check_t2_env():
79
+ from tasks.task2.environment import Task2Environment
80
+ from env.schemas import Action, ActionType
81
+ env = Task2Environment()
82
+ r = env.reset(seed=42)
83
+ assert r.observation.task_id == "task2_property_discovery"
84
+ assert "target_function" in r.observation.extra
85
+ # test each action type
86
+ for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC,
87
+ ActionType.GET_FILE_NATSPEC, ActionType.GET_IO, ActionType.GET_RELATED_FUNCTIONS]:
88
+ s = env.step(Action(action_type=at)); assert s.reward.value < 0
89
+ s = env.step(Action(action_type=ActionType.GET_SIMILAR_RULE))
90
+ assert s.reward.value == -0.20
91
+
92
+ def check_t2_env_submit():
93
+ from tasks.task2.environment import Task2Environment
94
+ from data.data_loader import load_contracts, get_function_by_name
95
+ from env.schemas import Action, ActionType
96
+ env = Task2Environment()
97
+ r = env.reset(seed=42)
98
+ fn_name = r.observation.extra["target_function"]
99
+ contract = r.observation.contract_name
100
+ contracts = load_contracts()
101
+ gt_text = ""
102
+ for c in contracts:
103
+ if c["contract_name"] == contract:
104
+ fn = get_function_by_name(c, fn_name)
105
+ if fn and fn.get("property"):
106
+ gt_text = fn["property"]["natural_language"]
107
+ result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": gt_text}))
108
+ assert result.done
109
+ assert result.reward.value > 0, f"GT text should score >0, got {result.reward.value}"
110
+
111
+ def check_t2_one_submit_only():
112
+ from tasks.task2.environment import Task2Environment
113
+ from env.schemas import Action, ActionType
114
+ env = Task2Environment()
115
+ env.reset(seed=5)
116
+ env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test"}))
117
+ # Second submit must either fail (episode done β†’ RuntimeError) or return negative reward
118
+ try:
119
+ s2 = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test2"}))
120
+ # If it doesn't raise, the reward must be negative
121
+ assert s2.reward.value < 0, "Second submit should penalise"
122
+ except RuntimeError:
123
+ pass # expected
124
 
125
+ def check_t1_grader():
126
  from tasks.task1.grader import Task1Grader
127
  cases = [
128
+ ("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
129
+ ("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
130
+ ("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
131
  ]
132
  for tf, issue, sf, sv, expected in cases:
133
  g = Task1Grader(tf, issue)
134
  score = g.grade_submission(sf, sv)
135
+ assert 0.0 <= score <= 1.0
136
  assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
137
 
138
+ def check_t2_grader():
139
+ from tasks.task2.grader import Task2Grader
140
+ from data.data_loader import load_contracts, get_all_property_entries
141
+ contracts = load_contracts()
142
+ entries = get_all_property_entries(contracts)
143
+ for contract, fn in entries:
144
+ g = Task2Grader(fn["name"], fn["property"])
145
+ # Ground truth must score β‰₯ 0.65
146
+ gt_score = g.grade(fn["property"]["natural_language"])
147
+ assert gt_score >= 0.65, f"{fn['name']}: gt_score={gt_score} < 0.65"
148
+ # Empty must be 0.0
149
+ assert g.grade("") == 0.0
150
+ # Deterministic
151
+ assert g.grade("test text") == g.grade("test text")
152
+ # Score in [0,1]
153
+ assert 0.0 <= gt_score <= 1.0
154
+ # Reward maps correctly
155
+ assert abs(g.reward_for_score(gt_score) - gt_score * 5.0) < 0.01
156
 
157
  def check_reward_shaping():
158
+ from tasks.task2.environment import Task2Environment
 
159
  from env.schemas import Action, ActionType
160
+ env = Task2Environment()
161
  env.reset(seed=1)
162
+ rewards = {env.step(Action(action_type=at)).reward.value
163
+ for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FILE_NATSPEC, ActionType.GET_IO]}
164
+ assert len(rewards) >= 2, f"Need multiple reward values, got {rewards}"
 
 
 
 
165
 
166
+ def check_t1_episode_boundary():
 
167
  from tasks.task1.environment import Task1Environment
168
  from env.schemas import Action, ActionType
169
  env = Task1Environment()
170
  env.reset(seed=2)
171
+ env.step(Action(action_type=ActionType.SUBMIT,
172
+ params={"function_name": "withdraw", "vulnerability_type": "test"}))
 
173
  try:
174
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
175
+ raise AssertionError("Should raise RuntimeError after done")
176
  except RuntimeError:
177
+ pass
 
178
 
179
  def check_repeated_query_penalty():
180
  from tasks.task1.environment import Task1Environment
181
  from env.schemas import Action, ActionType
182
+ env = Task1Environment(); env.reset(seed=3)
 
183
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
184
  r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
185
+ assert r.reward.value == -0.40
 
186
 
187
+ def check_t2_repeated_penalty():
188
+ from tasks.task2.environment import Task2Environment
189
+ from env.schemas import Action, ActionType
190
+ env = Task2Environment(); env.reset(seed=3)
191
+ env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
192
+ r = env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
193
+ assert r.reward.value == -0.40
194
 
195
+ def check_task_placeholders():
196
+ from tasks.task3 import __all__ as t3
197
 
198
+ def check_dockerfile():
199
  import os
200
+ assert os.path.exists("Dockerfile")
201
+ with open("Dockerfile") as f: c = f.read()
202
+ assert "7860" in c
203
+ assert "uvicorn" in c or "CMD" in c
 
 
204
 
205
  def check_inference_script():
206
  import os
207
+ assert os.path.exists("inference.py")
208
+ with open("inference.py") as f: c = f.read()
209
+ assert "HF_TOKEN" in c
210
+ assert "API_BASE_URL" in c
211
+ assert "MODEL_NAME" in c
212
+ assert "task2" in c.lower() or "Task2" in c or "TASK 2" in c
213
+
214
+ def check_baseline_json():
 
 
 
215
  import os
216
+ if not os.path.exists("baseline_scores.json"): return
217
+ with open("baseline_scores.json") as f: data = json.load(f)
 
 
218
  assert "tasks" in data
219
+ for t in data["tasks"]:
220
+ assert 0.0 <= t["avg_grader_score"] <= 1.0
 
221
 
222
+ def check_similar_rule_lookup():
223
+ from data.data_loader import load_contracts, get_similar_rule
224
+ contracts = load_contracts()
225
+ sr = get_similar_rule(contracts, "SimpleVault", "withdraw")
226
+ assert sr is not None, "similar_rule should exist for withdraw"
227
+ assert "property_hint" in sr
228
+ assert "contract_name" in sr
229
+
230
+ # ── Runner ────────────────────────────────────────────────────────────────────
231
+
232
+ ALL_CHECKS = [
233
+ ("Python imports (T1 + T2)", check_imports),
234
+ ("openenv.yaml format", check_openenv_yaml),
235
+ ("Pydantic models (incl T2 actions)", check_pydantic_models),
236
+ ("Dataset: vuln + property entries", check_data_loading),
237
+ ("Task 1: reset / step / state", check_t1_env),
238
+ ("Task 2: reset + all 6 browse actions",check_t2_env),
239
+ ("Task 2: submit_property scores > 0", check_t2_env_submit),
240
+ ("Task 2: one submit only", check_t2_one_submit_only),
241
+ ("Task 1 grader: 0/0.5/1.0 rubric", check_t1_grader),
242
+ ("Task 2 grader: all 11 properties", check_t2_grader),
243
+ ("Reward shaping (multi-value)", check_reward_shaping),
244
+ ("T1 episode boundary", check_t1_episode_boundary),
245
+ ("T1 repeated query penalty (-0.40)", check_repeated_query_penalty),
246
+ ("T2 repeated query penalty (-0.40)", check_t2_repeated_penalty),
247
+ ("Task 3 placeholder exists", check_task_placeholders),
248
+ ("Dockerfile + port 7860", check_dockerfile),
249
+ ("inference.py: creds + Task 2 code", check_inference_script),
250
+ ("baseline_scores.json schema", check_baseline_json),
251
+ ("similar_rule data lookup", check_similar_rule_lookup),
252
+ ]
253
 
254
  def main():
255
+ print("=" * 64)
256
+ print("OpenEnv Pre-Submission Validation (Task 1 + Task 2)")
257
+ print("=" * 64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  print()
259
+ for name, fn in ALL_CHECKS:
260
  check(name, fn)
261
 
 
262
  passed = sum(1 for _, ok, _ in results if ok)
263
+ total = len(results)
264
+ failed = [(n, m) for n, ok, m in results if not ok]
265
 
266
+ print()
267
+ print("=" * 64)
268
  print(f"Results: {passed}/{total} checks passed")
 
269
  if failed:
270
  print("\nFailed checks:")
271
+ for n, m in failed:
272
+ print(f" {FAIL} {n}: {m}")
273
+ print("\n❌ VALIDATION FAILED β€” fix the issues above before submitting.")
 
274
  sys.exit(1)
275
  else:
276
+ print("\nβœ… ALL CHECKS PASSED β€” ready to submit!")
 
277
  sys.exit(0)
278
 
 
279
  if __name__ == "__main__":
280
  main()