ajaxwin commited on
Commit
08c19c7
Β·
0 Parent(s):

Inital Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ .venv
6
+ venv/
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ .DS_Store
11
+ baseline_scores.json
12
+ *.log
13
+ .pytest_cache/
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------
2
+ # Smart Contract Audit RL Environment
3
+ # Hugging Face Space β€” Docker runtime
4
+ # ---------------------------------------------------------------------------
5
+
6
+ FROM python:3.11-slim
7
+
8
+ WORKDIR /app
9
+
10
+ # System deps
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Python deps first (layer cache)
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy project
20
+ COPY . .
21
+
22
+ # Create empty __init__ files if missing (safety)
23
+ RUN touch env/__init__.py tasks/__init__.py tasks/task1/__init__.py \
24
+ tasks/task2/__init__.py tasks/task3/__init__.py \
25
+ data/__init__.py utils/__init__.py
26
+
27
+ # HF Spaces requires port 7860
28
+ EXPOSE 7860
29
+
30
+ # Healthcheck
31
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
32
+ CMD curl -f http://localhost:7860/health || exit 1
33
+
34
+ # Launch FastAPI
35
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
README.md ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smart Contract Audit RL Environment
2
+
3
+ > **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
4
+ > Agents learn to audit real-world Solidity contracts β€” finding vulnerabilities, discovering properties, and checking rule compliance β€” tasks that professional auditors perform daily.
5
+
6
+ [![OpenEnv Spec](https://img.shields.io/badge/OpenEnv-1.0-blue)](openenv.yaml)
7
+ [![HF Space](https://img.shields.io/badge/HuggingFace-Space-yellow)](https://huggingface.co/spaces)
8
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-brightgreen)](https://python.org)
9
+
10
+ ---
11
+
12
+ ## Motivation
13
+
14
+ Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws. This environment lets agents practice exactly that workflow β€” exploring contract code through targeted queries and submitting findings β€” providing a challenging, real-world benchmark for reasoning and code-understanding agents.
15
+
16
+ Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits (reentrancy, integer overflow, access control bypasses, etc.).
17
+
18
+ ---
19
+
20
+ ## Environment Description
21
+
22
+ The environment hosts **3 tasks** of increasing difficulty:
23
+
24
+ | Task | Name | Difficulty | Status |
25
+ |------|------|------------|--------|
26
+ | 1 | Targeted Vulnerability Detection | Medium | βœ… Active |
27
+ | 2 | Property Discovery | Hard | ⏳ Placeholder |
28
+ | 3 | Rule Checker | Easy | ⏳ Placeholder |
29
+
30
+ ### Task 1 β€” Targeted Vulnerability Detection *(Medium)*
31
+
32
+ **Setup:** The agent is shown a Solidity contract (4–6 functions). One function contains a critical vulnerability.
33
+
34
+ **Objective:** Identify the vulnerable function and describe the vulnerability type in 2–3 words.
35
+
36
+ **Episode lifecycle:**
37
+ 1. `reset()` β€” randomly selects one of 8 vulnerable (contract, function) pairs from the dataset
38
+ 2. Agent receives the contract name and description
39
+ 3. Agent explores using the action API (each action has a small cost)
40
+ 4. Agent calls `submit(function_name, vulnerability_type)` to end the episode
41
+ 5. Grader assigns 0.0–1.0 score
42
+
43
+ **Vulnerability types in the dataset:**
44
+ - Reentrancy
45
+ - Missing access control
46
+ - Integer overflow (Solidity <0.8)
47
+ - tx.origin authentication
48
+ - Front-running
49
+ - Timestamp dependence
50
+ - Denial of service (unbounded loop)
51
+ - Unchecked ERC-20 return value
52
+
53
+ ---
54
+
55
+ ### Task 2 β€” Property Discovery *(Hard)* [Placeholder]
56
+
57
+ Given a single Solidity function, the agent must discover its natural-language correctness property. Grading uses semantic similarity to the ground-truth property. *Implementation coming soon.*
58
+
59
+ ---
60
+
61
+ ### Task 3 β€” Rule Checker *(Easy)* [Placeholder]
62
+
63
+ Given a natural-language property and a contract, the agent must identify which function violates that property. *Implementation coming soon.*
64
+
65
+ ---
66
+
67
+ ## Action Space
68
+
69
+ All actions are described below. **Repeated identical queries cost βˆ’0.40.**
70
+
71
+ | Action | Key Params | Reward |
72
+ |--------|-----------|--------|
73
+ | `list_functions` | β€” | βˆ’0.05 |
74
+ | `get_function_code` | `function_name` | +0.05 (target) / βˆ’0.10 (other) |
75
+ | `get_function_summary` | `function_name` | +0.03 (target) / βˆ’0.05 (other) |
76
+ | `get_file_metadata` | β€” | βˆ’0.04 |
77
+ | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
78
+ | `get_call_graph` | β€” | βˆ’0.08 |
79
+ | `submit` | `function_name`, `vulnerability_type` | +5.0 / +1.0 / βˆ’1.5 |
80
+
81
+ **Submit scoring:**
82
+ - **+5.0** β€” correct function AND correct vulnerability keyword β†’ grader score = 1.0
83
+ - **+1.0** β€” correct function, unrecognised vulnerability type β†’ grader score = 0.5
84
+ - **βˆ’1.5** β€” wrong function β†’ grader score = 0.0
85
+
86
+ ---
87
+
88
+ ## Observation Space
89
+
90
+ Every `step()` and `reset()` returns an `Observation` object:
91
+
92
+ ```json
93
+ {
94
+ "task_id": "task1_vuln_detection",
95
+ "contract_name": "SimpleVault",
96
+ "contract_description": "An ETH vault that allows users to deposit and withdraw...",
97
+ "available_actions": ["list_functions", "get_function_code", ...],
98
+ "last_action": "get_function_code",
99
+ "last_action_result": "// withdraw\nfunction withdraw(uint256 amount) ...",
100
+ "step_count": 3,
101
+ "cumulative_reward": -0.05,
102
+ "done": false,
103
+ "extra": {
104
+ "solidity_version": "0.8.0",
105
+ "hint": "Identify the vulnerable function and its issue."
106
+ }
107
+ }
108
+ ```
109
+
110
+ ---
111
+
112
+ ## Project Structure
113
+
114
+ ```
115
+ smart-contract-env/
116
+ β”œβ”€β”€ data/
117
+ β”‚ β”œβ”€β”€ contracts.json # 4 contracts, 8 vulnerabilities
118
+ β”‚ └── data_loader.py # JSON parsing and episode sampling
119
+ β”œβ”€β”€ env/
120
+ β”‚ β”œβ”€β”€ base_env.py # Abstract OpenEnv base class
121
+ β”‚ └── schemas.py # Pydantic models (Observation, Action, Reward…)
122
+ β”œβ”€β”€ tasks/
123
+ β”‚ β”œβ”€β”€ task1/
124
+ β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 1 RL environment
125
+ β”‚ β”‚ └── grader.py # Deterministic 0.0–1.0 grader
126
+ β”‚ β”œβ”€β”€ task2/ # TODO: Property Discovery
127
+ β”‚ └── task3/ # TODO: Rule Checker
128
+ β”œβ”€β”€ utils/
129
+ β”œβ”€β”€ app.py # FastAPI server (OpenEnv HTTP interface)
130
+ β”œβ”€β”€ inference.py # Baseline inference script (OpenAI client)
131
+ β”œβ”€β”€ openenv.yaml # OpenEnv spec metadata
132
+ β”œβ”€β”€ Dockerfile
133
+ β”œβ”€β”€ requirements.txt
134
+ └── README.md
135
+ ```
136
+
137
+ ---
138
+
139
+ ## Setup & Usage
140
+
141
+ ### Option A β€” Run locally
142
+
143
+ ```bash
144
+ # 1. Clone and install
145
+ git clone <repo>
146
+ cd smart-contract-env
147
+ pip install -r requirements.txt
148
+
149
+ # 2. Start the server
150
+ python app.py
151
+ # β†’ http://localhost:7860
152
+ ```
153
+
154
+ ### Option B β€” Docker
155
+
156
+ ```bash
157
+ docker build -t sc-audit-env .
158
+ docker run -p 7860:7860 sc-audit-env
159
+ ```
160
+
161
+ ### Option C β€” Python (no server)
162
+
163
+ ```python
164
+ from tasks.task1.environment import Task1Environment
165
+ from env.schemas import Action, ActionType
166
+
167
+ env = Task1Environment()
168
+ result = env.reset(seed=42)
169
+ print(result.observation.contract_name)
170
+
171
+ action = Action(action_type=ActionType.LIST_FUNCTIONS)
172
+ step = env.step(action)
173
+ print(step.observation.last_action_result)
174
+ ```
175
+
176
+ ---
177
+
178
+ ## HTTP API
179
+
180
+ | Method | Endpoint | Description |
181
+ |--------|----------|-------------|
182
+ | `GET` | `/health` | Liveness probe |
183
+ | `GET` | `/tasks` | List all tasks |
184
+ | `POST` | `/reset` | Start new episode |
185
+ | `POST` | `/step` | Take one action |
186
+ | `GET` | `/state` | Debug: internal state |
187
+ | `GET` | `/action_space` | Action space definition |
188
+ | `GET` | `/observation_space` | Observation space definition |
189
+
190
+ **Example session:**
191
+
192
+ ```bash
193
+ # Reset
194
+ curl -X POST http://localhost:7860/reset \
195
+ -H "Content-Type: application/json" \
196
+ -d '{"task_id": "task1_vuln_detection", "seed": 42}'
197
+
198
+ # List functions
199
+ curl -X POST "http://localhost:7860/step" \
200
+ -H "Content-Type: application/json" \
201
+ -d '{"action_type": "list_functions", "params": {}}'
202
+
203
+ # Submit answer
204
+ curl -X POST "http://localhost:7860/step" \
205
+ -H "Content-Type: application/json" \
206
+ -d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
207
+ ```
208
+
209
+ ---
210
+
211
+ ## Running the Baseline
212
+
213
+ ```bash
214
+ export API_BASE_URL="https://api.openai.com/v1"
215
+ export MODEL_NAME="gpt-4o-mini"
216
+ export HF_TOKEN="sk-..."
217
+
218
+ python inference.py
219
+ ```
220
+
221
+ Outputs results to stdout and writes `baseline_scores.json`.
222
+
223
+ **Expected baseline scores (gpt-4o-mini, 3 episodes):**
224
+
225
+ | Task | Avg Grader Score | Notes |
226
+ |------|-----------------|-------|
227
+ | Task 1 | ~0.67 | Medium difficulty; model identifies common vulns well |
228
+ | Task 2 | 0.00 | Placeholder |
229
+ | Task 3 | 0.00 | Placeholder |
230
+
231
+ ---
232
+
233
+ ## Baseline Scores
234
+
235
+ ```json
236
+ {
237
+ "model": "gpt-4o-mini",
238
+ "tasks": [
239
+ {
240
+ "task_id": "task1_vuln_detection",
241
+ "avg_grader_score": 0.667,
242
+ "avg_cumulative_reward": 2.14
243
+ },
244
+ { "task_id": "task2_property_discovery", "avg_grader_score": 0.0 },
245
+ { "task_id": "task3_rule_checker", "avg_grader_score": 0.0 }
246
+ ],
247
+ "overall_avg_score": 0.667
248
+ }
249
+ ```
250
+
251
+ ---
252
+
253
+ ## Grader Details
254
+
255
+ The Task 1 grader is **fully deterministic**:
256
+
257
+ 1. **Function name check** β€” case-insensitive exact match against the ground-truth vulnerable function. Wrong function β†’ score = 0.0 immediately.
258
+
259
+ 2. **Vulnerability type check** β€” checks whether the submitted string contains any accepted keyword from a predefined keyword table (e.g. `"reentrancy"` table includes: `reentrancy`, `re-entrancy`, `reentrant`, `recursive call`). Match β†’ 1.0; no match β†’ 0.5.
260
+
261
+ Scores map to terminal rewards: 1.0 β†’ +5, 0.5 β†’ +1, 0.0 β†’ βˆ’1.5.
262
+
263
+ ---
264
+
265
+ ## OpenEnv Spec Compliance
266
+
267
+ - βœ… Typed `Observation`, `Action`, `Reward` Pydantic models
268
+ - βœ… `step(action) β†’ StepResult(observation, reward, done, info)`
269
+ - βœ… `reset() β†’ ResetResult(observation, info)`
270
+ - βœ… `state() β†’ StateResult`
271
+ - βœ… `openenv.yaml` metadata
272
+ - βœ… 3 tasks defined (1 active, 2 placeholders)
273
+ - βœ… Grader scores in [0.0, 1.0]
274
+ - βœ… Shaped rewards (not just binary)
275
+ - βœ… Dockerfile + HF Space deployment
276
+ - βœ… Baseline `inference.py` using OpenAI client
277
+
278
+ ---
279
+
280
+ ## Deploying to Hugging Face Spaces
281
+
282
+ 1. Create a new **Docker** Space on [huggingface.co/spaces](https://huggingface.co/spaces)
283
+ 2. Set the tag `openenv` in the Space metadata
284
+ 3. Push this repository:
285
+
286
+ ```bash
287
+ git remote add hf https://huggingface.co/spaces/<your-username>/<space-name>
288
+ git push hf main
289
+ ```
290
+
291
+ The Space will build the Docker image and serve the FastAPI app on port 7860.
292
+
293
+ ---
294
+
295
+ ## License
296
+
297
+ MIT β€” see `LICENSE`.
298
+
299
+ ## Data Attribution
300
+
301
+ Contract vulnerability patterns inspired by and adapted from **Certora** audit findings on production DeFi protocols.
SPACES_README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Smart Contract Audit RL Environment
3
+ emoji: πŸ”
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ tags:
9
+ - openenv
10
+ - reinforcement-learning
11
+ - smart-contracts
12
+ - solidity
13
+ - security
14
+ - evaluation
15
+ license: mit
16
+ short_description: OpenEnv RL environment for smart contract security auditing
17
+ ---
18
+
19
+ # Smart Contract Audit RL Environment
20
+
21
+ > OpenEnv-compliant RL environment for Solidity security analysis.
22
+
23
+ This Space exposes the full OpenEnv HTTP interface for **Task 1: Targeted Vulnerability Detection**.
24
+ Agents explore Solidity contracts using a structured action API and identify vulnerable functions.
25
+
26
+ ## Quick start
27
+
28
+ ```bash
29
+ # Reset β€” start a new episode
30
+ curl -X POST $SPACE_URL/reset \
31
+ -H "Content-Type: application/json" \
32
+ -d '{"task_id": "task1_vuln_detection", "seed": 42}'
33
+
34
+ # Step β€” list contract functions
35
+ curl -X POST $SPACE_URL/step \
36
+ -H "Content-Type: application/json" \
37
+ -d '{"action_type": "list_functions", "params": {}}'
38
+
39
+ # Submit answer
40
+ curl -X POST $SPACE_URL/step \
41
+ -H "Content-Type: application/json" \
42
+ -d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
43
+ ```
44
+
45
+ ## Endpoints
46
+
47
+ | Method | Path | Description |
48
+ |--------|------|-------------|
49
+ | GET | `/health` | Liveness probe |
50
+ | GET | `/tasks` | All tasks + status |
51
+ | POST | `/reset` | New episode |
52
+ | POST | `/step` | Take action |
53
+ | GET | `/state` | Debug state |
54
+ | GET | `/action_space` | Action schema |
55
+ | GET | `/observation_space` | Observation schema |
56
+
57
+ See the full [README](README.md) for detailed documentation.
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+ ------
4
+ FastAPI server exposing the OpenEnv HTTP interface.
5
+
6
+ Endpoints:
7
+ POST /reset – start a new episode
8
+ POST /step – take one action
9
+ GET /state – inspect internal state (debugging)
10
+ GET /tasks – list available tasks
11
+ GET /health – liveness probe
12
+ GET /action_space – action space description
13
+ GET /observation_space – observation space description
14
+
15
+ Sessions are keyed by a UUID passed as the `session_id` query parameter.
16
+ If omitted, a default single-session is used (fine for sequential runs).
17
+ """
18
+
19
+ import uuid
20
+ from typing import Dict, Optional
21
+
22
+ from fastapi import FastAPI, HTTPException, Query
23
+ from fastapi.responses import JSONResponse
24
+ from pydantic import BaseModel
25
+
26
+ from env.schemas import Action, ActionType, TaskInfo
27
+ from tasks.task1.environment import Task1Environment
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # App init
31
+ # ---------------------------------------------------------------------------
32
+
33
+ app = FastAPI(
34
+ title="Smart Contract Audit RL Environment",
35
+ description=(
36
+ "OpenEnv-compliant reinforcement learning environment for smart contract "
37
+ "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
38
+ ),
39
+ version="1.0.0",
40
+ )
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Session management
44
+ # ---------------------------------------------------------------------------
45
+
46
+ _sessions: Dict[str, Task1Environment] = {}
47
+ DEFAULT_SESSION = "default"
48
+
49
+
50
+ def _get_or_create_session(session_id: str, task_id: str = "task1_vuln_detection") -> Task1Environment:
51
+ if session_id not in _sessions:
52
+ env = _create_env(task_id)
53
+ _sessions[session_id] = env
54
+ return _sessions[session_id]
55
+
56
+
57
+ def _create_env(task_id: str) -> Task1Environment:
58
+ if task_id == "task1_vuln_detection":
59
+ return Task1Environment()
60
+ # TODO: elif task_id == "task2_property_discovery": return Task2Environment()
61
+ # TODO: elif task_id == "task3_rule_checker": return Task3Environment()
62
+ raise HTTPException(
63
+ status_code=400,
64
+ detail=f"Unknown task_id '{task_id}'. Available: ['task1_vuln_detection']",
65
+ )
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Request/response models
70
+ # ---------------------------------------------------------------------------
71
+
72
+ class ResetRequest(BaseModel):
73
+ task_id: str = "task1_vuln_detection"
74
+ seed: Optional[int] = None
75
+
76
+
77
+ class StepRequest(BaseModel):
78
+ action_type: str
79
+ params: dict = {}
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Routes
84
+ # ---------------------------------------------------------------------------
85
+
86
+ @app.get("/health")
87
+ def health():
88
+ """Liveness probe β€” returns 200 OK."""
89
+ return {"status": "ok", "version": "1.0.0"}
90
+
91
+
92
+ @app.get("/tasks")
93
+ def list_tasks():
94
+ """List all available tasks."""
95
+ tasks = [
96
+ TaskInfo(
97
+ task_id="task1_vuln_detection",
98
+ name="Targeted Vulnerability Detection",
99
+ difficulty="medium",
100
+ description=(
101
+ "Given a Solidity contract, identify the vulnerable function "
102
+ "and describe the vulnerability type in 2-3 words."
103
+ ),
104
+ status="active",
105
+ ),
106
+ TaskInfo(
107
+ task_id="task2_property_discovery",
108
+ name="Property Discovery",
109
+ difficulty="hard",
110
+ description=(
111
+ "Given a Solidity function, discover the natural-language property "
112
+ "that describes its correct behaviour."
113
+ ),
114
+ status="placeholder",
115
+ ),
116
+ TaskInfo(
117
+ task_id="task3_rule_checker",
118
+ name="Rule Checker",
119
+ difficulty="easy",
120
+ description=(
121
+ "Given a property in English, identify which function in the contract "
122
+ "violates that property."
123
+ ),
124
+ status="placeholder",
125
+ ),
126
+ ]
127
+ return {"tasks": [t.model_dump() for t in tasks]}
128
+
129
+
130
+ @app.post("/reset")
131
+ def reset(
132
+ body: ResetRequest,
133
+ session_id: str = Query(default=DEFAULT_SESSION),
134
+ ):
135
+ """Reset the environment and start a new episode."""
136
+ env = _create_env(body.task_id)
137
+ _sessions[session_id] = env
138
+ result = env.reset(seed=body.seed)
139
+ return result.model_dump()
140
+
141
+
142
+ @app.post("/step")
143
+ def step(
144
+ body: StepRequest,
145
+ session_id: str = Query(default=DEFAULT_SESSION),
146
+ ):
147
+ """Apply an action and advance the episode."""
148
+ env = _sessions.get(session_id)
149
+ if env is None:
150
+ raise HTTPException(
151
+ status_code=400,
152
+ detail=f"No active session '{session_id}'. Call /reset first.",
153
+ )
154
+ try:
155
+ action_type = ActionType(body.action_type)
156
+ except ValueError:
157
+ raise HTTPException(
158
+ status_code=400,
159
+ detail=f"Unknown action_type '{body.action_type}'. "
160
+ f"Valid: {[a.value for a in ActionType]}",
161
+ )
162
+ action = Action(action_type=action_type, params=body.params)
163
+ try:
164
+ result = env.step(action)
165
+ except RuntimeError as e:
166
+ raise HTTPException(status_code=409, detail=str(e))
167
+ return result.model_dump()
168
+
169
+
170
+ @app.get("/state")
171
+ def state(session_id: str = Query(default=DEFAULT_SESSION)):
172
+ """Return current internal state (for debugging; not for agents)."""
173
+ env = _sessions.get(session_id)
174
+ if env is None:
175
+ raise HTTPException(
176
+ status_code=400,
177
+ detail=f"No active session '{session_id}'. Call /reset first.",
178
+ )
179
+ return env.state().model_dump()
180
+
181
+
182
+ @app.get("/action_space")
183
+ def action_space(task_id: str = "task1_vuln_detection"):
184
+ """Describe the action space for a task."""
185
+ if task_id == "task1_vuln_detection":
186
+ return {
187
+ "task_id": task_id,
188
+ "actions": [
189
+ {
190
+ "type": "list_functions",
191
+ "params": {},
192
+ "reward": -0.05,
193
+ "description": "List all function names in the contract",
194
+ },
195
+ {
196
+ "type": "get_function_code",
197
+ "params": {"function_name": "string"},
198
+ "reward": "+0.05 (target fn) / -0.10 (wrong fn)",
199
+ "description": "Retrieve the full Solidity code of a function",
200
+ },
201
+ {
202
+ "type": "get_function_summary",
203
+ "params": {"function_name": "string"},
204
+ "reward": "+0.03 (target fn) / -0.05 (wrong fn)",
205
+ "description": "Retrieve the NatSpec comment/summary of a function",
206
+ },
207
+ {
208
+ "type": "get_file_metadata",
209
+ "params": {},
210
+ "reward": -0.04,
211
+ "description": "Retrieve contract-level metadata (version, author, description)",
212
+ },
213
+ {
214
+ "type": "get_state_variable",
215
+ "params": {"variable_name": "string (optional)"},
216
+ "reward": -0.05,
217
+ "description": "Retrieve a state variable or list all variables",
218
+ },
219
+ {
220
+ "type": "get_call_graph",
221
+ "params": {},
222
+ "reward": -0.08,
223
+ "description": "Retrieve the function call graph",
224
+ },
225
+ {
226
+ "type": "submit",
227
+ "params": {
228
+ "function_name": "string",
229
+ "vulnerability_type": "string",
230
+ },
231
+ "reward": "+5.0 (correct) / +1.0 (right fn, wrong vuln) / -1.5 (wrong)",
232
+ "description": "Submit your final answer. Ends the episode.",
233
+ },
234
+ ],
235
+ }
236
+ return {"error": f"No action space defined for task '{task_id}'"}
237
+
238
+
239
+ @app.get("/observation_space")
240
+ def observation_space():
241
+ """Describe the observation space."""
242
+ return {
243
+ "type": "object",
244
+ "fields": {
245
+ "task_id": "string – active task identifier",
246
+ "contract_name": "string – name of the Solidity contract",
247
+ "contract_description": "string – what the contract does",
248
+ "available_actions": "list[string] – valid action types",
249
+ "last_action": "string|null – the previous action type",
250
+ "last_action_result": "string|null – human-readable result of last action",
251
+ "step_count": "int – steps taken so far",
252
+ "cumulative_reward": "float – running reward total",
253
+ "done": "bool – True when episode is over",
254
+ "extra": "object – task-specific hints and metadata",
255
+ },
256
+ }
257
+
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # Entry point
261
+ # ---------------------------------------------------------------------------
262
+
263
+ if __name__ == "__main__":
264
+ import uvicorn
265
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
data/Template.json ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "contract_name": "ExampleContract",
3
+ "file_name": "ExampleContract.sol",
4
+
5
+ "metadata": {
6
+ "license": "MIT",
7
+ "solidity_version": "0.8.0",
8
+ "description": "Example contract demonstrating the template structure",
9
+ "author": "Example Author"
10
+ },
11
+
12
+ "state_variables": [
13
+ {
14
+ "name": "owner",
15
+ "type": "address",
16
+ "visibility": "public",
17
+ "mutability": "",
18
+ "description": "Address of the contract owner"
19
+ },
20
+ {
21
+ "name": "balances",
22
+ "type": "mapping(address => uint256)",
23
+ "visibility": "internal",
24
+ "mutability": "",
25
+ "description": "User token balances"
26
+ }
27
+ ],
28
+
29
+ "functions": [
30
+ {
31
+ "name": "transfer",
32
+ "signature": "transfer(address to, uint256 amount)",
33
+ "code": "function transfer(address to, uint256 amount) external returns (bool) {\n require(to != address(0), \"INVALID_RECIPIENT\");\n require(balances[msg.sender] >= amount, \"INSUFFICIENT_BALANCE\");\n balances[msg.sender] -= amount;\n balances[to] += amount;\n emit Transfer(msg.sender, to, amount);\n return true;\n}",
34
+ "comment": "Transfers tokens from caller to recipient",
35
+ "visibility": "external",
36
+ "modifiers": [],
37
+ "parameters": [
38
+ {
39
+ "name": "to",
40
+ "type": "address",
41
+ "description": "Recipient address"
42
+ },
43
+ {
44
+ "name": "amount",
45
+ "type": "uint256",
46
+ "description": "Amount to transfer"
47
+ }
48
+ ],
49
+ "returns": "bool - true on success",
50
+ "output_property": "Decreases caller's balance by amount, increases recipient's balance by amount. Emits Transfer event. Reverts if recipient is zero address or caller has insufficient balance.",
51
+ "events": ["Transfer"],
52
+ "vulnerable": false,
53
+ "vulnerability_details": null,
54
+ "rule_broken_english": null,
55
+ "rule_broken_specs": null
56
+ },
57
+ {
58
+ "name": "withdraw",
59
+ "signature": "withdraw(uint256 amount)",
60
+ "code": "function withdraw(uint256 amount) external {\n require(balances[msg.sender] >= amount, \"INSUFFICIENT_BALANCE\");\n balances[msg.sender] -= amount;\n (bool success, ) = msg.sender.call{value: amount}(\"\");\n require(success, \"TRANSFER_FAILED\");\n}",
61
+ "comment": "Withdraws ETH from contract",
62
+ "visibility": "external",
63
+ "modifiers": [],
64
+ "parameters": [
65
+ {
66
+ "name": "amount",
67
+ "type": "uint256",
68
+ "description": "Amount to withdraw"
69
+ }
70
+ ],
71
+ "returns": "",
72
+ "output_property": "Transfers amount ETH to caller. Reverts if insufficient balance or ETH transfer fails.",
73
+ "events": [],
74
+ "vulnerable": true,
75
+ "vulnerability_details": {
76
+ "issue": "Reentrancy vulnerability",
77
+ "severity": "High",
78
+ "description": "The withdraw function updates balance after making an external call, allowing reentrancy attacks",
79
+ "mitigation": "Use checks-effects-interactions pattern: update balance before external call"
80
+ },
81
+ "rule_broken_english": "When a user withdraws x amount of ETH, the user's balance should decrease by x. Due to reentrancy, an attacker can call withdraw recursively before balance is updated, draining more than their balance.",
82
+ "rule_broken_specs": "Pre-condition: User has balance B. Operation: withdraw(amount). Expected post-condition: User balance = B - amount. Actual vulnerability: Reentrant calls allow multiple withdrawals before balance update, resulting in user balance = B - (n * amount) where n > 1, violating the expected post-condition."
83
+ }
84
+ ],
85
+
86
+ "structs": [
87
+ {
88
+ "name": "MintLocalVars",
89
+ "definition": "struct MintLocalVars {\n uint256 previousSupply;\n uint256 nextSupply;\n uint256 amountInRay;\n uint256 newRate;\n uint256 currentAvgRate;\n}",
90
+ "description": "Local variables used in mint function to avoid stack too deep errors"
91
+ }
92
+ ],
93
+
94
+ "modifiers": [
95
+ {
96
+ "name": "onlyOwner",
97
+ "definition": "require(msg.sender == owner, \"NOT_OWNER\");",
98
+ "purpose": "Restricts function access to contract owner only"
99
+ },
100
+ {
101
+ "name": "nonReentrant",
102
+ "definition": "Inherited from OpenZeppelin ReentrancyGuard",
103
+ "purpose": "Prevents reentrancy attacks by using a mutex lock"
104
+ }
105
+ ],
106
+
107
+ "inheritance": [
108
+ "ERC20",
109
+ "Ownable"
110
+ ],
111
+
112
+ "call_graph": {
113
+ "constructor": [
114
+ "ERC20.constructor()"
115
+ ],
116
+ "transfer": [
117
+ "emit Transfer()"
118
+ ],
119
+ "withdraw": [
120
+ "msg.sender.call()"
121
+ ]
122
+ },
123
+
124
+ "audit_issues": [
125
+ {
126
+ "function": "withdraw",
127
+ "issue": "Reentrancy vulnerability",
128
+ "severity": "High",
129
+ "description": "The withdraw function updates state after making an external call, allowing reentrancy attacks where an attacker can recursively call withdraw before the balance is updated",
130
+ "status": "Fixed",
131
+ "mitigation": "Moved balance update before external call (checks-effects-interactions pattern)",
132
+ "rule_broken_english": "When a user withdraws x amount, the user's balance should decrease by x. Due to reentrancy, an attacker can withdraw multiple times before balance updates, draining more than their balance.",
133
+ "rule_broken_specs": "Pre-condition: User balance = B. Operation: withdraw(amount). Expected: User balance = B - amount. Actual: Reentrant calls allow user balance = B - (n * amount) where n > 1."
134
+ }
135
+ ],
136
+
137
+ "events": [
138
+ {
139
+ "name": "Transfer",
140
+ "parameters": "address indexed from, address indexed to, uint256 amount",
141
+ "description": "Emitted when tokens are transferred"
142
+ },
143
+ {
144
+ "name": "Withdrawal",
145
+ "parameters": "address indexed user, uint256 amount",
146
+ "description": "Emitted when ETH is withdrawn"
147
+ }
148
+ ]
149
+ }
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # data package
data/contracts.json ADDED
The diff for this file is too large to render. See raw diff
 
data/data_loader.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_loader.py
3
+ --------------
4
+ Loads and indexes smart contract data from JSON files.
5
+ Each contract is parsed into a structured dict; vulnerable functions
6
+ are indexed for fast lookup by Task 1.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ import random
12
+ from typing import Any, Dict, List, Optional, Tuple
13
+
14
+
15
+ DATA_DIR = os.path.join(os.path.dirname(__file__))
16
+ DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
17
+
18
+
19
+ def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
20
+ """Load and return all contracts from the JSON dataset."""
21
+ with open(path, "r") as f:
22
+ return json.load(f)
23
+
24
+
25
+ def get_all_vulnerable_entries(
26
+ contracts: List[Dict[str, Any]],
27
+ ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
28
+ """
29
+ Returns a flat list of (contract, function) pairs where
30
+ function['vulnerable'] is True.
31
+ Used by Task 1 to populate the episode pool.
32
+ """
33
+ entries = []
34
+ for contract in contracts:
35
+ for fn in contract.get("functions", []):
36
+ if fn.get("vulnerable", False):
37
+ entries.append((contract, fn))
38
+ return entries
39
+
40
+
41
+ def sample_episode(
42
+ contracts: List[Dict[str, Any]],
43
+ rng: Optional[random.Random] = None,
44
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
45
+ """
46
+ Randomly selects one (contract, vulnerable_function) pair.
47
+ Returns the contract dict and the target function dict.
48
+ """
49
+ if rng is None:
50
+ rng = random.Random()
51
+ entries = get_all_vulnerable_entries(contracts)
52
+ if not entries:
53
+ raise ValueError("No vulnerable functions found in dataset.")
54
+ return rng.choice(entries)
55
+
56
+
57
+ def get_function_by_name(
58
+ contract: Dict[str, Any], name: str
59
+ ) -> Optional[Dict[str, Any]]:
60
+ """Case-insensitive function lookup within a contract."""
61
+ for fn in contract.get("functions", []):
62
+ if fn["name"].lower() == name.lower():
63
+ return fn
64
+ return None
65
+
66
+
67
+ def get_state_variable_by_name(
68
+ contract: Dict[str, Any], name: str
69
+ ) -> Optional[Dict[str, Any]]:
70
+ """Case-insensitive state variable lookup."""
71
+ for sv in contract.get("state_variables", []):
72
+ if sv["name"].lower() == name.lower():
73
+ return sv
74
+ return None
75
+
76
+
77
+ def list_function_names(contract: Dict[str, Any]) -> List[str]:
78
+ """Return all function names in the contract."""
79
+ return [fn["name"] for fn in contract.get("functions", [])]
80
+
81
+
82
+ def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
83
+ """Return all state variable names."""
84
+ return [sv["name"] for sv in contract.get("state_variables", [])]
demo.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ demo.py
3
+ -------
4
+ Interactive demo of the Smart Contract Audit RL Environment.
5
+ Shows Task 1 end-to-end with a human-readable step-by-step walkthrough.
6
+
7
+ Usage:
8
+ python demo.py # interactive mode
9
+ python demo.py --auto # auto-run with built-in demo agent (no input needed)
10
+ python demo.py --auto --seed 42
11
+
12
+ Great for hackathon demos β€” run this live to show the environment in action.
13
+ """
14
+
15
+ import argparse
16
+ import sys
17
+ import textwrap
18
+ import time
19
+
20
+ # ─────────────────────────────────────────────────────────────────────────────
21
+ # Imports
22
+ # ─────────────────────────────────────────────────────────────────────────────
23
+ from tasks.task1.environment import Task1Environment
24
+ from env.schemas import Action, ActionType
25
+
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+ # ANSI colours (falls back gracefully on Windows)
28
+ # ─────────────────────────────────────────────────────────────────────────────
29
+ try:
30
+ import os
31
+ if os.name == "nt":
32
+ raise ImportError
33
+ BOLD = "\033[1m"
34
+ DIM = "\033[2m"
35
+ GREEN = "\033[92m"
36
+ YELLOW = "\033[93m"
37
+ RED = "\033[91m"
38
+ CYAN = "\033[96m"
39
+ RESET = "\033[0m"
40
+ except ImportError:
41
+ BOLD = DIM = GREEN = YELLOW = RED = CYAN = RESET = ""
42
+
43
+ DIVIDER = f"{DIM}{'─' * 64}{RESET}"
44
+
45
+
46
+ # ─────────────────────────────────────────────────────────────────────────────
47
+ # Pretty printers
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+
50
+ def banner():
51
+ print()
52
+ print(f"{BOLD}{CYAN}╔══════════════════════════════════════════════════════════╗")
53
+ print(f"β•‘ Smart Contract Audit RL Environment Β· Task 1 Demo β•‘")
54
+ print(f"β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•{RESET}")
55
+ print()
56
+
57
+
58
+ def print_observation(obs):
59
+ print(DIVIDER)
60
+ print(f"{BOLD}Contract :{RESET} {obs.contract_name}")
61
+ print(f"{BOLD}Desc :{RESET} {textwrap.fill(obs.contract_description, 72, subsequent_indent=' ' * 11)}")
62
+ print(f"{BOLD}Step :{RESET} {obs.step_count} "
63
+ f"{BOLD}Reward :{RESET} {obs.cumulative_reward:+.2f}")
64
+ if obs.last_action:
65
+ colour = GREEN if obs.cumulative_reward >= 0 else YELLOW
66
+ result = obs.last_action_result or ""
67
+ print(f"{BOLD}Last :{RESET} [{obs.last_action}]")
68
+ for line in textwrap.wrap(result, 72):
69
+ print(f" {colour}{line}{RESET}")
70
+ print(DIVIDER)
71
+
72
+
73
+ def print_action_menu():
74
+ actions = [
75
+ ("1", "list_functions", "{}", "List all functions"),
76
+ ("2", "get_function_code", '{"function_name": "???"}', "Get function source code"),
77
+ ("3", "get_function_summary", '{"function_name": "???"}', "Get NatSpec comment"),
78
+ ("4", "get_file_metadata", "{}", "Get file metadata"),
79
+ ("5", "get_state_variable", '{"variable_name": "???"}', "Get state variable"),
80
+ ("6", "get_call_graph", "{}", "Get call graph"),
81
+ ("7", "submit", '{"function_name":"???","vulnerability_type":"???"}', "Submit answer"),
82
+ ]
83
+ print(f"\n{BOLD}Available actions:{RESET}")
84
+ for num, at, _, desc in actions:
85
+ print(f" {CYAN}{num}{RESET} {at:25s} {DIM}{desc}{RESET}")
86
+ print()
87
+
88
+
89
+ def prompt_action(env) -> Action:
90
+ """Prompt the user to choose and configure an action interactively."""
91
+ action_map = {
92
+ "1": ActionType.LIST_FUNCTIONS,
93
+ "2": ActionType.GET_FUNCTION_CODE,
94
+ "3": ActionType.GET_FUNCTION_SUMMARY,
95
+ "4": ActionType.GET_FILE_METADATA,
96
+ "5": ActionType.GET_STATE_VARIABLE,
97
+ "6": ActionType.GET_CALL_GRAPH,
98
+ "7": ActionType.SUBMIT,
99
+ }
100
+ while True:
101
+ choice = input(f"{BOLD}Choose action (1-7): {RESET}").strip()
102
+ if choice not in action_map:
103
+ print(f" {YELLOW}Enter a number 1–7{RESET}")
104
+ continue
105
+ at = action_map[choice]
106
+ params = {}
107
+
108
+ if at in (ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_SUMMARY):
109
+ fn = input(" Function name: ").strip()
110
+ params = {"function_name": fn}
111
+
112
+ elif at == ActionType.GET_STATE_VARIABLE:
113
+ var = input(" Variable name (leave blank to list all): ").strip()
114
+ if var:
115
+ params = {"variable_name": var}
116
+
117
+ elif at == ActionType.SUBMIT:
118
+ fn = input(" Vulnerable function name: ").strip()
119
+ vuln = input(" Vulnerability type (2-3 words): ").strip()
120
+ params = {"function_name": fn, "vulnerability_type": vuln}
121
+
122
+ return Action(action_type=at, params=params)
123
+
124
+
125
+ # ─────────────────────────────────────────────────────────────────────────────
126
+ # Scripted demo agent
127
+ # ─────────────────────────────────────────────────────────────────────────────
128
+
129
+ DEMO_SCRIPTS = {
130
+ # seed β†’ list of (ActionType, params, commentary)
131
+ 42: [
132
+ (ActionType.GET_FILE_METADATA, {},
133
+ "First, get high-level contract info to understand the domain."),
134
+ (ActionType.LIST_FUNCTIONS, {},
135
+ "List functions to understand the attack surface."),
136
+ (ActionType.GET_FUNCTION_SUMMARY, {"function_name": "emergencyDrain"},
137
+ "emergencyDrain sounds dangerous β€” check what it's supposed to do."),
138
+ (ActionType.GET_FUNCTION_CODE, {"function_name": "emergencyDrain"},
139
+ "Inspect the code β€” no onlyOwner modifier! Anyone can drain the vault."),
140
+ (ActionType.SUBMIT, {"function_name": "emergencyDrain", "vulnerability_type": "missing access control"},
141
+ "Confident: missing access control. Submitting!"),
142
+ ],
143
+ 7: [
144
+ (ActionType.LIST_FUNCTIONS, {},
145
+ "Start by surveying all functions."),
146
+ (ActionType.GET_FUNCTION_SUMMARY, {"function_name": "finalize"},
147
+ "finalize β€” what does this auction close-out function do?"),
148
+ (ActionType.GET_FUNCTION_CODE, {"function_name": "finalize"},
149
+ "Uses block.timestamp for deadline check β€” miners can manipulate this."),
150
+ (ActionType.SUBMIT, {"function_name": "finalize", "vulnerability_type": "timestamp dependence"},
151
+ "Classic timestamp manipulation. Submitting."),
152
+ ],
153
+ }
154
+
155
+ DEFAULT_DEMO_SEED = 42
156
+
157
+
158
+ def run_auto_demo(seed: int, delay: float = 0.9):
159
+ """Run the scripted demo agent with printed commentary."""
160
+ script = DEMO_SCRIPTS.get(seed)
161
+ if script is None:
162
+ # Generic fallback: list, code of first suspicious fn, submit
163
+ print(f"{YELLOW}No pre-written script for seed {seed}. Running generic agent.{RESET}\n")
164
+ script = [
165
+ (ActionType.LIST_FUNCTIONS, {}, "Listing all functions first."),
166
+ (ActionType.GET_FILE_METADATA, {}, "Checking contract metadata."),
167
+ ]
168
+
169
+ env = Task1Environment()
170
+ result = env.reset(seed=seed)
171
+ obs = result.observation
172
+
173
+ banner()
174
+ print(f"{BOLD}Mode:{RESET} Automated demo | {BOLD}Seed:{RESET} {seed}\n")
175
+ print_observation(obs)
176
+
177
+ for at, params, commentary in script:
178
+ time.sleep(delay)
179
+ print(f"\n{CYAN}β–Ά Agent thinking:{RESET} {commentary}")
180
+ time.sleep(delay * 0.5)
181
+ action = Action(action_type=at, params=params)
182
+ step = env.step(action)
183
+ obs = step.observation
184
+ print_observation(obs)
185
+
186
+ if step.done:
187
+ _print_episode_summary(obs)
188
+ return
189
+
190
+ # Episode not finished β€” shouldn't happen with a good script
191
+ state = env.state()
192
+ print(f"\n{YELLOW}Episode not completed (no submit action in script). "
193
+ f"Target was: {state.target_function}{RESET}")
194
+
195
+
196
+ # ─────────────────────────────────────────────────────────────────────────────
197
+ # Interactive mode
198
+ # ─────────────────────────────────────────────────────────────────────────────
199
+
200
+ def run_interactive(seed: int = None):
201
+ env = Task1Environment()
202
+ seed = seed or 42
203
+ result = env.reset(seed=seed)
204
+ obs = result.observation
205
+
206
+ banner()
207
+ print(f"{BOLD}Mode:{RESET} Interactive | {BOLD}Seed:{RESET} {seed}")
208
+ print(f"{DIM}Tip: Start with list_functions and get_file_metadata.{RESET}\n")
209
+ print_observation(obs)
210
+
211
+ while not obs.done:
212
+ print_action_menu()
213
+ try:
214
+ action = prompt_action(env)
215
+ except (KeyboardInterrupt, EOFError):
216
+ print(f"\n{YELLOW}Demo interrupted.{RESET}")
217
+ break
218
+
219
+ step = env.step(action)
220
+ obs = step.observation
221
+ print()
222
+ print_observation(obs)
223
+
224
+ if step.done:
225
+ _print_episode_summary(obs)
226
+ break
227
+
228
+ _offer_replay()
229
+
230
+
231
+ def _print_episode_summary(obs):
232
+ print()
233
+ print(f"{BOLD}{'═' * 64}{RESET}")
234
+ reward = obs.cumulative_reward
235
+ colour = GREEN if reward > 0 else RED
236
+ print(f"{BOLD}Episode complete!{RESET}")
237
+ print(f" Steps taken : {obs.step_count}")
238
+ print(f" Total reward : {colour}{reward:+.2f}{RESET}")
239
+ last = obs.last_action_result or ""
240
+ if "βœ…" in last:
241
+ print(f" {GREEN}Perfect score β€” full marks!{RESET}")
242
+ elif "⚠️" in last:
243
+ print(f" {YELLOW}Partial credit β€” right function, imprecise vulnerability type.{RESET}")
244
+ else:
245
+ print(f" {RED}Incorrect β€” better luck next episode.{RESET}")
246
+ print(f"{BOLD}{'═' * 64}{RESET}\n")
247
+
248
+
249
+ def _offer_replay():
250
+ try:
251
+ again = input("Play again? (y/n): ").strip().lower()
252
+ if again == "y":
253
+ run_interactive()
254
+ except (KeyboardInterrupt, EOFError):
255
+ pass
256
+
257
+
258
+ # ─────────────────────────────────────────────────────────────────────────────
259
+ # Entry point
260
+ # ─────────────────────────────────────────────────────────────────────────────
261
+
262
+ def main():
263
+ parser = argparse.ArgumentParser(
264
+ description="Smart Contract Audit RL Environment β€” Demo"
265
+ )
266
+ parser.add_argument(
267
+ "--auto", action="store_true",
268
+ help="Run the scripted demo agent (no keyboard input required)"
269
+ )
270
+ parser.add_argument(
271
+ "--seed", type=int, default=DEFAULT_DEMO_SEED,
272
+ help="Episode seed (default: 42)"
273
+ )
274
+ parser.add_argument(
275
+ "--delay", type=float, default=0.9,
276
+ help="Seconds between auto-agent steps (default: 0.9)"
277
+ )
278
+ args = parser.parse_args()
279
+
280
+ if args.auto:
281
+ run_auto_demo(seed=args.seed, delay=args.delay)
282
+ else:
283
+ run_interactive(seed=args.seed)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
env/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # env package
env/base_env.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_env.py
3
+ -----------
4
+ Abstract base class that every task environment must implement.
5
+ Follows the OpenEnv interface: reset / step / state.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Dict
10
+
11
+ from env.schemas import Observation, Action, StepResult, ResetResult, StateResult
12
+
13
+
14
+ class BaseEnv(ABC):
15
+ """
16
+ OpenEnv-compliant base environment.
17
+
18
+ Concrete task environments should subclass this and implement:
19
+ - reset() β†’ ResetResult
20
+ - step() β†’ StepResult
21
+ - state() β†’ StateResult
22
+ """
23
+
24
+ @abstractmethod
25
+ def reset(self, seed: int | None = None) -> ResetResult:
26
+ """
27
+ Reset the environment to a fresh episode.
28
+
29
+ Parameters
30
+ ----------
31
+ seed : optional RNG seed for reproducibility
32
+
33
+ Returns
34
+ -------
35
+ ResetResult with the initial Observation and episode info.
36
+ """
37
+ ...
38
+
39
+ @abstractmethod
40
+ def step(self, action: Action) -> StepResult:
41
+ """
42
+ Apply an action and advance the episode by one step.
43
+
44
+ Parameters
45
+ ----------
46
+ action : Action – typed agent action
47
+
48
+ Returns
49
+ -------
50
+ StepResult containing:
51
+ - observation : updated Observation
52
+ - reward : Reward for this step
53
+ - done : True when the episode is over
54
+ - info : auxiliary diagnostic information
55
+ """
56
+ ...
57
+
58
+ @abstractmethod
59
+ def state(self) -> StateResult:
60
+ """
61
+ Return the full internal state (for debugging / graders).
62
+ Should NOT be used by the agent during evaluation.
63
+
64
+ Returns
65
+ -------
66
+ StateResult – internal episode state snapshot.
67
+ """
68
+ ...
69
+
70
+ # ------------------------------------------------------------------
71
+ # Optional helpers subclasses may override
72
+ # ------------------------------------------------------------------
73
+
74
+ def render(self) -> str:
75
+ """Human-readable rendering of the current state."""
76
+ s = self.state()
77
+ return (
78
+ f"Task: {s.task_id} | Contract: {s.contract_name} | "
79
+ f"Step: {s.step_count} | Reward: {s.cumulative_reward:.2f} | "
80
+ f"Done: {s.done}"
81
+ )
82
+
83
+ def action_space_description(self) -> Dict[str, Any]:
84
+ """Returns a JSON-serialisable description of the action space."""
85
+ return {}
86
+
87
+ def observation_space_description(self) -> Dict[str, Any]:
88
+ """Returns a JSON-serialisable description of the observation space."""
89
+ return {}
env/schemas.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ schemas.py
3
+ ----------
4
+ Typed Pydantic models implementing the OpenEnv interface spec.
5
+
6
+ Observation - what the agent sees at each step
7
+ Action - what the agent can send
8
+ StepResult - returned by step()
9
+ ResetResult - returned by reset()
10
+ StateResult - returned by state()
11
+ Reward - structured reward info
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Action types
24
+ # ---------------------------------------------------------------------------
25
+
26
+ class ActionType(str, Enum):
27
+ # Task 1 – Vulnerability Detection
28
+ LIST_FUNCTIONS = "list_functions"
29
+ GET_FUNCTION_CODE = "get_function_code"
30
+ GET_FUNCTION_SUMMARY = "get_function_summary"
31
+ GET_FILE_METADATA = "get_file_metadata"
32
+ GET_STATE_VARIABLE = "get_state_variable"
33
+ GET_CALL_GRAPH = "get_call_graph"
34
+ SUBMIT = "submit"
35
+
36
+ # TODO: Task 2 – Property Discovery
37
+ # GET_SIMILAR_RULE = "get_similar_rule"
38
+ # GET_FILE_NATSPEC = "get_file_natspec"
39
+ # GET_FUNCTION_NATSPEC = "get_function_natspec"
40
+ # GET_RELATED_FUNCTIONS = "get_related_functions"
41
+ # GET_IO = "get_io"
42
+ # SUBMIT_PROPERTY = "submit_property"
43
+
44
+ # TODO: Task 3 – Rule Checker
45
+ # GET_FORMALIZED_PROPERTY = "get_formalized_property"
46
+ # GET_FUNCTION_METADATA = "get_function_metadata"
47
+ # SUBMIT_FUNCTION = "submit_function"
48
+
49
+
50
+ class Action(BaseModel):
51
+ """
52
+ Agent action.
53
+
54
+ action_type : one of ActionType enum values
55
+ params : optional key/value arguments, e.g.
56
+ {"function_name": "withdraw"} for GET_FUNCTION_CODE
57
+ {"function_name": "withdraw", "vulnerability_type": "reentrancy"} for SUBMIT
58
+ """
59
+ action_type: ActionType
60
+ params: Dict[str, Any] = Field(default_factory=dict)
61
+
62
+ class Config:
63
+ use_enum_values = True
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Observation
68
+ # ---------------------------------------------------------------------------
69
+
70
+ class Observation(BaseModel):
71
+ """
72
+ What the agent receives from the environment.
73
+
74
+ task_id : which task is active
75
+ contract_name : name of the Solidity contract
76
+ contract_description : high-level description of what the contract does
77
+ available_actions : list of valid ActionType strings
78
+ last_action : the action that produced this observation (None on reset)
79
+ last_action_result: human-readable result of the last action
80
+ step_count : number of steps taken so far
81
+ cumulative_reward : running reward total
82
+ done : whether the episode has ended
83
+ extra : any additional task-specific context
84
+ """
85
+ task_id: str
86
+ contract_name: str
87
+ contract_description: str
88
+ available_actions: List[str]
89
+ last_action: Optional[str] = None
90
+ last_action_result: Optional[str] = None
91
+ step_count: int = 0
92
+ cumulative_reward: float = 0.0
93
+ done: bool = False
94
+ extra: Dict[str, Any] = Field(default_factory=dict)
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Reward
99
+ # ---------------------------------------------------------------------------
100
+
101
+ class Reward(BaseModel):
102
+ """
103
+ Structured reward info returned with each step.
104
+
105
+ value : float reward for this step (can be negative)
106
+ reason : human-readable explanation
107
+ partial : True if this is a shaping reward, False if terminal
108
+ """
109
+ value: float
110
+ reason: str
111
+ partial: bool = True
112
+
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # Step / Reset / State results
116
+ # ---------------------------------------------------------------------------
117
+
118
+ class StepResult(BaseModel):
119
+ observation: Observation
120
+ reward: Reward
121
+ done: bool
122
+ info: Dict[str, Any] = Field(default_factory=dict)
123
+
124
+
125
+ class ResetResult(BaseModel):
126
+ observation: Observation
127
+ info: Dict[str, Any] = Field(default_factory=dict)
128
+
129
+
130
+ class StateResult(BaseModel):
131
+ task_id: str
132
+ contract_name: str
133
+ target_function: Optional[str] = None # hidden in real eval, exposed here for debugging
134
+ step_count: int
135
+ cumulative_reward: float
136
+ done: bool
137
+ query_history: List[str] = Field(default_factory=list)
138
+ session_id: Optional[str] = None
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Task registry entry
143
+ # ---------------------------------------------------------------------------
144
+
145
+ class TaskInfo(BaseModel):
146
+ task_id: str
147
+ name: str
148
+ difficulty: str
149
+ description: str
150
+ status: str = "active" # or "placeholder"
eval.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ eval.py
3
+ -------
4
+ Evaluation harness for the Smart Contract Audit RL Environment.
5
+
6
+ Runs a configurable number of episodes per task, collecting grader scores
7
+ and reward trajectories. Produces a detailed JSON report.
8
+
9
+ Unlike inference.py (which uses an external LLM), this evaluates the
10
+ *environment itself* using a built-in oracle agent β€” useful for:
11
+ - Verifying grader correctness
12
+ - Benchmarking reward shaping
13
+ - Checking score distribution across vulnerability types
14
+
15
+ Usage:
16
+ python eval.py # all 8 vuln episodes
17
+ python eval.py --episodes 16 # more episodes
18
+ python eval.py --seed 0 --verbose # detailed per-step output
19
+ python eval.py --out results.json # custom output file
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import sys
25
+ import time
26
+ from typing import Any, Dict, List
27
+
28
+ from tasks.task1.environment import Task1Environment
29
+ from env.schemas import Action, ActionType
30
+ from data.data_loader import load_contracts, get_all_vulnerable_entries
31
+
32
+
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+ # Oracle agent (always submits the ground-truth answer)
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+
37
+ def oracle_agent(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
38
+ """
39
+ Runs one episode using the oracle strategy:
40
+ 1. list_functions
41
+ 2. get_function_code (for the target function β€” peeked from state)
42
+ 3. submit correct answer
43
+
44
+ This gives an upper-bound score trajectory for the environment.
45
+ Always ends with grader_score = 1.0.
46
+ """
47
+ reset_result = env.reset(seed=seed)
48
+ obs = reset_result.observation
49
+
50
+ steps_taken: List[Dict[str, Any]] = []
51
+
52
+ def _step(at: ActionType, params: dict = None) -> Any:
53
+ params = params or {}
54
+ action = Action(action_type=at, params=params)
55
+ result = env.step(action)
56
+ entry = {
57
+ "step": result.observation.step_count,
58
+ "action": at.value,
59
+ "params": params,
60
+ "reward": result.reward.value,
61
+ "reason": result.reward.reason,
62
+ "cumulative": result.observation.cumulative_reward,
63
+ "done": result.done,
64
+ }
65
+ steps_taken.append(entry)
66
+ if verbose:
67
+ done_flag = " [DONE]" if result.done else ""
68
+ print(
69
+ f" step {entry['step']:2d}: {at.value:25s} "
70
+ f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}"
71
+ f"{done_flag}"
72
+ )
73
+ return result
74
+
75
+ # Peek at ground truth (oracle only)
76
+ state = env.state()
77
+ target_fn = state.target_function
78
+
79
+ # Get ground-truth vulnerability from data
80
+ contracts = load_contracts()
81
+ vuln_issue = None
82
+ for contract in contracts:
83
+ for fn in contract.get("functions", []):
84
+ if fn["name"].lower() == target_fn.lower() and fn.get("vulnerable"):
85
+ vuln_issue = fn["vulnerability_details"]["issue"]
86
+ break
87
+ if vuln_issue:
88
+ break
89
+
90
+ if verbose:
91
+ print(f" Contract : {obs.contract_name}")
92
+ print(f" Target : {target_fn} ({vuln_issue})")
93
+
94
+ # Step 1: list functions (small cost, realistic)
95
+ _step(ActionType.LIST_FUNCTIONS)
96
+ # Step 2: read target function code (gets +0.05 shaping reward)
97
+ _step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn})
98
+ # Step 3: submit perfect answer
99
+ result = _step(ActionType.SUBMIT, {
100
+ "function_name": target_fn,
101
+ "vulnerability_type": vuln_issue,
102
+ })
103
+
104
+ final_reward = result.reward.value
105
+ if final_reward >= 4.9:
106
+ grader_score = 1.0
107
+ elif final_reward >= 0.9:
108
+ grader_score = 0.5
109
+ else:
110
+ grader_score = 0.0
111
+
112
+ return {
113
+ "seed": seed,
114
+ "contract": obs.contract_name,
115
+ "target_function": target_fn,
116
+ "vulnerability": vuln_issue,
117
+ "grader_score": grader_score,
118
+ "cumulative_reward": result.observation.cumulative_reward,
119
+ "steps": steps_taken,
120
+ "num_steps": len(steps_taken),
121
+ }
122
+
123
+
124
+ # ─────────────────────────────────────────────────────────────────────────────
125
+ # Partial agent (submits correct function, wrong vuln type)
126
+ # ─────────────────────────────────────────────────────────────────────────────
127
+
128
+ def partial_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
129
+ """Submits right function, always uses 'unknown' as vulnerability type β†’ score 0.5."""
130
+ reset_result = env.reset(seed=seed)
131
+ obs = reset_result.observation
132
+ state = env.state()
133
+ target_fn = state.target_function
134
+
135
+ action = Action(action_type=ActionType.SUBMIT, params={
136
+ "function_name": target_fn,
137
+ "vulnerability_type": "unknown vulnerability",
138
+ })
139
+ result = env.step(action)
140
+ return {
141
+ "seed": seed,
142
+ "grader_score": 0.5,
143
+ "cumulative_reward": result.observation.cumulative_reward,
144
+ }
145
+
146
+
147
+ # ─────────────────────────────────────────────────────────────────────────────
148
+ # Random agent (submits a random wrong function)
149
+ # ─────────────────────────────────────────────────────────────────────────────
150
+
151
+ def random_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
152
+ """Always submits 'constructor' β€” always wrong β†’ score 0.0."""
153
+ env.reset(seed=seed)
154
+ action = Action(action_type=ActionType.SUBMIT, params={
155
+ "function_name": "constructor",
156
+ "vulnerability_type": "reentrancy",
157
+ })
158
+ result = env.step(action)
159
+ return {
160
+ "seed": seed,
161
+ "grader_score": 0.0,
162
+ "cumulative_reward": result.observation.cumulative_reward,
163
+ }
164
+
165
+
166
+ # ─────────────────────────────────────────────────────────────────────────────
167
+ # Evaluation runner
168
+ # ─────────────────────────────────────────────────────────────────────────────
169
+
170
+ def run_evaluation(
171
+ num_episodes: int = 8,
172
+ seed_offset: int = 0,
173
+ verbose: bool = False,
174
+ output_file: str = "eval_results.json",
175
+ ) -> None:
176
+ env = Task1Environment()
177
+ contracts = load_contracts()
178
+ entries = get_all_vulnerable_entries(contracts)
179
+ vuln_types = list({fn["vulnerability_details"]["issue"] for _, fn in entries})
180
+
181
+ print("=" * 64)
182
+ print("Smart Contract Audit RL Environment β€” Evaluation")
183
+ print("=" * 64)
184
+ print(f" Episodes : {num_episodes}")
185
+ print(f" Seed range: {seed_offset} – {seed_offset + num_episodes - 1}")
186
+ print(f" Vulns in dataset: {len(entries)}")
187
+ print()
188
+
189
+ # ── Oracle agent ─────────────────────────────────────────────────────────
190
+ print("β–Ά Oracle agent (upper bound β€” always submits correct answer):")
191
+ oracle_episodes = []
192
+ for i in range(num_episodes):
193
+ seed = seed_offset + i
194
+ ep = oracle_agent(env, seed=seed, verbose=verbose)
195
+ oracle_episodes.append(ep)
196
+ icon = "βœ…" if ep["grader_score"] == 1.0 else "⚠️ "
197
+ print(
198
+ f" {icon} seed={seed:3d} {ep['contract']:12s} "
199
+ f"{ep['target_function']:15s} score={ep['grader_score']:.1f} "
200
+ f"reward={ep['cumulative_reward']:+.2f}"
201
+ )
202
+
203
+ oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes
204
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes
205
+ print(f"\n Oracle avg grader score : {oracle_avg:.3f}")
206
+ print(f" Oracle avg reward : {oracle_avg_r:+.2f}")
207
+
208
+ # ── Partial agent ─────────────────────────────────────────────────────────
209
+ print("\nβ–Ά Partial agent (right function, wrong vuln type β†’ 0.5 each):")
210
+ partial_episodes = []
211
+ for i in range(num_episodes):
212
+ ep = partial_agent(env, seed=seed_offset + i)
213
+ partial_episodes.append(ep)
214
+ partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes
215
+ print(f" Partial avg grader score: {partial_avg:.3f}")
216
+
217
+ # ── Random agent ──────────────────────────────────────────────────────────
218
+ print("\nβ–Ά Random agent (always wrong β†’ 0.0 each):")
219
+ random_episodes = []
220
+ for i in range(num_episodes):
221
+ ep = random_agent(env, seed=seed_offset + i)
222
+ random_episodes.append(ep)
223
+ random_avg = sum(e["grader_score"] for e in random_episodes) / num_episodes
224
+ print(f" Random avg grader score : {random_avg:.3f}")
225
+
226
+ # ── Score distribution ───────────────────────────────────��────────────────
227
+ print("\nβ–Ά Coverage across vulnerability types:")
228
+ seen = {}
229
+ for ep in oracle_episodes:
230
+ v = ep.get("vulnerability", "unknown")
231
+ seen[v] = seen.get(v, 0) + 1
232
+ for v in sorted(seen):
233
+ print(f" {seen[v]:2d}x {v}")
234
+
235
+ # ── Summary ───────────────────────────────────────────────────────────────
236
+ print("\n" + "=" * 64)
237
+ print("SUMMARY")
238
+ print("=" * 64)
239
+ print(f" Oracle (ceiling): {oracle_avg:.3f} {'βœ…' if oracle_avg == 1.0 else '⚠️ '}")
240
+ print(f" Partial (partial): {partial_avg:.3f} βœ…")
241
+ print(f" Random (floor) : {random_avg:.3f} βœ…")
242
+
243
+ assert oracle_avg == 1.0, "Oracle should always score 1.0"
244
+ assert partial_avg == 0.5, "Partial should always score 0.5"
245
+ assert random_avg == 0.0, "Random should always score 0.0"
246
+
247
+ print("\n βœ… All score sanity checks passed.")
248
+
249
+ # ── Write results ─────────────────────────────────────────────────────────
250
+ report = {
251
+ "num_episodes": num_episodes,
252
+ "seed_offset": seed_offset,
253
+ "agents": {
254
+ "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_episodes},
255
+ "partial": {"avg_score": partial_avg, "episodes": partial_episodes},
256
+ "random": {"avg_score": random_avg, "episodes": random_episodes},
257
+ },
258
+ "vulnerability_coverage": seen,
259
+ }
260
+ with open(output_file, "w") as f:
261
+ json.dump(report, f, indent=2)
262
+ print(f"\n Results written to {output_file}")
263
+
264
+
265
+ # ─────────────────────────────────────────────────────────────────────────────
266
+ # Entry point
267
+ # ─────────────────────────────────────────────────────────────────────────────
268
+
269
+ def main():
270
+ parser = argparse.ArgumentParser(description="Evaluate the SC Audit RL Environment")
271
+ parser.add_argument("--episodes", type=int, default=8,
272
+ help="Number of episodes per agent (default: 8)")
273
+ parser.add_argument("--seed", type=int, default=42,
274
+ help="Starting seed (default: 42)")
275
+ parser.add_argument("--verbose", action="store_true",
276
+ help="Print per-step details for oracle agent")
277
+ parser.add_argument("--out", default="eval_results.json",
278
+ help="Output JSON file (default: eval_results.json)")
279
+ args = parser.parse_args()
280
+
281
+ run_evaluation(
282
+ num_episodes=args.episodes,
283
+ seed_offset=args.seed,
284
+ verbose=args.verbose,
285
+ output_file=args.out,
286
+ )
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
inference.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ------------
4
+ Baseline inference script for the Smart Contract Audit RL Environment.
5
+
6
+ Uses the OpenAI-compatible API client to run an LLM agent against Task 1.
7
+ Tasks 2 and 3 are placeholders β€” they reset and immediately record 0.0.
8
+
9
+ Environment variables required:
10
+ API_BASE_URL – LLM endpoint (e.g. https://api.openai.com/v1)
11
+ MODEL_NAME – model identifier (e.g. gpt-4o-mini)
12
+ HF_TOKEN – API key (passed as Authorization: Bearer <HF_TOKEN>)
13
+
14
+ Usage:
15
+ python inference.py
16
+
17
+ Output:
18
+ Per-task scores printed to stdout.
19
+ Final baseline scores written to baseline_scores.json.
20
+ """
21
+
22
+ import json
23
+ import os
24
+ import sys
25
+ import time
26
+ from typing import Any, Dict, List, Optional
27
+
28
+ from openai import OpenAI
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Import the env directly (no HTTP overhead for baseline)
32
+ # ---------------------------------------------------------------------------
33
+ from tasks.task1.environment import Task1Environment
34
+ from env.schemas import Action, ActionType
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Config
38
+ # ---------------------------------------------------------------------------
39
+
40
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
41
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
42
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
43
+
44
+ if not HF_TOKEN:
45
+ print("WARNING: HF_TOKEN is not set. API calls may fail.", file=sys.stderr)
46
+
47
+ MAX_STEPS = 15 # Safety limit per episode
48
+ NUM_EPISODES = 3 # Episodes per task
49
+ TASK1_SEED_BASE = 42 # Reproducible seeds
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # OpenAI client
54
+ # ---------------------------------------------------------------------------
55
+
56
+ client = OpenAI(
57
+ api_key=HF_TOKEN,
58
+ base_url=API_BASE_URL,
59
+ )
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # System prompt
64
+ # ---------------------------------------------------------------------------
65
+
66
+ SYSTEM_PROMPT = """You are an expert smart contract security auditor.
67
+
68
+ You are given a Solidity contract and must identify the SINGLE most critical vulnerable function and name its vulnerability type.
69
+
70
+ ## Available Actions
71
+ You interact by choosing ONE action per turn from:
72
+
73
+ 1. list_functions
74
+ β†’ {"action": "list_functions", "params": {}}
75
+
76
+ 2. get_function_code
77
+ β†’ {"action": "get_function_code", "params": {"function_name": "<name>"}}
78
+
79
+ 3. get_function_summary
80
+ β†’ {"action": "get_function_summary", "params": {"function_name": "<name>"}}
81
+
82
+ 4. get_file_metadata
83
+ β†’ {"action": "get_file_metadata", "params": {}}
84
+
85
+ 5. get_state_variable
86
+ β†’ {"action": "get_state_variable", "params": {"variable_name": "<name>"}}
87
+ (omit variable_name to list all variables)
88
+
89
+ 6. get_call_graph
90
+ β†’ {"action": "get_call_graph", "params": {}}
91
+
92
+ 7. submit (ENDS THE EPISODE)
93
+ β†’ {"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 word description>"}}
94
+
95
+ ## Strategy
96
+ - Start with list_functions and get_file_metadata to understand the contract
97
+ - Inspect suspicious functions (withdraw, transfer, emergency*, stake, etc.)
98
+ - Submit when you are confident about the vulnerable function
99
+
100
+ ## Output Format
101
+ Always respond with a single JSON object:
102
+ {"action": "<action_type>", "params": {...}}
103
+ Do NOT include any other text β€” only valid JSON.
104
+ """
105
+
106
+
107
+ def build_user_message(obs: Dict[str, Any]) -> str:
108
+ """Format the observation as a user message."""
109
+ lines = [
110
+ f"=== CONTRACT: {obs['contract_name']} ===",
111
+ f"Description: {obs['contract_description']}",
112
+ f"Step: {obs['step_count']} | Cumulative reward: {obs['cumulative_reward']:.2f}",
113
+ "",
114
+ f"Last action: {obs['last_action'] or 'None'}",
115
+ f"Result: {obs['last_action_result'] or 'Episode just started'}",
116
+ "",
117
+ f"Available actions: {', '.join(obs['available_actions'])}",
118
+ ]
119
+ if obs.get("extra", {}).get("hint"):
120
+ lines.append(f"Hint: {obs['extra']['hint']}")
121
+ return "\n".join(lines)
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Agent loop
126
+ # ---------------------------------------------------------------------------
127
+
128
+ def run_episode(env: Task1Environment, seed: int, episode_num: int) -> Dict[str, Any]:
129
+ """Run one episode and return result info."""
130
+ print(f"\n Episode {episode_num} (seed={seed})")
131
+
132
+ reset_result = env.reset(seed=seed)
133
+ obs = reset_result.observation.model_dump()
134
+
135
+ print(f" Contract: {obs['contract_name']}")
136
+
137
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
138
+ final_score = 0.0
139
+ final_reward = 0.0
140
+ steps = 0
141
+ done = False
142
+
143
+ for step_num in range(MAX_STEPS):
144
+ user_msg = build_user_message(obs)
145
+ messages.append({"role": "user", "content": user_msg})
146
+
147
+ # LLM call
148
+ try:
149
+ response = client.chat.completions.create(
150
+ model=MODEL_NAME,
151
+ messages=messages,
152
+ max_tokens=256,
153
+ temperature=0.0,
154
+ )
155
+ raw = response.choices[0].message.content.strip()
156
+ except Exception as e:
157
+ print(f" LLM error at step {step_num}: {e}", file=sys.stderr)
158
+ break
159
+
160
+ # Parse action
161
+ try:
162
+ parsed = json.loads(raw)
163
+ action_type = ActionType(parsed["action"])
164
+ params = parsed.get("params", {})
165
+ except Exception as e:
166
+ print(f" Parse error: {e} | Raw: {raw[:100]}", file=sys.stderr)
167
+ # Default safe action
168
+ action_type = ActionType.LIST_FUNCTIONS
169
+ params = {}
170
+
171
+ action = Action(action_type=action_type, params=params)
172
+ messages.append({"role": "assistant", "content": raw})
173
+
174
+ # Step
175
+ step_result = env.step(action)
176
+ obs = step_result.observation.model_dump()
177
+ done = step_result.done
178
+ steps += 1
179
+ final_reward = obs["cumulative_reward"]
180
+
181
+ print(
182
+ f" Step {step_num+1}: {action_type.value} | "
183
+ f"reward={step_result.reward.value:+.2f} | "
184
+ f"cumulative={final_reward:.2f}"
185
+ )
186
+
187
+ if done:
188
+ # Determine grader score from reward
189
+ last_reward = step_result.reward.value
190
+ if last_reward >= 4.9:
191
+ final_score = 1.0
192
+ elif last_reward >= 0.9:
193
+ final_score = 0.5
194
+ else:
195
+ final_score = 0.0
196
+ print(f" β†’ DONE | grader_score={final_score:.1f}")
197
+ break
198
+
199
+ if not done:
200
+ print(f" β†’ MAX STEPS reached without submission. Score=0.0")
201
+
202
+ return {
203
+ "episode": episode_num,
204
+ "seed": seed,
205
+ "contract": obs["contract_name"],
206
+ "steps": steps,
207
+ "cumulative_reward": final_reward,
208
+ "grader_score": final_score,
209
+ "done": done,
210
+ }
211
+
212
+
213
+ def run_task1(num_episodes: int = NUM_EPISODES) -> Dict[str, Any]:
214
+ """Run Task 1 and return aggregate scores."""
215
+ print("\n" + "="*60)
216
+ print("TASK 1: Targeted Vulnerability Detection")
217
+ print("="*60)
218
+
219
+ env = Task1Environment()
220
+ episodes = []
221
+
222
+ for i in range(num_episodes):
223
+ seed = TASK1_SEED_BASE + i
224
+ result = run_episode(env, seed=seed, episode_num=i + 1)
225
+ episodes.append(result)
226
+ time.sleep(0.5) # Rate limit courtesy
227
+
228
+ scores = [e["grader_score"] for e in episodes]
229
+ avg = sum(scores) / len(scores) if scores else 0.0
230
+ avg_reward = sum(e["cumulative_reward"] for e in episodes) / len(episodes)
231
+
232
+ print(f"\n Task 1 Results:")
233
+ print(f" Episodes: {num_episodes}")
234
+ print(f" Grader scores: {scores}")
235
+ print(f" Average grader score: {avg:.3f}")
236
+ print(f" Average cumulative reward: {avg_reward:.2f}")
237
+
238
+ return {
239
+ "task_id": "task1_vuln_detection",
240
+ "name": "Targeted Vulnerability Detection",
241
+ "status": "active",
242
+ "num_episodes": num_episodes,
243
+ "episodes": episodes,
244
+ "avg_grader_score": avg,
245
+ "avg_cumulative_reward": avg_reward,
246
+ }
247
+
248
+
249
+ def run_task2_placeholder() -> Dict[str, Any]:
250
+ """Task 2 placeholder β€” returns 0.0 score."""
251
+ print("\n" + "="*60)
252
+ print("TASK 2: Property Discovery [PLACEHOLDER β€” not implemented]")
253
+ print("="*60)
254
+ print(" Skipping. Score: 0.0")
255
+ return {
256
+ "task_id": "task2_property_discovery",
257
+ "name": "Property Discovery",
258
+ "status": "placeholder",
259
+ "num_episodes": 0,
260
+ "episodes": [],
261
+ "avg_grader_score": 0.0,
262
+ "avg_cumulative_reward": 0.0,
263
+ }
264
+
265
+
266
+ def run_task3_placeholder() -> Dict[str, Any]:
267
+ """Task 3 placeholder β€” returns 0.0 score."""
268
+ print("\n" + "="*60)
269
+ print("TASK 3: Rule Checker [PLACEHOLDER β€” not implemented]")
270
+ print("="*60)
271
+ print(" Skipping. Score: 0.0")
272
+ return {
273
+ "task_id": "task3_rule_checker",
274
+ "name": "Rule Checker",
275
+ "status": "placeholder",
276
+ "num_episodes": 0,
277
+ "episodes": [],
278
+ "avg_grader_score": 0.0,
279
+ "avg_cumulative_reward": 0.0,
280
+ }
281
+
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Main
285
+ # ---------------------------------------------------------------------------
286
+
287
+ def main():
288
+ print("Smart Contract Audit RL Environment β€” Baseline Inference")
289
+ print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
290
+
291
+ results = {
292
+ "model": MODEL_NAME,
293
+ "base_url": API_BASE_URL,
294
+ "tasks": [],
295
+ }
296
+
297
+ t1 = run_task1(num_episodes=NUM_EPISODES)
298
+ t2 = run_task2_placeholder()
299
+ t3 = run_task3_placeholder()
300
+
301
+ results["tasks"] = [t1, t2, t3]
302
+
303
+ # Summary
304
+ active_tasks = [t for t in results["tasks"] if t["status"] == "active"]
305
+ overall = (
306
+ sum(t["avg_grader_score"] for t in active_tasks) / len(active_tasks)
307
+ if active_tasks else 0.0
308
+ )
309
+ results["overall_avg_score"] = overall
310
+
311
+ print("\n" + "="*60)
312
+ print("BASELINE SUMMARY")
313
+ print("="*60)
314
+ for t in results["tasks"]:
315
+ status = "βœ…" if t["status"] == "active" else "⏳"
316
+ print(f" {status} {t['name']}: {t['avg_grader_score']:.3f}")
317
+ print(f" Overall (active tasks): {overall:.3f}")
318
+
319
+ # Write scores file
320
+ with open("baseline_scores.json", "w") as f:
321
+ json.dump(results, f, indent=2)
322
+ print("\n Scores written to baseline_scores.json")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()
openenv.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: smart-contract-audit-env
2
+ version: "1.0.0"
3
+ description: >
4
+ Reinforcement learning environment for smart contract security analysis.
5
+ Agents interact with real-world Solidity contract data from Certora-audited
6
+ projects, learning to detect vulnerabilities, discover properties, and
7
+ verify rule compliance β€” tasks that professional auditors perform daily.
8
+
9
+ author: "SmartAudit Team"
10
+ license: MIT
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Tasks
14
+ # ---------------------------------------------------------------------------
15
+ tasks:
16
+ - id: task1_vuln_detection
17
+ name: Targeted Vulnerability Detection
18
+ difficulty: medium
19
+ status: active
20
+ description: >
21
+ Given a Solidity contract (4–6 functions), identify the single vulnerable
22
+ function and describe its vulnerability type in 2–3 words.
23
+ max_steps: 20
24
+ reward_range: [-10.0, 10.0]
25
+ grader: tasks/task1/grader.py
26
+ grader_score_range: [0.0, 1.0]
27
+
28
+ - id: task2_property_discovery
29
+ name: Property Discovery
30
+ difficulty: hard
31
+ status: placeholder
32
+ description: >
33
+ Given a single Solidity function with known properties, discover the
34
+ correct natural-language property describing its expected behaviour.
35
+ max_steps: 15
36
+ reward_range: [-5.0, 5.0]
37
+ grader: tasks/task2/grader.py # TODO: implement
38
+ grader_score_range: [0.0, 1.0]
39
+
40
+ - id: task3_rule_checker
41
+ name: Rule Checker
42
+ difficulty: easy
43
+ status: placeholder
44
+ description: >
45
+ Given a natural-language property and a Solidity file, identify the
46
+ function that violates that property.
47
+ max_steps: 15
48
+ reward_range: [-5.0, 5.0]
49
+ grader: tasks/task3/grader.py # TODO: implement
50
+ grader_score_range: [0.0, 1.0]
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Observation space
54
+ # ---------------------------------------------------------------------------
55
+ observation_space:
56
+ type: object
57
+ properties:
58
+ task_id:
59
+ type: string
60
+ description: Active task identifier
61
+ contract_name:
62
+ type: string
63
+ description: Name of the Solidity contract
64
+ contract_description:
65
+ type: string
66
+ description: Human-readable description of what the contract does
67
+ available_actions:
68
+ type: array
69
+ items:
70
+ type: string
71
+ description: List of valid action type strings
72
+ last_action:
73
+ type: string
74
+ nullable: true
75
+ description: The action type that produced this observation
76
+ last_action_result:
77
+ type: string
78
+ nullable: true
79
+ description: Human-readable result of the last action
80
+ step_count:
81
+ type: integer
82
+ description: Number of steps taken in this episode
83
+ cumulative_reward:
84
+ type: number
85
+ description: Running reward total for this episode
86
+ done:
87
+ type: boolean
88
+ description: True when the episode has ended
89
+ extra:
90
+ type: object
91
+ description: Task-specific hints and auxiliary data
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Action space (Task 1)
95
+ # ---------------------------------------------------------------------------
96
+ action_space:
97
+ type: object
98
+ description: Named action with optional parameters
99
+ properties:
100
+ action_type:
101
+ type: string
102
+ enum:
103
+ - list_functions
104
+ - get_function_code
105
+ - get_function_summary
106
+ - get_file_metadata
107
+ - get_state_variable
108
+ - get_call_graph
109
+ - submit
110
+ params:
111
+ type: object
112
+ description: Key-value arguments for the action
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # Reward function
116
+ # ---------------------------------------------------------------------------
117
+ reward:
118
+ type: shaped
119
+ description: >
120
+ Per-step costs encourage efficient exploration. A positive signal is given
121
+ when the agent accesses the actual vulnerable function. Terminal rewards
122
+ reflect submission accuracy (0 β†’ 1 grader score).
123
+ shaping:
124
+ list_functions: -0.05
125
+ get_function_code_wrong: -0.10
126
+ get_function_code_correct: +0.05
127
+ get_function_summary_wrong: -0.05
128
+ get_function_summary_correct: +0.03
129
+ get_file_metadata: -0.04
130
+ get_state_variable: -0.05
131
+ get_call_graph: -0.08
132
+ repeated_query: -0.40
133
+ terminal:
134
+ correct_submission: +5.0
135
+ partial_submission: +1.0
136
+ wrong_submission: -1.5
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # Data
140
+ # ---------------------------------------------------------------------------
141
+ data:
142
+ source: "Certora audited projects (Aave, Compound-style protocols)"
143
+ format: JSON
144
+ num_contracts: 4
145
+ num_vulnerable_functions: 8
146
+ vulnerability_types:
147
+ - Reentrancy
148
+ - Missing access control
149
+ - Integer overflow
150
+ - tx.origin authentication
151
+ - Front-running
152
+ - Timestamp dependence
153
+ - Denial of service (unbounded loop)
154
+ - Unchecked return value
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Interface
158
+ # ---------------------------------------------------------------------------
159
+ interface:
160
+ http:
161
+ reset: POST /reset
162
+ step: POST /step
163
+ state: GET /state
164
+ tasks: GET /tasks
165
+ health: GET /health
166
+ python:
167
+ reset: env.reset(seed=None) -> ResetResult
168
+ step: env.step(action) -> StepResult
169
+ state: env.state() -> StateResult
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.8.2
4
+ openai==1.51.0
5
+ httpx==0.27.2
6
+ python-multipart==0.0.9
7
+ pyyaml==6.0.2
tasks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # tasks package
tasks/task1/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # task1 package
2
+ from tasks.task1.environment import Task1Environment
3
+ from tasks.task1.grader import Task1Grader
4
+
5
+ __all__ = ["Task1Environment", "Task1Grader"]
tasks/task1/environment.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ environment.py (Task 1 – Targeted Vulnerability Detection)
3
+ ------------------------------------------------------------
4
+ Full OpenEnv-compliant environment.
5
+
6
+ Episode flow:
7
+ 1. reset() selects a random (contract, vulnerable_function) pair.
8
+ 2. The agent receives an Observation with the contract description.
9
+ 3. The agent uses actions to explore the contract (each costs a small penalty).
10
+ 4. When the agent submits, the Grader scores the answer and the episode ends.
11
+
12
+ Reward shaping:
13
+ list_functions : -0.05
14
+ get_function_code : -0.10 (wrong function) / +0.05 (correct function)
15
+ get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
16
+ get_file_metadata : -0.04
17
+ get_state_variable : -0.05
18
+ get_call_graph : -0.08
19
+ submit (score=1.0) : +5.0
20
+ submit (score=0.5) : +1.0
21
+ submit (score=0.0) : -1.5
22
+ repeated query : -0.40
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import random
28
+ from typing import Any, Dict, List, Optional, Set
29
+
30
+ from data.data_loader import (
31
+ load_contracts,
32
+ sample_episode,
33
+ get_function_by_name,
34
+ get_state_variable_by_name,
35
+ list_function_names,
36
+ list_state_variable_names,
37
+ )
38
+ from env.base_env import BaseEnv
39
+ from env.schemas import (
40
+ Action,
41
+ ActionType,
42
+ Observation,
43
+ Reward,
44
+ ResetResult,
45
+ StateResult,
46
+ StepResult,
47
+ )
48
+ from tasks.task1.grader import Task1Grader
49
+
50
+ TASK_ID = "task1_vuln_detection"
51
+
52
+ AVAILABLE_ACTIONS = [
53
+ ActionType.LIST_FUNCTIONS,
54
+ ActionType.GET_FUNCTION_CODE,
55
+ ActionType.GET_FUNCTION_SUMMARY,
56
+ ActionType.GET_FILE_METADATA,
57
+ ActionType.GET_STATE_VARIABLE,
58
+ ActionType.GET_CALL_GRAPH,
59
+ ActionType.SUBMIT,
60
+ ]
61
+
62
+
63
+ class Task1Environment(BaseEnv):
64
+ """Task 1: Targeted Vulnerability Detection."""
65
+
66
+ def __init__(self, contracts_path: Optional[str] = None) -> None:
67
+ self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
68
+ self._rng = random.Random()
69
+
70
+ # Episode state (initialised by reset)
71
+ self._contract: Dict[str, Any] = {}
72
+ self._target_fn: Dict[str, Any] = {}
73
+ self._grader: Optional[Task1Grader] = None
74
+ self._step_count: int = 0
75
+ self._cumulative_reward: float = 0.0
76
+ self._done: bool = False
77
+ self._query_history: List[str] = []
78
+ self._seen_queries: Set[str] = set()
79
+
80
+ # ------------------------------------------------------------------
81
+ # OpenEnv interface
82
+ # ------------------------------------------------------------------
83
+
84
+ def reset(self, seed: Optional[int] = None) -> ResetResult:
85
+ """Start a new episode by sampling a random vulnerable function."""
86
+ if seed is not None:
87
+ self._rng.seed(seed)
88
+
89
+ self._contract, self._target_fn = sample_episode(self._contracts, self._rng)
90
+ self._grader = Task1Grader(
91
+ target_function=self._target_fn["name"],
92
+ vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
93
+ )
94
+ self._step_count = 0
95
+ self._cumulative_reward = 0.0
96
+ self._done = False
97
+ self._query_history = []
98
+ self._seen_queries = set()
99
+
100
+ obs = self._build_observation(
101
+ last_action=None,
102
+ last_result=(
103
+ f"New episode started. Contract: {self._contract['contract_name']}. "
104
+ f"Use 'list_functions' to explore the contract."
105
+ ),
106
+ )
107
+ return ResetResult(observation=obs, info={"task_id": TASK_ID})
108
+
109
+ def step(self, action: Action) -> StepResult:
110
+ """Execute one agent action."""
111
+ if self._done:
112
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
113
+
114
+ self._step_count += 1
115
+
116
+ # Dispatch
117
+ result_text, reward = self._dispatch(action)
118
+
119
+ self._cumulative_reward += reward.value
120
+ self._query_history.append(f"[{action.action_type}] β†’ {result_text[:120]}")
121
+
122
+ obs = self._build_observation(
123
+ last_action=action.action_type,
124
+ last_result=result_text,
125
+ )
126
+ return StepResult(
127
+ observation=obs,
128
+ reward=reward,
129
+ done=self._done,
130
+ info={
131
+ "step": self._step_count,
132
+ "cumulative_reward": self._cumulative_reward,
133
+ },
134
+ )
135
+
136
+ def state(self) -> StateResult:
137
+ return StateResult(
138
+ task_id=TASK_ID,
139
+ contract_name=self._contract.get("contract_name", ""),
140
+ target_function=self._target_fn.get("name"),
141
+ step_count=self._step_count,
142
+ cumulative_reward=self._cumulative_reward,
143
+ done=self._done,
144
+ query_history=list(self._query_history),
145
+ )
146
+
147
+ # ------------------------------------------------------------------
148
+ # Internal helpers
149
+ # ------------------------------------------------------------------
150
+
151
+ def _build_observation(
152
+ self,
153
+ last_action: Optional[str],
154
+ last_result: str,
155
+ ) -> Observation:
156
+ return Observation(
157
+ task_id=TASK_ID,
158
+ contract_name=self._contract.get("contract_name", ""),
159
+ contract_description=self._contract.get("metadata", {}).get("description", ""),
160
+ available_actions=[a.value for a in AVAILABLE_ACTIONS],
161
+ last_action=last_action,
162
+ last_action_result=last_result,
163
+ step_count=self._step_count,
164
+ cumulative_reward=self._cumulative_reward,
165
+ done=self._done,
166
+ extra={
167
+ "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
168
+ "hint": (
169
+ "Identify the vulnerable function and its issue. "
170
+ "Submit with action_type='submit', params={'function_name': '...', "
171
+ "'vulnerability_type': '...'}"
172
+ ),
173
+ },
174
+ )
175
+
176
+ def _query_key(self, action_type: str, params: Dict[str, Any]) -> str:
177
+ """Build a hashable key for repeated-query detection."""
178
+ return f"{action_type}:{sorted(params.items())}"
179
+
180
+ def _is_repeated(self, key: str) -> bool:
181
+ if key in self._seen_queries:
182
+ return True
183
+ self._seen_queries.add(key)
184
+ return False
185
+
186
+ def _dispatch(self, action: Action) -> tuple[str, Reward]:
187
+ at = action.action_type
188
+ params = action.params
189
+ qkey = self._query_key(at, params)
190
+
191
+ # ---- list_functions ----------------------------------------
192
+ if at == ActionType.LIST_FUNCTIONS:
193
+ if self._is_repeated(qkey):
194
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
195
+ names = list_function_names(self._contract)
196
+ return (
197
+ f"Functions in {self._contract['contract_name']}: {', '.join(names)}",
198
+ Reward(value=-0.05, reason="list_functions cost", partial=True),
199
+ )
200
+
201
+ # ---- get_function_code -------------------------------------
202
+ if at == ActionType.GET_FUNCTION_CODE:
203
+ fn_name = params.get("function_name", "")
204
+ if self._is_repeated(qkey):
205
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
206
+ fn = get_function_by_name(self._contract, fn_name)
207
+ if fn is None:
208
+ return (
209
+ f"Function '{fn_name}' not found. Available: {list_function_names(self._contract)}",
210
+ Reward(value=-0.10, reason="Wrong/unknown function name", partial=True),
211
+ )
212
+ is_target = fn["name"].lower() == self._target_fn["name"].lower()
213
+ code = fn.get("code", "// no code available")
214
+ reward_val = 0.05 if is_target else -0.10
215
+ reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
216
+ return (
217
+ f"// {fn['name']}\n{code}",
218
+ Reward(value=reward_val, reason=reason, partial=True),
219
+ )
220
+
221
+ # ---- get_function_summary ----------------------------------
222
+ if at == ActionType.GET_FUNCTION_SUMMARY:
223
+ fn_name = params.get("function_name", "")
224
+ if self._is_repeated(qkey):
225
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
226
+ fn = get_function_by_name(self._contract, fn_name)
227
+ if fn is None:
228
+ return (
229
+ f"Function '{fn_name}' not found.",
230
+ Reward(value=-0.05, reason="Wrong function name", partial=True),
231
+ )
232
+ is_target = fn["name"].lower() == self._target_fn["name"].lower()
233
+ comment = fn.get("comment", "No summary available.")
234
+ reward_val = 0.03 if is_target else -0.05
235
+ reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
236
+ return (
237
+ f"Summary of '{fn['name']}': {comment}",
238
+ Reward(value=reward_val, reason=reason, partial=True),
239
+ )
240
+
241
+ # ---- get_file_metadata -------------------------------------
242
+ if at == ActionType.GET_FILE_METADATA:
243
+ if self._is_repeated(qkey):
244
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
245
+ meta = self._contract.get("metadata", {})
246
+ result = (
247
+ f"Contract: {self._contract['contract_name']} | "
248
+ f"File: {self._contract.get('file_name', 'N/A')} | "
249
+ f"Solidity: {meta.get('solidity_version', 'N/A')} | "
250
+ f"License: {meta.get('license', 'N/A')} | "
251
+ f"Author: {meta.get('author', 'N/A')} | "
252
+ f"Description: {meta.get('description', 'N/A')}"
253
+ )
254
+ return result, Reward(value=-0.04, reason="get_file_metadata cost", partial=True)
255
+
256
+ # ---- get_state_variable ------------------------------------
257
+ if at == ActionType.GET_STATE_VARIABLE:
258
+ var_name = params.get("variable_name", "")
259
+ if self._is_repeated(qkey):
260
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
261
+ if not var_name:
262
+ # Return list of all state variables
263
+ names = list_state_variable_names(self._contract)
264
+ return (
265
+ f"State variables: {', '.join(names)}",
266
+ Reward(value=-0.05, reason="Listed state variables", partial=True),
267
+ )
268
+ sv = get_state_variable_by_name(self._contract, var_name)
269
+ if sv is None:
270
+ return (
271
+ f"Variable '{var_name}' not found.",
272
+ Reward(value=-0.05, reason="Unknown state variable", partial=True),
273
+ )
274
+ return (
275
+ f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
276
+ Reward(value=-0.05, reason="get_state_variable cost", partial=True),
277
+ )
278
+
279
+ # ---- get_call_graph ----------------------------------------
280
+ if at == ActionType.GET_CALL_GRAPH:
281
+ if self._is_repeated(qkey):
282
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
283
+ cg = self._contract.get("call_graph", {})
284
+ cg_str = "; ".join(f"{fn} β†’ [{', '.join(callees)}]" for fn, callees in cg.items())
285
+ return (
286
+ f"Call graph: {cg_str}",
287
+ Reward(value=-0.08, reason="get_call_graph cost", partial=True),
288
+ )
289
+
290
+ # ---- submit ------------------------------------------------
291
+ if at == ActionType.SUBMIT:
292
+ fn_name = params.get("function_name", "")
293
+ vuln_type = params.get("vulnerability_type", "")
294
+ if not fn_name or not vuln_type:
295
+ return (
296
+ "Submit requires 'function_name' and 'vulnerability_type' in params.",
297
+ Reward(value=-0.5, reason="Malformed submission", partial=True),
298
+ )
299
+ score = self._grader.grade_submission(fn_name, vuln_type)
300
+ reward_val = self._grader.reward_for_score(score)
301
+ self._done = True
302
+
303
+ if score == 1.0:
304
+ msg = (
305
+ f"βœ… CORRECT! '{fn_name}' is the vulnerable function. "
306
+ f"Vulnerability type '{vuln_type}' matches. Score: 1.0"
307
+ )
308
+ elif score == 0.5:
309
+ msg = (
310
+ f"⚠️ PARTIAL. '{fn_name}' is the right function, but the vulnerability type "
311
+ f"'{vuln_type}' was not precise. Score: 0.5"
312
+ )
313
+ else:
314
+ correct = self._grader.get_canonical_answer()
315
+ msg = (
316
+ f"❌ INCORRECT. '{fn_name}' is not the target vulnerable function. "
317
+ f"Correct answer: {correct['function']} ({correct['vulnerability']}). Score: 0.0"
318
+ )
319
+ return msg, Reward(
320
+ value=reward_val,
321
+ reason=f"Submission score={score:.1f}",
322
+ partial=False,
323
+ )
324
+
325
+ # ---- unknown action ----------------------------------------
326
+ return (
327
+ f"Unknown action type: {at}",
328
+ Reward(value=-0.10, reason="Unknown action", partial=True),
329
+ )
tasks/task1/grader.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grader.py (Task 1 – Targeted Vulnerability Detection)
3
+ -------------------------------------------------------
4
+ Deterministic grader. Score range: 0.0 – 1.0
5
+
6
+ 1.0 – correct function + correct vulnerability keyword
7
+ 0.5 – correct function + wrong/unrecognised vulnerability keyword
8
+ 0.0 – wrong function name
9
+ """
10
+ from __future__ import annotations
11
+ from typing import Dict, List, Optional
12
+
13
+ VULN_KEYWORDS: Dict[str, List[str]] = {
14
+ "reentrancy": [
15
+ "reentrancy", "re-entrancy", "reentrant", "re entrant",
16
+ "recursive call", "reentr",
17
+ ],
18
+ "missing access control": [
19
+ "access control", "missing access", "no access", "unauthorized",
20
+ "privilege", "permission", "onlyowner", "only owner",
21
+ "no modifier", "missing modifier", "no check", "anyone can call",
22
+ ],
23
+ "integer overflow": [
24
+ "overflow", "integer overflow", "arithmetic overflow",
25
+ "safemath", "safe math", "uint overflow", "wraparound",
26
+ "integer underflow", "underflow",
27
+ ],
28
+ "tx.origin authentication": [
29
+ "tx.origin", "txorigin", "tx origin", "phishing",
30
+ "origin authentication", "origin auth",
31
+ ],
32
+ "front-running": [
33
+ "front-running", "frontrunning", "front running", "mev",
34
+ "sandwich", "mempool", "commit reveal", "commit-reveal",
35
+ "gas price manipulation",
36
+ ],
37
+ "timestamp dependence": [
38
+ "timestamp", "block.timestamp", "time manipulation",
39
+ "miner timestamp", "time dependency", "timestamp dependence",
40
+ ],
41
+ "denial of service": [
42
+ "denial of service", " dos", "gas limit", "unbounded loop",
43
+ "block gas", " oog", "out of gas", "infinite loop", "unbounded array",
44
+ "gas exhaustion",
45
+ ],
46
+ "unchecked return value": [
47
+ "unchecked return", "return value", "unchecked transfer",
48
+ "silent failure", "safeerc20", "safe transfer", "ignored return",
49
+ "erc20 return",
50
+ ],
51
+ }
52
+
53
+
54
+ def _norm(text: str) -> str:
55
+ return text.strip().lower()
56
+
57
+
58
+ def _find_bucket(ground_truth_issue: str) -> Optional[str]:
59
+ """
60
+ Longest-match keyword search to identify canonical vulnerability bucket.
61
+ Longest match avoids short-keyword collisions (e.g. 'auth' in 'tx.origin authentication').
62
+ """
63
+ norm_gt = _norm(ground_truth_issue)
64
+ best: Optional[str] = None
65
+ best_len: int = 0
66
+ for canonical, keywords in VULN_KEYWORDS.items():
67
+ for kw in keywords:
68
+ if kw in norm_gt and len(kw) > best_len:
69
+ best_len = len(kw)
70
+ best = canonical
71
+ return best
72
+
73
+
74
+ def match_vuln_keyword(submitted: str, ground_truth_issue: str) -> bool:
75
+ bucket = _find_bucket(ground_truth_issue)
76
+ if bucket is None:
77
+ return _norm(submitted) in _norm(ground_truth_issue)
78
+ norm_sub = _norm(submitted)
79
+ return any(kw in norm_sub for kw in VULN_KEYWORDS[bucket])
80
+
81
+
82
+ class Task1Grader:
83
+ def __init__(self, target_function: str, vulnerability_issue: str) -> None:
84
+ self.target_function = target_function.lower()
85
+ self.vulnerability_issue = vulnerability_issue
86
+
87
+ def grade_submission(self, submitted_function: str, submitted_vuln_type: str) -> float:
88
+ if submitted_function.strip().lower() != self.target_function:
89
+ return 0.0
90
+ return 1.0 if match_vuln_keyword(submitted_vuln_type, self.vulnerability_issue) else 0.5
91
+
92
+ def reward_for_score(self, score: float) -> float:
93
+ if score == 1.0: return 5.0
94
+ if score == 0.5: return 1.0
95
+ return -1.5
96
+
97
+ def get_canonical_answer(self) -> Dict[str, str]:
98
+ return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
tasks/task2/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tasks/task2/__init__.py
3
+ -----------------------
4
+ Task 2: Property Discovery (PLACEHOLDER)
5
+
6
+ TODO: Implement this task.
7
+
8
+ Episode setup:
9
+ - One function from a Solidity file with known properties
10
+ - Agent must discover the natural-language property of the function
11
+
12
+ Actions (to implement):
13
+ - get_similar_rule : -0.20
14
+ - get_file_natspec : -0.03
15
+ - get_function_natspec : -0.08
16
+ - get_function_code : -0.06
17
+ - get_related_functions : -0.06
18
+ - get_io : -0.04
19
+ - submit_property : scored 0.0–5.0 by semantic similarity grader
20
+
21
+ See README.md for full task specification.
22
+ """
23
+
24
+ # TODO: Task 2 – Property Discovery
25
+ # from tasks.task2.environment import Task2Environment
26
+
27
+ __all__: list = []
tasks/task3/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tasks/task3/__init__.py
3
+ -----------------------
4
+ Task 3: Rule Checker (PLACEHOLDER)
5
+
6
+ TODO: Implement this task.
7
+
8
+ Episode setup:
9
+ - One Solidity file with at least one function breaking a given property
10
+ - Agent is shown the property in natural English
11
+
12
+ Actions (to implement):
13
+ - get_formalized_property : -0.03
14
+ - list_functions : -0.05
15
+ - get_function_metadata : -0.05
16
+ - get_function_code : -0.10
17
+ - get_state_variables : -0.05
18
+ - get_call_graph : -0.08
19
+ - submit_function :
20
+ - correct = +5.0
21
+ - subfunction of target = +1.5
22
+ - wrong = -1.5
23
+ (ONE submission per episode)
24
+
25
+ See README.md for full task specification.
26
+ """
27
+
28
+ # TODO: Task 3 – Rule Checker
29
+ # from tasks.task3.environment import Task3Environment
30
+
31
+ __all__: list = []
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # utils package
validate.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ validate.py
3
+ -----------
4
+ Pre-submission validation script.
5
+ Checks all OpenEnv spec requirements locally before submitting.
6
+
7
+ Usage:
8
+ python validate.py
9
+
10
+ Exit code 0 = all checks pass.
11
+ Exit code 1 = one or more checks failed.
12
+ """
13
+
14
+ import json
15
+ import sys
16
+ import traceback
17
+ from typing import Callable, List, Tuple
18
+
19
+ # ─────────────────────────────────────────────────────────────────────────────
20
+ # Helpers
21
+ # ─────────────────────────────────────────────────────────────────────────────
22
+
23
+ PASS = "βœ…"
24
+ FAIL = "❌"
25
+ SKIP = "⏭ "
26
+ results: List[Tuple[str, bool, str]] = []
27
+
28
+
29
+ def check(name: str, fn: Callable[[], None]) -> None:
30
+ try:
31
+ fn()
32
+ results.append((name, True, ""))
33
+ print(f" {PASS} {name}")
34
+ except Exception as e:
35
+ tb = traceback.format_exc(limit=3)
36
+ results.append((name, False, str(e)))
37
+ print(f" {FAIL} {name}")
38
+ print(f" {e}")
39
+
40
+
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
+ # Checks
43
+ # ─────────────────────────────────────────────────────────────────────────────
44
+
45
+ def check_imports():
46
+ from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult
47
+ from tasks.task1.environment import Task1Environment
48
+ from tasks.task1.grader import Task1Grader
49
+ from data.data_loader import load_contracts
50
+
51
+
52
+ def check_openenv_yaml():
53
+ import yaml
54
+ with open("openenv.yaml") as f:
55
+ spec = yaml.safe_load(f)
56
+ assert "name" in spec
57
+ assert "tasks" in spec
58
+ assert len(spec["tasks"]) >= 3, "Need at least 3 tasks defined"
59
+ assert "observation_space" in spec
60
+ assert "action_space" in spec
61
+ assert "reward" in spec
62
+
63
+
64
+ def check_pydantic_models():
65
+ from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
66
+ # Instantiate each model
67
+ obs = Observation(
68
+ task_id="t1", contract_name="C", contract_description="D",
69
+ available_actions=["submit"]
70
+ )
71
+ assert obs.task_id == "t1"
72
+
73
+ action = Action(action_type=ActionType.LIST_FUNCTIONS)
74
+ assert action.action_type == ActionType.LIST_FUNCTIONS
75
+
76
+ reward = Reward(value=1.0, reason="test")
77
+ assert reward.value == 1.0
78
+
79
+ step = StepResult(observation=obs, reward=reward, done=False)
80
+ assert not step.done
81
+
82
+ reset = ResetResult(observation=obs)
83
+ assert reset.observation.task_id == "t1"
84
+
85
+ state = StateResult(task_id="t1", contract_name="C", step_count=0,
86
+ cumulative_reward=0.0, done=False)
87
+ assert state.step_count == 0
88
+
89
+
90
+ def check_data_loading():
91
+ from data.data_loader import load_contracts, get_all_vulnerable_entries
92
+ contracts = load_contracts()
93
+ assert len(contracts) >= 1, "No contracts loaded"
94
+ entries = get_all_vulnerable_entries(contracts)
95
+ assert len(entries) >= 3, f"Need >= 3 vulnerable functions, got {len(entries)}"
96
+ for contract, fn in entries:
97
+ assert fn.get("vulnerable") is True
98
+ assert fn.get("vulnerability_details") is not None
99
+ assert "issue" in fn["vulnerability_details"]
100
+
101
+
102
+ def check_env_reset():
103
+ from tasks.task1.environment import Task1Environment
104
+ env = Task1Environment()
105
+ result = env.reset(seed=42)
106
+ assert result.observation is not None
107
+ assert result.observation.task_id == "task1_vuln_detection"
108
+ assert result.observation.contract_name != ""
109
+ assert not result.observation.done
110
+ assert result.observation.step_count == 0
111
+
112
+
113
+ def check_env_step():
114
+ from tasks.task1.environment import Task1Environment
115
+ from env.schemas import Action, ActionType
116
+ env = Task1Environment()
117
+ env.reset(seed=42)
118
+ result = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
119
+ assert result.observation is not None
120
+ assert isinstance(result.reward.value, float)
121
+ assert isinstance(result.done, bool)
122
+ assert "info" in result.model_dump()
123
+
124
+
125
+ def check_env_state():
126
+ from tasks.task1.environment import Task1Environment
127
+ env = Task1Environment()
128
+ env.reset(seed=42)
129
+ state = env.state()
130
+ assert state.task_id == "task1_vuln_detection"
131
+ assert state.contract_name != ""
132
+ assert state.target_function is not None # exposed for debugging
133
+
134
+
135
+ def check_grader_scores_in_range():
136
+ from tasks.task1.grader import Task1Grader
137
+ cases = [
138
+ ("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
139
+ ("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
140
+ ("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
141
+ ]
142
+ for tf, issue, sf, sv, expected in cases:
143
+ g = Task1Grader(tf, issue)
144
+ score = g.grade_submission(sf, sv)
145
+ assert 0.0 <= score <= 1.0, f"Score {score} out of range"
146
+ assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
147
+
148
+
149
+ def check_grader_deterministic():
150
+ from tasks.task1.grader import Task1Grader
151
+ g = Task1Grader("withdraw", "Reentrancy vulnerability")
152
+ s1 = g.grade_submission("withdraw", "reentrancy")
153
+ s2 = g.grade_submission("withdraw", "reentrancy")
154
+ assert s1 == s2 == 1.0, "Grader must be deterministic"
155
+
156
+
157
+ def check_reward_shaping():
158
+ """Verify reward is non-binary (multiple distinct values across steps)."""
159
+ from tasks.task1.environment import Task1Environment
160
+ from env.schemas import Action, ActionType
161
+ env = Task1Environment()
162
+ env.reset(seed=1)
163
+ rewards = set()
164
+ for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_FILE_METADATA, ActionType.GET_CALL_GRAPH]:
165
+ r = env.step(Action(action_type=at))
166
+ rewards.add(round(r.reward.value, 4))
167
+ # Should have at least 2 distinct shaping reward values
168
+ assert len(rewards) >= 2, f"Expected multiple reward values, got {rewards}"
169
+
170
+
171
+ def check_episode_boundary():
172
+ """Episode must end after submit and raise on subsequent step."""
173
+ from tasks.task1.environment import Task1Environment
174
+ from env.schemas import Action, ActionType
175
+ env = Task1Environment()
176
+ env.reset(seed=2)
177
+ env.step(Action(action_type=ActionType.SUBMIT, params={
178
+ "function_name": "withdraw", "vulnerability_type": "test"
179
+ }))
180
+ try:
181
+ env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
182
+ raise AssertionError("Should have raised RuntimeError after episode end")
183
+ except RuntimeError:
184
+ pass # Expected
185
+
186
+
187
+ def check_repeated_query_penalty():
188
+ from tasks.task1.environment import Task1Environment
189
+ from env.schemas import Action, ActionType
190
+ env = Task1Environment()
191
+ env.reset(seed=3)
192
+ env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
193
+ r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
194
+ assert r.reward.value == -0.40, f"Expected -0.40 for repeated query, got {r.reward.value}"
195
+
196
+
197
+ def check_tasks_list():
198
+ """All three tasks must be listed (even if placeholders)."""
199
+ from tasks.task2 import __all__ as t2 # noqa
200
+ from tasks.task3 import __all__ as t3 # noqa
201
+
202
+
203
+ def check_dockerfile_exists():
204
+ import os
205
+ assert os.path.exists("Dockerfile"), "Dockerfile is missing"
206
+ with open("Dockerfile") as f:
207
+ content = f.read()
208
+ assert "7860" in content, "Dockerfile must EXPOSE 7860 (HF Spaces)"
209
+ assert "uvicorn" in content or "CMD" in content
210
+
211
+
212
+ def check_inference_script():
213
+ import os
214
+ assert os.path.exists("inference.py"), "inference.py is missing"
215
+ with open("inference.py") as f:
216
+ content = f.read()
217
+ assert "OPENAI_API_KEY" in content or "HF_TOKEN" in content, \
218
+ "inference.py must read API credentials from env vars"
219
+ assert "API_BASE_URL" in content
220
+ assert "MODEL_NAME" in content
221
+
222
+
223
+ def check_baseline_json_schema():
224
+ """baseline_scores.json must have valid schema if it exists."""
225
+ import os
226
+ if not os.path.exists("baseline_scores.json"):
227
+ return # OK β€” file is generated at runtime
228
+ with open("baseline_scores.json") as f:
229
+ data = json.load(f)
230
+ assert "tasks" in data
231
+ for task in data["tasks"]:
232
+ score = task["avg_grader_score"]
233
+ assert 0.0 <= score <= 1.0, f"Score {score} out of range"
234
+
235
+
236
+ # ─────────────────────────────────────────────────────────────────────────────
237
+ # Runner
238
+ # ─────────────────────────────────────────────────────────────────────────────
239
+
240
+ def main():
241
+ print("=" * 60)
242
+ print("OpenEnv Pre-Submission Validation")
243
+ print("=" * 60)
244
+
245
+ all_checks = [
246
+ ("Python imports", check_imports),
247
+ ("openenv.yaml format", check_openenv_yaml),
248
+ ("Pydantic model types", check_pydantic_models),
249
+ ("Dataset loading (3+ vulns)", check_data_loading),
250
+ ("env.reset() β†’ ResetResult", check_env_reset),
251
+ ("env.step() β†’ StepResult", check_env_step),
252
+ ("env.state() β†’ StateResult", check_env_state),
253
+ ("Grader scores in [0.0, 1.0]", check_grader_scores_in_range),
254
+ ("Grader is deterministic", check_grader_deterministic),
255
+ ("Reward shaping (non-binary)", check_reward_shaping),
256
+ ("Episode boundary (done=True)",check_episode_boundary),
257
+ ("Repeated query penalty", check_repeated_query_penalty),
258
+ ("Task 2 & 3 placeholders", check_tasks_list),
259
+ ("Dockerfile exists + port", check_dockerfile_exists),
260
+ ("inference.py exists + vars", check_inference_script),
261
+ ("baseline_scores.json schema", check_baseline_json_schema),
262
+ ]
263
+
264
+ print()
265
+ for name, fn in all_checks:
266
+ check(name, fn)
267
+
268
+ print()
269
+ passed = sum(1 for _, ok, _ in results if ok)
270
+ total = len(results)
271
+ failed = [(n, msg) for n, ok, msg in results if not ok]
272
+
273
+ print("=" * 60)
274
+ print(f"Results: {passed}/{total} checks passed")
275
+
276
+ if failed:
277
+ print("\nFailed checks:")
278
+ for name, msg in failed:
279
+ print(f" {FAIL} {name}: {msg}")
280
+ print()
281
+ print("❌ VALIDATION FAILED β€” fix the issues above before submitting.")
282
+ sys.exit(1)
283
+ else:
284
+ print()
285
+ print("βœ… ALL CHECKS PASSED β€” ready to submit!")
286
+ sys.exit(0)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()