Spaces:
Running
Running
ajaxwin commited on
Commit Β·
08c19c7
0
Parent(s):
Inital Commit
Browse files- .gitignore +13 -0
- Dockerfile +35 -0
- README.md +301 -0
- SPACES_README.md +57 -0
- app.py +265 -0
- data/Template.json +149 -0
- data/__init__.py +1 -0
- data/contracts.json +0 -0
- data/data_loader.py +84 -0
- demo.py +287 -0
- env/__init__.py +1 -0
- env/base_env.py +89 -0
- env/schemas.py +150 -0
- eval.py +290 -0
- inference.py +326 -0
- openenv.yaml +169 -0
- requirements.txt +7 -0
- tasks/__init__.py +1 -0
- tasks/task1/__init__.py +5 -0
- tasks/task1/environment.py +329 -0
- tasks/task1/grader.py +98 -0
- tasks/task2/__init__.py +27 -0
- tasks/task3/__init__.py +31 -0
- utils/__init__.py +1 -0
- validate.py +290 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
.venv
|
| 6 |
+
venv/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.DS_Store
|
| 11 |
+
baseline_scores.json
|
| 12 |
+
*.log
|
| 13 |
+
.pytest_cache/
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------
|
| 2 |
+
# Smart Contract Audit RL Environment
|
| 3 |
+
# Hugging Face Space β Docker runtime
|
| 4 |
+
# ---------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
FROM python:3.11-slim
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# System deps
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Install Python deps first (layer cache)
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy project
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Create empty __init__ files if missing (safety)
|
| 23 |
+
RUN touch env/__init__.py tasks/__init__.py tasks/task1/__init__.py \
|
| 24 |
+
tasks/task2/__init__.py tasks/task3/__init__.py \
|
| 25 |
+
data/__init__.py utils/__init__.py
|
| 26 |
+
|
| 27 |
+
# HF Spaces requires port 7860
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Healthcheck
|
| 31 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
|
| 32 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 33 |
+
|
| 34 |
+
# Launch FastAPI
|
| 35 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Smart Contract Audit RL Environment
|
| 2 |
+
|
| 3 |
+
> **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
|
| 4 |
+
> Agents learn to audit real-world Solidity contracts β finding vulnerabilities, discovering properties, and checking rule compliance β tasks that professional auditors perform daily.
|
| 5 |
+
|
| 6 |
+
[](openenv.yaml)
|
| 7 |
+
[](https://huggingface.co/spaces)
|
| 8 |
+
[](https://python.org)
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Motivation
|
| 13 |
+
|
| 14 |
+
Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws. This environment lets agents practice exactly that workflow β exploring contract code through targeted queries and submitting findings β providing a challenging, real-world benchmark for reasoning and code-understanding agents.
|
| 15 |
+
|
| 16 |
+
Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits (reentrancy, integer overflow, access control bypasses, etc.).
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Environment Description
|
| 21 |
+
|
| 22 |
+
The environment hosts **3 tasks** of increasing difficulty:
|
| 23 |
+
|
| 24 |
+
| Task | Name | Difficulty | Status |
|
| 25 |
+
|------|------|------------|--------|
|
| 26 |
+
| 1 | Targeted Vulnerability Detection | Medium | β
Active |
|
| 27 |
+
| 2 | Property Discovery | Hard | β³ Placeholder |
|
| 28 |
+
| 3 | Rule Checker | Easy | β³ Placeholder |
|
| 29 |
+
|
| 30 |
+
### Task 1 β Targeted Vulnerability Detection *(Medium)*
|
| 31 |
+
|
| 32 |
+
**Setup:** The agent is shown a Solidity contract (4β6 functions). One function contains a critical vulnerability.
|
| 33 |
+
|
| 34 |
+
**Objective:** Identify the vulnerable function and describe the vulnerability type in 2β3 words.
|
| 35 |
+
|
| 36 |
+
**Episode lifecycle:**
|
| 37 |
+
1. `reset()` β randomly selects one of 8 vulnerable (contract, function) pairs from the dataset
|
| 38 |
+
2. Agent receives the contract name and description
|
| 39 |
+
3. Agent explores using the action API (each action has a small cost)
|
| 40 |
+
4. Agent calls `submit(function_name, vulnerability_type)` to end the episode
|
| 41 |
+
5. Grader assigns 0.0β1.0 score
|
| 42 |
+
|
| 43 |
+
**Vulnerability types in the dataset:**
|
| 44 |
+
- Reentrancy
|
| 45 |
+
- Missing access control
|
| 46 |
+
- Integer overflow (Solidity <0.8)
|
| 47 |
+
- tx.origin authentication
|
| 48 |
+
- Front-running
|
| 49 |
+
- Timestamp dependence
|
| 50 |
+
- Denial of service (unbounded loop)
|
| 51 |
+
- Unchecked ERC-20 return value
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
### Task 2 β Property Discovery *(Hard)* [Placeholder]
|
| 56 |
+
|
| 57 |
+
Given a single Solidity function, the agent must discover its natural-language correctness property. Grading uses semantic similarity to the ground-truth property. *Implementation coming soon.*
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
### Task 3 β Rule Checker *(Easy)* [Placeholder]
|
| 62 |
+
|
| 63 |
+
Given a natural-language property and a contract, the agent must identify which function violates that property. *Implementation coming soon.*
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## Action Space
|
| 68 |
+
|
| 69 |
+
All actions are described below. **Repeated identical queries cost β0.40.**
|
| 70 |
+
|
| 71 |
+
| Action | Key Params | Reward |
|
| 72 |
+
|--------|-----------|--------|
|
| 73 |
+
| `list_functions` | β | β0.05 |
|
| 74 |
+
| `get_function_code` | `function_name` | +0.05 (target) / β0.10 (other) |
|
| 75 |
+
| `get_function_summary` | `function_name` | +0.03 (target) / β0.05 (other) |
|
| 76 |
+
| `get_file_metadata` | β | β0.04 |
|
| 77 |
+
| `get_state_variable` | `variable_name` (opt.) | β0.05 |
|
| 78 |
+
| `get_call_graph` | β | β0.08 |
|
| 79 |
+
| `submit` | `function_name`, `vulnerability_type` | +5.0 / +1.0 / β1.5 |
|
| 80 |
+
|
| 81 |
+
**Submit scoring:**
|
| 82 |
+
- **+5.0** β correct function AND correct vulnerability keyword β grader score = 1.0
|
| 83 |
+
- **+1.0** β correct function, unrecognised vulnerability type β grader score = 0.5
|
| 84 |
+
- **β1.5** β wrong function β grader score = 0.0
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Observation Space
|
| 89 |
+
|
| 90 |
+
Every `step()` and `reset()` returns an `Observation` object:
|
| 91 |
+
|
| 92 |
+
```json
|
| 93 |
+
{
|
| 94 |
+
"task_id": "task1_vuln_detection",
|
| 95 |
+
"contract_name": "SimpleVault",
|
| 96 |
+
"contract_description": "An ETH vault that allows users to deposit and withdraw...",
|
| 97 |
+
"available_actions": ["list_functions", "get_function_code", ...],
|
| 98 |
+
"last_action": "get_function_code",
|
| 99 |
+
"last_action_result": "// withdraw\nfunction withdraw(uint256 amount) ...",
|
| 100 |
+
"step_count": 3,
|
| 101 |
+
"cumulative_reward": -0.05,
|
| 102 |
+
"done": false,
|
| 103 |
+
"extra": {
|
| 104 |
+
"solidity_version": "0.8.0",
|
| 105 |
+
"hint": "Identify the vulnerable function and its issue."
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Project Structure
|
| 113 |
+
|
| 114 |
+
```
|
| 115 |
+
smart-contract-env/
|
| 116 |
+
βββ data/
|
| 117 |
+
β βββ contracts.json # 4 contracts, 8 vulnerabilities
|
| 118 |
+
β βββ data_loader.py # JSON parsing and episode sampling
|
| 119 |
+
βββ env/
|
| 120 |
+
β βββ base_env.py # Abstract OpenEnv base class
|
| 121 |
+
β βββ schemas.py # Pydantic models (Observation, Action, Rewardβ¦)
|
| 122 |
+
βββ tasks/
|
| 123 |
+
β βββ task1/
|
| 124 |
+
β β βββ environment.py # Full Task 1 RL environment
|
| 125 |
+
β β βββ grader.py # Deterministic 0.0β1.0 grader
|
| 126 |
+
β βββ task2/ # TODO: Property Discovery
|
| 127 |
+
β βββ task3/ # TODO: Rule Checker
|
| 128 |
+
βββ utils/
|
| 129 |
+
βββ app.py # FastAPI server (OpenEnv HTTP interface)
|
| 130 |
+
βββ inference.py # Baseline inference script (OpenAI client)
|
| 131 |
+
βββ openenv.yaml # OpenEnv spec metadata
|
| 132 |
+
βββ Dockerfile
|
| 133 |
+
βββ requirements.txt
|
| 134 |
+
βββ README.md
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Setup & Usage
|
| 140 |
+
|
| 141 |
+
### Option A β Run locally
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
# 1. Clone and install
|
| 145 |
+
git clone <repo>
|
| 146 |
+
cd smart-contract-env
|
| 147 |
+
pip install -r requirements.txt
|
| 148 |
+
|
| 149 |
+
# 2. Start the server
|
| 150 |
+
python app.py
|
| 151 |
+
# β http://localhost:7860
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Option B β Docker
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
docker build -t sc-audit-env .
|
| 158 |
+
docker run -p 7860:7860 sc-audit-env
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Option C β Python (no server)
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
from tasks.task1.environment import Task1Environment
|
| 165 |
+
from env.schemas import Action, ActionType
|
| 166 |
+
|
| 167 |
+
env = Task1Environment()
|
| 168 |
+
result = env.reset(seed=42)
|
| 169 |
+
print(result.observation.contract_name)
|
| 170 |
+
|
| 171 |
+
action = Action(action_type=ActionType.LIST_FUNCTIONS)
|
| 172 |
+
step = env.step(action)
|
| 173 |
+
print(step.observation.last_action_result)
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## HTTP API
|
| 179 |
+
|
| 180 |
+
| Method | Endpoint | Description |
|
| 181 |
+
|--------|----------|-------------|
|
| 182 |
+
| `GET` | `/health` | Liveness probe |
|
| 183 |
+
| `GET` | `/tasks` | List all tasks |
|
| 184 |
+
| `POST` | `/reset` | Start new episode |
|
| 185 |
+
| `POST` | `/step` | Take one action |
|
| 186 |
+
| `GET` | `/state` | Debug: internal state |
|
| 187 |
+
| `GET` | `/action_space` | Action space definition |
|
| 188 |
+
| `GET` | `/observation_space` | Observation space definition |
|
| 189 |
+
|
| 190 |
+
**Example session:**
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
# Reset
|
| 194 |
+
curl -X POST http://localhost:7860/reset \
|
| 195 |
+
-H "Content-Type: application/json" \
|
| 196 |
+
-d '{"task_id": "task1_vuln_detection", "seed": 42}'
|
| 197 |
+
|
| 198 |
+
# List functions
|
| 199 |
+
curl -X POST "http://localhost:7860/step" \
|
| 200 |
+
-H "Content-Type: application/json" \
|
| 201 |
+
-d '{"action_type": "list_functions", "params": {}}'
|
| 202 |
+
|
| 203 |
+
# Submit answer
|
| 204 |
+
curl -X POST "http://localhost:7860/step" \
|
| 205 |
+
-H "Content-Type: application/json" \
|
| 206 |
+
-d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## Running the Baseline
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 215 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 216 |
+
export HF_TOKEN="sk-..."
|
| 217 |
+
|
| 218 |
+
python inference.py
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
Outputs results to stdout and writes `baseline_scores.json`.
|
| 222 |
+
|
| 223 |
+
**Expected baseline scores (gpt-4o-mini, 3 episodes):**
|
| 224 |
+
|
| 225 |
+
| Task | Avg Grader Score | Notes |
|
| 226 |
+
|------|-----------------|-------|
|
| 227 |
+
| Task 1 | ~0.67 | Medium difficulty; model identifies common vulns well |
|
| 228 |
+
| Task 2 | 0.00 | Placeholder |
|
| 229 |
+
| Task 3 | 0.00 | Placeholder |
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## Baseline Scores
|
| 234 |
+
|
| 235 |
+
```json
|
| 236 |
+
{
|
| 237 |
+
"model": "gpt-4o-mini",
|
| 238 |
+
"tasks": [
|
| 239 |
+
{
|
| 240 |
+
"task_id": "task1_vuln_detection",
|
| 241 |
+
"avg_grader_score": 0.667,
|
| 242 |
+
"avg_cumulative_reward": 2.14
|
| 243 |
+
},
|
| 244 |
+
{ "task_id": "task2_property_discovery", "avg_grader_score": 0.0 },
|
| 245 |
+
{ "task_id": "task3_rule_checker", "avg_grader_score": 0.0 }
|
| 246 |
+
],
|
| 247 |
+
"overall_avg_score": 0.667
|
| 248 |
+
}
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## Grader Details
|
| 254 |
+
|
| 255 |
+
The Task 1 grader is **fully deterministic**:
|
| 256 |
+
|
| 257 |
+
1. **Function name check** β case-insensitive exact match against the ground-truth vulnerable function. Wrong function β score = 0.0 immediately.
|
| 258 |
+
|
| 259 |
+
2. **Vulnerability type check** β checks whether the submitted string contains any accepted keyword from a predefined keyword table (e.g. `"reentrancy"` table includes: `reentrancy`, `re-entrancy`, `reentrant`, `recursive call`). Match β 1.0; no match β 0.5.
|
| 260 |
+
|
| 261 |
+
Scores map to terminal rewards: 1.0 β +5, 0.5 β +1, 0.0 β β1.5.
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
## OpenEnv Spec Compliance
|
| 266 |
+
|
| 267 |
+
- β
Typed `Observation`, `Action`, `Reward` Pydantic models
|
| 268 |
+
- β
`step(action) β StepResult(observation, reward, done, info)`
|
| 269 |
+
- β
`reset() β ResetResult(observation, info)`
|
| 270 |
+
- β
`state() β StateResult`
|
| 271 |
+
- β
`openenv.yaml` metadata
|
| 272 |
+
- β
3 tasks defined (1 active, 2 placeholders)
|
| 273 |
+
- β
Grader scores in [0.0, 1.0]
|
| 274 |
+
- β
Shaped rewards (not just binary)
|
| 275 |
+
- β
Dockerfile + HF Space deployment
|
| 276 |
+
- β
Baseline `inference.py` using OpenAI client
|
| 277 |
+
|
| 278 |
+
---
|
| 279 |
+
|
| 280 |
+
## Deploying to Hugging Face Spaces
|
| 281 |
+
|
| 282 |
+
1. Create a new **Docker** Space on [huggingface.co/spaces](https://huggingface.co/spaces)
|
| 283 |
+
2. Set the tag `openenv` in the Space metadata
|
| 284 |
+
3. Push this repository:
|
| 285 |
+
|
| 286 |
+
```bash
|
| 287 |
+
git remote add hf https://huggingface.co/spaces/<your-username>/<space-name>
|
| 288 |
+
git push hf main
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
The Space will build the Docker image and serve the FastAPI app on port 7860.
|
| 292 |
+
|
| 293 |
+
---
|
| 294 |
+
|
| 295 |
+
## License
|
| 296 |
+
|
| 297 |
+
MIT β see `LICENSE`.
|
| 298 |
+
|
| 299 |
+
## Data Attribution
|
| 300 |
+
|
| 301 |
+
Contract vulnerability patterns inspired by and adapted from **Certora** audit findings on production DeFi protocols.
|
SPACES_README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Smart Contract Audit RL Environment
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- smart-contracts
|
| 12 |
+
- solidity
|
| 13 |
+
- security
|
| 14 |
+
- evaluation
|
| 15 |
+
license: mit
|
| 16 |
+
short_description: OpenEnv RL environment for smart contract security auditing
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Smart Contract Audit RL Environment
|
| 20 |
+
|
| 21 |
+
> OpenEnv-compliant RL environment for Solidity security analysis.
|
| 22 |
+
|
| 23 |
+
This Space exposes the full OpenEnv HTTP interface for **Task 1: Targeted Vulnerability Detection**.
|
| 24 |
+
Agents explore Solidity contracts using a structured action API and identify vulnerable functions.
|
| 25 |
+
|
| 26 |
+
## Quick start
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
# Reset β start a new episode
|
| 30 |
+
curl -X POST $SPACE_URL/reset \
|
| 31 |
+
-H "Content-Type: application/json" \
|
| 32 |
+
-d '{"task_id": "task1_vuln_detection", "seed": 42}'
|
| 33 |
+
|
| 34 |
+
# Step β list contract functions
|
| 35 |
+
curl -X POST $SPACE_URL/step \
|
| 36 |
+
-H "Content-Type: application/json" \
|
| 37 |
+
-d '{"action_type": "list_functions", "params": {}}'
|
| 38 |
+
|
| 39 |
+
# Submit answer
|
| 40 |
+
curl -X POST $SPACE_URL/step \
|
| 41 |
+
-H "Content-Type: application/json" \
|
| 42 |
+
-d '{"action_type": "submit", "params": {"function_name": "withdraw", "vulnerability_type": "reentrancy"}}'
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Endpoints
|
| 46 |
+
|
| 47 |
+
| Method | Path | Description |
|
| 48 |
+
|--------|------|-------------|
|
| 49 |
+
| GET | `/health` | Liveness probe |
|
| 50 |
+
| GET | `/tasks` | All tasks + status |
|
| 51 |
+
| POST | `/reset` | New episode |
|
| 52 |
+
| POST | `/step` | Take action |
|
| 53 |
+
| GET | `/state` | Debug state |
|
| 54 |
+
| GET | `/action_space` | Action schema |
|
| 55 |
+
| GET | `/observation_space` | Observation schema |
|
| 56 |
+
|
| 57 |
+
See the full [README](README.md) for detailed documentation.
|
app.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py
|
| 3 |
+
------
|
| 4 |
+
FastAPI server exposing the OpenEnv HTTP interface.
|
| 5 |
+
|
| 6 |
+
Endpoints:
|
| 7 |
+
POST /reset β start a new episode
|
| 8 |
+
POST /step β take one action
|
| 9 |
+
GET /state β inspect internal state (debugging)
|
| 10 |
+
GET /tasks β list available tasks
|
| 11 |
+
GET /health β liveness probe
|
| 12 |
+
GET /action_space β action space description
|
| 13 |
+
GET /observation_space β observation space description
|
| 14 |
+
|
| 15 |
+
Sessions are keyed by a UUID passed as the `session_id` query parameter.
|
| 16 |
+
If omitted, a default single-session is used (fine for sequential runs).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import uuid
|
| 20 |
+
from typing import Dict, Optional
|
| 21 |
+
|
| 22 |
+
from fastapi import FastAPI, HTTPException, Query
|
| 23 |
+
from fastapi.responses import JSONResponse
|
| 24 |
+
from pydantic import BaseModel
|
| 25 |
+
|
| 26 |
+
from env.schemas import Action, ActionType, TaskInfo
|
| 27 |
+
from tasks.task1.environment import Task1Environment
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# App init
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
app = FastAPI(
|
| 34 |
+
title="Smart Contract Audit RL Environment",
|
| 35 |
+
description=(
|
| 36 |
+
"OpenEnv-compliant reinforcement learning environment for smart contract "
|
| 37 |
+
"security analysis. Train and evaluate agents on real-world Solidity audit tasks."
|
| 38 |
+
),
|
| 39 |
+
version="1.0.0",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Session management
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
_sessions: Dict[str, Task1Environment] = {}
|
| 47 |
+
DEFAULT_SESSION = "default"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_or_create_session(session_id: str, task_id: str = "task1_vuln_detection") -> Task1Environment:
|
| 51 |
+
if session_id not in _sessions:
|
| 52 |
+
env = _create_env(task_id)
|
| 53 |
+
_sessions[session_id] = env
|
| 54 |
+
return _sessions[session_id]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _create_env(task_id: str) -> Task1Environment:
|
| 58 |
+
if task_id == "task1_vuln_detection":
|
| 59 |
+
return Task1Environment()
|
| 60 |
+
# TODO: elif task_id == "task2_property_discovery": return Task2Environment()
|
| 61 |
+
# TODO: elif task_id == "task3_rule_checker": return Task3Environment()
|
| 62 |
+
raise HTTPException(
|
| 63 |
+
status_code=400,
|
| 64 |
+
detail=f"Unknown task_id '{task_id}'. Available: ['task1_vuln_detection']",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Request/response models
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
class ResetRequest(BaseModel):
|
| 73 |
+
task_id: str = "task1_vuln_detection"
|
| 74 |
+
seed: Optional[int] = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class StepRequest(BaseModel):
|
| 78 |
+
action_type: str
|
| 79 |
+
params: dict = {}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# Routes
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
@app.get("/health")
|
| 87 |
+
def health():
|
| 88 |
+
"""Liveness probe β returns 200 OK."""
|
| 89 |
+
return {"status": "ok", "version": "1.0.0"}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@app.get("/tasks")
|
| 93 |
+
def list_tasks():
|
| 94 |
+
"""List all available tasks."""
|
| 95 |
+
tasks = [
|
| 96 |
+
TaskInfo(
|
| 97 |
+
task_id="task1_vuln_detection",
|
| 98 |
+
name="Targeted Vulnerability Detection",
|
| 99 |
+
difficulty="medium",
|
| 100 |
+
description=(
|
| 101 |
+
"Given a Solidity contract, identify the vulnerable function "
|
| 102 |
+
"and describe the vulnerability type in 2-3 words."
|
| 103 |
+
),
|
| 104 |
+
status="active",
|
| 105 |
+
),
|
| 106 |
+
TaskInfo(
|
| 107 |
+
task_id="task2_property_discovery",
|
| 108 |
+
name="Property Discovery",
|
| 109 |
+
difficulty="hard",
|
| 110 |
+
description=(
|
| 111 |
+
"Given a Solidity function, discover the natural-language property "
|
| 112 |
+
"that describes its correct behaviour."
|
| 113 |
+
),
|
| 114 |
+
status="placeholder",
|
| 115 |
+
),
|
| 116 |
+
TaskInfo(
|
| 117 |
+
task_id="task3_rule_checker",
|
| 118 |
+
name="Rule Checker",
|
| 119 |
+
difficulty="easy",
|
| 120 |
+
description=(
|
| 121 |
+
"Given a property in English, identify which function in the contract "
|
| 122 |
+
"violates that property."
|
| 123 |
+
),
|
| 124 |
+
status="placeholder",
|
| 125 |
+
),
|
| 126 |
+
]
|
| 127 |
+
return {"tasks": [t.model_dump() for t in tasks]}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@app.post("/reset")
|
| 131 |
+
def reset(
|
| 132 |
+
body: ResetRequest,
|
| 133 |
+
session_id: str = Query(default=DEFAULT_SESSION),
|
| 134 |
+
):
|
| 135 |
+
"""Reset the environment and start a new episode."""
|
| 136 |
+
env = _create_env(body.task_id)
|
| 137 |
+
_sessions[session_id] = env
|
| 138 |
+
result = env.reset(seed=body.seed)
|
| 139 |
+
return result.model_dump()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.post("/step")
|
| 143 |
+
def step(
|
| 144 |
+
body: StepRequest,
|
| 145 |
+
session_id: str = Query(default=DEFAULT_SESSION),
|
| 146 |
+
):
|
| 147 |
+
"""Apply an action and advance the episode."""
|
| 148 |
+
env = _sessions.get(session_id)
|
| 149 |
+
if env is None:
|
| 150 |
+
raise HTTPException(
|
| 151 |
+
status_code=400,
|
| 152 |
+
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 153 |
+
)
|
| 154 |
+
try:
|
| 155 |
+
action_type = ActionType(body.action_type)
|
| 156 |
+
except ValueError:
|
| 157 |
+
raise HTTPException(
|
| 158 |
+
status_code=400,
|
| 159 |
+
detail=f"Unknown action_type '{body.action_type}'. "
|
| 160 |
+
f"Valid: {[a.value for a in ActionType]}",
|
| 161 |
+
)
|
| 162 |
+
action = Action(action_type=action_type, params=body.params)
|
| 163 |
+
try:
|
| 164 |
+
result = env.step(action)
|
| 165 |
+
except RuntimeError as e:
|
| 166 |
+
raise HTTPException(status_code=409, detail=str(e))
|
| 167 |
+
return result.model_dump()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@app.get("/state")
|
| 171 |
+
def state(session_id: str = Query(default=DEFAULT_SESSION)):
|
| 172 |
+
"""Return current internal state (for debugging; not for agents)."""
|
| 173 |
+
env = _sessions.get(session_id)
|
| 174 |
+
if env is None:
|
| 175 |
+
raise HTTPException(
|
| 176 |
+
status_code=400,
|
| 177 |
+
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 178 |
+
)
|
| 179 |
+
return env.state().model_dump()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@app.get("/action_space")
|
| 183 |
+
def action_space(task_id: str = "task1_vuln_detection"):
|
| 184 |
+
"""Describe the action space for a task."""
|
| 185 |
+
if task_id == "task1_vuln_detection":
|
| 186 |
+
return {
|
| 187 |
+
"task_id": task_id,
|
| 188 |
+
"actions": [
|
| 189 |
+
{
|
| 190 |
+
"type": "list_functions",
|
| 191 |
+
"params": {},
|
| 192 |
+
"reward": -0.05,
|
| 193 |
+
"description": "List all function names in the contract",
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"type": "get_function_code",
|
| 197 |
+
"params": {"function_name": "string"},
|
| 198 |
+
"reward": "+0.05 (target fn) / -0.10 (wrong fn)",
|
| 199 |
+
"description": "Retrieve the full Solidity code of a function",
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"type": "get_function_summary",
|
| 203 |
+
"params": {"function_name": "string"},
|
| 204 |
+
"reward": "+0.03 (target fn) / -0.05 (wrong fn)",
|
| 205 |
+
"description": "Retrieve the NatSpec comment/summary of a function",
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"type": "get_file_metadata",
|
| 209 |
+
"params": {},
|
| 210 |
+
"reward": -0.04,
|
| 211 |
+
"description": "Retrieve contract-level metadata (version, author, description)",
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"type": "get_state_variable",
|
| 215 |
+
"params": {"variable_name": "string (optional)"},
|
| 216 |
+
"reward": -0.05,
|
| 217 |
+
"description": "Retrieve a state variable or list all variables",
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"type": "get_call_graph",
|
| 221 |
+
"params": {},
|
| 222 |
+
"reward": -0.08,
|
| 223 |
+
"description": "Retrieve the function call graph",
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"type": "submit",
|
| 227 |
+
"params": {
|
| 228 |
+
"function_name": "string",
|
| 229 |
+
"vulnerability_type": "string",
|
| 230 |
+
},
|
| 231 |
+
"reward": "+5.0 (correct) / +1.0 (right fn, wrong vuln) / -1.5 (wrong)",
|
| 232 |
+
"description": "Submit your final answer. Ends the episode.",
|
| 233 |
+
},
|
| 234 |
+
],
|
| 235 |
+
}
|
| 236 |
+
return {"error": f"No action space defined for task '{task_id}'"}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@app.get("/observation_space")
|
| 240 |
+
def observation_space():
|
| 241 |
+
"""Describe the observation space."""
|
| 242 |
+
return {
|
| 243 |
+
"type": "object",
|
| 244 |
+
"fields": {
|
| 245 |
+
"task_id": "string β active task identifier",
|
| 246 |
+
"contract_name": "string β name of the Solidity contract",
|
| 247 |
+
"contract_description": "string β what the contract does",
|
| 248 |
+
"available_actions": "list[string] β valid action types",
|
| 249 |
+
"last_action": "string|null β the previous action type",
|
| 250 |
+
"last_action_result": "string|null β human-readable result of last action",
|
| 251 |
+
"step_count": "int β steps taken so far",
|
| 252 |
+
"cumulative_reward": "float β running reward total",
|
| 253 |
+
"done": "bool β True when episode is over",
|
| 254 |
+
"extra": "object β task-specific hints and metadata",
|
| 255 |
+
},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
# Entry point
|
| 261 |
+
# ---------------------------------------------------------------------------
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
import uvicorn
|
| 265 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
data/Template.json
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"contract_name": "ExampleContract",
|
| 3 |
+
"file_name": "ExampleContract.sol",
|
| 4 |
+
|
| 5 |
+
"metadata": {
|
| 6 |
+
"license": "MIT",
|
| 7 |
+
"solidity_version": "0.8.0",
|
| 8 |
+
"description": "Example contract demonstrating the template structure",
|
| 9 |
+
"author": "Example Author"
|
| 10 |
+
},
|
| 11 |
+
|
| 12 |
+
"state_variables": [
|
| 13 |
+
{
|
| 14 |
+
"name": "owner",
|
| 15 |
+
"type": "address",
|
| 16 |
+
"visibility": "public",
|
| 17 |
+
"mutability": "",
|
| 18 |
+
"description": "Address of the contract owner"
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"name": "balances",
|
| 22 |
+
"type": "mapping(address => uint256)",
|
| 23 |
+
"visibility": "internal",
|
| 24 |
+
"mutability": "",
|
| 25 |
+
"description": "User token balances"
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
|
| 29 |
+
"functions": [
|
| 30 |
+
{
|
| 31 |
+
"name": "transfer",
|
| 32 |
+
"signature": "transfer(address to, uint256 amount)",
|
| 33 |
+
"code": "function transfer(address to, uint256 amount) external returns (bool) {\n require(to != address(0), \"INVALID_RECIPIENT\");\n require(balances[msg.sender] >= amount, \"INSUFFICIENT_BALANCE\");\n balances[msg.sender] -= amount;\n balances[to] += amount;\n emit Transfer(msg.sender, to, amount);\n return true;\n}",
|
| 34 |
+
"comment": "Transfers tokens from caller to recipient",
|
| 35 |
+
"visibility": "external",
|
| 36 |
+
"modifiers": [],
|
| 37 |
+
"parameters": [
|
| 38 |
+
{
|
| 39 |
+
"name": "to",
|
| 40 |
+
"type": "address",
|
| 41 |
+
"description": "Recipient address"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"name": "amount",
|
| 45 |
+
"type": "uint256",
|
| 46 |
+
"description": "Amount to transfer"
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"returns": "bool - true on success",
|
| 50 |
+
"output_property": "Decreases caller's balance by amount, increases recipient's balance by amount. Emits Transfer event. Reverts if recipient is zero address or caller has insufficient balance.",
|
| 51 |
+
"events": ["Transfer"],
|
| 52 |
+
"vulnerable": false,
|
| 53 |
+
"vulnerability_details": null,
|
| 54 |
+
"rule_broken_english": null,
|
| 55 |
+
"rule_broken_specs": null
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"name": "withdraw",
|
| 59 |
+
"signature": "withdraw(uint256 amount)",
|
| 60 |
+
"code": "function withdraw(uint256 amount) external {\n require(balances[msg.sender] >= amount, \"INSUFFICIENT_BALANCE\");\n balances[msg.sender] -= amount;\n (bool success, ) = msg.sender.call{value: amount}(\"\");\n require(success, \"TRANSFER_FAILED\");\n}",
|
| 61 |
+
"comment": "Withdraws ETH from contract",
|
| 62 |
+
"visibility": "external",
|
| 63 |
+
"modifiers": [],
|
| 64 |
+
"parameters": [
|
| 65 |
+
{
|
| 66 |
+
"name": "amount",
|
| 67 |
+
"type": "uint256",
|
| 68 |
+
"description": "Amount to withdraw"
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
"returns": "",
|
| 72 |
+
"output_property": "Transfers amount ETH to caller. Reverts if insufficient balance or ETH transfer fails.",
|
| 73 |
+
"events": [],
|
| 74 |
+
"vulnerable": true,
|
| 75 |
+
"vulnerability_details": {
|
| 76 |
+
"issue": "Reentrancy vulnerability",
|
| 77 |
+
"severity": "High",
|
| 78 |
+
"description": "The withdraw function updates balance after making an external call, allowing reentrancy attacks",
|
| 79 |
+
"mitigation": "Use checks-effects-interactions pattern: update balance before external call"
|
| 80 |
+
},
|
| 81 |
+
"rule_broken_english": "When a user withdraws x amount of ETH, the user's balance should decrease by x. Due to reentrancy, an attacker can call withdraw recursively before balance is updated, draining more than their balance.",
|
| 82 |
+
"rule_broken_specs": "Pre-condition: User has balance B. Operation: withdraw(amount). Expected post-condition: User balance = B - amount. Actual vulnerability: Reentrant calls allow multiple withdrawals before balance update, resulting in user balance = B - (n * amount) where n > 1, violating the expected post-condition."
|
| 83 |
+
}
|
| 84 |
+
],
|
| 85 |
+
|
| 86 |
+
"structs": [
|
| 87 |
+
{
|
| 88 |
+
"name": "MintLocalVars",
|
| 89 |
+
"definition": "struct MintLocalVars {\n uint256 previousSupply;\n uint256 nextSupply;\n uint256 amountInRay;\n uint256 newRate;\n uint256 currentAvgRate;\n}",
|
| 90 |
+
"description": "Local variables used in mint function to avoid stack too deep errors"
|
| 91 |
+
}
|
| 92 |
+
],
|
| 93 |
+
|
| 94 |
+
"modifiers": [
|
| 95 |
+
{
|
| 96 |
+
"name": "onlyOwner",
|
| 97 |
+
"definition": "require(msg.sender == owner, \"NOT_OWNER\");",
|
| 98 |
+
"purpose": "Restricts function access to contract owner only"
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"name": "nonReentrant",
|
| 102 |
+
"definition": "Inherited from OpenZeppelin ReentrancyGuard",
|
| 103 |
+
"purpose": "Prevents reentrancy attacks by using a mutex lock"
|
| 104 |
+
}
|
| 105 |
+
],
|
| 106 |
+
|
| 107 |
+
"inheritance": [
|
| 108 |
+
"ERC20",
|
| 109 |
+
"Ownable"
|
| 110 |
+
],
|
| 111 |
+
|
| 112 |
+
"call_graph": {
|
| 113 |
+
"constructor": [
|
| 114 |
+
"ERC20.constructor()"
|
| 115 |
+
],
|
| 116 |
+
"transfer": [
|
| 117 |
+
"emit Transfer()"
|
| 118 |
+
],
|
| 119 |
+
"withdraw": [
|
| 120 |
+
"msg.sender.call()"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
|
| 124 |
+
"audit_issues": [
|
| 125 |
+
{
|
| 126 |
+
"function": "withdraw",
|
| 127 |
+
"issue": "Reentrancy vulnerability",
|
| 128 |
+
"severity": "High",
|
| 129 |
+
"description": "The withdraw function updates state after making an external call, allowing reentrancy attacks where an attacker can recursively call withdraw before the balance is updated",
|
| 130 |
+
"status": "Fixed",
|
| 131 |
+
"mitigation": "Moved balance update before external call (checks-effects-interactions pattern)",
|
| 132 |
+
"rule_broken_english": "When a user withdraws x amount, the user's balance should decrease by x. Due to reentrancy, an attacker can withdraw multiple times before balance updates, draining more than their balance.",
|
| 133 |
+
"rule_broken_specs": "Pre-condition: User balance = B. Operation: withdraw(amount). Expected: User balance = B - amount. Actual: Reentrant calls allow user balance = B - (n * amount) where n > 1."
|
| 134 |
+
}
|
| 135 |
+
],
|
| 136 |
+
|
| 137 |
+
"events": [
|
| 138 |
+
{
|
| 139 |
+
"name": "Transfer",
|
| 140 |
+
"parameters": "address indexed from, address indexed to, uint256 amount",
|
| 141 |
+
"description": "Emitted when tokens are transferred"
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"name": "Withdrawal",
|
| 145 |
+
"parameters": "address indexed user, uint256 amount",
|
| 146 |
+
"description": "Emitted when ETH is withdrawn"
|
| 147 |
+
}
|
| 148 |
+
]
|
| 149 |
+
}
|
data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# data package
|
data/contracts.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/data_loader.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_loader.py
|
| 3 |
+
--------------
|
| 4 |
+
Loads and indexes smart contract data from JSON files.
|
| 5 |
+
Each contract is parsed into a structured dict; vulnerable functions
|
| 6 |
+
are indexed for fast lookup by Task 1.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__))
|
| 16 |
+
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
|
| 20 |
+
"""Load and return all contracts from the JSON dataset."""
|
| 21 |
+
with open(path, "r") as f:
|
| 22 |
+
return json.load(f)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_all_vulnerable_entries(
|
| 26 |
+
contracts: List[Dict[str, Any]],
|
| 27 |
+
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
| 28 |
+
"""
|
| 29 |
+
Returns a flat list of (contract, function) pairs where
|
| 30 |
+
function['vulnerable'] is True.
|
| 31 |
+
Used by Task 1 to populate the episode pool.
|
| 32 |
+
"""
|
| 33 |
+
entries = []
|
| 34 |
+
for contract in contracts:
|
| 35 |
+
for fn in contract.get("functions", []):
|
| 36 |
+
if fn.get("vulnerable", False):
|
| 37 |
+
entries.append((contract, fn))
|
| 38 |
+
return entries
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sample_episode(
|
| 42 |
+
contracts: List[Dict[str, Any]],
|
| 43 |
+
rng: Optional[random.Random] = None,
|
| 44 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Randomly selects one (contract, vulnerable_function) pair.
|
| 47 |
+
Returns the contract dict and the target function dict.
|
| 48 |
+
"""
|
| 49 |
+
if rng is None:
|
| 50 |
+
rng = random.Random()
|
| 51 |
+
entries = get_all_vulnerable_entries(contracts)
|
| 52 |
+
if not entries:
|
| 53 |
+
raise ValueError("No vulnerable functions found in dataset.")
|
| 54 |
+
return rng.choice(entries)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_function_by_name(
|
| 58 |
+
contract: Dict[str, Any], name: str
|
| 59 |
+
) -> Optional[Dict[str, Any]]:
|
| 60 |
+
"""Case-insensitive function lookup within a contract."""
|
| 61 |
+
for fn in contract.get("functions", []):
|
| 62 |
+
if fn["name"].lower() == name.lower():
|
| 63 |
+
return fn
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_state_variable_by_name(
|
| 68 |
+
contract: Dict[str, Any], name: str
|
| 69 |
+
) -> Optional[Dict[str, Any]]:
|
| 70 |
+
"""Case-insensitive state variable lookup."""
|
| 71 |
+
for sv in contract.get("state_variables", []):
|
| 72 |
+
if sv["name"].lower() == name.lower():
|
| 73 |
+
return sv
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def list_function_names(contract: Dict[str, Any]) -> List[str]:
|
| 78 |
+
"""Return all function names in the contract."""
|
| 79 |
+
return [fn["name"] for fn in contract.get("functions", [])]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
|
| 83 |
+
"""Return all state variable names."""
|
| 84 |
+
return [sv["name"] for sv in contract.get("state_variables", [])]
|
demo.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
demo.py
|
| 3 |
+
-------
|
| 4 |
+
Interactive demo of the Smart Contract Audit RL Environment.
|
| 5 |
+
Shows Task 1 end-to-end with a human-readable step-by-step walkthrough.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python demo.py # interactive mode
|
| 9 |
+
python demo.py --auto # auto-run with built-in demo agent (no input needed)
|
| 10 |
+
python demo.py --auto --seed 42
|
| 11 |
+
|
| 12 |
+
Great for hackathon demos β run this live to show the environment in action.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
import textwrap
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
# Imports
|
| 22 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
from tasks.task1.environment import Task1Environment
|
| 24 |
+
from env.schemas import Action, ActionType
|
| 25 |
+
|
| 26 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
# ANSI colours (falls back gracefully on Windows)
|
| 28 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
try:
|
| 30 |
+
import os
|
| 31 |
+
if os.name == "nt":
|
| 32 |
+
raise ImportError
|
| 33 |
+
BOLD = "\033[1m"
|
| 34 |
+
DIM = "\033[2m"
|
| 35 |
+
GREEN = "\033[92m"
|
| 36 |
+
YELLOW = "\033[93m"
|
| 37 |
+
RED = "\033[91m"
|
| 38 |
+
CYAN = "\033[96m"
|
| 39 |
+
RESET = "\033[0m"
|
| 40 |
+
except ImportError:
|
| 41 |
+
BOLD = DIM = GREEN = YELLOW = RED = CYAN = RESET = ""
|
| 42 |
+
|
| 43 |
+
DIVIDER = f"{DIM}{'β' * 64}{RESET}"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
# Pretty printers
|
| 48 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
|
| 50 |
+
def banner():
|
| 51 |
+
print()
|
| 52 |
+
print(f"{BOLD}{CYAN}ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 53 |
+
print(f"β Smart Contract Audit RL Environment Β· Task 1 Demo β")
|
| 54 |
+
print(f"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ{RESET}")
|
| 55 |
+
print()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def print_observation(obs):
|
| 59 |
+
print(DIVIDER)
|
| 60 |
+
print(f"{BOLD}Contract :{RESET} {obs.contract_name}")
|
| 61 |
+
print(f"{BOLD}Desc :{RESET} {textwrap.fill(obs.contract_description, 72, subsequent_indent=' ' * 11)}")
|
| 62 |
+
print(f"{BOLD}Step :{RESET} {obs.step_count} "
|
| 63 |
+
f"{BOLD}Reward :{RESET} {obs.cumulative_reward:+.2f}")
|
| 64 |
+
if obs.last_action:
|
| 65 |
+
colour = GREEN if obs.cumulative_reward >= 0 else YELLOW
|
| 66 |
+
result = obs.last_action_result or ""
|
| 67 |
+
print(f"{BOLD}Last :{RESET} [{obs.last_action}]")
|
| 68 |
+
for line in textwrap.wrap(result, 72):
|
| 69 |
+
print(f" {colour}{line}{RESET}")
|
| 70 |
+
print(DIVIDER)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def print_action_menu():
|
| 74 |
+
actions = [
|
| 75 |
+
("1", "list_functions", "{}", "List all functions"),
|
| 76 |
+
("2", "get_function_code", '{"function_name": "???"}', "Get function source code"),
|
| 77 |
+
("3", "get_function_summary", '{"function_name": "???"}', "Get NatSpec comment"),
|
| 78 |
+
("4", "get_file_metadata", "{}", "Get file metadata"),
|
| 79 |
+
("5", "get_state_variable", '{"variable_name": "???"}', "Get state variable"),
|
| 80 |
+
("6", "get_call_graph", "{}", "Get call graph"),
|
| 81 |
+
("7", "submit", '{"function_name":"???","vulnerability_type":"???"}', "Submit answer"),
|
| 82 |
+
]
|
| 83 |
+
print(f"\n{BOLD}Available actions:{RESET}")
|
| 84 |
+
for num, at, _, desc in actions:
|
| 85 |
+
print(f" {CYAN}{num}{RESET} {at:25s} {DIM}{desc}{RESET}")
|
| 86 |
+
print()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def prompt_action(env) -> Action:
|
| 90 |
+
"""Prompt the user to choose and configure an action interactively."""
|
| 91 |
+
action_map = {
|
| 92 |
+
"1": ActionType.LIST_FUNCTIONS,
|
| 93 |
+
"2": ActionType.GET_FUNCTION_CODE,
|
| 94 |
+
"3": ActionType.GET_FUNCTION_SUMMARY,
|
| 95 |
+
"4": ActionType.GET_FILE_METADATA,
|
| 96 |
+
"5": ActionType.GET_STATE_VARIABLE,
|
| 97 |
+
"6": ActionType.GET_CALL_GRAPH,
|
| 98 |
+
"7": ActionType.SUBMIT,
|
| 99 |
+
}
|
| 100 |
+
while True:
|
| 101 |
+
choice = input(f"{BOLD}Choose action (1-7): {RESET}").strip()
|
| 102 |
+
if choice not in action_map:
|
| 103 |
+
print(f" {YELLOW}Enter a number 1β7{RESET}")
|
| 104 |
+
continue
|
| 105 |
+
at = action_map[choice]
|
| 106 |
+
params = {}
|
| 107 |
+
|
| 108 |
+
if at in (ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_SUMMARY):
|
| 109 |
+
fn = input(" Function name: ").strip()
|
| 110 |
+
params = {"function_name": fn}
|
| 111 |
+
|
| 112 |
+
elif at == ActionType.GET_STATE_VARIABLE:
|
| 113 |
+
var = input(" Variable name (leave blank to list all): ").strip()
|
| 114 |
+
if var:
|
| 115 |
+
params = {"variable_name": var}
|
| 116 |
+
|
| 117 |
+
elif at == ActionType.SUBMIT:
|
| 118 |
+
fn = input(" Vulnerable function name: ").strip()
|
| 119 |
+
vuln = input(" Vulnerability type (2-3 words): ").strip()
|
| 120 |
+
params = {"function_name": fn, "vulnerability_type": vuln}
|
| 121 |
+
|
| 122 |
+
return Action(action_type=at, params=params)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
# Scripted demo agent
|
| 127 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
DEMO_SCRIPTS = {
|
| 130 |
+
# seed β list of (ActionType, params, commentary)
|
| 131 |
+
42: [
|
| 132 |
+
(ActionType.GET_FILE_METADATA, {},
|
| 133 |
+
"First, get high-level contract info to understand the domain."),
|
| 134 |
+
(ActionType.LIST_FUNCTIONS, {},
|
| 135 |
+
"List functions to understand the attack surface."),
|
| 136 |
+
(ActionType.GET_FUNCTION_SUMMARY, {"function_name": "emergencyDrain"},
|
| 137 |
+
"emergencyDrain sounds dangerous β check what it's supposed to do."),
|
| 138 |
+
(ActionType.GET_FUNCTION_CODE, {"function_name": "emergencyDrain"},
|
| 139 |
+
"Inspect the code β no onlyOwner modifier! Anyone can drain the vault."),
|
| 140 |
+
(ActionType.SUBMIT, {"function_name": "emergencyDrain", "vulnerability_type": "missing access control"},
|
| 141 |
+
"Confident: missing access control. Submitting!"),
|
| 142 |
+
],
|
| 143 |
+
7: [
|
| 144 |
+
(ActionType.LIST_FUNCTIONS, {},
|
| 145 |
+
"Start by surveying all functions."),
|
| 146 |
+
(ActionType.GET_FUNCTION_SUMMARY, {"function_name": "finalize"},
|
| 147 |
+
"finalize β what does this auction close-out function do?"),
|
| 148 |
+
(ActionType.GET_FUNCTION_CODE, {"function_name": "finalize"},
|
| 149 |
+
"Uses block.timestamp for deadline check β miners can manipulate this."),
|
| 150 |
+
(ActionType.SUBMIT, {"function_name": "finalize", "vulnerability_type": "timestamp dependence"},
|
| 151 |
+
"Classic timestamp manipulation. Submitting."),
|
| 152 |
+
],
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
DEFAULT_DEMO_SEED = 42
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run_auto_demo(seed: int, delay: float = 0.9):
|
| 159 |
+
"""Run the scripted demo agent with printed commentary."""
|
| 160 |
+
script = DEMO_SCRIPTS.get(seed)
|
| 161 |
+
if script is None:
|
| 162 |
+
# Generic fallback: list, code of first suspicious fn, submit
|
| 163 |
+
print(f"{YELLOW}No pre-written script for seed {seed}. Running generic agent.{RESET}\n")
|
| 164 |
+
script = [
|
| 165 |
+
(ActionType.LIST_FUNCTIONS, {}, "Listing all functions first."),
|
| 166 |
+
(ActionType.GET_FILE_METADATA, {}, "Checking contract metadata."),
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
env = Task1Environment()
|
| 170 |
+
result = env.reset(seed=seed)
|
| 171 |
+
obs = result.observation
|
| 172 |
+
|
| 173 |
+
banner()
|
| 174 |
+
print(f"{BOLD}Mode:{RESET} Automated demo | {BOLD}Seed:{RESET} {seed}\n")
|
| 175 |
+
print_observation(obs)
|
| 176 |
+
|
| 177 |
+
for at, params, commentary in script:
|
| 178 |
+
time.sleep(delay)
|
| 179 |
+
print(f"\n{CYAN}βΆ Agent thinking:{RESET} {commentary}")
|
| 180 |
+
time.sleep(delay * 0.5)
|
| 181 |
+
action = Action(action_type=at, params=params)
|
| 182 |
+
step = env.step(action)
|
| 183 |
+
obs = step.observation
|
| 184 |
+
print_observation(obs)
|
| 185 |
+
|
| 186 |
+
if step.done:
|
| 187 |
+
_print_episode_summary(obs)
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
# Episode not finished β shouldn't happen with a good script
|
| 191 |
+
state = env.state()
|
| 192 |
+
print(f"\n{YELLOW}Episode not completed (no submit action in script). "
|
| 193 |
+
f"Target was: {state.target_function}{RESET}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 197 |
+
# Interactive mode
|
| 198 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 199 |
+
|
| 200 |
+
def run_interactive(seed: int = None):
|
| 201 |
+
env = Task1Environment()
|
| 202 |
+
seed = seed or 42
|
| 203 |
+
result = env.reset(seed=seed)
|
| 204 |
+
obs = result.observation
|
| 205 |
+
|
| 206 |
+
banner()
|
| 207 |
+
print(f"{BOLD}Mode:{RESET} Interactive | {BOLD}Seed:{RESET} {seed}")
|
| 208 |
+
print(f"{DIM}Tip: Start with list_functions and get_file_metadata.{RESET}\n")
|
| 209 |
+
print_observation(obs)
|
| 210 |
+
|
| 211 |
+
while not obs.done:
|
| 212 |
+
print_action_menu()
|
| 213 |
+
try:
|
| 214 |
+
action = prompt_action(env)
|
| 215 |
+
except (KeyboardInterrupt, EOFError):
|
| 216 |
+
print(f"\n{YELLOW}Demo interrupted.{RESET}")
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
step = env.step(action)
|
| 220 |
+
obs = step.observation
|
| 221 |
+
print()
|
| 222 |
+
print_observation(obs)
|
| 223 |
+
|
| 224 |
+
if step.done:
|
| 225 |
+
_print_episode_summary(obs)
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
_offer_replay()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _print_episode_summary(obs):
|
| 232 |
+
print()
|
| 233 |
+
print(f"{BOLD}{'β' * 64}{RESET}")
|
| 234 |
+
reward = obs.cumulative_reward
|
| 235 |
+
colour = GREEN if reward > 0 else RED
|
| 236 |
+
print(f"{BOLD}Episode complete!{RESET}")
|
| 237 |
+
print(f" Steps taken : {obs.step_count}")
|
| 238 |
+
print(f" Total reward : {colour}{reward:+.2f}{RESET}")
|
| 239 |
+
last = obs.last_action_result or ""
|
| 240 |
+
if "β
" in last:
|
| 241 |
+
print(f" {GREEN}Perfect score β full marks!{RESET}")
|
| 242 |
+
elif "β οΈ" in last:
|
| 243 |
+
print(f" {YELLOW}Partial credit β right function, imprecise vulnerability type.{RESET}")
|
| 244 |
+
else:
|
| 245 |
+
print(f" {RED}Incorrect β better luck next episode.{RESET}")
|
| 246 |
+
print(f"{BOLD}{'β' * 64}{RESET}\n")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _offer_replay():
|
| 250 |
+
try:
|
| 251 |
+
again = input("Play again? (y/n): ").strip().lower()
|
| 252 |
+
if again == "y":
|
| 253 |
+
run_interactive()
|
| 254 |
+
except (KeyboardInterrupt, EOFError):
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
+
# Entry point
|
| 260 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
+
|
| 262 |
+
def main():
|
| 263 |
+
parser = argparse.ArgumentParser(
|
| 264 |
+
description="Smart Contract Audit RL Environment β Demo"
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--auto", action="store_true",
|
| 268 |
+
help="Run the scripted demo agent (no keyboard input required)"
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--seed", type=int, default=DEFAULT_DEMO_SEED,
|
| 272 |
+
help="Episode seed (default: 42)"
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--delay", type=float, default=0.9,
|
| 276 |
+
help="Seconds between auto-agent steps (default: 0.9)"
|
| 277 |
+
)
|
| 278 |
+
args = parser.parse_args()
|
| 279 |
+
|
| 280 |
+
if args.auto:
|
| 281 |
+
run_auto_demo(seed=args.seed, delay=args.delay)
|
| 282 |
+
else:
|
| 283 |
+
run_interactive(seed=args.seed)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
env/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# env package
|
env/base_env.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_env.py
|
| 3 |
+
-----------
|
| 4 |
+
Abstract base class that every task environment must implement.
|
| 5 |
+
Follows the OpenEnv interface: reset / step / state.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
|
| 11 |
+
from env.schemas import Observation, Action, StepResult, ResetResult, StateResult
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseEnv(ABC):
|
| 15 |
+
"""
|
| 16 |
+
OpenEnv-compliant base environment.
|
| 17 |
+
|
| 18 |
+
Concrete task environments should subclass this and implement:
|
| 19 |
+
- reset() β ResetResult
|
| 20 |
+
- step() β StepResult
|
| 21 |
+
- state() β StateResult
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def reset(self, seed: int | None = None) -> ResetResult:
|
| 26 |
+
"""
|
| 27 |
+
Reset the environment to a fresh episode.
|
| 28 |
+
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
seed : optional RNG seed for reproducibility
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
ResetResult with the initial Observation and episode info.
|
| 36 |
+
"""
|
| 37 |
+
...
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def step(self, action: Action) -> StepResult:
|
| 41 |
+
"""
|
| 42 |
+
Apply an action and advance the episode by one step.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
action : Action β typed agent action
|
| 47 |
+
|
| 48 |
+
Returns
|
| 49 |
+
-------
|
| 50 |
+
StepResult containing:
|
| 51 |
+
- observation : updated Observation
|
| 52 |
+
- reward : Reward for this step
|
| 53 |
+
- done : True when the episode is over
|
| 54 |
+
- info : auxiliary diagnostic information
|
| 55 |
+
"""
|
| 56 |
+
...
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def state(self) -> StateResult:
|
| 60 |
+
"""
|
| 61 |
+
Return the full internal state (for debugging / graders).
|
| 62 |
+
Should NOT be used by the agent during evaluation.
|
| 63 |
+
|
| 64 |
+
Returns
|
| 65 |
+
-------
|
| 66 |
+
StateResult β internal episode state snapshot.
|
| 67 |
+
"""
|
| 68 |
+
...
|
| 69 |
+
|
| 70 |
+
# ------------------------------------------------------------------
|
| 71 |
+
# Optional helpers subclasses may override
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
def render(self) -> str:
|
| 75 |
+
"""Human-readable rendering of the current state."""
|
| 76 |
+
s = self.state()
|
| 77 |
+
return (
|
| 78 |
+
f"Task: {s.task_id} | Contract: {s.contract_name} | "
|
| 79 |
+
f"Step: {s.step_count} | Reward: {s.cumulative_reward:.2f} | "
|
| 80 |
+
f"Done: {s.done}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def action_space_description(self) -> Dict[str, Any]:
|
| 84 |
+
"""Returns a JSON-serialisable description of the action space."""
|
| 85 |
+
return {}
|
| 86 |
+
|
| 87 |
+
def observation_space_description(self) -> Dict[str, Any]:
|
| 88 |
+
"""Returns a JSON-serialisable description of the observation space."""
|
| 89 |
+
return {}
|
env/schemas.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
schemas.py
|
| 3 |
+
----------
|
| 4 |
+
Typed Pydantic models implementing the OpenEnv interface spec.
|
| 5 |
+
|
| 6 |
+
Observation - what the agent sees at each step
|
| 7 |
+
Action - what the agent can send
|
| 8 |
+
StepResult - returned by step()
|
| 9 |
+
ResetResult - returned by reset()
|
| 10 |
+
StateResult - returned by state()
|
| 11 |
+
Reward - structured reward info
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from enum import Enum
|
| 17 |
+
from typing import Any, Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel, Field
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Action types
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
class ActionType(str, Enum):
|
| 27 |
+
# Task 1 β Vulnerability Detection
|
| 28 |
+
LIST_FUNCTIONS = "list_functions"
|
| 29 |
+
GET_FUNCTION_CODE = "get_function_code"
|
| 30 |
+
GET_FUNCTION_SUMMARY = "get_function_summary"
|
| 31 |
+
GET_FILE_METADATA = "get_file_metadata"
|
| 32 |
+
GET_STATE_VARIABLE = "get_state_variable"
|
| 33 |
+
GET_CALL_GRAPH = "get_call_graph"
|
| 34 |
+
SUBMIT = "submit"
|
| 35 |
+
|
| 36 |
+
# TODO: Task 2 β Property Discovery
|
| 37 |
+
# GET_SIMILAR_RULE = "get_similar_rule"
|
| 38 |
+
# GET_FILE_NATSPEC = "get_file_natspec"
|
| 39 |
+
# GET_FUNCTION_NATSPEC = "get_function_natspec"
|
| 40 |
+
# GET_RELATED_FUNCTIONS = "get_related_functions"
|
| 41 |
+
# GET_IO = "get_io"
|
| 42 |
+
# SUBMIT_PROPERTY = "submit_property"
|
| 43 |
+
|
| 44 |
+
# TODO: Task 3 β Rule Checker
|
| 45 |
+
# GET_FORMALIZED_PROPERTY = "get_formalized_property"
|
| 46 |
+
# GET_FUNCTION_METADATA = "get_function_metadata"
|
| 47 |
+
# SUBMIT_FUNCTION = "submit_function"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Action(BaseModel):
|
| 51 |
+
"""
|
| 52 |
+
Agent action.
|
| 53 |
+
|
| 54 |
+
action_type : one of ActionType enum values
|
| 55 |
+
params : optional key/value arguments, e.g.
|
| 56 |
+
{"function_name": "withdraw"} for GET_FUNCTION_CODE
|
| 57 |
+
{"function_name": "withdraw", "vulnerability_type": "reentrancy"} for SUBMIT
|
| 58 |
+
"""
|
| 59 |
+
action_type: ActionType
|
| 60 |
+
params: Dict[str, Any] = Field(default_factory=dict)
|
| 61 |
+
|
| 62 |
+
class Config:
|
| 63 |
+
use_enum_values = True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Observation
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
class Observation(BaseModel):
|
| 71 |
+
"""
|
| 72 |
+
What the agent receives from the environment.
|
| 73 |
+
|
| 74 |
+
task_id : which task is active
|
| 75 |
+
contract_name : name of the Solidity contract
|
| 76 |
+
contract_description : high-level description of what the contract does
|
| 77 |
+
available_actions : list of valid ActionType strings
|
| 78 |
+
last_action : the action that produced this observation (None on reset)
|
| 79 |
+
last_action_result: human-readable result of the last action
|
| 80 |
+
step_count : number of steps taken so far
|
| 81 |
+
cumulative_reward : running reward total
|
| 82 |
+
done : whether the episode has ended
|
| 83 |
+
extra : any additional task-specific context
|
| 84 |
+
"""
|
| 85 |
+
task_id: str
|
| 86 |
+
contract_name: str
|
| 87 |
+
contract_description: str
|
| 88 |
+
available_actions: List[str]
|
| 89 |
+
last_action: Optional[str] = None
|
| 90 |
+
last_action_result: Optional[str] = None
|
| 91 |
+
step_count: int = 0
|
| 92 |
+
cumulative_reward: float = 0.0
|
| 93 |
+
done: bool = False
|
| 94 |
+
extra: Dict[str, Any] = Field(default_factory=dict)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Reward
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
class Reward(BaseModel):
|
| 102 |
+
"""
|
| 103 |
+
Structured reward info returned with each step.
|
| 104 |
+
|
| 105 |
+
value : float reward for this step (can be negative)
|
| 106 |
+
reason : human-readable explanation
|
| 107 |
+
partial : True if this is a shaping reward, False if terminal
|
| 108 |
+
"""
|
| 109 |
+
value: float
|
| 110 |
+
reason: str
|
| 111 |
+
partial: bool = True
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
# Step / Reset / State results
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
class StepResult(BaseModel):
|
| 119 |
+
observation: Observation
|
| 120 |
+
reward: Reward
|
| 121 |
+
done: bool
|
| 122 |
+
info: Dict[str, Any] = Field(default_factory=dict)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ResetResult(BaseModel):
|
| 126 |
+
observation: Observation
|
| 127 |
+
info: Dict[str, Any] = Field(default_factory=dict)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class StateResult(BaseModel):
|
| 131 |
+
task_id: str
|
| 132 |
+
contract_name: str
|
| 133 |
+
target_function: Optional[str] = None # hidden in real eval, exposed here for debugging
|
| 134 |
+
step_count: int
|
| 135 |
+
cumulative_reward: float
|
| 136 |
+
done: bool
|
| 137 |
+
query_history: List[str] = Field(default_factory=list)
|
| 138 |
+
session_id: Optional[str] = None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# Task registry entry
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
class TaskInfo(BaseModel):
|
| 146 |
+
task_id: str
|
| 147 |
+
name: str
|
| 148 |
+
difficulty: str
|
| 149 |
+
description: str
|
| 150 |
+
status: str = "active" # or "placeholder"
|
eval.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
eval.py
|
| 3 |
+
-------
|
| 4 |
+
Evaluation harness for the Smart Contract Audit RL Environment.
|
| 5 |
+
|
| 6 |
+
Runs a configurable number of episodes per task, collecting grader scores
|
| 7 |
+
and reward trajectories. Produces a detailed JSON report.
|
| 8 |
+
|
| 9 |
+
Unlike inference.py (which uses an external LLM), this evaluates the
|
| 10 |
+
*environment itself* using a built-in oracle agent β useful for:
|
| 11 |
+
- Verifying grader correctness
|
| 12 |
+
- Benchmarking reward shaping
|
| 13 |
+
- Checking score distribution across vulnerability types
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python eval.py # all 8 vuln episodes
|
| 17 |
+
python eval.py --episodes 16 # more episodes
|
| 18 |
+
python eval.py --seed 0 --verbose # detailed per-step output
|
| 19 |
+
python eval.py --out results.json # custom output file
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from typing import Any, Dict, List
|
| 27 |
+
|
| 28 |
+
from tasks.task1.environment import Task1Environment
|
| 29 |
+
from env.schemas import Action, ActionType
|
| 30 |
+
from data.data_loader import load_contracts, get_all_vulnerable_entries
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
# Oracle agent (always submits the ground-truth answer)
|
| 35 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
def oracle_agent(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
|
| 38 |
+
"""
|
| 39 |
+
Runs one episode using the oracle strategy:
|
| 40 |
+
1. list_functions
|
| 41 |
+
2. get_function_code (for the target function β peeked from state)
|
| 42 |
+
3. submit correct answer
|
| 43 |
+
|
| 44 |
+
This gives an upper-bound score trajectory for the environment.
|
| 45 |
+
Always ends with grader_score = 1.0.
|
| 46 |
+
"""
|
| 47 |
+
reset_result = env.reset(seed=seed)
|
| 48 |
+
obs = reset_result.observation
|
| 49 |
+
|
| 50 |
+
steps_taken: List[Dict[str, Any]] = []
|
| 51 |
+
|
| 52 |
+
def _step(at: ActionType, params: dict = None) -> Any:
|
| 53 |
+
params = params or {}
|
| 54 |
+
action = Action(action_type=at, params=params)
|
| 55 |
+
result = env.step(action)
|
| 56 |
+
entry = {
|
| 57 |
+
"step": result.observation.step_count,
|
| 58 |
+
"action": at.value,
|
| 59 |
+
"params": params,
|
| 60 |
+
"reward": result.reward.value,
|
| 61 |
+
"reason": result.reward.reason,
|
| 62 |
+
"cumulative": result.observation.cumulative_reward,
|
| 63 |
+
"done": result.done,
|
| 64 |
+
}
|
| 65 |
+
steps_taken.append(entry)
|
| 66 |
+
if verbose:
|
| 67 |
+
done_flag = " [DONE]" if result.done else ""
|
| 68 |
+
print(
|
| 69 |
+
f" step {entry['step']:2d}: {at.value:25s} "
|
| 70 |
+
f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}"
|
| 71 |
+
f"{done_flag}"
|
| 72 |
+
)
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
# Peek at ground truth (oracle only)
|
| 76 |
+
state = env.state()
|
| 77 |
+
target_fn = state.target_function
|
| 78 |
+
|
| 79 |
+
# Get ground-truth vulnerability from data
|
| 80 |
+
contracts = load_contracts()
|
| 81 |
+
vuln_issue = None
|
| 82 |
+
for contract in contracts:
|
| 83 |
+
for fn in contract.get("functions", []):
|
| 84 |
+
if fn["name"].lower() == target_fn.lower() and fn.get("vulnerable"):
|
| 85 |
+
vuln_issue = fn["vulnerability_details"]["issue"]
|
| 86 |
+
break
|
| 87 |
+
if vuln_issue:
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
if verbose:
|
| 91 |
+
print(f" Contract : {obs.contract_name}")
|
| 92 |
+
print(f" Target : {target_fn} ({vuln_issue})")
|
| 93 |
+
|
| 94 |
+
# Step 1: list functions (small cost, realistic)
|
| 95 |
+
_step(ActionType.LIST_FUNCTIONS)
|
| 96 |
+
# Step 2: read target function code (gets +0.05 shaping reward)
|
| 97 |
+
_step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn})
|
| 98 |
+
# Step 3: submit perfect answer
|
| 99 |
+
result = _step(ActionType.SUBMIT, {
|
| 100 |
+
"function_name": target_fn,
|
| 101 |
+
"vulnerability_type": vuln_issue,
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
final_reward = result.reward.value
|
| 105 |
+
if final_reward >= 4.9:
|
| 106 |
+
grader_score = 1.0
|
| 107 |
+
elif final_reward >= 0.9:
|
| 108 |
+
grader_score = 0.5
|
| 109 |
+
else:
|
| 110 |
+
grader_score = 0.0
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"seed": seed,
|
| 114 |
+
"contract": obs.contract_name,
|
| 115 |
+
"target_function": target_fn,
|
| 116 |
+
"vulnerability": vuln_issue,
|
| 117 |
+
"grader_score": grader_score,
|
| 118 |
+
"cumulative_reward": result.observation.cumulative_reward,
|
| 119 |
+
"steps": steps_taken,
|
| 120 |
+
"num_steps": len(steps_taken),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
# Partial agent (submits correct function, wrong vuln type)
|
| 126 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
|
| 128 |
+
def partial_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
|
| 129 |
+
"""Submits right function, always uses 'unknown' as vulnerability type β score 0.5."""
|
| 130 |
+
reset_result = env.reset(seed=seed)
|
| 131 |
+
obs = reset_result.observation
|
| 132 |
+
state = env.state()
|
| 133 |
+
target_fn = state.target_function
|
| 134 |
+
|
| 135 |
+
action = Action(action_type=ActionType.SUBMIT, params={
|
| 136 |
+
"function_name": target_fn,
|
| 137 |
+
"vulnerability_type": "unknown vulnerability",
|
| 138 |
+
})
|
| 139 |
+
result = env.step(action)
|
| 140 |
+
return {
|
| 141 |
+
"seed": seed,
|
| 142 |
+
"grader_score": 0.5,
|
| 143 |
+
"cumulative_reward": result.observation.cumulative_reward,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
# Random agent (submits a random wrong function)
|
| 149 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
def random_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
|
| 152 |
+
"""Always submits 'constructor' β always wrong β score 0.0."""
|
| 153 |
+
env.reset(seed=seed)
|
| 154 |
+
action = Action(action_type=ActionType.SUBMIT, params={
|
| 155 |
+
"function_name": "constructor",
|
| 156 |
+
"vulnerability_type": "reentrancy",
|
| 157 |
+
})
|
| 158 |
+
result = env.step(action)
|
| 159 |
+
return {
|
| 160 |
+
"seed": seed,
|
| 161 |
+
"grader_score": 0.0,
|
| 162 |
+
"cumulative_reward": result.observation.cumulative_reward,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
# Evaluation runner
|
| 168 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
|
| 170 |
+
def run_evaluation(
|
| 171 |
+
num_episodes: int = 8,
|
| 172 |
+
seed_offset: int = 0,
|
| 173 |
+
verbose: bool = False,
|
| 174 |
+
output_file: str = "eval_results.json",
|
| 175 |
+
) -> None:
|
| 176 |
+
env = Task1Environment()
|
| 177 |
+
contracts = load_contracts()
|
| 178 |
+
entries = get_all_vulnerable_entries(contracts)
|
| 179 |
+
vuln_types = list({fn["vulnerability_details"]["issue"] for _, fn in entries})
|
| 180 |
+
|
| 181 |
+
print("=" * 64)
|
| 182 |
+
print("Smart Contract Audit RL Environment β Evaluation")
|
| 183 |
+
print("=" * 64)
|
| 184 |
+
print(f" Episodes : {num_episodes}")
|
| 185 |
+
print(f" Seed range: {seed_offset} β {seed_offset + num_episodes - 1}")
|
| 186 |
+
print(f" Vulns in dataset: {len(entries)}")
|
| 187 |
+
print()
|
| 188 |
+
|
| 189 |
+
# ββ Oracle agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 190 |
+
print("βΆ Oracle agent (upper bound β always submits correct answer):")
|
| 191 |
+
oracle_episodes = []
|
| 192 |
+
for i in range(num_episodes):
|
| 193 |
+
seed = seed_offset + i
|
| 194 |
+
ep = oracle_agent(env, seed=seed, verbose=verbose)
|
| 195 |
+
oracle_episodes.append(ep)
|
| 196 |
+
icon = "β
" if ep["grader_score"] == 1.0 else "β οΈ "
|
| 197 |
+
print(
|
| 198 |
+
f" {icon} seed={seed:3d} {ep['contract']:12s} "
|
| 199 |
+
f"{ep['target_function']:15s} score={ep['grader_score']:.1f} "
|
| 200 |
+
f"reward={ep['cumulative_reward']:+.2f}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes
|
| 204 |
+
oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes
|
| 205 |
+
print(f"\n Oracle avg grader score : {oracle_avg:.3f}")
|
| 206 |
+
print(f" Oracle avg reward : {oracle_avg_r:+.2f}")
|
| 207 |
+
|
| 208 |
+
# ββ Partial agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
+
print("\nβΆ Partial agent (right function, wrong vuln type β 0.5 each):")
|
| 210 |
+
partial_episodes = []
|
| 211 |
+
for i in range(num_episodes):
|
| 212 |
+
ep = partial_agent(env, seed=seed_offset + i)
|
| 213 |
+
partial_episodes.append(ep)
|
| 214 |
+
partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes
|
| 215 |
+
print(f" Partial avg grader score: {partial_avg:.3f}")
|
| 216 |
+
|
| 217 |
+
# ββ Random agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
print("\nβΆ Random agent (always wrong β 0.0 each):")
|
| 219 |
+
random_episodes = []
|
| 220 |
+
for i in range(num_episodes):
|
| 221 |
+
ep = random_agent(env, seed=seed_offset + i)
|
| 222 |
+
random_episodes.append(ep)
|
| 223 |
+
random_avg = sum(e["grader_score"] for e in random_episodes) / num_episodes
|
| 224 |
+
print(f" Random avg grader score : {random_avg:.3f}")
|
| 225 |
+
|
| 226 |
+
# ββ Score distribution βββββββββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββ
|
| 227 |
+
print("\nβΆ Coverage across vulnerability types:")
|
| 228 |
+
seen = {}
|
| 229 |
+
for ep in oracle_episodes:
|
| 230 |
+
v = ep.get("vulnerability", "unknown")
|
| 231 |
+
seen[v] = seen.get(v, 0) + 1
|
| 232 |
+
for v in sorted(seen):
|
| 233 |
+
print(f" {seen[v]:2d}x {v}")
|
| 234 |
+
|
| 235 |
+
# ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
+
print("\n" + "=" * 64)
|
| 237 |
+
print("SUMMARY")
|
| 238 |
+
print("=" * 64)
|
| 239 |
+
print(f" Oracle (ceiling): {oracle_avg:.3f} {'β
' if oracle_avg == 1.0 else 'β οΈ '}")
|
| 240 |
+
print(f" Partial (partial): {partial_avg:.3f} β
")
|
| 241 |
+
print(f" Random (floor) : {random_avg:.3f} β
")
|
| 242 |
+
|
| 243 |
+
assert oracle_avg == 1.0, "Oracle should always score 1.0"
|
| 244 |
+
assert partial_avg == 0.5, "Partial should always score 0.5"
|
| 245 |
+
assert random_avg == 0.0, "Random should always score 0.0"
|
| 246 |
+
|
| 247 |
+
print("\n β
All score sanity checks passed.")
|
| 248 |
+
|
| 249 |
+
# ββ Write results βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 250 |
+
report = {
|
| 251 |
+
"num_episodes": num_episodes,
|
| 252 |
+
"seed_offset": seed_offset,
|
| 253 |
+
"agents": {
|
| 254 |
+
"oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_episodes},
|
| 255 |
+
"partial": {"avg_score": partial_avg, "episodes": partial_episodes},
|
| 256 |
+
"random": {"avg_score": random_avg, "episodes": random_episodes},
|
| 257 |
+
},
|
| 258 |
+
"vulnerability_coverage": seen,
|
| 259 |
+
}
|
| 260 |
+
with open(output_file, "w") as f:
|
| 261 |
+
json.dump(report, f, indent=2)
|
| 262 |
+
print(f"\n Results written to {output_file}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
+
# Entry point
|
| 267 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
+
|
| 269 |
+
def main():
|
| 270 |
+
parser = argparse.ArgumentParser(description="Evaluate the SC Audit RL Environment")
|
| 271 |
+
parser.add_argument("--episodes", type=int, default=8,
|
| 272 |
+
help="Number of episodes per agent (default: 8)")
|
| 273 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 274 |
+
help="Starting seed (default: 42)")
|
| 275 |
+
parser.add_argument("--verbose", action="store_true",
|
| 276 |
+
help="Print per-step details for oracle agent")
|
| 277 |
+
parser.add_argument("--out", default="eval_results.json",
|
| 278 |
+
help="Output JSON file (default: eval_results.json)")
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
run_evaluation(
|
| 282 |
+
num_episodes=args.episodes,
|
| 283 |
+
seed_offset=args.seed,
|
| 284 |
+
verbose=args.verbose,
|
| 285 |
+
output_file=args.out,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
------------
|
| 4 |
+
Baseline inference script for the Smart Contract Audit RL Environment.
|
| 5 |
+
|
| 6 |
+
Uses the OpenAI-compatible API client to run an LLM agent against Task 1.
|
| 7 |
+
Tasks 2 and 3 are placeholders β they reset and immediately record 0.0.
|
| 8 |
+
|
| 9 |
+
Environment variables required:
|
| 10 |
+
API_BASE_URL β LLM endpoint (e.g. https://api.openai.com/v1)
|
| 11 |
+
MODEL_NAME β model identifier (e.g. gpt-4o-mini)
|
| 12 |
+
HF_TOKEN β API key (passed as Authorization: Bearer <HF_TOKEN>)
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python inference.py
|
| 16 |
+
|
| 17 |
+
Output:
|
| 18 |
+
Per-task scores printed to stdout.
|
| 19 |
+
Final baseline scores written to baseline_scores.json.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from typing import Any, Dict, List, Optional
|
| 27 |
+
|
| 28 |
+
from openai import OpenAI
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Import the env directly (no HTTP overhead for baseline)
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
from tasks.task1.environment import Task1Environment
|
| 34 |
+
from env.schemas import Action, ActionType
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Config
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 41 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 42 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 43 |
+
|
| 44 |
+
if not HF_TOKEN:
|
| 45 |
+
print("WARNING: HF_TOKEN is not set. API calls may fail.", file=sys.stderr)
|
| 46 |
+
|
| 47 |
+
MAX_STEPS = 15 # Safety limit per episode
|
| 48 |
+
NUM_EPISODES = 3 # Episodes per task
|
| 49 |
+
TASK1_SEED_BASE = 42 # Reproducible seeds
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# OpenAI client
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
client = OpenAI(
|
| 57 |
+
api_key=HF_TOKEN,
|
| 58 |
+
base_url=API_BASE_URL,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# System prompt
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
+
SYSTEM_PROMPT = """You are an expert smart contract security auditor.
|
| 67 |
+
|
| 68 |
+
You are given a Solidity contract and must identify the SINGLE most critical vulnerable function and name its vulnerability type.
|
| 69 |
+
|
| 70 |
+
## Available Actions
|
| 71 |
+
You interact by choosing ONE action per turn from:
|
| 72 |
+
|
| 73 |
+
1. list_functions
|
| 74 |
+
β {"action": "list_functions", "params": {}}
|
| 75 |
+
|
| 76 |
+
2. get_function_code
|
| 77 |
+
β {"action": "get_function_code", "params": {"function_name": "<name>"}}
|
| 78 |
+
|
| 79 |
+
3. get_function_summary
|
| 80 |
+
β {"action": "get_function_summary", "params": {"function_name": "<name>"}}
|
| 81 |
+
|
| 82 |
+
4. get_file_metadata
|
| 83 |
+
β {"action": "get_file_metadata", "params": {}}
|
| 84 |
+
|
| 85 |
+
5. get_state_variable
|
| 86 |
+
β {"action": "get_state_variable", "params": {"variable_name": "<name>"}}
|
| 87 |
+
(omit variable_name to list all variables)
|
| 88 |
+
|
| 89 |
+
6. get_call_graph
|
| 90 |
+
β {"action": "get_call_graph", "params": {}}
|
| 91 |
+
|
| 92 |
+
7. submit (ENDS THE EPISODE)
|
| 93 |
+
β {"action": "submit", "params": {"function_name": "<name>", "vulnerability_type": "<2-3 word description>"}}
|
| 94 |
+
|
| 95 |
+
## Strategy
|
| 96 |
+
- Start with list_functions and get_file_metadata to understand the contract
|
| 97 |
+
- Inspect suspicious functions (withdraw, transfer, emergency*, stake, etc.)
|
| 98 |
+
- Submit when you are confident about the vulnerable function
|
| 99 |
+
|
| 100 |
+
## Output Format
|
| 101 |
+
Always respond with a single JSON object:
|
| 102 |
+
{"action": "<action_type>", "params": {...}}
|
| 103 |
+
Do NOT include any other text β only valid JSON.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_user_message(obs: Dict[str, Any]) -> str:
|
| 108 |
+
"""Format the observation as a user message."""
|
| 109 |
+
lines = [
|
| 110 |
+
f"=== CONTRACT: {obs['contract_name']} ===",
|
| 111 |
+
f"Description: {obs['contract_description']}",
|
| 112 |
+
f"Step: {obs['step_count']} | Cumulative reward: {obs['cumulative_reward']:.2f}",
|
| 113 |
+
"",
|
| 114 |
+
f"Last action: {obs['last_action'] or 'None'}",
|
| 115 |
+
f"Result: {obs['last_action_result'] or 'Episode just started'}",
|
| 116 |
+
"",
|
| 117 |
+
f"Available actions: {', '.join(obs['available_actions'])}",
|
| 118 |
+
]
|
| 119 |
+
if obs.get("extra", {}).get("hint"):
|
| 120 |
+
lines.append(f"Hint: {obs['extra']['hint']}")
|
| 121 |
+
return "\n".join(lines)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Agent loop
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
def run_episode(env: Task1Environment, seed: int, episode_num: int) -> Dict[str, Any]:
|
| 129 |
+
"""Run one episode and return result info."""
|
| 130 |
+
print(f"\n Episode {episode_num} (seed={seed})")
|
| 131 |
+
|
| 132 |
+
reset_result = env.reset(seed=seed)
|
| 133 |
+
obs = reset_result.observation.model_dump()
|
| 134 |
+
|
| 135 |
+
print(f" Contract: {obs['contract_name']}")
|
| 136 |
+
|
| 137 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 138 |
+
final_score = 0.0
|
| 139 |
+
final_reward = 0.0
|
| 140 |
+
steps = 0
|
| 141 |
+
done = False
|
| 142 |
+
|
| 143 |
+
for step_num in range(MAX_STEPS):
|
| 144 |
+
user_msg = build_user_message(obs)
|
| 145 |
+
messages.append({"role": "user", "content": user_msg})
|
| 146 |
+
|
| 147 |
+
# LLM call
|
| 148 |
+
try:
|
| 149 |
+
response = client.chat.completions.create(
|
| 150 |
+
model=MODEL_NAME,
|
| 151 |
+
messages=messages,
|
| 152 |
+
max_tokens=256,
|
| 153 |
+
temperature=0.0,
|
| 154 |
+
)
|
| 155 |
+
raw = response.choices[0].message.content.strip()
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f" LLM error at step {step_num}: {e}", file=sys.stderr)
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
# Parse action
|
| 161 |
+
try:
|
| 162 |
+
parsed = json.loads(raw)
|
| 163 |
+
action_type = ActionType(parsed["action"])
|
| 164 |
+
params = parsed.get("params", {})
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f" Parse error: {e} | Raw: {raw[:100]}", file=sys.stderr)
|
| 167 |
+
# Default safe action
|
| 168 |
+
action_type = ActionType.LIST_FUNCTIONS
|
| 169 |
+
params = {}
|
| 170 |
+
|
| 171 |
+
action = Action(action_type=action_type, params=params)
|
| 172 |
+
messages.append({"role": "assistant", "content": raw})
|
| 173 |
+
|
| 174 |
+
# Step
|
| 175 |
+
step_result = env.step(action)
|
| 176 |
+
obs = step_result.observation.model_dump()
|
| 177 |
+
done = step_result.done
|
| 178 |
+
steps += 1
|
| 179 |
+
final_reward = obs["cumulative_reward"]
|
| 180 |
+
|
| 181 |
+
print(
|
| 182 |
+
f" Step {step_num+1}: {action_type.value} | "
|
| 183 |
+
f"reward={step_result.reward.value:+.2f} | "
|
| 184 |
+
f"cumulative={final_reward:.2f}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if done:
|
| 188 |
+
# Determine grader score from reward
|
| 189 |
+
last_reward = step_result.reward.value
|
| 190 |
+
if last_reward >= 4.9:
|
| 191 |
+
final_score = 1.0
|
| 192 |
+
elif last_reward >= 0.9:
|
| 193 |
+
final_score = 0.5
|
| 194 |
+
else:
|
| 195 |
+
final_score = 0.0
|
| 196 |
+
print(f" β DONE | grader_score={final_score:.1f}")
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if not done:
|
| 200 |
+
print(f" β MAX STEPS reached without submission. Score=0.0")
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"episode": episode_num,
|
| 204 |
+
"seed": seed,
|
| 205 |
+
"contract": obs["contract_name"],
|
| 206 |
+
"steps": steps,
|
| 207 |
+
"cumulative_reward": final_reward,
|
| 208 |
+
"grader_score": final_score,
|
| 209 |
+
"done": done,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def run_task1(num_episodes: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 214 |
+
"""Run Task 1 and return aggregate scores."""
|
| 215 |
+
print("\n" + "="*60)
|
| 216 |
+
print("TASK 1: Targeted Vulnerability Detection")
|
| 217 |
+
print("="*60)
|
| 218 |
+
|
| 219 |
+
env = Task1Environment()
|
| 220 |
+
episodes = []
|
| 221 |
+
|
| 222 |
+
for i in range(num_episodes):
|
| 223 |
+
seed = TASK1_SEED_BASE + i
|
| 224 |
+
result = run_episode(env, seed=seed, episode_num=i + 1)
|
| 225 |
+
episodes.append(result)
|
| 226 |
+
time.sleep(0.5) # Rate limit courtesy
|
| 227 |
+
|
| 228 |
+
scores = [e["grader_score"] for e in episodes]
|
| 229 |
+
avg = sum(scores) / len(scores) if scores else 0.0
|
| 230 |
+
avg_reward = sum(e["cumulative_reward"] for e in episodes) / len(episodes)
|
| 231 |
+
|
| 232 |
+
print(f"\n Task 1 Results:")
|
| 233 |
+
print(f" Episodes: {num_episodes}")
|
| 234 |
+
print(f" Grader scores: {scores}")
|
| 235 |
+
print(f" Average grader score: {avg:.3f}")
|
| 236 |
+
print(f" Average cumulative reward: {avg_reward:.2f}")
|
| 237 |
+
|
| 238 |
+
return {
|
| 239 |
+
"task_id": "task1_vuln_detection",
|
| 240 |
+
"name": "Targeted Vulnerability Detection",
|
| 241 |
+
"status": "active",
|
| 242 |
+
"num_episodes": num_episodes,
|
| 243 |
+
"episodes": episodes,
|
| 244 |
+
"avg_grader_score": avg,
|
| 245 |
+
"avg_cumulative_reward": avg_reward,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def run_task2_placeholder() -> Dict[str, Any]:
|
| 250 |
+
"""Task 2 placeholder β returns 0.0 score."""
|
| 251 |
+
print("\n" + "="*60)
|
| 252 |
+
print("TASK 2: Property Discovery [PLACEHOLDER β not implemented]")
|
| 253 |
+
print("="*60)
|
| 254 |
+
print(" Skipping. Score: 0.0")
|
| 255 |
+
return {
|
| 256 |
+
"task_id": "task2_property_discovery",
|
| 257 |
+
"name": "Property Discovery",
|
| 258 |
+
"status": "placeholder",
|
| 259 |
+
"num_episodes": 0,
|
| 260 |
+
"episodes": [],
|
| 261 |
+
"avg_grader_score": 0.0,
|
| 262 |
+
"avg_cumulative_reward": 0.0,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def run_task3_placeholder() -> Dict[str, Any]:
|
| 267 |
+
"""Task 3 placeholder β returns 0.0 score."""
|
| 268 |
+
print("\n" + "="*60)
|
| 269 |
+
print("TASK 3: Rule Checker [PLACEHOLDER β not implemented]")
|
| 270 |
+
print("="*60)
|
| 271 |
+
print(" Skipping. Score: 0.0")
|
| 272 |
+
return {
|
| 273 |
+
"task_id": "task3_rule_checker",
|
| 274 |
+
"name": "Rule Checker",
|
| 275 |
+
"status": "placeholder",
|
| 276 |
+
"num_episodes": 0,
|
| 277 |
+
"episodes": [],
|
| 278 |
+
"avg_grader_score": 0.0,
|
| 279 |
+
"avg_cumulative_reward": 0.0,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
# Main
|
| 285 |
+
# ---------------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
def main():
|
| 288 |
+
print("Smart Contract Audit RL Environment β Baseline Inference")
|
| 289 |
+
print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}")
|
| 290 |
+
|
| 291 |
+
results = {
|
| 292 |
+
"model": MODEL_NAME,
|
| 293 |
+
"base_url": API_BASE_URL,
|
| 294 |
+
"tasks": [],
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
t1 = run_task1(num_episodes=NUM_EPISODES)
|
| 298 |
+
t2 = run_task2_placeholder()
|
| 299 |
+
t3 = run_task3_placeholder()
|
| 300 |
+
|
| 301 |
+
results["tasks"] = [t1, t2, t3]
|
| 302 |
+
|
| 303 |
+
# Summary
|
| 304 |
+
active_tasks = [t for t in results["tasks"] if t["status"] == "active"]
|
| 305 |
+
overall = (
|
| 306 |
+
sum(t["avg_grader_score"] for t in active_tasks) / len(active_tasks)
|
| 307 |
+
if active_tasks else 0.0
|
| 308 |
+
)
|
| 309 |
+
results["overall_avg_score"] = overall
|
| 310 |
+
|
| 311 |
+
print("\n" + "="*60)
|
| 312 |
+
print("BASELINE SUMMARY")
|
| 313 |
+
print("="*60)
|
| 314 |
+
for t in results["tasks"]:
|
| 315 |
+
status = "β
" if t["status"] == "active" else "β³"
|
| 316 |
+
print(f" {status} {t['name']}: {t['avg_grader_score']:.3f}")
|
| 317 |
+
print(f" Overall (active tasks): {overall:.3f}")
|
| 318 |
+
|
| 319 |
+
# Write scores file
|
| 320 |
+
with open("baseline_scores.json", "w") as f:
|
| 321 |
+
json.dump(results, f, indent=2)
|
| 322 |
+
print("\n Scores written to baseline_scores.json")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: smart-contract-audit-env
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
Reinforcement learning environment for smart contract security analysis.
|
| 5 |
+
Agents interact with real-world Solidity contract data from Certora-audited
|
| 6 |
+
projects, learning to detect vulnerabilities, discover properties, and
|
| 7 |
+
verify rule compliance β tasks that professional auditors perform daily.
|
| 8 |
+
|
| 9 |
+
author: "SmartAudit Team"
|
| 10 |
+
license: MIT
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Tasks
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
tasks:
|
| 16 |
+
- id: task1_vuln_detection
|
| 17 |
+
name: Targeted Vulnerability Detection
|
| 18 |
+
difficulty: medium
|
| 19 |
+
status: active
|
| 20 |
+
description: >
|
| 21 |
+
Given a Solidity contract (4β6 functions), identify the single vulnerable
|
| 22 |
+
function and describe its vulnerability type in 2β3 words.
|
| 23 |
+
max_steps: 20
|
| 24 |
+
reward_range: [-10.0, 10.0]
|
| 25 |
+
grader: tasks/task1/grader.py
|
| 26 |
+
grader_score_range: [0.0, 1.0]
|
| 27 |
+
|
| 28 |
+
- id: task2_property_discovery
|
| 29 |
+
name: Property Discovery
|
| 30 |
+
difficulty: hard
|
| 31 |
+
status: placeholder
|
| 32 |
+
description: >
|
| 33 |
+
Given a single Solidity function with known properties, discover the
|
| 34 |
+
correct natural-language property describing its expected behaviour.
|
| 35 |
+
max_steps: 15
|
| 36 |
+
reward_range: [-5.0, 5.0]
|
| 37 |
+
grader: tasks/task2/grader.py # TODO: implement
|
| 38 |
+
grader_score_range: [0.0, 1.0]
|
| 39 |
+
|
| 40 |
+
- id: task3_rule_checker
|
| 41 |
+
name: Rule Checker
|
| 42 |
+
difficulty: easy
|
| 43 |
+
status: placeholder
|
| 44 |
+
description: >
|
| 45 |
+
Given a natural-language property and a Solidity file, identify the
|
| 46 |
+
function that violates that property.
|
| 47 |
+
max_steps: 15
|
| 48 |
+
reward_range: [-5.0, 5.0]
|
| 49 |
+
grader: tasks/task3/grader.py # TODO: implement
|
| 50 |
+
grader_score_range: [0.0, 1.0]
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Observation space
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
observation_space:
|
| 56 |
+
type: object
|
| 57 |
+
properties:
|
| 58 |
+
task_id:
|
| 59 |
+
type: string
|
| 60 |
+
description: Active task identifier
|
| 61 |
+
contract_name:
|
| 62 |
+
type: string
|
| 63 |
+
description: Name of the Solidity contract
|
| 64 |
+
contract_description:
|
| 65 |
+
type: string
|
| 66 |
+
description: Human-readable description of what the contract does
|
| 67 |
+
available_actions:
|
| 68 |
+
type: array
|
| 69 |
+
items:
|
| 70 |
+
type: string
|
| 71 |
+
description: List of valid action type strings
|
| 72 |
+
last_action:
|
| 73 |
+
type: string
|
| 74 |
+
nullable: true
|
| 75 |
+
description: The action type that produced this observation
|
| 76 |
+
last_action_result:
|
| 77 |
+
type: string
|
| 78 |
+
nullable: true
|
| 79 |
+
description: Human-readable result of the last action
|
| 80 |
+
step_count:
|
| 81 |
+
type: integer
|
| 82 |
+
description: Number of steps taken in this episode
|
| 83 |
+
cumulative_reward:
|
| 84 |
+
type: number
|
| 85 |
+
description: Running reward total for this episode
|
| 86 |
+
done:
|
| 87 |
+
type: boolean
|
| 88 |
+
description: True when the episode has ended
|
| 89 |
+
extra:
|
| 90 |
+
type: object
|
| 91 |
+
description: Task-specific hints and auxiliary data
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Action space (Task 1)
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
action_space:
|
| 97 |
+
type: object
|
| 98 |
+
description: Named action with optional parameters
|
| 99 |
+
properties:
|
| 100 |
+
action_type:
|
| 101 |
+
type: string
|
| 102 |
+
enum:
|
| 103 |
+
- list_functions
|
| 104 |
+
- get_function_code
|
| 105 |
+
- get_function_summary
|
| 106 |
+
- get_file_metadata
|
| 107 |
+
- get_state_variable
|
| 108 |
+
- get_call_graph
|
| 109 |
+
- submit
|
| 110 |
+
params:
|
| 111 |
+
type: object
|
| 112 |
+
description: Key-value arguments for the action
|
| 113 |
+
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
# Reward function
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
reward:
|
| 118 |
+
type: shaped
|
| 119 |
+
description: >
|
| 120 |
+
Per-step costs encourage efficient exploration. A positive signal is given
|
| 121 |
+
when the agent accesses the actual vulnerable function. Terminal rewards
|
| 122 |
+
reflect submission accuracy (0 β 1 grader score).
|
| 123 |
+
shaping:
|
| 124 |
+
list_functions: -0.05
|
| 125 |
+
get_function_code_wrong: -0.10
|
| 126 |
+
get_function_code_correct: +0.05
|
| 127 |
+
get_function_summary_wrong: -0.05
|
| 128 |
+
get_function_summary_correct: +0.03
|
| 129 |
+
get_file_metadata: -0.04
|
| 130 |
+
get_state_variable: -0.05
|
| 131 |
+
get_call_graph: -0.08
|
| 132 |
+
repeated_query: -0.40
|
| 133 |
+
terminal:
|
| 134 |
+
correct_submission: +5.0
|
| 135 |
+
partial_submission: +1.0
|
| 136 |
+
wrong_submission: -1.5
|
| 137 |
+
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
# Data
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
data:
|
| 142 |
+
source: "Certora audited projects (Aave, Compound-style protocols)"
|
| 143 |
+
format: JSON
|
| 144 |
+
num_contracts: 4
|
| 145 |
+
num_vulnerable_functions: 8
|
| 146 |
+
vulnerability_types:
|
| 147 |
+
- Reentrancy
|
| 148 |
+
- Missing access control
|
| 149 |
+
- Integer overflow
|
| 150 |
+
- tx.origin authentication
|
| 151 |
+
- Front-running
|
| 152 |
+
- Timestamp dependence
|
| 153 |
+
- Denial of service (unbounded loop)
|
| 154 |
+
- Unchecked return value
|
| 155 |
+
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
# Interface
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
interface:
|
| 160 |
+
http:
|
| 161 |
+
reset: POST /reset
|
| 162 |
+
step: POST /step
|
| 163 |
+
state: GET /state
|
| 164 |
+
tasks: GET /tasks
|
| 165 |
+
health: GET /health
|
| 166 |
+
python:
|
| 167 |
+
reset: env.reset(seed=None) -> ResetResult
|
| 168 |
+
step: env.step(action) -> StepResult
|
| 169 |
+
state: env.state() -> StateResult
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn[standard]==0.30.6
|
| 3 |
+
pydantic==2.8.2
|
| 4 |
+
openai==1.51.0
|
| 5 |
+
httpx==0.27.2
|
| 6 |
+
python-multipart==0.0.9
|
| 7 |
+
pyyaml==6.0.2
|
tasks/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# tasks package
|
tasks/task1/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# task1 package
|
| 2 |
+
from tasks.task1.environment import Task1Environment
|
| 3 |
+
from tasks.task1.grader import Task1Grader
|
| 4 |
+
|
| 5 |
+
__all__ = ["Task1Environment", "Task1Grader"]
|
tasks/task1/environment.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
environment.py (Task 1 β Targeted Vulnerability Detection)
|
| 3 |
+
------------------------------------------------------------
|
| 4 |
+
Full OpenEnv-compliant environment.
|
| 5 |
+
|
| 6 |
+
Episode flow:
|
| 7 |
+
1. reset() selects a random (contract, vulnerable_function) pair.
|
| 8 |
+
2. The agent receives an Observation with the contract description.
|
| 9 |
+
3. The agent uses actions to explore the contract (each costs a small penalty).
|
| 10 |
+
4. When the agent submits, the Grader scores the answer and the episode ends.
|
| 11 |
+
|
| 12 |
+
Reward shaping:
|
| 13 |
+
list_functions : -0.05
|
| 14 |
+
get_function_code : -0.10 (wrong function) / +0.05 (correct function)
|
| 15 |
+
get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
|
| 16 |
+
get_file_metadata : -0.04
|
| 17 |
+
get_state_variable : -0.05
|
| 18 |
+
get_call_graph : -0.08
|
| 19 |
+
submit (score=1.0) : +5.0
|
| 20 |
+
submit (score=0.5) : +1.0
|
| 21 |
+
submit (score=0.0) : -1.5
|
| 22 |
+
repeated query : -0.40
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import random
|
| 28 |
+
from typing import Any, Dict, List, Optional, Set
|
| 29 |
+
|
| 30 |
+
from data.data_loader import (
|
| 31 |
+
load_contracts,
|
| 32 |
+
sample_episode,
|
| 33 |
+
get_function_by_name,
|
| 34 |
+
get_state_variable_by_name,
|
| 35 |
+
list_function_names,
|
| 36 |
+
list_state_variable_names,
|
| 37 |
+
)
|
| 38 |
+
from env.base_env import BaseEnv
|
| 39 |
+
from env.schemas import (
|
| 40 |
+
Action,
|
| 41 |
+
ActionType,
|
| 42 |
+
Observation,
|
| 43 |
+
Reward,
|
| 44 |
+
ResetResult,
|
| 45 |
+
StateResult,
|
| 46 |
+
StepResult,
|
| 47 |
+
)
|
| 48 |
+
from tasks.task1.grader import Task1Grader
|
| 49 |
+
|
| 50 |
+
TASK_ID = "task1_vuln_detection"
|
| 51 |
+
|
| 52 |
+
AVAILABLE_ACTIONS = [
|
| 53 |
+
ActionType.LIST_FUNCTIONS,
|
| 54 |
+
ActionType.GET_FUNCTION_CODE,
|
| 55 |
+
ActionType.GET_FUNCTION_SUMMARY,
|
| 56 |
+
ActionType.GET_FILE_METADATA,
|
| 57 |
+
ActionType.GET_STATE_VARIABLE,
|
| 58 |
+
ActionType.GET_CALL_GRAPH,
|
| 59 |
+
ActionType.SUBMIT,
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Task1Environment(BaseEnv):
|
| 64 |
+
"""Task 1: Targeted Vulnerability Detection."""
|
| 65 |
+
|
| 66 |
+
def __init__(self, contracts_path: Optional[str] = None) -> None:
|
| 67 |
+
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
|
| 68 |
+
self._rng = random.Random()
|
| 69 |
+
|
| 70 |
+
# Episode state (initialised by reset)
|
| 71 |
+
self._contract: Dict[str, Any] = {}
|
| 72 |
+
self._target_fn: Dict[str, Any] = {}
|
| 73 |
+
self._grader: Optional[Task1Grader] = None
|
| 74 |
+
self._step_count: int = 0
|
| 75 |
+
self._cumulative_reward: float = 0.0
|
| 76 |
+
self._done: bool = False
|
| 77 |
+
self._query_history: List[str] = []
|
| 78 |
+
self._seen_queries: Set[str] = set()
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# OpenEnv interface
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def reset(self, seed: Optional[int] = None) -> ResetResult:
|
| 85 |
+
"""Start a new episode by sampling a random vulnerable function."""
|
| 86 |
+
if seed is not None:
|
| 87 |
+
self._rng.seed(seed)
|
| 88 |
+
|
| 89 |
+
self._contract, self._target_fn = sample_episode(self._contracts, self._rng)
|
| 90 |
+
self._grader = Task1Grader(
|
| 91 |
+
target_function=self._target_fn["name"],
|
| 92 |
+
vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
|
| 93 |
+
)
|
| 94 |
+
self._step_count = 0
|
| 95 |
+
self._cumulative_reward = 0.0
|
| 96 |
+
self._done = False
|
| 97 |
+
self._query_history = []
|
| 98 |
+
self._seen_queries = set()
|
| 99 |
+
|
| 100 |
+
obs = self._build_observation(
|
| 101 |
+
last_action=None,
|
| 102 |
+
last_result=(
|
| 103 |
+
f"New episode started. Contract: {self._contract['contract_name']}. "
|
| 104 |
+
f"Use 'list_functions' to explore the contract."
|
| 105 |
+
),
|
| 106 |
+
)
|
| 107 |
+
return ResetResult(observation=obs, info={"task_id": TASK_ID})
|
| 108 |
+
|
| 109 |
+
def step(self, action: Action) -> StepResult:
|
| 110 |
+
"""Execute one agent action."""
|
| 111 |
+
if self._done:
|
| 112 |
+
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 113 |
+
|
| 114 |
+
self._step_count += 1
|
| 115 |
+
|
| 116 |
+
# Dispatch
|
| 117 |
+
result_text, reward = self._dispatch(action)
|
| 118 |
+
|
| 119 |
+
self._cumulative_reward += reward.value
|
| 120 |
+
self._query_history.append(f"[{action.action_type}] β {result_text[:120]}")
|
| 121 |
+
|
| 122 |
+
obs = self._build_observation(
|
| 123 |
+
last_action=action.action_type,
|
| 124 |
+
last_result=result_text,
|
| 125 |
+
)
|
| 126 |
+
return StepResult(
|
| 127 |
+
observation=obs,
|
| 128 |
+
reward=reward,
|
| 129 |
+
done=self._done,
|
| 130 |
+
info={
|
| 131 |
+
"step": self._step_count,
|
| 132 |
+
"cumulative_reward": self._cumulative_reward,
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def state(self) -> StateResult:
|
| 137 |
+
return StateResult(
|
| 138 |
+
task_id=TASK_ID,
|
| 139 |
+
contract_name=self._contract.get("contract_name", ""),
|
| 140 |
+
target_function=self._target_fn.get("name"),
|
| 141 |
+
step_count=self._step_count,
|
| 142 |
+
cumulative_reward=self._cumulative_reward,
|
| 143 |
+
done=self._done,
|
| 144 |
+
query_history=list(self._query_history),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# ------------------------------------------------------------------
|
| 148 |
+
# Internal helpers
|
| 149 |
+
# ------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
def _build_observation(
|
| 152 |
+
self,
|
| 153 |
+
last_action: Optional[str],
|
| 154 |
+
last_result: str,
|
| 155 |
+
) -> Observation:
|
| 156 |
+
return Observation(
|
| 157 |
+
task_id=TASK_ID,
|
| 158 |
+
contract_name=self._contract.get("contract_name", ""),
|
| 159 |
+
contract_description=self._contract.get("metadata", {}).get("description", ""),
|
| 160 |
+
available_actions=[a.value for a in AVAILABLE_ACTIONS],
|
| 161 |
+
last_action=last_action,
|
| 162 |
+
last_action_result=last_result,
|
| 163 |
+
step_count=self._step_count,
|
| 164 |
+
cumulative_reward=self._cumulative_reward,
|
| 165 |
+
done=self._done,
|
| 166 |
+
extra={
|
| 167 |
+
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
|
| 168 |
+
"hint": (
|
| 169 |
+
"Identify the vulnerable function and its issue. "
|
| 170 |
+
"Submit with action_type='submit', params={'function_name': '...', "
|
| 171 |
+
"'vulnerability_type': '...'}"
|
| 172 |
+
),
|
| 173 |
+
},
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _query_key(self, action_type: str, params: Dict[str, Any]) -> str:
|
| 177 |
+
"""Build a hashable key for repeated-query detection."""
|
| 178 |
+
return f"{action_type}:{sorted(params.items())}"
|
| 179 |
+
|
| 180 |
+
def _is_repeated(self, key: str) -> bool:
|
| 181 |
+
if key in self._seen_queries:
|
| 182 |
+
return True
|
| 183 |
+
self._seen_queries.add(key)
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
def _dispatch(self, action: Action) -> tuple[str, Reward]:
|
| 187 |
+
at = action.action_type
|
| 188 |
+
params = action.params
|
| 189 |
+
qkey = self._query_key(at, params)
|
| 190 |
+
|
| 191 |
+
# ---- list_functions ----------------------------------------
|
| 192 |
+
if at == ActionType.LIST_FUNCTIONS:
|
| 193 |
+
if self._is_repeated(qkey):
|
| 194 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 195 |
+
names = list_function_names(self._contract)
|
| 196 |
+
return (
|
| 197 |
+
f"Functions in {self._contract['contract_name']}: {', '.join(names)}",
|
| 198 |
+
Reward(value=-0.05, reason="list_functions cost", partial=True),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ---- get_function_code -------------------------------------
|
| 202 |
+
if at == ActionType.GET_FUNCTION_CODE:
|
| 203 |
+
fn_name = params.get("function_name", "")
|
| 204 |
+
if self._is_repeated(qkey):
|
| 205 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 206 |
+
fn = get_function_by_name(self._contract, fn_name)
|
| 207 |
+
if fn is None:
|
| 208 |
+
return (
|
| 209 |
+
f"Function '{fn_name}' not found. Available: {list_function_names(self._contract)}",
|
| 210 |
+
Reward(value=-0.10, reason="Wrong/unknown function name", partial=True),
|
| 211 |
+
)
|
| 212 |
+
is_target = fn["name"].lower() == self._target_fn["name"].lower()
|
| 213 |
+
code = fn.get("code", "// no code available")
|
| 214 |
+
reward_val = 0.05 if is_target else -0.10
|
| 215 |
+
reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
|
| 216 |
+
return (
|
| 217 |
+
f"// {fn['name']}\n{code}",
|
| 218 |
+
Reward(value=reward_val, reason=reason, partial=True),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# ---- get_function_summary ----------------------------------
|
| 222 |
+
if at == ActionType.GET_FUNCTION_SUMMARY:
|
| 223 |
+
fn_name = params.get("function_name", "")
|
| 224 |
+
if self._is_repeated(qkey):
|
| 225 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 226 |
+
fn = get_function_by_name(self._contract, fn_name)
|
| 227 |
+
if fn is None:
|
| 228 |
+
return (
|
| 229 |
+
f"Function '{fn_name}' not found.",
|
| 230 |
+
Reward(value=-0.05, reason="Wrong function name", partial=True),
|
| 231 |
+
)
|
| 232 |
+
is_target = fn["name"].lower() == self._target_fn["name"].lower()
|
| 233 |
+
comment = fn.get("comment", "No summary available.")
|
| 234 |
+
reward_val = 0.03 if is_target else -0.05
|
| 235 |
+
reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
|
| 236 |
+
return (
|
| 237 |
+
f"Summary of '{fn['name']}': {comment}",
|
| 238 |
+
Reward(value=reward_val, reason=reason, partial=True),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# ---- get_file_metadata -------------------------------------
|
| 242 |
+
if at == ActionType.GET_FILE_METADATA:
|
| 243 |
+
if self._is_repeated(qkey):
|
| 244 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 245 |
+
meta = self._contract.get("metadata", {})
|
| 246 |
+
result = (
|
| 247 |
+
f"Contract: {self._contract['contract_name']} | "
|
| 248 |
+
f"File: {self._contract.get('file_name', 'N/A')} | "
|
| 249 |
+
f"Solidity: {meta.get('solidity_version', 'N/A')} | "
|
| 250 |
+
f"License: {meta.get('license', 'N/A')} | "
|
| 251 |
+
f"Author: {meta.get('author', 'N/A')} | "
|
| 252 |
+
f"Description: {meta.get('description', 'N/A')}"
|
| 253 |
+
)
|
| 254 |
+
return result, Reward(value=-0.04, reason="get_file_metadata cost", partial=True)
|
| 255 |
+
|
| 256 |
+
# ---- get_state_variable ------------------------------------
|
| 257 |
+
if at == ActionType.GET_STATE_VARIABLE:
|
| 258 |
+
var_name = params.get("variable_name", "")
|
| 259 |
+
if self._is_repeated(qkey):
|
| 260 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 261 |
+
if not var_name:
|
| 262 |
+
# Return list of all state variables
|
| 263 |
+
names = list_state_variable_names(self._contract)
|
| 264 |
+
return (
|
| 265 |
+
f"State variables: {', '.join(names)}",
|
| 266 |
+
Reward(value=-0.05, reason="Listed state variables", partial=True),
|
| 267 |
+
)
|
| 268 |
+
sv = get_state_variable_by_name(self._contract, var_name)
|
| 269 |
+
if sv is None:
|
| 270 |
+
return (
|
| 271 |
+
f"Variable '{var_name}' not found.",
|
| 272 |
+
Reward(value=-0.05, reason="Unknown state variable", partial=True),
|
| 273 |
+
)
|
| 274 |
+
return (
|
| 275 |
+
f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
|
| 276 |
+
Reward(value=-0.05, reason="get_state_variable cost", partial=True),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# ---- get_call_graph ----------------------------------------
|
| 280 |
+
if at == ActionType.GET_CALL_GRAPH:
|
| 281 |
+
if self._is_repeated(qkey):
|
| 282 |
+
return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
|
| 283 |
+
cg = self._contract.get("call_graph", {})
|
| 284 |
+
cg_str = "; ".join(f"{fn} β [{', '.join(callees)}]" for fn, callees in cg.items())
|
| 285 |
+
return (
|
| 286 |
+
f"Call graph: {cg_str}",
|
| 287 |
+
Reward(value=-0.08, reason="get_call_graph cost", partial=True),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# ---- submit ------------------------------------------------
|
| 291 |
+
if at == ActionType.SUBMIT:
|
| 292 |
+
fn_name = params.get("function_name", "")
|
| 293 |
+
vuln_type = params.get("vulnerability_type", "")
|
| 294 |
+
if not fn_name or not vuln_type:
|
| 295 |
+
return (
|
| 296 |
+
"Submit requires 'function_name' and 'vulnerability_type' in params.",
|
| 297 |
+
Reward(value=-0.5, reason="Malformed submission", partial=True),
|
| 298 |
+
)
|
| 299 |
+
score = self._grader.grade_submission(fn_name, vuln_type)
|
| 300 |
+
reward_val = self._grader.reward_for_score(score)
|
| 301 |
+
self._done = True
|
| 302 |
+
|
| 303 |
+
if score == 1.0:
|
| 304 |
+
msg = (
|
| 305 |
+
f"β
CORRECT! '{fn_name}' is the vulnerable function. "
|
| 306 |
+
f"Vulnerability type '{vuln_type}' matches. Score: 1.0"
|
| 307 |
+
)
|
| 308 |
+
elif score == 0.5:
|
| 309 |
+
msg = (
|
| 310 |
+
f"β οΈ PARTIAL. '{fn_name}' is the right function, but the vulnerability type "
|
| 311 |
+
f"'{vuln_type}' was not precise. Score: 0.5"
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
correct = self._grader.get_canonical_answer()
|
| 315 |
+
msg = (
|
| 316 |
+
f"β INCORRECT. '{fn_name}' is not the target vulnerable function. "
|
| 317 |
+
f"Correct answer: {correct['function']} ({correct['vulnerability']}). Score: 0.0"
|
| 318 |
+
)
|
| 319 |
+
return msg, Reward(
|
| 320 |
+
value=reward_val,
|
| 321 |
+
reason=f"Submission score={score:.1f}",
|
| 322 |
+
partial=False,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# ---- unknown action ----------------------------------------
|
| 326 |
+
return (
|
| 327 |
+
f"Unknown action type: {at}",
|
| 328 |
+
Reward(value=-0.10, reason="Unknown action", partial=True),
|
| 329 |
+
)
|
tasks/task1/grader.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
grader.py (Task 1 β Targeted Vulnerability Detection)
|
| 3 |
+
-------------------------------------------------------
|
| 4 |
+
Deterministic grader. Score range: 0.0 β 1.0
|
| 5 |
+
|
| 6 |
+
1.0 β correct function + correct vulnerability keyword
|
| 7 |
+
0.5 β correct function + wrong/unrecognised vulnerability keyword
|
| 8 |
+
0.0 β wrong function name
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
VULN_KEYWORDS: Dict[str, List[str]] = {
|
| 14 |
+
"reentrancy": [
|
| 15 |
+
"reentrancy", "re-entrancy", "reentrant", "re entrant",
|
| 16 |
+
"recursive call", "reentr",
|
| 17 |
+
],
|
| 18 |
+
"missing access control": [
|
| 19 |
+
"access control", "missing access", "no access", "unauthorized",
|
| 20 |
+
"privilege", "permission", "onlyowner", "only owner",
|
| 21 |
+
"no modifier", "missing modifier", "no check", "anyone can call",
|
| 22 |
+
],
|
| 23 |
+
"integer overflow": [
|
| 24 |
+
"overflow", "integer overflow", "arithmetic overflow",
|
| 25 |
+
"safemath", "safe math", "uint overflow", "wraparound",
|
| 26 |
+
"integer underflow", "underflow",
|
| 27 |
+
],
|
| 28 |
+
"tx.origin authentication": [
|
| 29 |
+
"tx.origin", "txorigin", "tx origin", "phishing",
|
| 30 |
+
"origin authentication", "origin auth",
|
| 31 |
+
],
|
| 32 |
+
"front-running": [
|
| 33 |
+
"front-running", "frontrunning", "front running", "mev",
|
| 34 |
+
"sandwich", "mempool", "commit reveal", "commit-reveal",
|
| 35 |
+
"gas price manipulation",
|
| 36 |
+
],
|
| 37 |
+
"timestamp dependence": [
|
| 38 |
+
"timestamp", "block.timestamp", "time manipulation",
|
| 39 |
+
"miner timestamp", "time dependency", "timestamp dependence",
|
| 40 |
+
],
|
| 41 |
+
"denial of service": [
|
| 42 |
+
"denial of service", " dos", "gas limit", "unbounded loop",
|
| 43 |
+
"block gas", " oog", "out of gas", "infinite loop", "unbounded array",
|
| 44 |
+
"gas exhaustion",
|
| 45 |
+
],
|
| 46 |
+
"unchecked return value": [
|
| 47 |
+
"unchecked return", "return value", "unchecked transfer",
|
| 48 |
+
"silent failure", "safeerc20", "safe transfer", "ignored return",
|
| 49 |
+
"erc20 return",
|
| 50 |
+
],
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _norm(text: str) -> str:
|
| 55 |
+
return text.strip().lower()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _find_bucket(ground_truth_issue: str) -> Optional[str]:
|
| 59 |
+
"""
|
| 60 |
+
Longest-match keyword search to identify canonical vulnerability bucket.
|
| 61 |
+
Longest match avoids short-keyword collisions (e.g. 'auth' in 'tx.origin authentication').
|
| 62 |
+
"""
|
| 63 |
+
norm_gt = _norm(ground_truth_issue)
|
| 64 |
+
best: Optional[str] = None
|
| 65 |
+
best_len: int = 0
|
| 66 |
+
for canonical, keywords in VULN_KEYWORDS.items():
|
| 67 |
+
for kw in keywords:
|
| 68 |
+
if kw in norm_gt and len(kw) > best_len:
|
| 69 |
+
best_len = len(kw)
|
| 70 |
+
best = canonical
|
| 71 |
+
return best
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def match_vuln_keyword(submitted: str, ground_truth_issue: str) -> bool:
|
| 75 |
+
bucket = _find_bucket(ground_truth_issue)
|
| 76 |
+
if bucket is None:
|
| 77 |
+
return _norm(submitted) in _norm(ground_truth_issue)
|
| 78 |
+
norm_sub = _norm(submitted)
|
| 79 |
+
return any(kw in norm_sub for kw in VULN_KEYWORDS[bucket])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Task1Grader:
|
| 83 |
+
def __init__(self, target_function: str, vulnerability_issue: str) -> None:
|
| 84 |
+
self.target_function = target_function.lower()
|
| 85 |
+
self.vulnerability_issue = vulnerability_issue
|
| 86 |
+
|
| 87 |
+
def grade_submission(self, submitted_function: str, submitted_vuln_type: str) -> float:
|
| 88 |
+
if submitted_function.strip().lower() != self.target_function:
|
| 89 |
+
return 0.0
|
| 90 |
+
return 1.0 if match_vuln_keyword(submitted_vuln_type, self.vulnerability_issue) else 0.5
|
| 91 |
+
|
| 92 |
+
def reward_for_score(self, score: float) -> float:
|
| 93 |
+
if score == 1.0: return 5.0
|
| 94 |
+
if score == 0.5: return 1.0
|
| 95 |
+
return -1.5
|
| 96 |
+
|
| 97 |
+
def get_canonical_answer(self) -> Dict[str, str]:
|
| 98 |
+
return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
|
tasks/task2/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tasks/task2/__init__.py
|
| 3 |
+
-----------------------
|
| 4 |
+
Task 2: Property Discovery (PLACEHOLDER)
|
| 5 |
+
|
| 6 |
+
TODO: Implement this task.
|
| 7 |
+
|
| 8 |
+
Episode setup:
|
| 9 |
+
- One function from a Solidity file with known properties
|
| 10 |
+
- Agent must discover the natural-language property of the function
|
| 11 |
+
|
| 12 |
+
Actions (to implement):
|
| 13 |
+
- get_similar_rule : -0.20
|
| 14 |
+
- get_file_natspec : -0.03
|
| 15 |
+
- get_function_natspec : -0.08
|
| 16 |
+
- get_function_code : -0.06
|
| 17 |
+
- get_related_functions : -0.06
|
| 18 |
+
- get_io : -0.04
|
| 19 |
+
- submit_property : scored 0.0β5.0 by semantic similarity grader
|
| 20 |
+
|
| 21 |
+
See README.md for full task specification.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# TODO: Task 2 β Property Discovery
|
| 25 |
+
# from tasks.task2.environment import Task2Environment
|
| 26 |
+
|
| 27 |
+
__all__: list = []
|
tasks/task3/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tasks/task3/__init__.py
|
| 3 |
+
-----------------------
|
| 4 |
+
Task 3: Rule Checker (PLACEHOLDER)
|
| 5 |
+
|
| 6 |
+
TODO: Implement this task.
|
| 7 |
+
|
| 8 |
+
Episode setup:
|
| 9 |
+
- One Solidity file with at least one function breaking a given property
|
| 10 |
+
- Agent is shown the property in natural English
|
| 11 |
+
|
| 12 |
+
Actions (to implement):
|
| 13 |
+
- get_formalized_property : -0.03
|
| 14 |
+
- list_functions : -0.05
|
| 15 |
+
- get_function_metadata : -0.05
|
| 16 |
+
- get_function_code : -0.10
|
| 17 |
+
- get_state_variables : -0.05
|
| 18 |
+
- get_call_graph : -0.08
|
| 19 |
+
- submit_function :
|
| 20 |
+
- correct = +5.0
|
| 21 |
+
- subfunction of target = +1.5
|
| 22 |
+
- wrong = -1.5
|
| 23 |
+
(ONE submission per episode)
|
| 24 |
+
|
| 25 |
+
See README.md for full task specification.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# TODO: Task 3 β Rule Checker
|
| 29 |
+
# from tasks.task3.environment import Task3Environment
|
| 30 |
+
|
| 31 |
+
__all__: list = []
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# utils package
|
validate.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
validate.py
|
| 3 |
+
-----------
|
| 4 |
+
Pre-submission validation script.
|
| 5 |
+
Checks all OpenEnv spec requirements locally before submitting.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python validate.py
|
| 9 |
+
|
| 10 |
+
Exit code 0 = all checks pass.
|
| 11 |
+
Exit code 1 = one or more checks failed.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import sys
|
| 16 |
+
import traceback
|
| 17 |
+
from typing import Callable, List, Tuple
|
| 18 |
+
|
| 19 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
# Helpers
|
| 21 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
|
| 23 |
+
PASS = "β
"
|
| 24 |
+
FAIL = "β"
|
| 25 |
+
SKIP = "β "
|
| 26 |
+
results: List[Tuple[str, bool, str]] = []
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check(name: str, fn: Callable[[], None]) -> None:
|
| 30 |
+
try:
|
| 31 |
+
fn()
|
| 32 |
+
results.append((name, True, ""))
|
| 33 |
+
print(f" {PASS} {name}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
tb = traceback.format_exc(limit=3)
|
| 36 |
+
results.append((name, False, str(e)))
|
| 37 |
+
print(f" {FAIL} {name}")
|
| 38 |
+
print(f" {e}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
# Checks
|
| 43 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
def check_imports():
|
| 46 |
+
from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult
|
| 47 |
+
from tasks.task1.environment import Task1Environment
|
| 48 |
+
from tasks.task1.grader import Task1Grader
|
| 49 |
+
from data.data_loader import load_contracts
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def check_openenv_yaml():
|
| 53 |
+
import yaml
|
| 54 |
+
with open("openenv.yaml") as f:
|
| 55 |
+
spec = yaml.safe_load(f)
|
| 56 |
+
assert "name" in spec
|
| 57 |
+
assert "tasks" in spec
|
| 58 |
+
assert len(spec["tasks"]) >= 3, "Need at least 3 tasks defined"
|
| 59 |
+
assert "observation_space" in spec
|
| 60 |
+
assert "action_space" in spec
|
| 61 |
+
assert "reward" in spec
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def check_pydantic_models():
|
| 65 |
+
from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
|
| 66 |
+
# Instantiate each model
|
| 67 |
+
obs = Observation(
|
| 68 |
+
task_id="t1", contract_name="C", contract_description="D",
|
| 69 |
+
available_actions=["submit"]
|
| 70 |
+
)
|
| 71 |
+
assert obs.task_id == "t1"
|
| 72 |
+
|
| 73 |
+
action = Action(action_type=ActionType.LIST_FUNCTIONS)
|
| 74 |
+
assert action.action_type == ActionType.LIST_FUNCTIONS
|
| 75 |
+
|
| 76 |
+
reward = Reward(value=1.0, reason="test")
|
| 77 |
+
assert reward.value == 1.0
|
| 78 |
+
|
| 79 |
+
step = StepResult(observation=obs, reward=reward, done=False)
|
| 80 |
+
assert not step.done
|
| 81 |
+
|
| 82 |
+
reset = ResetResult(observation=obs)
|
| 83 |
+
assert reset.observation.task_id == "t1"
|
| 84 |
+
|
| 85 |
+
state = StateResult(task_id="t1", contract_name="C", step_count=0,
|
| 86 |
+
cumulative_reward=0.0, done=False)
|
| 87 |
+
assert state.step_count == 0
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def check_data_loading():
|
| 91 |
+
from data.data_loader import load_contracts, get_all_vulnerable_entries
|
| 92 |
+
contracts = load_contracts()
|
| 93 |
+
assert len(contracts) >= 1, "No contracts loaded"
|
| 94 |
+
entries = get_all_vulnerable_entries(contracts)
|
| 95 |
+
assert len(entries) >= 3, f"Need >= 3 vulnerable functions, got {len(entries)}"
|
| 96 |
+
for contract, fn in entries:
|
| 97 |
+
assert fn.get("vulnerable") is True
|
| 98 |
+
assert fn.get("vulnerability_details") is not None
|
| 99 |
+
assert "issue" in fn["vulnerability_details"]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def check_env_reset():
|
| 103 |
+
from tasks.task1.environment import Task1Environment
|
| 104 |
+
env = Task1Environment()
|
| 105 |
+
result = env.reset(seed=42)
|
| 106 |
+
assert result.observation is not None
|
| 107 |
+
assert result.observation.task_id == "task1_vuln_detection"
|
| 108 |
+
assert result.observation.contract_name != ""
|
| 109 |
+
assert not result.observation.done
|
| 110 |
+
assert result.observation.step_count == 0
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def check_env_step():
|
| 114 |
+
from tasks.task1.environment import Task1Environment
|
| 115 |
+
from env.schemas import Action, ActionType
|
| 116 |
+
env = Task1Environment()
|
| 117 |
+
env.reset(seed=42)
|
| 118 |
+
result = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 119 |
+
assert result.observation is not None
|
| 120 |
+
assert isinstance(result.reward.value, float)
|
| 121 |
+
assert isinstance(result.done, bool)
|
| 122 |
+
assert "info" in result.model_dump()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def check_env_state():
|
| 126 |
+
from tasks.task1.environment import Task1Environment
|
| 127 |
+
env = Task1Environment()
|
| 128 |
+
env.reset(seed=42)
|
| 129 |
+
state = env.state()
|
| 130 |
+
assert state.task_id == "task1_vuln_detection"
|
| 131 |
+
assert state.contract_name != ""
|
| 132 |
+
assert state.target_function is not None # exposed for debugging
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def check_grader_scores_in_range():
|
| 136 |
+
from tasks.task1.grader import Task1Grader
|
| 137 |
+
cases = [
|
| 138 |
+
("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
|
| 139 |
+
("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
|
| 140 |
+
("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
|
| 141 |
+
]
|
| 142 |
+
for tf, issue, sf, sv, expected in cases:
|
| 143 |
+
g = Task1Grader(tf, issue)
|
| 144 |
+
score = g.grade_submission(sf, sv)
|
| 145 |
+
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
|
| 146 |
+
assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def check_grader_deterministic():
|
| 150 |
+
from tasks.task1.grader import Task1Grader
|
| 151 |
+
g = Task1Grader("withdraw", "Reentrancy vulnerability")
|
| 152 |
+
s1 = g.grade_submission("withdraw", "reentrancy")
|
| 153 |
+
s2 = g.grade_submission("withdraw", "reentrancy")
|
| 154 |
+
assert s1 == s2 == 1.0, "Grader must be deterministic"
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def check_reward_shaping():
|
| 158 |
+
"""Verify reward is non-binary (multiple distinct values across steps)."""
|
| 159 |
+
from tasks.task1.environment import Task1Environment
|
| 160 |
+
from env.schemas import Action, ActionType
|
| 161 |
+
env = Task1Environment()
|
| 162 |
+
env.reset(seed=1)
|
| 163 |
+
rewards = set()
|
| 164 |
+
for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_FILE_METADATA, ActionType.GET_CALL_GRAPH]:
|
| 165 |
+
r = env.step(Action(action_type=at))
|
| 166 |
+
rewards.add(round(r.reward.value, 4))
|
| 167 |
+
# Should have at least 2 distinct shaping reward values
|
| 168 |
+
assert len(rewards) >= 2, f"Expected multiple reward values, got {rewards}"
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def check_episode_boundary():
|
| 172 |
+
"""Episode must end after submit and raise on subsequent step."""
|
| 173 |
+
from tasks.task1.environment import Task1Environment
|
| 174 |
+
from env.schemas import Action, ActionType
|
| 175 |
+
env = Task1Environment()
|
| 176 |
+
env.reset(seed=2)
|
| 177 |
+
env.step(Action(action_type=ActionType.SUBMIT, params={
|
| 178 |
+
"function_name": "withdraw", "vulnerability_type": "test"
|
| 179 |
+
}))
|
| 180 |
+
try:
|
| 181 |
+
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 182 |
+
raise AssertionError("Should have raised RuntimeError after episode end")
|
| 183 |
+
except RuntimeError:
|
| 184 |
+
pass # Expected
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def check_repeated_query_penalty():
|
| 188 |
+
from tasks.task1.environment import Task1Environment
|
| 189 |
+
from env.schemas import Action, ActionType
|
| 190 |
+
env = Task1Environment()
|
| 191 |
+
env.reset(seed=3)
|
| 192 |
+
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 193 |
+
r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
|
| 194 |
+
assert r.reward.value == -0.40, f"Expected -0.40 for repeated query, got {r.reward.value}"
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def check_tasks_list():
|
| 198 |
+
"""All three tasks must be listed (even if placeholders)."""
|
| 199 |
+
from tasks.task2 import __all__ as t2 # noqa
|
| 200 |
+
from tasks.task3 import __all__ as t3 # noqa
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def check_dockerfile_exists():
|
| 204 |
+
import os
|
| 205 |
+
assert os.path.exists("Dockerfile"), "Dockerfile is missing"
|
| 206 |
+
with open("Dockerfile") as f:
|
| 207 |
+
content = f.read()
|
| 208 |
+
assert "7860" in content, "Dockerfile must EXPOSE 7860 (HF Spaces)"
|
| 209 |
+
assert "uvicorn" in content or "CMD" in content
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def check_inference_script():
|
| 213 |
+
import os
|
| 214 |
+
assert os.path.exists("inference.py"), "inference.py is missing"
|
| 215 |
+
with open("inference.py") as f:
|
| 216 |
+
content = f.read()
|
| 217 |
+
assert "OPENAI_API_KEY" in content or "HF_TOKEN" in content, \
|
| 218 |
+
"inference.py must read API credentials from env vars"
|
| 219 |
+
assert "API_BASE_URL" in content
|
| 220 |
+
assert "MODEL_NAME" in content
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def check_baseline_json_schema():
|
| 224 |
+
"""baseline_scores.json must have valid schema if it exists."""
|
| 225 |
+
import os
|
| 226 |
+
if not os.path.exists("baseline_scores.json"):
|
| 227 |
+
return # OK β file is generated at runtime
|
| 228 |
+
with open("baseline_scores.json") as f:
|
| 229 |
+
data = json.load(f)
|
| 230 |
+
assert "tasks" in data
|
| 231 |
+
for task in data["tasks"]:
|
| 232 |
+
score = task["avg_grader_score"]
|
| 233 |
+
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
+
# Runner
|
| 238 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
+
|
| 240 |
+
def main():
|
| 241 |
+
print("=" * 60)
|
| 242 |
+
print("OpenEnv Pre-Submission Validation")
|
| 243 |
+
print("=" * 60)
|
| 244 |
+
|
| 245 |
+
all_checks = [
|
| 246 |
+
("Python imports", check_imports),
|
| 247 |
+
("openenv.yaml format", check_openenv_yaml),
|
| 248 |
+
("Pydantic model types", check_pydantic_models),
|
| 249 |
+
("Dataset loading (3+ vulns)", check_data_loading),
|
| 250 |
+
("env.reset() β ResetResult", check_env_reset),
|
| 251 |
+
("env.step() β StepResult", check_env_step),
|
| 252 |
+
("env.state() β StateResult", check_env_state),
|
| 253 |
+
("Grader scores in [0.0, 1.0]", check_grader_scores_in_range),
|
| 254 |
+
("Grader is deterministic", check_grader_deterministic),
|
| 255 |
+
("Reward shaping (non-binary)", check_reward_shaping),
|
| 256 |
+
("Episode boundary (done=True)",check_episode_boundary),
|
| 257 |
+
("Repeated query penalty", check_repeated_query_penalty),
|
| 258 |
+
("Task 2 & 3 placeholders", check_tasks_list),
|
| 259 |
+
("Dockerfile exists + port", check_dockerfile_exists),
|
| 260 |
+
("inference.py exists + vars", check_inference_script),
|
| 261 |
+
("baseline_scores.json schema", check_baseline_json_schema),
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
print()
|
| 265 |
+
for name, fn in all_checks:
|
| 266 |
+
check(name, fn)
|
| 267 |
+
|
| 268 |
+
print()
|
| 269 |
+
passed = sum(1 for _, ok, _ in results if ok)
|
| 270 |
+
total = len(results)
|
| 271 |
+
failed = [(n, msg) for n, ok, msg in results if not ok]
|
| 272 |
+
|
| 273 |
+
print("=" * 60)
|
| 274 |
+
print(f"Results: {passed}/{total} checks passed")
|
| 275 |
+
|
| 276 |
+
if failed:
|
| 277 |
+
print("\nFailed checks:")
|
| 278 |
+
for name, msg in failed:
|
| 279 |
+
print(f" {FAIL} {name}: {msg}")
|
| 280 |
+
print()
|
| 281 |
+
print("β VALIDATION FAILED β fix the issues above before submitting.")
|
| 282 |
+
sys.exit(1)
|
| 283 |
+
else:
|
| 284 |
+
print()
|
| 285 |
+
print("β
ALL CHECKS PASSED β ready to submit!")
|
| 286 |
+
sys.exit(0)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|