Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- README.md +70 -12
- TECHNICAL_DEEP_DIVE.md +642 -0
- client.py +8 -0
- inference.py +61 -4
- models.py +21 -1
- server/api_specs.py +316 -1
- server/app.py +34 -1
- server/environment.py +263 -6
- server/error_injectors.py +261 -0
- server/response_specs.py +268 -0
- tests/test_environment.py +240 -0
- training/README.md +71 -0
- training/__init__.py +0 -0
- training/requirements.txt +6 -0
- training/train.py +309 -0
README.md
CHANGED
|
@@ -21,7 +21,7 @@ Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology 2026.
|
|
| 21 |
|
| 22 |
## Infinite Unique Scenarios
|
| 23 |
|
| 24 |
-
Unlike fixed-fixture environments where every episode presents the same scenario, this environment generates a unique broken request each episode. With
|
| 25 |
|
| 26 |
## Why This Domain
|
| 27 |
|
|
@@ -29,7 +29,7 @@ Developers spend significant time debugging API contract mismatches. Research fr
|
|
| 29 |
|
| 30 |
## How It Works
|
| 31 |
|
| 32 |
-
1. On `reset()`, the environment picks a random API spec from
|
| 33 |
2. The agent receives the broken request, headers, and the API specification
|
| 34 |
3. The agent submits a fix attempt via `step()`
|
| 35 |
4. The environment grades the attempt and returns structured feedback
|
|
@@ -42,7 +42,10 @@ Each episode allows multiple attempts. Perfect answers on early steps earn full
|
|
| 42 |
| Task | Difficulty | Max Steps | Errors | Grading |
|
| 43 |
|------|-----------|-----------|--------|---------|
|
| 44 |
| easy | Identify error type and affected fields | 3 | 1 | Deterministic: 0.6 x type_match + 0.4 x fields_match |
|
|
|
|
| 45 |
| medium | Fix the broken request | 5 | 1 | Deterministic: per-field validation against spec |
|
|
|
|
|
|
|
| 46 |
| hard | Fix request + explain for developers | 7 | 2-3 | 70% deterministic fix + 30% LLM-as-judge explanation (gpt-4o-mini) |
|
| 47 |
|
| 48 |
## Error Types
|
|
@@ -59,8 +62,13 @@ Each episode allows multiple attempts. Perfect answers on early steps earn full
|
|
| 59 |
| malformed_json_value | Corrupted field value | `{broken` as a value |
|
| 60 |
| invalid_enum_value | Value not in allowed list | `currency: "xyz"` |
|
| 61 |
| datetime_format_error | Wrong date format | `04/01/2026` instead of ISO 8601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
## API Spec Domains (
|
| 64 |
|
| 65 |
| Domain | Count | Examples |
|
| 66 |
|--------|-------|---------|
|
|
@@ -70,6 +78,9 @@ Each episode allows multiple attempts. Perfect answers on early steps earn full
|
|
| 70 |
| Messaging (Twilio-like) | 5 | Send SMS, Send Email, Create Webhook |
|
| 71 |
| E-Commerce | 5 | Create Order, Process Payment, Create Shipping Label |
|
| 72 |
| Calendar and Auth | 5 | Create Event, OAuth Token, Create API Key |
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
## Action Space
|
| 75 |
|
|
@@ -77,11 +88,14 @@ The agent sends an `APIDebugAction` with these fields (all optional, submit what
|
|
| 77 |
|
| 78 |
| Field | Type | Used In | Description |
|
| 79 |
|-------|------|---------|-------------|
|
| 80 |
-
| error_type | string | easy, hard | Diagnosed error type |
|
| 81 |
-
|
|
|
|
|
| 82 |
| fixed_request | string (JSON) | medium, hard | Corrected request body |
|
| 83 |
-
| fixed_headers | dict | medium, hard | Corrected HTTP headers |
|
| 84 |
| explanation | string | hard | Developer-facing explanation |
|
|
|
|
|
|
|
| 85 |
|
| 86 |
## Observation Space
|
| 87 |
|
|
@@ -89,13 +103,15 @@ The environment returns an `APIDebugObservation` with:
|
|
| 89 |
|
| 90 |
| Field | Type | Description |
|
| 91 |
|-------|------|-------------|
|
| 92 |
-
| task | string | Current
|
| 93 |
| api_name | string | Name of the API (e.g. "Create Customer") |
|
| 94 |
| http_method | string | HTTP method of the request |
|
| 95 |
| endpoint | string | API endpoint path |
|
| 96 |
| broken_request | string (JSON) | The malformed request body |
|
| 97 |
| broken_headers | dict | HTTP headers sent with the request |
|
| 98 |
| api_spec | string (JSON) | API specification with required fields and types |
|
|
|
|
|
|
|
| 99 |
| error_count | int | Number of errors injected |
|
| 100 |
| step_number | int | Current step in the episode |
|
| 101 |
| max_steps | int | Maximum steps allowed |
|
|
@@ -138,12 +154,48 @@ Scores from running inference.py against the live HF Space (3 episodes per task,
|
|
| 138 |
| Task | Episodes | Qwen2.5-72B-Instruct | gpt-4o-mini |
|
| 139 |
|------|----------|----------------------|-------------|
|
| 140 |
| easy | 3 | 0.999 | 0.667 |
|
|
|
|
| 141 |
| medium | 3 | 0.999 | 0.999 |
|
|
|
|
|
|
|
| 142 |
| hard | 3 | 0.780 | 0.760 |
|
| 143 |
-
| **overall** | **9** | **0.926** | **0.809** |
|
| 144 |
|
| 145 |
Hard task uses LLM-as-judge (gpt-4o-mini) for explanation quality scoring, which is stricter than a heuristic baseline. The agent must fix 2-3 simultaneous errors and provide a developer-facing explanation to score high. Larger models perform better on the hard task, showing meaningful difficulty progression.
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
## Setup
|
| 148 |
|
| 149 |
### Prerequisites
|
|
@@ -213,13 +265,19 @@ api-debug-env/
|
|
| 213 |
├── __init__.py
|
| 214 |
├── tests/
|
| 215 |
│ ├── __init__.py
|
| 216 |
-
│ └── test_environment.py #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
└── server/
|
| 218 |
├── __init__.py
|
| 219 |
├── app.py # FastAPI app via create_app()
|
| 220 |
-
├── environment.py # Core logic:
|
| 221 |
-
├── api_specs.py #
|
| 222 |
-
├── error_injectors.py #
|
|
|
|
| 223 |
└── validators.py # Field type validation helpers
|
| 224 |
```
|
| 225 |
|
|
|
|
| 21 |
|
| 22 |
## Infinite Unique Scenarios
|
| 23 |
|
| 24 |
+
Unlike fixed-fixture environments where every episode presents the same scenario, this environment generates a unique broken request each episode. With 45 API spec templates across 9 domains and 15 error injection functions (including chained multi-step errors), the environment produces tens of thousands of distinct training scenarios. An agent cannot memorize answers after one run - it must learn a generalizable debugging strategy. This is critical for real RL training value: agents learn transferable skills rather than dataset-specific shortcuts.
|
| 25 |
|
| 26 |
## Why This Domain
|
| 27 |
|
|
|
|
| 29 |
|
| 30 |
## How It Works
|
| 31 |
|
| 32 |
+
1. On `reset()`, the environment picks a random API spec from 45 templates and injects 1-3 errors
|
| 33 |
2. The agent receives the broken request, headers, and the API specification
|
| 34 |
3. The agent submits a fix attempt via `step()`
|
| 35 |
4. The environment grades the attempt and returns structured feedback
|
|
|
|
| 42 |
| Task | Difficulty | Max Steps | Errors | Grading |
|
| 43 |
|------|-----------|-----------|--------|---------|
|
| 44 |
| easy | Identify error type and affected fields | 3 | 1 | Deterministic: 0.6 x type_match + 0.4 x fields_match |
|
| 45 |
+
| classify | Identify ALL error types across multiple errors | 4 | 2-3 | Deterministic: 0.6 x Jaccard(types) + 0.4 x Jaccard(fields) |
|
| 46 |
| medium | Fix the broken request | 5 | 1 | Deterministic: per-field validation against spec |
|
| 47 |
+
| headers | Fix header-level errors (auth, content-type, tokens) | 4 | 1 | Deterministic: 0.7 x header_fix + 0.3 x type_match |
|
| 48 |
+
| response | Validate API response for issues | 4 | 1-2 | Deterministic: 0.5 x Jaccard(issues) + 0.3 x Jaccard(fields) + 0.2 x status_code |
|
| 49 |
| hard | Fix request + explain for developers | 7 | 2-3 | 70% deterministic fix + 30% LLM-as-judge explanation (gpt-4o-mini) |
|
| 50 |
|
| 51 |
## Error Types
|
|
|
|
| 62 |
| malformed_json_value | Corrupted field value | `{broken` as a value |
|
| 63 |
| invalid_enum_value | Value not in allowed list | `currency: "xyz"` |
|
| 64 |
| datetime_format_error | Wrong date format | `04/01/2026` instead of ISO 8601 |
|
| 65 |
+
| wrong_content_type | Wrong Content-Type header | `text/plain` instead of `application/json` |
|
| 66 |
+
| expired_auth_token | Expired or invalid auth token | `Bearer expired_token_2024` |
|
| 67 |
+
| wrong_status_code | Wrong HTTP status code in response | `200` instead of `201` for resource creation |
|
| 68 |
+
| redirect_loop | Redirect configuration error | Version upgrade redirect loop |
|
| 69 |
+
| rate_limit_headers | Rate limit exceeded headers | `X-RateLimit-Remaining: 0` |
|
| 70 |
|
| 71 |
+
## API Spec Domains (45 templates)
|
| 72 |
|
| 73 |
| Domain | Count | Examples |
|
| 74 |
|--------|-------|---------|
|
|
|
|
| 78 |
| Messaging (Twilio-like) | 5 | Send SMS, Send Email, Create Webhook |
|
| 79 |
| E-Commerce | 5 | Create Order, Process Payment, Create Shipping Label |
|
| 80 |
| Calendar and Auth | 5 | Create Event, OAuth Token, Create API Key |
|
| 81 |
+
| Analytics/Monitoring | 5 | Create Dashboard, Add Metric, Create Alert |
|
| 82 |
+
| DevOps/Infrastructure | 5 | Create Deployment, Scale Service, Create DNS Record |
|
| 83 |
+
| AI/ML APIs | 5 | Submit Inference, Create Fine-tune Job, Upload Dataset |
|
| 84 |
|
| 85 |
## Action Space
|
| 86 |
|
|
|
|
| 88 |
|
| 89 |
| Field | Type | Used In | Description |
|
| 90 |
|-------|------|---------|-------------|
|
| 91 |
+
| error_type | string | easy, headers, hard | Diagnosed error type |
|
| 92 |
+
| error_types | list[string] | classify | All diagnosed error types (multi-error) |
|
| 93 |
+
| affected_fields | list[string] | easy, classify, response, hard | Fields affected by the error |
|
| 94 |
| fixed_request | string (JSON) | medium, hard | Corrected request body |
|
| 95 |
+
| fixed_headers | dict | medium, headers, hard | Corrected HTTP headers |
|
| 96 |
| explanation | string | hard | Developer-facing explanation |
|
| 97 |
+
| response_issues | list[string] | response | Issue types found in API response |
|
| 98 |
+
| expected_status_code | int | response | Correct HTTP status code |
|
| 99 |
|
| 100 |
## Observation Space
|
| 101 |
|
|
|
|
| 103 |
|
| 104 |
| Field | Type | Description |
|
| 105 |
|-------|------|-------------|
|
| 106 |
+
| task | string | Current task: easy, classify, medium, headers, response, hard |
|
| 107 |
| api_name | string | Name of the API (e.g. "Create Customer") |
|
| 108 |
| http_method | string | HTTP method of the request |
|
| 109 |
| endpoint | string | API endpoint path |
|
| 110 |
| broken_request | string (JSON) | The malformed request body |
|
| 111 |
| broken_headers | dict | HTTP headers sent with the request |
|
| 112 |
| api_spec | string (JSON) | API specification with required fields and types |
|
| 113 |
+
| response_body | string (JSON) | Server response body (response task only) |
|
| 114 |
+
| response_status_code | int | HTTP status code of response (response task only) |
|
| 115 |
| error_count | int | Number of errors injected |
|
| 116 |
| step_number | int | Current step in the episode |
|
| 117 |
| max_steps | int | Maximum steps allowed |
|
|
|
|
| 154 |
| Task | Episodes | Qwen2.5-72B-Instruct | gpt-4o-mini |
|
| 155 |
|------|----------|----------------------|-------------|
|
| 156 |
| easy | 3 | 0.999 | 0.667 |
|
| 157 |
+
| classify | 3 | -- | -- |
|
| 158 |
| medium | 3 | 0.999 | 0.999 |
|
| 159 |
+
| headers | 3 | -- | -- |
|
| 160 |
+
| response | 3 | -- | -- |
|
| 161 |
| hard | 3 | 0.780 | 0.760 |
|
|
|
|
| 162 |
|
| 163 |
Hard task uses LLM-as-judge (gpt-4o-mini) for explanation quality scoring, which is stricter than a heuristic baseline. The agent must fix 2-3 simultaneous errors and provide a developer-facing explanation to score high. Larger models perform better on the hard task, showing meaningful difficulty progression.
|
| 164 |
|
| 165 |
+
## Chained Multi-Step Errors
|
| 166 |
+
|
| 167 |
+
The hard task supports chained error scenarios where errors depend on each other. Fixing one error reveals the next, simulating real-world API debugging:
|
| 168 |
+
|
| 169 |
+
| Chain Pattern | Gate Error | Body Errors |
|
| 170 |
+
|---------------|-----------|-------------|
|
| 171 |
+
| auth_gate | missing_auth_header, expired_auth_token | (any body error) |
|
| 172 |
+
| content_type_gate | wrong_content_type | wrong_field_type, malformed_json_value, invalid_enum_value |
|
| 173 |
+
| method_chain | wrong_http_method | missing_required_field, extra_unknown_field, null_value_in_required |
|
| 174 |
+
| rate_limit_chain | rate_limit_headers | expired_auth_token, missing_required_field |
|
| 175 |
+
| redirect_chain | redirect_loop | wrong_field_type, datetime_format_error, invalid_email_format |
|
| 176 |
+
|
| 177 |
+
## GRPO Training with Curriculum Learning
|
| 178 |
+
|
| 179 |
+
The `training/` directory contains a GRPO training script that trains a small LLM (Qwen 0.5B) using reward signals from the live environment:
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
pip install -r training/requirements.txt
|
| 183 |
+
python training/train.py
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
The training auto-promotes through 6 difficulty levels based on rolling average reward:
|
| 187 |
+
|
| 188 |
+
| Level | Task | Threshold | Max Turns |
|
| 189 |
+
|-------|------|-----------|-----------|
|
| 190 |
+
| 1 | easy | 0.7 | 3 |
|
| 191 |
+
| 2 | classify | 0.6 | 4 |
|
| 192 |
+
| 3 | medium | 0.6 | 5 |
|
| 193 |
+
| 4 | headers | 0.5 | 4 |
|
| 194 |
+
| 5 | response | 0.5 | 4 |
|
| 195 |
+
| 6 | hard | -- | 7 |
|
| 196 |
+
|
| 197 |
+
The environment also supports `task="auto"` which lets the environment itself manage curriculum progression based on session history.
|
| 198 |
+
|
| 199 |
## Setup
|
| 200 |
|
| 201 |
### Prerequisites
|
|
|
|
| 265 |
├── __init__.py
|
| 266 |
├── tests/
|
| 267 |
│ ├── __init__.py
|
| 268 |
+
│ └── test_environment.py # 109 unit tests
|
| 269 |
+
├── training/
|
| 270 |
+
│ ├── __init__.py
|
| 271 |
+
│ ├── train.py # GRPO training with 6-level curriculum
|
| 272 |
+
│ ├── requirements.txt
|
| 273 |
+
│ └── README.md
|
| 274 |
└── server/
|
| 275 |
├── __init__.py
|
| 276 |
├── app.py # FastAPI app via create_app()
|
| 277 |
+
├── environment.py # Core logic: 6 tasks, graders, LLM judge, auto-curriculum
|
| 278 |
+
├── api_specs.py # 45 API spec templates across 9 domains
|
| 279 |
+
├── error_injectors.py # 15 error types + 5 chain patterns
|
| 280 |
+
├── response_specs.py # 8 response templates + 5 issue injection types
|
| 281 |
└── validators.py # Field type validation helpers
|
| 282 |
```
|
| 283 |
|
TECHNICAL_DEEP_DIVE.md
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Debug Environment - Technical Deep Dive
|
| 2 |
+
|
| 3 |
+
A complete behind-the-scenes breakdown of the environment: what it does, how every piece works at the code level, what makes it unique, the bugs that almost killed the submission, and where it can go next.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
1. [What We Built](#what-we-built)
|
| 10 |
+
2. [Terminology Reference](#terminology-reference)
|
| 11 |
+
3. [Architecture Overview](#architecture-overview)
|
| 12 |
+
4. [Code-Level Walkthrough](#code-level-walkthrough)
|
| 13 |
+
5. [The Grading System](#the-grading-system)
|
| 14 |
+
6. [Reward Shaping and Step Decay](#reward-shaping-and-step-decay)
|
| 15 |
+
7. [LLM-as-Judge: How It Actually Works](#llm-as-judge-how-it-actually-works)
|
| 16 |
+
8. [Error Injection System](#error-injection-system)
|
| 17 |
+
9. [Infinite Unique Scenarios](#infinite-unique-scenarios)
|
| 18 |
+
10. [The Bugs That Almost Killed It](#the-bugs-that-almost-killed-it)
|
| 19 |
+
11. [Unique Features and Benefits](#unique-features-and-benefits)
|
| 20 |
+
12. [Baseline Results](#baseline-results)
|
| 21 |
+
13. [Scope for Advancement (Round 2)](#scope-for-advancement-round-2)
|
| 22 |
+
14. [Practical Applications: Where the Trained LLM Sits](#practical-applications-where-the-trained-llm-sits)
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## What We Built
|
| 27 |
+
|
| 28 |
+
An RL (Reinforcement Learning) environment where an LLM agent learns to debug malformed API requests and validate API responses. The agent receives a broken HTTP request along with its API specification, and must handle 6 progressively harder tasks:
|
| 29 |
+
|
| 30 |
+
- **Easy**: Identify the error type and affected fields (single error)
|
| 31 |
+
- **Classify**: Identify ALL error types and fields across multiple simultaneous errors
|
| 32 |
+
- **Medium**: Fix the broken request body to match the spec
|
| 33 |
+
- **Headers**: Fix header-level errors (auth, content-type, expired tokens)
|
| 34 |
+
- **Response**: Validate API responses for wrong status codes, missing fields, data leaks
|
| 35 |
+
- **Hard**: Fix multi-error requests (including chained dependencies) and explain the fix
|
| 36 |
+
|
| 37 |
+
Built on the **OpenEnv framework** (by Meta/PyTorch), with 45 API specs across 9 domains, 15 error types, 5 chained error patterns, GRPO training pipeline, and 6-level curriculum learning.
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Terminology Reference
|
| 42 |
+
|
| 43 |
+
| Term | Meaning in this project |
|
| 44 |
+
|------|------------------------|
|
| 45 |
+
| **OpenEnv** | Meta/PyTorch framework for building RL environments. Provides `Environment` base class, `create_app()` for FastAPI, `EnvClient` for agents. |
|
| 46 |
+
| **Episode** | One complete debugging session. Starts with `reset()`, ends when `done=True`. |
|
| 47 |
+
| **Step** | One agent attempt within an episode. Agent submits an action, gets observation + reward + feedback. |
|
| 48 |
+
| **Observation** | What the environment returns to the agent: broken request, spec, feedback, reward, done flag. |
|
| 49 |
+
| **Action** | What the agent sends: error type, affected fields, fixed request, headers, explanation. |
|
| 50 |
+
| **Ground truth** | The correct answer stored internally. Includes the original valid request, error type, and affected fields. Never exposed to the agent. |
|
| 51 |
+
| **Reward** | Float in (0, 1) representing how good the agent's attempt was. Higher = better. |
|
| 52 |
+
| **Step decay** | Multiplier that reduces reward for later steps: 1.0x at step 1, down to 0.3x floor at step 7+. |
|
| 53 |
+
| **best_reward** | Highest reward achieved across all steps in an episode. Returned as final score. |
|
| 54 |
+
| **Jaccard similarity** | Set similarity metric: size of intersection divided by size of union. Used for partial credit on field identification. |
|
| 55 |
+
| **LLM-as-judge** | Using an LLM (gpt-4o-mini) to score the quality of the agent's explanation. Only for hard task. |
|
| 56 |
+
| **Heuristic fallback** | Keyword + length scoring when the LLM judge is unavailable. Ensures hard task never blocks. |
|
| 57 |
+
| **Reward clamping** | `max(0.001, min(0.999, reward))` - keeps rewards in open interval (0,1) because the evaluator rejects boundary values. |
|
| 58 |
+
| **create_app()** | OpenEnv function that takes your Environment class and generates a FastAPI app with all required endpoints. |
|
| 59 |
+
| **EnvClient** | OpenEnv SDK class that agents use to connect. Handles WebSocket, auto-connect via `_ensure_connected()`. |
|
| 60 |
+
| **GRPO** | Group Relative Policy Optimization. An RL algorithm suitable for training LLMs with environment reward signals. |
|
| 61 |
+
| **Deterministic grading** | Scoring that produces the same result every time for the same input. No randomness, no LLM calls. |
|
| 62 |
+
| **Procedural generation** | Creating scenarios algorithmically (random spec + random error + random field) rather than from a fixed dataset. |
|
| 63 |
+
| **HF Space** | HuggingFace Spaces - cloud deployment for the environment. Runs Docker container on port 7860. |
|
| 64 |
+
| **Phase 1** | Automated check: Docker builds, `openenv validate` passes, health endpoint responds. |
|
| 65 |
+
| **Phase 2** | Automated check: evaluator's agent runs against the environment, scores must be in (0, 1), no crashes. |
|
| 66 |
+
| **Phase 3** | Human judges review environment quality, RL training value, code quality, documentation. |
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Architecture Overview
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
+--------------------+
|
| 74 |
+
| LLM Agent |
|
| 75 |
+
| (Qwen2.5-72B / |
|
| 76 |
+
| any model) |
|
| 77 |
+
+--------+-----------+
|
| 78 |
+
|
|
| 79 |
+
WebSocket / HTTP
|
| 80 |
+
|
|
| 81 |
+
+--------v-----------+
|
| 82 |
+
| OpenEnv SDK |
|
| 83 |
+
| create_app() |
|
| 84 |
+
| POST /reset |
|
| 85 |
+
| POST /step |
|
| 86 |
+
| WS /ws |
|
| 87 |
+
| GET /health |
|
| 88 |
+
+--------+-----------+
|
| 89 |
+
|
|
| 90 |
+
+--------------+---------------+
|
| 91 |
+
| | |
|
| 92 |
+
+---------v--+ +--------v---+ +--------v--------+
|
| 93 |
+
| api_specs | | error_ | | environment.py |
|
| 94 |
+
| .py | | injectors | | (core logic) |
|
| 95 |
+
| 45 specs | | .py | | reset(), step() |
|
| 96 |
+
| 9 domains | | 15 types | | 6 graders |
|
| 97 |
+
+------------+ +------------+ +---------+-------+
|
| 98 |
+
|
|
| 99 |
+
+---------v-------+
|
| 100 |
+
| validators.py |
|
| 101 |
+
| Field type |
|
| 102 |
+
| checking, |
|
| 103 |
+
| spec validation |
|
| 104 |
+
+-----------------+
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
**Key files:**
|
| 108 |
+
|
| 109 |
+
| File | Lines | Purpose |
|
| 110 |
+
|------|-------|---------|
|
| 111 |
+
| `server/environment.py` | 472 | Core environment: reset(), step(), 3 graders, LLM judge, reward clamping |
|
| 112 |
+
| `server/api_specs.py` | 640 | 30 API spec templates across 6 domains |
|
| 113 |
+
| `server/error_injectors.py` | 389 | 10 error injection functions + multi-error injector |
|
| 114 |
+
| `server/validators.py` | 151 | 12 field type validators + request/header validation |
|
| 115 |
+
| `server/app.py` | 84 | FastAPI app via OpenEnv's `create_app()` + `/tasks` endpoint |
|
| 116 |
+
| `models.py` | 97 | Pydantic models: APIDebugAction (5 optional fields), APIDebugObservation (13 fields) |
|
| 117 |
+
| `client.py` | 80 | EnvClient SDK: `_step_payload`, `_parse_result`, `_parse_state` |
|
| 118 |
+
| `inference.py` | 331 | Baseline agent: LLM prompting, JSON parsing, episode runner, structured logging |
|
| 119 |
+
| `tests/test_environment.py` | 579 | 79 unit tests covering all graders, edge cases, reward bounds |
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## Code-Level Walkthrough
|
| 124 |
+
|
| 125 |
+
### 1. Server Startup (`server/app.py`)
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
app = create_app(
|
| 129 |
+
APIDebugEnvironment, # Our environment class
|
| 130 |
+
APIDebugAction, # What the agent sends
|
| 131 |
+
APIDebugObservation, # What the environment returns
|
| 132 |
+
env_name="api_debug",
|
| 133 |
+
max_concurrent_envs=10, # Supports 10 parallel sessions
|
| 134 |
+
)
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
`create_app()` is from OpenEnv SDK. It auto-generates all endpoints: `/reset`, `/step`, `/state`, `/schema`, `/ws`, `/health`. We just hand it our classes and it wires everything up. The WebSocket endpoint at `/ws` is what the evaluator's agent connects to.
|
| 138 |
+
|
| 139 |
+
We also added a custom `/tasks` endpoint that returns all task configs, error types, and spec count. This helps external agents understand what the environment offers without reading docs.
|
| 140 |
+
|
| 141 |
+
### 2. Episode Start: `reset()` (`server/environment.py:77-158`)
|
| 142 |
+
|
| 143 |
+
When an agent calls `reset(task="medium")`:
|
| 144 |
+
|
| 145 |
+
1. **Initialize RNG**: If a seed is provided, the episode is reproducible. Otherwise random.
|
| 146 |
+
2. **Load task config**: `TASK_CONFIG["medium"]` gives `{"max_steps": 5, "error_count": 1}`
|
| 147 |
+
3. **Pick random spec**: `get_random_spec(self.rng)` picks one of 30 API templates
|
| 148 |
+
4. **Deep copy the valid example**: We never mutate the template itself
|
| 149 |
+
5. **Inject errors**: For easy/medium, picks 1 random error type. For hard, picks 2-3 distinct error types using `rng.sample()` (no duplicates)
|
| 150 |
+
6. **Build observation**: Returns the broken request, headers, API spec (field types + required fields), error count, and a message like "Debug this POST /v1/customers request. It contains 1 error(s). You have 5 steps."
|
| 151 |
+
|
| 152 |
+
The agent never sees the ground truth. It only sees the broken request and the spec.
|
| 153 |
+
|
| 154 |
+
### 3. Agent Action: `step()` (`server/environment.py:160-213`)
|
| 155 |
+
|
| 156 |
+
When the agent submits a fix attempt:
|
| 157 |
+
|
| 158 |
+
1. **Increment step counter**: Tracks which step we're on
|
| 159 |
+
2. **Guard against stepping after done**: Returns 0 reward if episode already ended
|
| 160 |
+
3. **Route to correct grader**: `_grade_easy()`, `_grade_medium()`, or `_grade_hard()` based on task
|
| 161 |
+
4. **Apply step decay**: `multiplier = max(1.0 - 0.1 * (step - 1), 0.3)` - step 1 gets full reward, step 7+ gets 30% floor
|
| 162 |
+
5. **Clamp reward**: `max(0.001, min(0.999, reward))` - open interval (0, 1) because the evaluator rejects exactly 0.0 or 1.0
|
| 163 |
+
6. **Track best reward**: `self.best_reward = max(self.best_reward, reward)` - at episode end, returns the best reward across all steps
|
| 164 |
+
7. **Check termination**: Episode ends if `raw_score >= 0.95` (near-perfect) or all steps exhausted
|
| 165 |
+
|
| 166 |
+
### 4. Pydantic Models (`models.py`)
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
class APIDebugAction(Action):
|
| 170 |
+
error_type: Optional[str] # "missing_required_field"
|
| 171 |
+
affected_fields: Optional[List[str]] # ["email"]
|
| 172 |
+
fixed_request: Optional[str] # JSON string of corrected body
|
| 173 |
+
fixed_headers: Optional[Dict[str, str]] # {"Authorization": "Bearer ..."}
|
| 174 |
+
explanation: Optional[str] # Developer-facing text
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
All fields are Optional. The agent submits only what's needed for the current task:
|
| 178 |
+
- Easy: `error_type` + `affected_fields`
|
| 179 |
+
- Medium: `fixed_request` + `fixed_headers`
|
| 180 |
+
- Hard: All five fields
|
| 181 |
+
|
| 182 |
+
```python
|
| 183 |
+
class APIDebugObservation(Observation):
|
| 184 |
+
task, api_name, http_method, endpoint, # Context
|
| 185 |
+
broken_request, broken_headers, api_spec, # The problem
|
| 186 |
+
error_count, step_number, max_steps, # Episode state
|
| 187 |
+
feedback, message, # Grader output
|
| 188 |
+
done, reward # From Observation base
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### 5. Client SDK (`client.py`)
|
| 192 |
+
|
| 193 |
+
The client implements three abstract methods from `EnvClient`:
|
| 194 |
+
|
| 195 |
+
- **`_step_payload(action)`**: Converts `APIDebugAction` to a JSON dict, only including non-None fields
|
| 196 |
+
- **`_parse_result(payload)`**: Converts the server's JSON response into a `StepResult[APIDebugObservation]`
|
| 197 |
+
- **`_parse_state(payload)`**: Converts server state to an `episode_id` + `step_count`
|
| 198 |
+
|
| 199 |
+
The `EnvClient` base class handles WebSocket connection management. It has `_ensure_connected()` which auto-calls `connect()` if the WebSocket is not yet open. This means you can call `env.reset()` directly without explicitly opening a connection first.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## The Grading System
|
| 204 |
+
|
| 205 |
+
### Easy Grader (`_grade_easy`, line 223)
|
| 206 |
+
|
| 207 |
+
Fully deterministic. Two components:
|
| 208 |
+
|
| 209 |
+
| Component | Weight | How it works |
|
| 210 |
+
|-----------|--------|-------------|
|
| 211 |
+
| Error type match | 0.6 | Exact string match against ground truth error type(s) |
|
| 212 |
+
| Affected fields | 0.4 | Jaccard similarity: `|intersection| / |union|` of predicted vs actual fields |
|
| 213 |
+
|
| 214 |
+
Jaccard similarity gives partial credit. If the ground truth is `["email", "name"]` and the agent says `["email", "phone"]`, the intersection is `{"email"}`, union is `{"email", "name", "phone"}`, Jaccard = 1/3, so fields score = 0.4 * 0.33 = 0.13.
|
| 215 |
+
|
| 216 |
+
### Medium Grader (`_grade_medium`, line 264)
|
| 217 |
+
|
| 218 |
+
Fully deterministic per-field validation:
|
| 219 |
+
|
| 220 |
+
1. Parse the agent's `fixed_request` as JSON (fail = 0.0)
|
| 221 |
+
2. Check every required field is present and non-null
|
| 222 |
+
3. Check every present field has the correct type (using `validators.py`)
|
| 223 |
+
4. Check no unknown fields are present
|
| 224 |
+
5. Score = `passed_checks / total_checks`
|
| 225 |
+
|
| 226 |
+
If the original error was `missing_auth_header`, headers are also validated (80% body + 20% headers blend).
|
| 227 |
+
|
| 228 |
+
### Hard Grader (`_grade_hard`, line 307)
|
| 229 |
+
|
| 230 |
+
70% deterministic + 30% LLM-judged:
|
| 231 |
+
|
| 232 |
+
```python
|
| 233 |
+
total = 0.7 * fix_score + 0.3 * explain_score
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
The fix portion reuses the medium grader exactly. The explanation goes through `_score_explanation()` which tries the LLM judge first, then falls back to a heuristic.
|
| 237 |
+
|
| 238 |
+
---
|
| 239 |
+
|
| 240 |
+
## Reward Shaping and Step Decay
|
| 241 |
+
|
| 242 |
+
The reward formula:
|
| 243 |
+
|
| 244 |
+
```
|
| 245 |
+
reward = raw_score * max(1.0 - 0.1 * (step - 1), 0.3)
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
| Step | Multiplier | Effect |
|
| 249 |
+
|------|-----------|--------|
|
| 250 |
+
| 1 | 1.0x | Full reward for first-try solutions |
|
| 251 |
+
| 2 | 0.9x | Slight penalty |
|
| 252 |
+
| 3 | 0.8x | |
|
| 253 |
+
| 4 | 0.7x | |
|
| 254 |
+
| 5 | 0.6x | |
|
| 255 |
+
| 6 | 0.5x | |
|
| 256 |
+
| 7+ | 0.3x | Floor - agent still gets credit for late fixes |
|
| 257 |
+
|
| 258 |
+
**Why this matters for RL training:**
|
| 259 |
+
- The agent is incentivized to solve problems quickly (higher reward on step 1)
|
| 260 |
+
- But it's not punished to zero for needing multiple attempts (0.3x floor)
|
| 261 |
+
- The multi-turn feedback loop means the agent can learn from structured feedback between steps
|
| 262 |
+
- At episode end, `best_reward` is returned - the best score across all attempts
|
| 263 |
+
|
| 264 |
+
**Reward clamping** (`environment.py:194`):
|
| 265 |
+
```python
|
| 266 |
+
reward = max(0.001, min(0.999, reward))
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
This was added after Submission #3 failed. The evaluator's score range check requires strictly open interval (0, 1). Without this, a completely wrong answer scored 0.0 and a perfect first-step answer scored 1.0, both of which the evaluator rejected.
|
| 270 |
+
|
| 271 |
+
---
|
| 272 |
+
|
| 273 |
+
## LLM-as-Judge: How It Actually Works
|
| 274 |
+
|
| 275 |
+
### The Judge Call (`_llm_judge_explanation`, line 349)
|
| 276 |
+
|
| 277 |
+
The judge receives:
|
| 278 |
+
- The API name, method, and endpoint
|
| 279 |
+
- The **actual ground truth** errors (type + affected fields) - the judge knows the right answer
|
| 280 |
+
- The agent's explanation text
|
| 281 |
+
|
| 282 |
+
The judge scores on three weighted criteria:
|
| 283 |
+
|
| 284 |
+
| Criterion | Max Score | What it evaluates |
|
| 285 |
+
|-----------|----------|-------------------|
|
| 286 |
+
| Root cause identification | 0.4 | Did the agent correctly name the error types and affected fields? |
|
| 287 |
+
| Fix guidance | 0.3 | Does the explanation describe the correct remediation? |
|
| 288 |
+
| Developer clarity | 0.3 | Is the explanation actionable and clear for a real developer? |
|
| 289 |
+
|
| 290 |
+
The judge returns a single JSON object: `{"score": 0.85}`.
|
| 291 |
+
|
| 292 |
+
**Key design decisions:**
|
| 293 |
+
- **10-second timeout** (`timeout=10`): Prevents blocking `step()` if the LLM is slow. The grader must respond quickly.
|
| 294 |
+
- **temperature=0.0**: Deterministic judge output for consistency
|
| 295 |
+
- **max_tokens=50**: The judge only needs to return a score, not a long response
|
| 296 |
+
|
| 297 |
+
### Heuristic Fallback (`_heuristic_score_explanation`, line 398)
|
| 298 |
+
|
| 299 |
+
If the LLM judge fails (network error, timeout, bad response), we fall back to:
|
| 300 |
+
|
| 301 |
+
```python
|
| 302 |
+
keyword_score = min(keyword_hits / 6.0, 1.0) # 23 debugging keywords
|
| 303 |
+
length_score = based on len(explanation) # Sweet spot: 50-500 chars
|
| 304 |
+
final = 0.5 * keyword_score + 0.5 * length_score
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
Keywords include: "because", "should", "missing", "type", "format", "expected", "invalid", "authorization", "schema", "endpoint", "method", "payload", etc.
|
| 308 |
+
|
| 309 |
+
This ensures the hard task never gets stuck. Even if the LLM judge is completely unavailable, agents still get meaningful (if less precise) scores for reasonable explanations.
|
| 310 |
+
|
| 311 |
+
### How multiple valid paths are handled
|
| 312 |
+
|
| 313 |
+
This is the question Alexa asked. The answer:
|
| 314 |
+
|
| 315 |
+
- **For the fix portion (70%)**: Grading validates against the **spec**, not against a single golden answer. Any request that has all required fields with correct types and no unknown fields gets full credit. Two completely different valid fixes both score 1.0 on the fix portion.
|
| 316 |
+
- **For the explanation (30%)**: The LLM judge evaluates whether the agent **identified the actual injected errors**, not whether it matched specific phrasing. An explanation that says "the email field was missing" and one that says "a required field (email) was not included in the request body" both get credit for root cause identification.
|
| 317 |
+
|
| 318 |
+
---
|
| 319 |
+
|
| 320 |
+
## Error Injection System
|
| 321 |
+
|
| 322 |
+
### The 10 Error Types (`server/error_injectors.py`)
|
| 323 |
+
|
| 324 |
+
Each injector is a pure function: `(request, headers, spec, rng) -> (broken_request, broken_headers, ground_truth)`
|
| 325 |
+
|
| 326 |
+
| # | Error Type | What it does | Code location |
|
| 327 |
+
|---|-----------|-------------|---------------|
|
| 328 |
+
| 1 | `missing_required_field` | Removes a random required field from the request body | Line 38 |
|
| 329 |
+
| 2 | `wrong_field_type` | Changes a field's value to the wrong type (int to string, etc.) | Line 62 |
|
| 330 |
+
| 3 | `invalid_email_format` | Corrupts an email field (e.g. `user@` or `@domain.com`) | Line 103 |
|
| 331 |
+
| 4 | `missing_auth_header` | Removes the Authorization header | Line 131 |
|
| 332 |
+
| 5 | `extra_unknown_field` | Adds a field not in the spec (`debug_mode: true`, `_private`, etc.) | Line 159 |
|
| 333 |
+
| 6 | `null_value_in_required` | Sets a required field to `null` | Line 185 |
|
| 334 |
+
| 7 | `wrong_http_method` | Records that the wrong HTTP method was shown | Line 209 |
|
| 335 |
+
| 8 | `malformed_json_value` | Replaces a field's value with broken JSON fragments (`{broken`, `NaN`) | Line 235 |
|
| 336 |
+
| 9 | `invalid_enum_value` | Uses a value not in the allowed enum list | Line 270 |
|
| 337 |
+
| 10 | `datetime_format_error` | Replaces ISO 8601 datetime with wrong format (`04/01/2026`) | Line 297 |
|
| 338 |
+
|
| 339 |
+
### Multi-Error Injection (`inject_multiple_errors`, line 370)
|
| 340 |
+
|
| 341 |
+
For hard tasks, we inject 2-3 errors simultaneously:
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
chosen_types = rng.sample(ERROR_TYPES, min(count, len(ERROR_TYPES)))
|
| 345 |
+
for err_type in chosen_types:
|
| 346 |
+
broken_req, broken_hdrs, gt = injector(broken_req, broken_hdrs, spec, rng)
|
| 347 |
+
all_truths.append(gt)
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
`rng.sample()` picks without replacement, so the agent never sees two of the same error type in one episode. Errors are applied sequentially to the same request, so they can compound (e.g., a field gets removed AND a type gets changed on another field).
|
| 351 |
+
|
| 352 |
+
### Fallback Handling
|
| 353 |
+
|
| 354 |
+
Some injectors need specific field types that might not exist in the chosen spec. For example, `inject_invalid_email_format` needs an email field. If the spec has no email fields, it falls back to `inject_missing_required_field` instead. Same for `inject_invalid_enum_value` (falls back to `inject_wrong_field_type`) and `inject_datetime_format_error`.
|
| 355 |
+
|
| 356 |
+
---
|
| 357 |
+
|
| 358 |
+
## Infinite Unique Scenarios
|
| 359 |
+
|
| 360 |
+
The combinatorial space:
|
| 361 |
+
|
| 362 |
+
- **30 API specs** across 6 domains
|
| 363 |
+
- **10 error types** (each with random field selection within the spec)
|
| 364 |
+
- **Multiple bad values per error type** (e.g., 5 bad email formats, 5 malformed JSON fragments, 5 bad datetime formats)
|
| 365 |
+
- **Random field selection** within each spec (which required field gets removed, which gets type-changed)
|
| 366 |
+
- **Hard mode**: 2-3 errors from different types, applied sequentially
|
| 367 |
+
|
| 368 |
+
Conservative estimate: 30 specs x 10 error types x 3 field choices x 5 bad values = **4,500+ unique easy/medium scenarios**. Hard mode with combinations: significantly more.
|
| 369 |
+
|
| 370 |
+
**Why this matters for RL**: An agent cannot memorize answers after one training run. It must learn a generalizable debugging strategy. If your environment has fixed scenarios, the agent overfits to those specific cases. With procedural generation, the agent has to learn the underlying skill.
|
| 371 |
+
|
| 372 |
+
---
|
| 373 |
+
|
| 374 |
+
## The Bugs That Almost Killed It
|
| 375 |
+
|
| 376 |
+
### Submission #2: "inference.py raised an unhandled exception"
|
| 377 |
+
|
| 378 |
+
**What happened**: The evaluator runs `inference.py` in their own runtime. Our script had two unprotected lines in `main()`:
|
| 379 |
+
|
| 380 |
+
1. `OpenAI(api_key=None)` - When `HF_TOKEN` is not set in the evaluator's environment, `os.getenv("HF_TOKEN")` returns `None`. The OpenAI client constructor crashes immediately with `OpenAIError: The api_key client option must be set` when given `None`. This was confirmed by running it locally.
|
| 381 |
+
|
| 382 |
+
2. `await APIDebugEnv.from_docker_image(IMAGE_NAME)` - If Docker isn't available or the image doesn't exist, this throws an unhandled exception.
|
| 383 |
+
|
| 384 |
+
**Fix**:
|
| 385 |
+
- Added fallback chain: `API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")`
|
| 386 |
+
- Wrapped `from_docker_image()` in try/except with fallback to direct URL connection
|
| 387 |
+
- Added try/except around each episode so one failure doesn't crash the whole run
|
| 388 |
+
|
| 389 |
+
**Misdiagnosis**: Another AI agent suggested the bug was a missing `async with` context manager for the WebSocket connection. This was wrong. We verified by reading the OpenEnv SDK source: `EnvClient._ensure_connected()` auto-calls `connect()` when the WebSocket is None. No explicit context manager needed.
|
| 390 |
+
|
| 391 |
+
### Submission #3: "One or more task scores are out of range"
|
| 392 |
+
|
| 393 |
+
**What happened**: The evaluator requires scores in the strictly open interval (0, 1). Our graders returned:
|
| 394 |
+
- Exactly `0.0` for completely wrong answers (empty action, invalid JSON, etc.)
|
| 395 |
+
- Exactly `1.0` for perfect first-step answers (1.0 raw score x 1.0 step multiplier)
|
| 396 |
+
|
| 397 |
+
Both boundary values were rejected.
|
| 398 |
+
|
| 399 |
+
**Fix**: Added reward clamping in `environment.py` step() method:
|
| 400 |
+
```python
|
| 401 |
+
reward = max(0.001, min(0.999, reward))
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
Also added score clamping in `inference.py`:
|
| 405 |
+
```python
|
| 406 |
+
score = min(max(score, 0.001), 0.999)
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
Updated 7 unit tests that previously asserted `== 0.0` or `== 1.0` to use `<= 0.01` and `>= 0.99`.
|
| 410 |
+
|
| 411 |
+
---
|
| 412 |
+
|
| 413 |
+
## Unique Features and Benefits
|
| 414 |
+
|
| 415 |
+
### 1. Procedural Generation (Not Fixed Datasets)
|
| 416 |
+
Most RL environments have a fixed set of scenarios. Our environment generates new scenarios every episode. 30 specs x 10 error types x random field selection = thousands of unique episodes. The agent genuinely learns a skill, not a lookup table.
|
| 417 |
+
|
| 418 |
+
### 2. Multi-Turn with Structured Feedback
|
| 419 |
+
The agent doesn't just get "right" or "wrong". It receives structured feedback like:
|
| 420 |
+
```
|
| 421 |
+
Validation: 3/5 checks passed.
|
| 422 |
+
email: PRESENT
|
| 423 |
+
name: MISSING
|
| 424 |
+
email type: VALID (email)
|
| 425 |
+
amount type: INVALID (expected integer)
|
| 426 |
+
debug_mode: UNKNOWN FIELD (not in spec)
|
| 427 |
+
```
|
| 428 |
+
This feedback loop lets the agent iterate and improve within an episode.
|
| 429 |
+
|
| 430 |
+
### 3. Progressive Difficulty
|
| 431 |
+
Three clearly separated difficulty levels:
|
| 432 |
+
- Easy: Classification only (identify the error)
|
| 433 |
+
- Medium: Fixing (produce correct JSON)
|
| 434 |
+
- Hard: Fixing + reasoning (explain like a developer)
|
| 435 |
+
|
| 436 |
+
Each level builds on the previous, and the hard task genuinely tests whether the agent can reason about what went wrong, not just pattern-match.
|
| 437 |
+
|
| 438 |
+
### 4. Deterministic + LLM Hybrid Grading
|
| 439 |
+
Easy and medium tasks are 100% deterministic. No variance, no API calls, reproducible results. Hard tasks are 70% deterministic (the fix portion) + 30% LLM-judged (explanation quality). This hybrid approach means:
|
| 440 |
+
- Most of the score is stable and reproducible
|
| 441 |
+
- The subjective part (explanation) uses an actual LLM to judge quality
|
| 442 |
+
- If the LLM judge fails, a keyword heuristic ensures the system never blocks
|
| 443 |
+
|
| 444 |
+
### 5. Reward Shaping for Efficient Problem-Solving
|
| 445 |
+
Step decay encourages solving on the first try. The 0.3x floor means late fixes are still rewarded. `best_reward` tracking means the agent's best attempt is what counts.
|
| 446 |
+
|
| 447 |
+
### 6. Real-World Domain Relevance
|
| 448 |
+
API debugging is a real developer pain point. Calendar Gym research showed malformed tool arguments caused >50% of agent failures. This environment directly trains agents to handle that failure mode.
|
| 449 |
+
|
| 450 |
+
### 7. Concurrent Session Support
|
| 451 |
+
`SUPPORTS_CONCURRENT_SESSIONS = True` + `max_concurrent_envs=10` means the environment can train 10 agents in parallel on the same deployment. Each session has its own state.
|
| 452 |
+
|
| 453 |
+
### 8. 79 Unit Tests (579 lines)
|
| 454 |
+
Comprehensive test coverage:
|
| 455 |
+
- All three graders (easy, medium, hard)
|
| 456 |
+
- Step decay multiplier values at each step
|
| 457 |
+
- Reward bounds (never < 0.001, never > 0.999)
|
| 458 |
+
- Episode termination conditions
|
| 459 |
+
- Seeded reproducibility
|
| 460 |
+
- Edge cases (empty actions, invalid JSON, non-dict JSON, extra fields)
|
| 461 |
+
- Best reward tracking
|
| 462 |
+
- Heuristic explanation scorer
|
| 463 |
+
|
| 464 |
+
---
|
| 465 |
+
|
| 466 |
+
## Baseline Results
|
| 467 |
+
|
| 468 |
+
Scores from running `inference.py` against the live HF Space (3 episodes per task):
|
| 469 |
+
|
| 470 |
+
| Task | Episodes | Qwen2.5-72B-Instruct | gpt-4o-mini |
|
| 471 |
+
|------|----------|----------------------|-------------|
|
| 472 |
+
| easy | 3 | 0.999 | 0.667 |
|
| 473 |
+
| medium | 3 | 0.999 | 0.999 |
|
| 474 |
+
| hard | 3 | 0.780 | 0.760 |
|
| 475 |
+
| **overall** | **9** | **0.926** | **0.809** |
|
| 476 |
+
|
| 477 |
+
**Key takeaway**: Larger models perform better on hard tasks (explanation quality + multi-error fixing), showing meaningful difficulty progression. The environment is not trivially solvable but also not impossibly hard.
|
| 478 |
+
|
| 479 |
+
---
|
| 480 |
+
|
| 481 |
+
## Implemented Advancements (Round 2)
|
| 482 |
+
|
| 483 |
+
All five advancement items from the original roadmap have been implemented:
|
| 484 |
+
|
| 485 |
+
### 1. GRPO Training Pipeline (IMPLEMENTED)
|
| 486 |
+
**What**: Full GRPO training loop in `training/train.py` that trains Qwen 0.5B on the live environment using TRL's GRPOTrainer with vLLM colocate mode.
|
| 487 |
+
**How it works**: The model generates JSON debugging attempts, the environment grades them via its deterministic graders, and GRPO updates the policy to prefer higher-scoring responses. The rollout function connects to the live HF Space via WebSocket, runs multi-turn episodes, and returns prompt_ids, completion_ids, logprobs, and env_reward.
|
| 488 |
+
**Key config**: `max_completion_length=128`, `gradient_accumulation_steps=16`, `vllm_gpu_memory_utilization=0.3`. Runs on free Colab T4 GPU.
|
| 489 |
+
|
| 490 |
+
### 2. Expanded API Specs and Domains (IMPLEMENTED)
|
| 491 |
+
**What**: Expanded from 30 specs / 6 domains to 45 specs / 9 domains.
|
| 492 |
+
**New domains**: Analytics/Monitoring (dashboards, metrics, alerts), DevOps/Infrastructure (deployments, DNS, load balancers), AI/ML APIs (inference, fine-tuning, embeddings).
|
| 493 |
+
**Impact**: 50% more scenario diversity for training generalization. Each domain uses realistic field types, headers, and validation rules.
|
| 494 |
+
|
| 495 |
+
### 3. Chained Multi-Step Error Scenarios (IMPLEMENTED)
|
| 496 |
+
**What**: 5 chain patterns where fixing one error reveals the next, simulating real-world API debugging.
|
| 497 |
+
**Chain patterns**:
|
| 498 |
+
- **auth_gate**: Missing/expired auth blocks body error visibility
|
| 499 |
+
- **content_type_gate**: Wrong content type masks type/value errors
|
| 500 |
+
- **method_chain**: Wrong HTTP method hides field-level errors
|
| 501 |
+
- **rate_limit_chain**: Rate limit headers combined with auth/field errors
|
| 502 |
+
- **redirect_chain**: Redirect loops combined with type/format errors
|
| 503 |
+
|
| 504 |
+
**How it works**: `inject_chained_errors()` picks a random chain pattern, applies the gate error first, then injects body errors from the pattern's pool. The hard task uses chained errors 50% of the time.
|
| 505 |
+
|
| 506 |
+
### 4. Response Validation Task (IMPLEMENTED)
|
| 507 |
+
**What**: 6th task where the agent receives a request-response pair and identifies response issues.
|
| 508 |
+
**Issue types**: wrong_status_code, missing_response_field, wrong_response_type, extra_response_field (data leak detection), inconsistent_error_format.
|
| 509 |
+
**Grading**: 0.5 x Jaccard(issue_types) + 0.3 x Jaccard(affected_fields) + 0.2 x status_code_match.
|
| 510 |
+
**8 response templates** covering: Create, List, Update, Delete, Batch, Authentication, File Upload, Search operations.
|
| 511 |
+
|
| 512 |
+
### 5. Curriculum Learning (IMPLEMENTED)
|
| 513 |
+
**What**: Both training-side and environment-side curriculum learning.
|
| 514 |
+
**Training side** (`training/train.py`): 6-level curriculum that auto-promotes through easy -> classify -> medium -> headers -> response -> hard based on rolling average reward exceeding thresholds (0.7, 0.6, 0.6, 0.5, 0.5).
|
| 515 |
+
**Environment side** (`task="auto"`): The environment itself tracks per-session reward history and auto-promotes, so any client can benefit from adaptive difficulty without implementing curriculum logic.
|
| 516 |
+
|
| 517 |
+
### Scope for Further Advancement
|
| 518 |
+
|
| 519 |
+
- **GraphQL and gRPC protocols**: Add non-REST API specs for cross-protocol debugging
|
| 520 |
+
- **OAuth flow simulation**: Multi-step auth flows with token refresh, scope validation
|
| 521 |
+
- **Response body fixing**: Agent generates the correct response body, not just identifies issues
|
| 522 |
+
- **Multi-agent debugging**: Two agents collaborate on different aspects (headers vs body)
|
| 523 |
+
- **Real-world API replay**: Import failed requests from production logs for training data
|
| 524 |
+
|
| 525 |
+
---
|
| 526 |
+
|
| 527 |
+
## Practical Applications: Where the Trained LLM Sits
|
| 528 |
+
|
| 529 |
+
An LLM trained on this environment learns a specific skill: given a broken API request and a spec, diagnose the error, fix the request, and explain what went wrong. Here's how that skill translates to real-world developer tooling:
|
| 530 |
+
|
| 531 |
+
### 1. IDE Integration (Copilot-Style API Debugger)
|
| 532 |
+
|
| 533 |
+
```
|
| 534 |
+
Developer writes code Trained LLM API Server
|
| 535 |
+
| | |
|
| 536 |
+
|--- makes API call ------->| |
|
| 537 |
+
| |--- forwards request -->|
|
| 538 |
+
| |<-- 400/422 error ------|
|
| 539 |
+
| | |
|
| 540 |
+
| [LLM analyzes request |
|
| 541 |
+
| vs API spec, identifies |
|
| 542 |
+
| error, generates fix] |
|
| 543 |
+
| | |
|
| 544 |
+
|<-- "Your 'amount' field | |
|
| 545 |
+
| is a string but the | |
|
| 546 |
+
| API expects integer. | |
|
| 547 |
+
| Here's the fix: ..." | |
|
| 548 |
+
```
|
| 549 |
+
|
| 550 |
+
**Where it sits**: As a VS Code / JetBrains extension plugin. Intercepts failed API calls in the developer's HTTP client (like Postman, Thunder Client, or `fetch`/`requests` in code), compares the request against known API specs, and suggests fixes inline.
|
| 551 |
+
|
| 552 |
+
**Developer experience**: Developer hits "Send" on an API request, gets a 400 error. Instead of reading the error response and manually debugging, the extension pops up: "Missing required field `email`. The spec requires it as type `email`. Here's the corrected request." One click to apply the fix.
|
| 553 |
+
|
| 554 |
+
### 2. API Gateway Middleware (Pre-Request Validation Layer)
|
| 555 |
+
|
| 556 |
+
```
|
| 557 |
+
Client App API Gateway + Trained LLM Backend API
|
| 558 |
+
| | |
|
| 559 |
+
|--- POST /v1/users -->| |
|
| 560 |
+
| {bad request} | |
|
| 561 |
+
| |-- [LLM validates against |
|
| 562 |
+
| | API schema before |
|
| 563 |
+
| | forwarding] |
|
| 564 |
+
| | |
|
| 565 |
+
|<-- 422 + fix hint ---| |
|
| 566 |
+
| "Field 'email' | |
|
| 567 |
+
| is malformed. | |
|
| 568 |
+
| Expected format: | |
|
| 569 |
+
| user@domain.com" | |
|
| 570 |
+
```
|
| 571 |
+
|
| 572 |
+
**Where it sits**: As a middleware layer in an API gateway (Kong, AWS API Gateway, Nginx). Before the request reaches the backend, the LLM validates it against the spec and returns human-readable fix suggestions instead of cryptic validation errors.
|
| 573 |
+
|
| 574 |
+
**Developer experience**: Instead of getting `{"error": "validation_error", "detail": [{"loc": ["body", "email"], "msg": "value is not a valid email address"}]}`, the developer gets: "The `email` field contains `user@` which is not a valid email. A valid email must have a domain (e.g., `user@example.com`). The `amount` field is a string but should be an integer. Send `2500` instead of `\"2500\"`."
|
| 575 |
+
|
| 576 |
+
### 3. CI/CD Pipeline Integration (Contract Testing)
|
| 577 |
+
|
| 578 |
+
```
|
| 579 |
+
Developer pushes code
|
| 580 |
+
|
|
| 581 |
+
v
|
| 582 |
+
CI Pipeline runs
|
| 583 |
+
|
|
| 584 |
+
v
|
| 585 |
+
API Contract Tests (using trained LLM)
|
| 586 |
+
|
|
| 587 |
+
|--- Replays recent API calls against updated spec
|
| 588 |
+
|--- LLM identifies breaking changes
|
| 589 |
+
|--- Generates migration guide
|
| 590 |
+
|
|
| 591 |
+
v
|
| 592 |
+
PR Comment: "3 API calls in your test suite
|
| 593 |
+
will break with the new spec. Here are the fixes..."
|
| 594 |
+
```
|
| 595 |
+
|
| 596 |
+
**Where it sits**: As a CI step that runs after API spec changes. The trained LLM compares existing API calls (from test suites, logs, or recorded traffic) against the updated spec and flags what will break.
|
| 597 |
+
|
| 598 |
+
**Developer experience**: Developer updates an API schema (adds a required field). The CI pipeline catches that 15 existing test calls are now invalid and generates the exact fix for each one.
|
| 599 |
+
|
| 600 |
+
### 4. Production Error Analysis (Log-Based Debugging)
|
| 601 |
+
|
| 602 |
+
```
|
| 603 |
+
Production System Error Aggregator Trained LLM
|
| 604 |
+
| | |
|
| 605 |
+
|--- 400/422 errors ------>| |
|
| 606 |
+
|--- request logs -------->| |
|
| 607 |
+
| |--- batch analysis ---->|
|
| 608 |
+
| | |
|
| 609 |
+
| |<-- "Top 3 error |
|
| 610 |
+
| | patterns: |
|
| 611 |
+
| | 1. 40% of failures |
|
| 612 |
+
| | are datetime |
|
| 613 |
+
| | format errors |
|
| 614 |
+
| | in /v1/events |
|
| 615 |
+
| | 2. ..." |
|
| 616 |
+
```
|
| 617 |
+
|
| 618 |
+
**Where it sits**: Connected to error aggregation tools (Sentry, Datadog, PagerDuty). Analyzes batches of 4xx errors, groups them by root cause, and suggests API spec improvements or client-side fixes.
|
| 619 |
+
|
| 620 |
+
**Developer experience**: Oncall engineer gets a Slack alert: "87 new 422 errors on POST /v1/subscriptions in the last hour. Root cause: mobile clients sending `start_date` as `MM/DD/YYYY` instead of ISO 8601. Suggested fix: add format hint to error response, or accept both formats in the endpoint."
|
| 621 |
+
|
| 622 |
+
### 5. SDK/Documentation Generator
|
| 623 |
+
|
| 624 |
+
```
|
| 625 |
+
API Spec (OpenAPI/Swagger)
|
| 626 |
+
|
|
| 627 |
+
v
|
| 628 |
+
Trained LLM analyzes common error patterns
|
| 629 |
+
|
|
| 630 |
+
v
|
| 631 |
+
Auto-generated:
|
| 632 |
+
- "Common mistakes" section per endpoint
|
| 633 |
+
- Request validation examples
|
| 634 |
+
- Error handling code snippets
|
| 635 |
+
- Migration guides between API versions
|
| 636 |
+
```
|
| 637 |
+
|
| 638 |
+
**Where it sits**: As part of the API documentation pipeline. The LLM, having been trained on thousands of debugging scenarios, knows which errors are most common per endpoint type and generates preventive documentation.
|
| 639 |
+
|
| 640 |
+
### Key Insight
|
| 641 |
+
|
| 642 |
+
The environment trains a skill that's useful at **every layer of the API stack** - from the developer's IDE to the API gateway to production monitoring. The core capability (understand spec, diagnose broken request, suggest fix) is the same; only the integration point changes. A model trained on this environment could power any of these tools, because the underlying reasoning is identical: compare request against spec, find the mismatch, produce the fix.
|
client.py
CHANGED
|
@@ -28,6 +28,8 @@ class APIDebugEnv(EnvClient[APIDebugAction, APIDebugObservation, State]):
|
|
| 28 |
payload = {}
|
| 29 |
if action.error_type is not None:
|
| 30 |
payload["error_type"] = action.error_type
|
|
|
|
|
|
|
| 31 |
if action.affected_fields is not None:
|
| 32 |
payload["affected_fields"] = action.affected_fields
|
| 33 |
if action.fixed_request is not None:
|
|
@@ -36,6 +38,10 @@ class APIDebugEnv(EnvClient[APIDebugAction, APIDebugObservation, State]):
|
|
| 36 |
payload["fixed_headers"] = action.fixed_headers
|
| 37 |
if action.explanation is not None:
|
| 38 |
payload["explanation"] = action.explanation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
return payload
|
| 40 |
|
| 41 |
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[APIDebugObservation]:
|
|
@@ -60,6 +66,8 @@ class APIDebugEnv(EnvClient[APIDebugAction, APIDebugObservation, State]):
|
|
| 60 |
error_count=obs_data.get("error_count", 1),
|
| 61 |
step_number=obs_data.get("step_number", 0),
|
| 62 |
max_steps=obs_data.get("max_steps", 3),
|
|
|
|
|
|
|
| 63 |
feedback=obs_data.get("feedback", ""),
|
| 64 |
message=obs_data.get("message", ""),
|
| 65 |
done=payload.get("done", False),
|
|
|
|
| 28 |
payload = {}
|
| 29 |
if action.error_type is not None:
|
| 30 |
payload["error_type"] = action.error_type
|
| 31 |
+
if action.error_types is not None:
|
| 32 |
+
payload["error_types"] = action.error_types
|
| 33 |
if action.affected_fields is not None:
|
| 34 |
payload["affected_fields"] = action.affected_fields
|
| 35 |
if action.fixed_request is not None:
|
|
|
|
| 38 |
payload["fixed_headers"] = action.fixed_headers
|
| 39 |
if action.explanation is not None:
|
| 40 |
payload["explanation"] = action.explanation
|
| 41 |
+
if action.response_issues is not None:
|
| 42 |
+
payload["response_issues"] = action.response_issues
|
| 43 |
+
if action.expected_status_code is not None:
|
| 44 |
+
payload["expected_status_code"] = action.expected_status_code
|
| 45 |
return payload
|
| 46 |
|
| 47 |
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[APIDebugObservation]:
|
|
|
|
| 66 |
error_count=obs_data.get("error_count", 1),
|
| 67 |
step_number=obs_data.get("step_number", 0),
|
| 68 |
max_steps=obs_data.get("max_steps", 3),
|
| 69 |
+
response_body=obs_data.get("response_body", ""),
|
| 70 |
+
response_status_code=obs_data.get("response_status_code", 0),
|
| 71 |
feedback=obs_data.get("feedback", ""),
|
| 72 |
message=obs_data.get("message", ""),
|
| 73 |
done=payload.get("done", False),
|
inference.py
CHANGED
|
@@ -34,9 +34,9 @@ ENV_URL = os.getenv("ENV_URL") or "https://avichauhan-api-debug-env.hf.space"
|
|
| 34 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 35 |
|
| 36 |
# Task configuration
|
| 37 |
-
TASKS = ["easy", "medium", "hard"]
|
| 38 |
EPISODES_PER_TASK = 3
|
| 39 |
-
MAX_STEPS = {"easy": 3, "medium": 5, "hard": 7}
|
| 40 |
BENCHMARK_NAME = "api_debug"
|
| 41 |
|
| 42 |
|
|
@@ -87,7 +87,21 @@ SYSTEM_PROMPTS = {
|
|
| 87 |
missing_required_field, wrong_field_type, invalid_email_format,
|
| 88 |
missing_auth_header, extra_unknown_field, null_value_in_required,
|
| 89 |
wrong_http_method, malformed_json_value, invalid_enum_value,
|
| 90 |
-
datetime_format_error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
""").strip(),
|
| 92 |
|
| 93 |
"medium": textwrap.dedent("""
|
|
@@ -100,6 +114,33 @@ SYSTEM_PROMPTS = {
|
|
| 100 |
The fixed_request must be a valid JSON string. Include all required fields with correct types.
|
| 101 |
""").strip(),
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"hard": textwrap.dedent("""
|
| 104 |
You are an API debugging expert. You receive a broken API request with multiple errors.
|
| 105 |
Your job: diagnose the errors, fix the request, and explain the fix for a developer.
|
|
@@ -126,10 +167,14 @@ def build_user_prompt(obs, step_num: int) -> str:
|
|
| 126 |
f"API: {obs.http_method} {obs.endpoint} ({obs.api_name})",
|
| 127 |
f"Error count: {obs.error_count}",
|
| 128 |
f"Step {step_num}/{obs.max_steps}",
|
| 129 |
-
f"\
|
| 130 |
f"\nRequest headers: {json.dumps(obs.broken_headers)}",
|
| 131 |
f"\nAPI Specification:\n{obs.api_spec}",
|
| 132 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
if obs.feedback:
|
| 134 |
parts.append(f"\nFeedback from previous attempt:\n{obs.feedback}")
|
| 135 |
return "\n".join(parts)
|
|
@@ -182,10 +227,13 @@ def build_action(data: dict) -> APIDebugAction:
|
|
| 182 |
|
| 183 |
return APIDebugAction(
|
| 184 |
error_type=data.get("error_type"),
|
|
|
|
| 185 |
affected_fields=data.get("affected_fields"),
|
| 186 |
fixed_request=fixed_req,
|
| 187 |
fixed_headers=data.get("fixed_headers"),
|
| 188 |
explanation=data.get("explanation"),
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
|
|
@@ -264,9 +312,18 @@ def _action_summary(action: APIDebugAction, task: str) -> str:
|
|
| 264 |
"""Short summary of the action for logging."""
|
| 265 |
if task == "easy":
|
| 266 |
return f"diagnose:{action.error_type or 'none'}"
|
|
|
|
|
|
|
|
|
|
| 267 |
elif task == "medium":
|
| 268 |
fix_len = len(action.fixed_request or "")
|
| 269 |
return f"fix:len={fix_len}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
else:
|
| 271 |
fix_len = len(action.fixed_request or "")
|
| 272 |
exp_len = len(action.explanation or "")
|
|
|
|
| 34 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 35 |
|
| 36 |
# Task configuration
|
| 37 |
+
TASKS = ["easy", "classify", "medium", "headers", "response", "hard"]
|
| 38 |
EPISODES_PER_TASK = 3
|
| 39 |
+
MAX_STEPS = {"easy": 3, "classify": 4, "medium": 5, "headers": 4, "response": 4, "hard": 7}
|
| 40 |
BENCHMARK_NAME = "api_debug"
|
| 41 |
|
| 42 |
|
|
|
|
| 87 |
missing_required_field, wrong_field_type, invalid_email_format,
|
| 88 |
missing_auth_header, extra_unknown_field, null_value_in_required,
|
| 89 |
wrong_http_method, malformed_json_value, invalid_enum_value,
|
| 90 |
+
datetime_format_error, wrong_content_type, expired_auth_token
|
| 91 |
+
""").strip(),
|
| 92 |
+
|
| 93 |
+
"classify": textwrap.dedent("""
|
| 94 |
+
You are an API debugging expert. You receive a broken API request with MULTIPLE errors.
|
| 95 |
+
Your job: identify ALL error types and ALL affected fields.
|
| 96 |
+
|
| 97 |
+
Respond with ONLY a JSON object in this format:
|
| 98 |
+
{"error_types": ["type1", "type2"], "affected_fields": ["field1", "field2"]}
|
| 99 |
+
|
| 100 |
+
Valid error types:
|
| 101 |
+
missing_required_field, wrong_field_type, invalid_email_format,
|
| 102 |
+
missing_auth_header, extra_unknown_field, null_value_in_required,
|
| 103 |
+
wrong_http_method, malformed_json_value, invalid_enum_value,
|
| 104 |
+
datetime_format_error, wrong_content_type, expired_auth_token
|
| 105 |
""").strip(),
|
| 106 |
|
| 107 |
"medium": textwrap.dedent("""
|
|
|
|
| 114 |
The fixed_request must be a valid JSON string. Include all required fields with correct types.
|
| 115 |
""").strip(),
|
| 116 |
|
| 117 |
+
"headers": textwrap.dedent("""
|
| 118 |
+
You are an API debugging expert. You receive a broken API request with header-level errors.
|
| 119 |
+
Your job: identify the header error type and provide the corrected headers.
|
| 120 |
+
|
| 121 |
+
Respond with ONLY a JSON object in this format:
|
| 122 |
+
{"error_type": "<type>", "fixed_headers": {"Header-Name": "correct-value"}}
|
| 123 |
+
|
| 124 |
+
Valid header error types:
|
| 125 |
+
missing_auth_header, wrong_content_type, expired_auth_token
|
| 126 |
+
|
| 127 |
+
Common headers: Authorization (Bearer token), Content-Type (application/json)
|
| 128 |
+
""").strip(),
|
| 129 |
+
|
| 130 |
+
"response": textwrap.dedent("""
|
| 131 |
+
You are an API response validation expert. You receive an API request, its specification,
|
| 132 |
+
and the server's response. Your job: identify issues in the response.
|
| 133 |
+
|
| 134 |
+
Respond with ONLY a JSON object in this format:
|
| 135 |
+
{"response_issues": ["issue_type1", "issue_type2"], "affected_fields": ["field1"], "expected_status_code": 200}
|
| 136 |
+
|
| 137 |
+
Valid response issue types:
|
| 138 |
+
wrong_status_code, missing_response_field, wrong_response_type,
|
| 139 |
+
extra_response_field, inconsistent_error_format
|
| 140 |
+
|
| 141 |
+
Only include expected_status_code if you detect a wrong_status_code issue.
|
| 142 |
+
""").strip(),
|
| 143 |
+
|
| 144 |
"hard": textwrap.dedent("""
|
| 145 |
You are an API debugging expert. You receive a broken API request with multiple errors.
|
| 146 |
Your job: diagnose the errors, fix the request, and explain the fix for a developer.
|
|
|
|
| 167 |
f"API: {obs.http_method} {obs.endpoint} ({obs.api_name})",
|
| 168 |
f"Error count: {obs.error_count}",
|
| 169 |
f"Step {step_num}/{obs.max_steps}",
|
| 170 |
+
f"\nRequest body:\n{obs.broken_request}",
|
| 171 |
f"\nRequest headers: {json.dumps(obs.broken_headers)}",
|
| 172 |
f"\nAPI Specification:\n{obs.api_spec}",
|
| 173 |
]
|
| 174 |
+
# Include response data for response validation task
|
| 175 |
+
if obs.response_body:
|
| 176 |
+
parts.append(f"\nResponse status code: {obs.response_status_code}")
|
| 177 |
+
parts.append(f"\nResponse body:\n{obs.response_body}")
|
| 178 |
if obs.feedback:
|
| 179 |
parts.append(f"\nFeedback from previous attempt:\n{obs.feedback}")
|
| 180 |
return "\n".join(parts)
|
|
|
|
| 227 |
|
| 228 |
return APIDebugAction(
|
| 229 |
error_type=data.get("error_type"),
|
| 230 |
+
error_types=data.get("error_types"),
|
| 231 |
affected_fields=data.get("affected_fields"),
|
| 232 |
fixed_request=fixed_req,
|
| 233 |
fixed_headers=data.get("fixed_headers"),
|
| 234 |
explanation=data.get("explanation"),
|
| 235 |
+
response_issues=data.get("response_issues"),
|
| 236 |
+
expected_status_code=data.get("expected_status_code"),
|
| 237 |
)
|
| 238 |
|
| 239 |
|
|
|
|
| 312 |
"""Short summary of the action for logging."""
|
| 313 |
if task == "easy":
|
| 314 |
return f"diagnose:{action.error_type or 'none'}"
|
| 315 |
+
elif task == "classify":
|
| 316 |
+
types = action.error_types or [action.error_type or "none"]
|
| 317 |
+
return f"classify:{','.join(str(t) for t in types)}"
|
| 318 |
elif task == "medium":
|
| 319 |
fix_len = len(action.fixed_request or "")
|
| 320 |
return f"fix:len={fix_len}"
|
| 321 |
+
elif task == "headers":
|
| 322 |
+
hdr_count = len(action.fixed_headers or {})
|
| 323 |
+
return f"headers:{action.error_type or 'none'}+fix:{hdr_count}"
|
| 324 |
+
elif task == "response":
|
| 325 |
+
issues = action.response_issues or []
|
| 326 |
+
return f"response:{','.join(issues) or 'none'}+status:{action.expected_status_code or 'none'}"
|
| 327 |
else:
|
| 328 |
fix_len = len(action.fixed_request or "")
|
| 329 |
exp_len = len(action.explanation or "")
|
models.py
CHANGED
|
@@ -22,6 +22,10 @@ class APIDebugAction(Action):
|
|
| 22 |
default=None,
|
| 23 |
description="Diagnosed error type, e.g. 'missing_required_field'"
|
| 24 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
affected_fields: Optional[List[str]] = Field(
|
| 26 |
default=None,
|
| 27 |
description="List of field names affected by the error"
|
|
@@ -38,6 +42,14 @@ class APIDebugAction(Action):
|
|
| 38 |
default=None,
|
| 39 |
description="Developer-facing explanation of the fix (hard task only)"
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class APIDebugObservation(Observation):
|
|
@@ -48,7 +60,7 @@ class APIDebugObservation(Observation):
|
|
| 48 |
|
| 49 |
task: str = Field(
|
| 50 |
default="easy",
|
| 51 |
-
description="Current task
|
| 52 |
)
|
| 53 |
api_name: str = Field(
|
| 54 |
default="",
|
|
@@ -86,6 +98,14 @@ class APIDebugObservation(Observation):
|
|
| 86 |
default=3,
|
| 87 |
description="Maximum steps allowed for this task"
|
| 88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
feedback: str = Field(
|
| 90 |
default="",
|
| 91 |
description="Structured validation feedback from the last action"
|
|
|
|
| 22 |
default=None,
|
| 23 |
description="Diagnosed error type, e.g. 'missing_required_field'"
|
| 24 |
)
|
| 25 |
+
error_types: Optional[List[str]] = Field(
|
| 26 |
+
default=None,
|
| 27 |
+
description="All diagnosed error types (for classify task with multiple errors)"
|
| 28 |
+
)
|
| 29 |
affected_fields: Optional[List[str]] = Field(
|
| 30 |
default=None,
|
| 31 |
description="List of field names affected by the error"
|
|
|
|
| 42 |
default=None,
|
| 43 |
description="Developer-facing explanation of the fix (hard task only)"
|
| 44 |
)
|
| 45 |
+
response_issues: Optional[List[str]] = Field(
|
| 46 |
+
default=None,
|
| 47 |
+
description="Issues found in the API response (response task only)"
|
| 48 |
+
)
|
| 49 |
+
expected_status_code: Optional[int] = Field(
|
| 50 |
+
default=None,
|
| 51 |
+
description="Correct HTTP status code for the response (response task only)"
|
| 52 |
+
)
|
| 53 |
|
| 54 |
|
| 55 |
class APIDebugObservation(Observation):
|
|
|
|
| 60 |
|
| 61 |
task: str = Field(
|
| 62 |
default="easy",
|
| 63 |
+
description="Current task: easy, classify, medium, headers, hard, response"
|
| 64 |
)
|
| 65 |
api_name: str = Field(
|
| 66 |
default="",
|
|
|
|
| 98 |
default=3,
|
| 99 |
description="Maximum steps allowed for this task"
|
| 100 |
)
|
| 101 |
+
response_body: str = Field(
|
| 102 |
+
default="",
|
| 103 |
+
description="JSON string of the API response to validate (response task only)"
|
| 104 |
+
)
|
| 105 |
+
response_status_code: int = Field(
|
| 106 |
+
default=0,
|
| 107 |
+
description="HTTP status code of the response (response task only)"
|
| 108 |
+
)
|
| 109 |
feedback: str = Field(
|
| 110 |
default="",
|
| 111 |
description="Structured validation feedback from the last action"
|
server/api_specs.py
CHANGED
|
@@ -623,7 +623,319 @@ CALENDAR_AUTH_SPECS = [
|
|
| 623 |
]
|
| 624 |
|
| 625 |
|
| 626 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
ALL_SPECS = (
|
| 628 |
PAYMENT_SPECS
|
| 629 |
+ USER_SPECS
|
|
@@ -631,6 +943,9 @@ ALL_SPECS = (
|
|
| 631 |
+ MESSAGING_SPECS
|
| 632 |
+ ECOMMERCE_SPECS
|
| 633 |
+ CALENDAR_AUTH_SPECS
|
|
|
|
|
|
|
|
|
|
| 634 |
)
|
| 635 |
|
| 636 |
|
|
|
|
| 623 |
]
|
| 624 |
|
| 625 |
|
| 626 |
+
# =========================================================================
|
| 627 |
+
# Domain 7: Analytics and Monitoring
|
| 628 |
+
# =========================================================================
|
| 629 |
+
|
| 630 |
+
ANALYTICS_SPECS = [
|
| 631 |
+
_spec(
|
| 632 |
+
api_name="Create Dashboard",
|
| 633 |
+
http_method="POST",
|
| 634 |
+
endpoint="/api/dashboards",
|
| 635 |
+
required_fields=["name", "workspace_id"],
|
| 636 |
+
optional_fields=["description", "layout", "shared"],
|
| 637 |
+
field_types={
|
| 638 |
+
"name": "string",
|
| 639 |
+
"workspace_id": "string",
|
| 640 |
+
"description": "string",
|
| 641 |
+
"layout": "enum:grid,freeform,list",
|
| 642 |
+
"shared": "boolean",
|
| 643 |
+
},
|
| 644 |
+
valid_example={
|
| 645 |
+
"name": "API Latency Overview",
|
| 646 |
+
"workspace_id": "ws_prod_001",
|
| 647 |
+
},
|
| 648 |
+
),
|
| 649 |
+
_spec(
|
| 650 |
+
api_name="Add Metric",
|
| 651 |
+
http_method="POST",
|
| 652 |
+
endpoint="/api/metrics",
|
| 653 |
+
required_fields=["name", "type", "value"],
|
| 654 |
+
optional_fields=["tags", "timestamp", "unit"],
|
| 655 |
+
field_types={
|
| 656 |
+
"name": "string",
|
| 657 |
+
"type": "enum:counter,gauge,histogram,summary",
|
| 658 |
+
"value": "float",
|
| 659 |
+
"tags": "array",
|
| 660 |
+
"timestamp": "datetime",
|
| 661 |
+
"unit": "string",
|
| 662 |
+
},
|
| 663 |
+
valid_example={
|
| 664 |
+
"name": "api.request.duration",
|
| 665 |
+
"type": "histogram",
|
| 666 |
+
"value": 245.7,
|
| 667 |
+
},
|
| 668 |
+
),
|
| 669 |
+
_spec(
|
| 670 |
+
api_name="Create Alert Rule",
|
| 671 |
+
http_method="POST",
|
| 672 |
+
endpoint="/api/alerts/rules",
|
| 673 |
+
required_fields=["name", "metric", "threshold", "condition"],
|
| 674 |
+
optional_fields=["description", "severity", "notification_channels"],
|
| 675 |
+
field_types={
|
| 676 |
+
"name": "string",
|
| 677 |
+
"metric": "string",
|
| 678 |
+
"threshold": "float",
|
| 679 |
+
"condition": "enum:above,below,equals",
|
| 680 |
+
"description": "string",
|
| 681 |
+
"severity": "enum:critical,warning,info",
|
| 682 |
+
"notification_channels": "array",
|
| 683 |
+
},
|
| 684 |
+
valid_example={
|
| 685 |
+
"name": "High Latency Alert",
|
| 686 |
+
"metric": "api.request.duration",
|
| 687 |
+
"threshold": 500.0,
|
| 688 |
+
"condition": "above",
|
| 689 |
+
},
|
| 690 |
+
),
|
| 691 |
+
_spec(
|
| 692 |
+
api_name="Log Event",
|
| 693 |
+
http_method="POST",
|
| 694 |
+
endpoint="/api/logs",
|
| 695 |
+
required_fields=["level", "message", "service"],
|
| 696 |
+
optional_fields=["timestamp", "trace_id", "metadata"],
|
| 697 |
+
field_types={
|
| 698 |
+
"level": "enum:debug,info,warn,error,fatal",
|
| 699 |
+
"message": "string",
|
| 700 |
+
"service": "string",
|
| 701 |
+
"timestamp": "datetime",
|
| 702 |
+
"trace_id": "string",
|
| 703 |
+
"metadata": "object",
|
| 704 |
+
},
|
| 705 |
+
valid_example={
|
| 706 |
+
"level": "error",
|
| 707 |
+
"message": "Connection timeout to database",
|
| 708 |
+
"service": "payment-service",
|
| 709 |
+
},
|
| 710 |
+
),
|
| 711 |
+
_spec(
|
| 712 |
+
api_name="Query Logs",
|
| 713 |
+
http_method="POST",
|
| 714 |
+
endpoint="/api/logs/search",
|
| 715 |
+
required_fields=["query", "start_time", "end_time"],
|
| 716 |
+
optional_fields=["limit", "service_filter", "level_filter"],
|
| 717 |
+
field_types={
|
| 718 |
+
"query": "string",
|
| 719 |
+
"start_time": "datetime",
|
| 720 |
+
"end_time": "datetime",
|
| 721 |
+
"limit": "integer",
|
| 722 |
+
"service_filter": "string",
|
| 723 |
+
"level_filter": "enum:debug,info,warn,error,fatal",
|
| 724 |
+
},
|
| 725 |
+
valid_example={
|
| 726 |
+
"query": "timeout OR connection refused",
|
| 727 |
+
"start_time": "2026-04-01T00:00:00Z",
|
| 728 |
+
"end_time": "2026-04-01T23:59:59Z",
|
| 729 |
+
},
|
| 730 |
+
),
|
| 731 |
+
]
|
| 732 |
+
|
| 733 |
+
# =========================================================================
|
| 734 |
+
# Domain 8: DevOps and Infrastructure
|
| 735 |
+
# =========================================================================
|
| 736 |
+
|
| 737 |
+
DEVOPS_SPECS = [
|
| 738 |
+
_spec(
|
| 739 |
+
api_name="Create Deployment",
|
| 740 |
+
http_method="POST",
|
| 741 |
+
endpoint="/api/deployments",
|
| 742 |
+
required_fields=["service_name", "image", "environment"],
|
| 743 |
+
optional_fields=["replicas", "cpu_limit", "memory_limit", "env_vars"],
|
| 744 |
+
field_types={
|
| 745 |
+
"service_name": "string",
|
| 746 |
+
"image": "string",
|
| 747 |
+
"environment": "enum:staging,production,development",
|
| 748 |
+
"replicas": "integer",
|
| 749 |
+
"cpu_limit": "string",
|
| 750 |
+
"memory_limit": "string",
|
| 751 |
+
"env_vars": "object",
|
| 752 |
+
},
|
| 753 |
+
valid_example={
|
| 754 |
+
"service_name": "api-gateway",
|
| 755 |
+
"image": "registry.io/api-gateway:v2.1.0",
|
| 756 |
+
"environment": "production",
|
| 757 |
+
},
|
| 758 |
+
),
|
| 759 |
+
_spec(
|
| 760 |
+
api_name="Scale Service",
|
| 761 |
+
http_method="PATCH",
|
| 762 |
+
endpoint="/api/services/{service_id}/scale",
|
| 763 |
+
required_fields=["service_id", "replicas"],
|
| 764 |
+
optional_fields=["min_replicas", "max_replicas"],
|
| 765 |
+
field_types={
|
| 766 |
+
"service_id": "string",
|
| 767 |
+
"replicas": "integer",
|
| 768 |
+
"min_replicas": "integer",
|
| 769 |
+
"max_replicas": "integer",
|
| 770 |
+
},
|
| 771 |
+
valid_example={
|
| 772 |
+
"service_id": "svc_api_gateway",
|
| 773 |
+
"replicas": 5,
|
| 774 |
+
},
|
| 775 |
+
),
|
| 776 |
+
_spec(
|
| 777 |
+
api_name="Create DNS Record",
|
| 778 |
+
http_method="POST",
|
| 779 |
+
endpoint="/api/dns/records",
|
| 780 |
+
required_fields=["domain", "type", "value"],
|
| 781 |
+
optional_fields=["ttl", "priority"],
|
| 782 |
+
field_types={
|
| 783 |
+
"domain": "string",
|
| 784 |
+
"type": "enum:A,AAAA,CNAME,MX,TXT,NS",
|
| 785 |
+
"value": "string",
|
| 786 |
+
"ttl": "integer",
|
| 787 |
+
"priority": "integer",
|
| 788 |
+
},
|
| 789 |
+
valid_example={
|
| 790 |
+
"domain": "api.example.com",
|
| 791 |
+
"type": "A",
|
| 792 |
+
"value": "203.0.113.50",
|
| 793 |
+
},
|
| 794 |
+
),
|
| 795 |
+
_spec(
|
| 796 |
+
api_name="Add SSL Certificate",
|
| 797 |
+
http_method="POST",
|
| 798 |
+
endpoint="/api/certificates",
|
| 799 |
+
required_fields=["domain", "certificate", "private_key"],
|
| 800 |
+
optional_fields=["chain", "auto_renew"],
|
| 801 |
+
field_types={
|
| 802 |
+
"domain": "string",
|
| 803 |
+
"certificate": "string",
|
| 804 |
+
"private_key": "string",
|
| 805 |
+
"chain": "string",
|
| 806 |
+
"auto_renew": "boolean",
|
| 807 |
+
},
|
| 808 |
+
valid_example={
|
| 809 |
+
"domain": "api.example.com",
|
| 810 |
+
"certificate": "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----",
|
| 811 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----",
|
| 812 |
+
},
|
| 813 |
+
),
|
| 814 |
+
_spec(
|
| 815 |
+
api_name="Create Load Balancer",
|
| 816 |
+
http_method="POST",
|
| 817 |
+
endpoint="/api/load-balancers",
|
| 818 |
+
required_fields=["name", "algorithm", "targets"],
|
| 819 |
+
optional_fields=["health_check_path", "health_check_interval", "sticky_sessions"],
|
| 820 |
+
field_types={
|
| 821 |
+
"name": "string",
|
| 822 |
+
"algorithm": "enum:round_robin,least_connections,ip_hash,weighted",
|
| 823 |
+
"targets": "array",
|
| 824 |
+
"health_check_path": "string",
|
| 825 |
+
"health_check_interval": "integer",
|
| 826 |
+
"sticky_sessions": "boolean",
|
| 827 |
+
},
|
| 828 |
+
valid_example={
|
| 829 |
+
"name": "api-lb-prod",
|
| 830 |
+
"algorithm": "round_robin",
|
| 831 |
+
"targets": [
|
| 832 |
+
{"host": "10.0.1.1", "port": 8080},
|
| 833 |
+
{"host": "10.0.1.2", "port": 8080},
|
| 834 |
+
],
|
| 835 |
+
},
|
| 836 |
+
),
|
| 837 |
+
]
|
| 838 |
+
|
| 839 |
+
# =========================================================================
|
| 840 |
+
# Domain 9: AI/ML APIs
|
| 841 |
+
# =========================================================================
|
| 842 |
+
|
| 843 |
+
AI_ML_SPECS = [
|
| 844 |
+
_spec(
|
| 845 |
+
api_name="Submit Inference",
|
| 846 |
+
http_method="POST",
|
| 847 |
+
endpoint="/api/inference",
|
| 848 |
+
required_fields=["model_id", "inputs"],
|
| 849 |
+
optional_fields=["parameters", "stream", "timeout"],
|
| 850 |
+
field_types={
|
| 851 |
+
"model_id": "string",
|
| 852 |
+
"inputs": "string",
|
| 853 |
+
"parameters": "object",
|
| 854 |
+
"stream": "boolean",
|
| 855 |
+
"timeout": "integer",
|
| 856 |
+
},
|
| 857 |
+
valid_example={
|
| 858 |
+
"model_id": "meta-llama/Llama-3-8B-Instruct",
|
| 859 |
+
"inputs": "Explain reinforcement learning in one sentence.",
|
| 860 |
+
},
|
| 861 |
+
),
|
| 862 |
+
_spec(
|
| 863 |
+
api_name="Create Fine-tune Job",
|
| 864 |
+
http_method="POST",
|
| 865 |
+
endpoint="/api/fine-tune",
|
| 866 |
+
required_fields=["base_model", "dataset_id", "num_epochs"],
|
| 867 |
+
optional_fields=["learning_rate", "batch_size", "validation_split"],
|
| 868 |
+
field_types={
|
| 869 |
+
"base_model": "string",
|
| 870 |
+
"dataset_id": "string",
|
| 871 |
+
"num_epochs": "integer",
|
| 872 |
+
"learning_rate": "float",
|
| 873 |
+
"batch_size": "integer",
|
| 874 |
+
"validation_split": "float",
|
| 875 |
+
},
|
| 876 |
+
valid_example={
|
| 877 |
+
"base_model": "Qwen/Qwen2.5-0.5B",
|
| 878 |
+
"dataset_id": "ds_api_debug_v1",
|
| 879 |
+
"num_epochs": 3,
|
| 880 |
+
},
|
| 881 |
+
),
|
| 882 |
+
_spec(
|
| 883 |
+
api_name="Upload Dataset",
|
| 884 |
+
http_method="POST",
|
| 885 |
+
endpoint="/api/datasets",
|
| 886 |
+
required_fields=["name", "format", "source_url"],
|
| 887 |
+
optional_fields=["description", "license", "tags"],
|
| 888 |
+
field_types={
|
| 889 |
+
"name": "string",
|
| 890 |
+
"format": "enum:json,csv,parquet,arrow",
|
| 891 |
+
"source_url": "url",
|
| 892 |
+
"description": "string",
|
| 893 |
+
"license": "string",
|
| 894 |
+
"tags": "array",
|
| 895 |
+
},
|
| 896 |
+
valid_example={
|
| 897 |
+
"name": "api-debug-training-v1",
|
| 898 |
+
"format": "json",
|
| 899 |
+
"source_url": "https://storage.example.com/datasets/api_debug.json",
|
| 900 |
+
},
|
| 901 |
+
),
|
| 902 |
+
_spec(
|
| 903 |
+
api_name="Create Embedding",
|
| 904 |
+
http_method="POST",
|
| 905 |
+
endpoint="/api/embeddings",
|
| 906 |
+
required_fields=["model_id", "input"],
|
| 907 |
+
optional_fields=["encoding_format", "dimensions"],
|
| 908 |
+
field_types={
|
| 909 |
+
"model_id": "string",
|
| 910 |
+
"input": "string",
|
| 911 |
+
"encoding_format": "enum:float,base64",
|
| 912 |
+
"dimensions": "integer",
|
| 913 |
+
},
|
| 914 |
+
valid_example={
|
| 915 |
+
"model_id": "BAAI/bge-small-en-v1.5",
|
| 916 |
+
"input": "API debugging is a critical developer skill.",
|
| 917 |
+
},
|
| 918 |
+
),
|
| 919 |
+
_spec(
|
| 920 |
+
api_name="List Models",
|
| 921 |
+
http_method="GET",
|
| 922 |
+
endpoint="/api/models",
|
| 923 |
+
required_fields=["task"],
|
| 924 |
+
optional_fields=["library", "sort", "limit"],
|
| 925 |
+
field_types={
|
| 926 |
+
"task": "enum:text-generation,text-classification,embeddings,image-classification",
|
| 927 |
+
"library": "string",
|
| 928 |
+
"sort": "enum:downloads,likes,trending",
|
| 929 |
+
"limit": "integer",
|
| 930 |
+
},
|
| 931 |
+
valid_example={
|
| 932 |
+
"task": "text-generation",
|
| 933 |
+
},
|
| 934 |
+
),
|
| 935 |
+
]
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
# All 45 specs in a single flat list
|
| 939 |
ALL_SPECS = (
|
| 940 |
PAYMENT_SPECS
|
| 941 |
+ USER_SPECS
|
|
|
|
| 943 |
+ MESSAGING_SPECS
|
| 944 |
+ ECOMMERCE_SPECS
|
| 945 |
+ CALENDAR_AUTH_SPECS
|
| 946 |
+
+ ANALYTICS_SPECS
|
| 947 |
+
+ DEVOPS_SPECS
|
| 948 |
+
+ AI_ML_SPECS
|
| 949 |
)
|
| 950 |
|
| 951 |
|
server/app.py
CHANGED
|
@@ -36,6 +36,13 @@ def list_tasks():
|
|
| 36 |
"grading": "deterministic",
|
| 37 |
"description": "Identify the error type and affected fields",
|
| 38 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
{
|
| 40 |
"name": "medium",
|
| 41 |
"max_steps": 5,
|
|
@@ -43,6 +50,20 @@ def list_tasks():
|
|
| 43 |
"grading": "deterministic",
|
| 44 |
"description": "Fix the broken request to match the API spec",
|
| 45 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
{
|
| 47 |
"name": "hard",
|
| 48 |
"max_steps": 7,
|
|
@@ -51,6 +72,13 @@ def list_tasks():
|
|
| 51 |
"description": "Fix the request and explain the fix for developers",
|
| 52 |
},
|
| 53 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"error_types": [
|
| 55 |
"missing_required_field",
|
| 56 |
"wrong_field_type",
|
|
@@ -62,8 +90,13 @@ def list_tasks():
|
|
| 62 |
"malformed_json_value",
|
| 63 |
"invalid_enum_value",
|
| 64 |
"datetime_format_error",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
],
|
| 66 |
-
"api_spec_count":
|
| 67 |
})
|
| 68 |
|
| 69 |
|
|
|
|
| 36 |
"grading": "deterministic",
|
| 37 |
"description": "Identify the error type and affected fields",
|
| 38 |
},
|
| 39 |
+
{
|
| 40 |
+
"name": "classify",
|
| 41 |
+
"max_steps": 4,
|
| 42 |
+
"error_count": "2-3",
|
| 43 |
+
"grading": "deterministic",
|
| 44 |
+
"description": "Identify ALL error types and affected fields across multiple errors",
|
| 45 |
+
},
|
| 46 |
{
|
| 47 |
"name": "medium",
|
| 48 |
"max_steps": 5,
|
|
|
|
| 50 |
"grading": "deterministic",
|
| 51 |
"description": "Fix the broken request to match the API spec",
|
| 52 |
},
|
| 53 |
+
{
|
| 54 |
+
"name": "headers",
|
| 55 |
+
"max_steps": 4,
|
| 56 |
+
"error_count": 1,
|
| 57 |
+
"grading": "deterministic",
|
| 58 |
+
"description": "Fix request headers (auth, content-type, tokens)",
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"name": "response",
|
| 62 |
+
"max_steps": 4,
|
| 63 |
+
"error_count": "1-2",
|
| 64 |
+
"grading": "deterministic",
|
| 65 |
+
"description": "Validate API response: identify wrong status codes, missing fields, type errors, data leaks",
|
| 66 |
+
},
|
| 67 |
{
|
| 68 |
"name": "hard",
|
| 69 |
"max_steps": 7,
|
|
|
|
| 72 |
"description": "Fix the request and explain the fix for developers",
|
| 73 |
},
|
| 74 |
],
|
| 75 |
+
"response_issue_types": [
|
| 76 |
+
"wrong_status_code",
|
| 77 |
+
"missing_response_field",
|
| 78 |
+
"wrong_response_type",
|
| 79 |
+
"extra_response_field",
|
| 80 |
+
"inconsistent_error_format",
|
| 81 |
+
],
|
| 82 |
"error_types": [
|
| 83 |
"missing_required_field",
|
| 84 |
"wrong_field_type",
|
|
|
|
| 90 |
"malformed_json_value",
|
| 91 |
"invalid_enum_value",
|
| 92 |
"datetime_format_error",
|
| 93 |
+
"wrong_content_type",
|
| 94 |
+
"expired_auth_token",
|
| 95 |
+
"wrong_status_code",
|
| 96 |
+
"redirect_loop",
|
| 97 |
+
"rate_limit_headers",
|
| 98 |
],
|
| 99 |
+
"api_spec_count": 45,
|
| 100 |
})
|
| 101 |
|
| 102 |
|
server/environment.py
CHANGED
|
@@ -2,10 +2,11 @@
|
|
| 2 |
Core environment for the API Debug Environment.
|
| 3 |
|
| 4 |
Implements the OpenEnv Environment interface with:
|
| 5 |
-
-
|
| 6 |
- Multi-turn episodes with structured feedback
|
| 7 |
-
- Deterministic grading for easy/medium, LLM-as-judge for hard
|
| 8 |
- Step reward decay to encourage efficient debugging
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import copy
|
|
@@ -26,9 +27,12 @@ except ImportError:
|
|
| 26 |
from .api_specs import get_random_spec
|
| 27 |
from .error_injectors import (
|
| 28 |
ERROR_TYPES,
|
|
|
|
|
|
|
| 29 |
inject_error,
|
| 30 |
inject_multiple_errors,
|
| 31 |
)
|
|
|
|
| 32 |
from .validators import (
|
| 33 |
validate_field_type,
|
| 34 |
validate_headers_against_spec,
|
|
@@ -39,8 +43,11 @@ from .validators import (
|
|
| 39 |
# Task configuration: max steps and error count per difficulty
|
| 40 |
TASK_CONFIG = {
|
| 41 |
"easy": {"max_steps": 3, "error_count": 1},
|
|
|
|
| 42 |
"medium": {"max_steps": 5, "error_count": 1},
|
|
|
|
| 43 |
"hard": {"max_steps": 7, "min_errors": 2, "max_errors": 3},
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
|
|
@@ -58,6 +65,18 @@ class APIDebugEnvironment(Environment):
|
|
| 58 |
|
| 59 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def __init__(self):
|
| 62 |
super().__init__()
|
| 63 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
|
@@ -73,6 +92,13 @@ class APIDebugEnvironment(Environment):
|
|
| 73 |
self.rng = random.Random()
|
| 74 |
# For wrong_http_method error: the method shown to the agent
|
| 75 |
self.shown_http_method = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def reset(
|
| 78 |
self,
|
|
@@ -86,7 +112,7 @@ class APIDebugEnvironment(Environment):
|
|
| 86 |
Args:
|
| 87 |
seed: Random seed for reproducible episodes.
|
| 88 |
episode_id: Custom episode identifier.
|
| 89 |
-
task: Difficulty level (easy, medium, hard).
|
| 90 |
"""
|
| 91 |
# Initialize RNG
|
| 92 |
if seed is not None:
|
|
@@ -94,8 +120,11 @@ class APIDebugEnvironment(Environment):
|
|
| 94 |
else:
|
| 95 |
self.rng = random.Random()
|
| 96 |
|
| 97 |
-
# Validate task
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
config = TASK_CONFIG[self.task]
|
| 100 |
self.max_steps = config["max_steps"]
|
| 101 |
self.current_step = 0
|
|
@@ -113,14 +142,70 @@ class APIDebugEnvironment(Environment):
|
|
| 113 |
valid_request = copy.deepcopy(self.spec["valid_example"])
|
| 114 |
valid_headers = copy.deepcopy(self.spec["required_headers"])
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Inject errors based on difficulty
|
| 117 |
if self.task == "hard":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
error_count = self.rng.randint(config["min_errors"], config["max_errors"])
|
| 119 |
self.broken_request, self.broken_headers, self.ground_truths = (
|
| 120 |
inject_multiple_errors(
|
| 121 |
valid_request, valid_headers, self.spec, self.rng, error_count
|
| 122 |
)
|
| 123 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
else:
|
| 125 |
error_type = self.rng.choice(ERROR_TYPES)
|
| 126 |
self.broken_request, self.broken_headers, gt = inject_error(
|
|
@@ -181,8 +266,14 @@ class APIDebugEnvironment(Environment):
|
|
| 181 |
# Grade based on task type
|
| 182 |
if self.task == "easy":
|
| 183 |
raw_score, feedback = self._grade_easy(action)
|
|
|
|
|
|
|
| 184 |
elif self.task == "medium":
|
| 185 |
raw_score, feedback = self._grade_medium(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
else:
|
| 187 |
raw_score, feedback = self._grade_hard(action)
|
| 188 |
|
|
@@ -205,6 +296,9 @@ class APIDebugEnvironment(Environment):
|
|
| 205 |
self.episode_done = True
|
| 206 |
# Return best reward achieved during the episode
|
| 207 |
reward = self.best_reward
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
return self._make_observation(
|
| 210 |
feedback=feedback,
|
|
@@ -216,6 +310,18 @@ class APIDebugEnvironment(Environment):
|
|
| 216 |
def state(self) -> State:
|
| 217 |
return self._state
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
# =====================================================================
|
| 220 |
# Grading methods
|
| 221 |
# =====================================================================
|
|
@@ -261,6 +367,61 @@ class APIDebugEnvironment(Environment):
|
|
| 261 |
|
| 262 |
return round(score, 4), "; ".join(parts)
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _grade_medium(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 265 |
"""Grade request fix. Fully deterministic per-field validation.
|
| 266 |
|
|
@@ -304,6 +465,97 @@ class APIDebugEnvironment(Environment):
|
|
| 304 |
|
| 305 |
return round(total_score, 4), feedback
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def _grade_hard(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 308 |
"""Grade fix + explanation. 70% deterministic fix, 30% explanation.
|
| 309 |
|
|
@@ -453,7 +705,7 @@ class APIDebugEnvironment(Environment):
|
|
| 453 |
remaining = self.max_steps - self.current_step
|
| 454 |
msg = f"{remaining} step(s) remaining. Use the feedback to improve."
|
| 455 |
|
| 456 |
-
|
| 457 |
task=self.task,
|
| 458 |
api_name=self.spec.get("api_name", ""),
|
| 459 |
http_method=self.shown_http_method,
|
|
@@ -469,3 +721,8 @@ class APIDebugEnvironment(Environment):
|
|
| 469 |
done=done,
|
| 470 |
reward=reward,
|
| 471 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
Core environment for the API Debug Environment.
|
| 3 |
|
| 4 |
Implements the OpenEnv Environment interface with:
|
| 5 |
+
- 5 task difficulty levels (easy, classify, medium, headers, hard)
|
| 6 |
- Multi-turn episodes with structured feedback
|
| 7 |
+
- Deterministic grading for easy/classify/medium/headers, LLM-as-judge for hard
|
| 8 |
- Step reward decay to encourage efficient debugging
|
| 9 |
+
- Auto-curriculum (task="auto") that promotes based on rolling reward
|
| 10 |
"""
|
| 11 |
|
| 12 |
import copy
|
|
|
|
| 27 |
from .api_specs import get_random_spec
|
| 28 |
from .error_injectors import (
|
| 29 |
ERROR_TYPES,
|
| 30 |
+
HEADER_ERROR_TYPES,
|
| 31 |
+
inject_chained_errors,
|
| 32 |
inject_error,
|
| 33 |
inject_multiple_errors,
|
| 34 |
)
|
| 35 |
+
from .response_specs import get_random_response_template, inject_response_issues
|
| 36 |
from .validators import (
|
| 37 |
validate_field_type,
|
| 38 |
validate_headers_against_spec,
|
|
|
|
| 43 |
# Task configuration: max steps and error count per difficulty
|
| 44 |
TASK_CONFIG = {
|
| 45 |
"easy": {"max_steps": 3, "error_count": 1},
|
| 46 |
+
"classify": {"max_steps": 4, "min_errors": 2, "max_errors": 3},
|
| 47 |
"medium": {"max_steps": 5, "error_count": 1},
|
| 48 |
+
"headers": {"max_steps": 4, "error_count": 1},
|
| 49 |
"hard": {"max_steps": 7, "min_errors": 2, "max_errors": 3},
|
| 50 |
+
"response": {"max_steps": 4, "min_issues": 1, "max_issues": 2},
|
| 51 |
}
|
| 52 |
|
| 53 |
|
|
|
|
| 65 |
|
| 66 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 67 |
|
| 68 |
+
# Curriculum thresholds for task="auto" mode
|
| 69 |
+
# When rolling avg reward exceeds threshold, promote to next task
|
| 70 |
+
AUTO_CURRICULUM = {
|
| 71 |
+
"easy": {"next": "classify", "threshold": 0.7},
|
| 72 |
+
"classify": {"next": "medium", "threshold": 0.6},
|
| 73 |
+
"medium": {"next": "headers", "threshold": 0.6},
|
| 74 |
+
"headers": {"next": "response", "threshold": 0.5},
|
| 75 |
+
"response": {"next": "hard", "threshold": 0.5},
|
| 76 |
+
"hard": {"next": None, "threshold": None},
|
| 77 |
+
}
|
| 78 |
+
AUTO_WINDOW = 10
|
| 79 |
+
|
| 80 |
def __init__(self):
|
| 81 |
super().__init__()
|
| 82 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
|
|
|
| 92 |
self.rng = random.Random()
|
| 93 |
# For wrong_http_method error: the method shown to the agent
|
| 94 |
self.shown_http_method = ""
|
| 95 |
+
# Response task state
|
| 96 |
+
self.response_body: Dict[str, Any] = {}
|
| 97 |
+
self.response_status_code: int = 0
|
| 98 |
+
self.response_template: Dict[str, Any] = {}
|
| 99 |
+
# Curriculum state for task="auto"
|
| 100 |
+
self._auto_task = "easy"
|
| 101 |
+
self._auto_rewards: List[float] = []
|
| 102 |
|
| 103 |
def reset(
|
| 104 |
self,
|
|
|
|
| 112 |
Args:
|
| 113 |
seed: Random seed for reproducible episodes.
|
| 114 |
episode_id: Custom episode identifier.
|
| 115 |
+
task: Difficulty level (easy, classify, medium, headers, hard, auto).
|
| 116 |
"""
|
| 117 |
# Initialize RNG
|
| 118 |
if seed is not None:
|
|
|
|
| 120 |
else:
|
| 121 |
self.rng = random.Random()
|
| 122 |
|
| 123 |
+
# Validate task -- "auto" uses curriculum to pick difficulty
|
| 124 |
+
if task == "auto":
|
| 125 |
+
self.task = self._auto_task
|
| 126 |
+
else:
|
| 127 |
+
self.task = task if task in TASK_CONFIG else "easy"
|
| 128 |
config = TASK_CONFIG[self.task]
|
| 129 |
self.max_steps = config["max_steps"]
|
| 130 |
self.current_step = 0
|
|
|
|
| 142 |
valid_request = copy.deepcopy(self.spec["valid_example"])
|
| 143 |
valid_headers = copy.deepcopy(self.spec["required_headers"])
|
| 144 |
|
| 145 |
+
# Response task has a completely different setup: broken response, not request
|
| 146 |
+
if self.task == "response":
|
| 147 |
+
issue_count = self.rng.randint(config["min_issues"], config["max_issues"])
|
| 148 |
+
self.response_template = get_random_response_template(self.rng)
|
| 149 |
+
self.response_body, self.response_status_code, self.ground_truths = (
|
| 150 |
+
inject_response_issues(self.response_template, self.rng, issue_count)
|
| 151 |
+
)
|
| 152 |
+
# For response task, the request is correct -- agent examines the response
|
| 153 |
+
self.broken_request = valid_request
|
| 154 |
+
self.broken_headers = valid_headers
|
| 155 |
+
self.shown_http_method = self.spec["http_method"]
|
| 156 |
+
error_count = len(self.ground_truths)
|
| 157 |
+
return APIDebugObservation(
|
| 158 |
+
task=self.task,
|
| 159 |
+
api_name=self.spec["api_name"],
|
| 160 |
+
http_method=self.shown_http_method,
|
| 161 |
+
endpoint=self.spec["endpoint"],
|
| 162 |
+
broken_request=json.dumps(self.broken_request, indent=2),
|
| 163 |
+
broken_headers=self.broken_headers,
|
| 164 |
+
api_spec=self._build_spec_string(),
|
| 165 |
+
response_body=json.dumps(self.response_body, indent=2),
|
| 166 |
+
response_status_code=self.response_status_code,
|
| 167 |
+
error_count=error_count,
|
| 168 |
+
step_number=0,
|
| 169 |
+
max_steps=self.max_steps,
|
| 170 |
+
feedback="",
|
| 171 |
+
message=(
|
| 172 |
+
f"Validate the response from {self.shown_http_method} {self.spec['endpoint']}. "
|
| 173 |
+
f"The response has {error_count} issue(s). "
|
| 174 |
+
f"You have {self.max_steps} steps."
|
| 175 |
+
),
|
| 176 |
+
done=False,
|
| 177 |
+
reward=0.0,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
# Inject errors based on difficulty
|
| 181 |
if self.task == "hard":
|
| 182 |
+
error_count = self.rng.randint(config["min_errors"], config["max_errors"])
|
| 183 |
+
# 50% chance of chained errors (header gate + body errors)
|
| 184 |
+
if self.rng.random() < 0.5:
|
| 185 |
+
self.broken_request, self.broken_headers, self.ground_truths = (
|
| 186 |
+
inject_chained_errors(
|
| 187 |
+
valid_request, valid_headers, self.spec, self.rng, error_count
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
self.broken_request, self.broken_headers, self.ground_truths = (
|
| 192 |
+
inject_multiple_errors(
|
| 193 |
+
valid_request, valid_headers, self.spec, self.rng, error_count
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
elif self.task == "classify":
|
| 197 |
error_count = self.rng.randint(config["min_errors"], config["max_errors"])
|
| 198 |
self.broken_request, self.broken_headers, self.ground_truths = (
|
| 199 |
inject_multiple_errors(
|
| 200 |
valid_request, valid_headers, self.spec, self.rng, error_count
|
| 201 |
)
|
| 202 |
)
|
| 203 |
+
elif self.task == "headers":
|
| 204 |
+
error_type = self.rng.choice(HEADER_ERROR_TYPES)
|
| 205 |
+
self.broken_request, self.broken_headers, gt = inject_error(
|
| 206 |
+
error_type, valid_request, valid_headers, self.spec, self.rng
|
| 207 |
+
)
|
| 208 |
+
self.ground_truths = [gt]
|
| 209 |
else:
|
| 210 |
error_type = self.rng.choice(ERROR_TYPES)
|
| 211 |
self.broken_request, self.broken_headers, gt = inject_error(
|
|
|
|
| 266 |
# Grade based on task type
|
| 267 |
if self.task == "easy":
|
| 268 |
raw_score, feedback = self._grade_easy(action)
|
| 269 |
+
elif self.task == "classify":
|
| 270 |
+
raw_score, feedback = self._grade_classify(action)
|
| 271 |
elif self.task == "medium":
|
| 272 |
raw_score, feedback = self._grade_medium(action)
|
| 273 |
+
elif self.task == "headers":
|
| 274 |
+
raw_score, feedback = self._grade_headers(action)
|
| 275 |
+
elif self.task == "response":
|
| 276 |
+
raw_score, feedback = self._grade_response(action)
|
| 277 |
else:
|
| 278 |
raw_score, feedback = self._grade_hard(action)
|
| 279 |
|
|
|
|
| 296 |
self.episode_done = True
|
| 297 |
# Return best reward achieved during the episode
|
| 298 |
reward = self.best_reward
|
| 299 |
+
# Track for auto-curriculum promotion
|
| 300 |
+
self._auto_rewards.append(reward)
|
| 301 |
+
self._maybe_auto_promote()
|
| 302 |
|
| 303 |
return self._make_observation(
|
| 304 |
feedback=feedback,
|
|
|
|
| 310 |
def state(self) -> State:
|
| 311 |
return self._state
|
| 312 |
|
| 313 |
+
def _maybe_auto_promote(self):
|
| 314 |
+
"""Check if auto-curriculum should promote to next difficulty."""
|
| 315 |
+
config = self.AUTO_CURRICULUM.get(self._auto_task)
|
| 316 |
+
if not config or config["next"] is None or config["threshold"] is None:
|
| 317 |
+
return
|
| 318 |
+
if len(self._auto_rewards) < self.AUTO_WINDOW:
|
| 319 |
+
return
|
| 320 |
+
avg = sum(self._auto_rewards[-self.AUTO_WINDOW:]) / self.AUTO_WINDOW
|
| 321 |
+
if avg >= config["threshold"]:
|
| 322 |
+
self._auto_task = config["next"]
|
| 323 |
+
self._auto_rewards.clear()
|
| 324 |
+
|
| 325 |
# =====================================================================
|
| 326 |
# Grading methods
|
| 327 |
# =====================================================================
|
|
|
|
| 367 |
|
| 368 |
return round(score, 4), "; ".join(parts)
|
| 369 |
|
| 370 |
+
def _grade_classify(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 371 |
+
"""Grade multi-error classification. Fully deterministic.
|
| 372 |
+
|
| 373 |
+
Like easy but the agent must identify ALL error types and ALL
|
| 374 |
+
affected fields across multiple injected errors.
|
| 375 |
+
|
| 376 |
+
Scoring: 0.6 for error types (Jaccard) + 0.4 for affected fields (Jaccard).
|
| 377 |
+
Accepts either error_types (list) or error_type (single) from the agent.
|
| 378 |
+
"""
|
| 379 |
+
score = 0.0
|
| 380 |
+
parts = []
|
| 381 |
+
|
| 382 |
+
gt_types = {gt["error_type"] for gt in self.ground_truths}
|
| 383 |
+
gt_fields: set = set()
|
| 384 |
+
for gt in self.ground_truths:
|
| 385 |
+
gt_fields.update(gt.get("affected_fields", []))
|
| 386 |
+
|
| 387 |
+
# Accept error_types (list) or fall back to error_type (single)
|
| 388 |
+
agent_types = set(action.error_types or [])
|
| 389 |
+
if not agent_types and action.error_type:
|
| 390 |
+
agent_types = {action.error_type}
|
| 391 |
+
|
| 392 |
+
# Error types Jaccard (0.6 weight)
|
| 393 |
+
if gt_types and agent_types:
|
| 394 |
+
intersection = gt_types & agent_types
|
| 395 |
+
union = gt_types | agent_types
|
| 396 |
+
jaccard = len(intersection) / len(union) if union else 0.0
|
| 397 |
+
score += 0.6 * jaccard
|
| 398 |
+
parts.append(
|
| 399 |
+
f"error_types: {len(intersection)}/{len(gt_types)} correct, "
|
| 400 |
+
f"{len(agent_types - gt_types)} extra"
|
| 401 |
+
)
|
| 402 |
+
elif not agent_types:
|
| 403 |
+
parts.append("error_types: MISSING (none provided)")
|
| 404 |
+
else:
|
| 405 |
+
parts.append("error_types: INCORRECT (0 matches)")
|
| 406 |
+
|
| 407 |
+
# Affected fields Jaccard (0.4 weight)
|
| 408 |
+
agent_fields = set(action.affected_fields or [])
|
| 409 |
+
if gt_fields and agent_fields:
|
| 410 |
+
intersection = gt_fields & agent_fields
|
| 411 |
+
union = gt_fields | agent_fields
|
| 412 |
+
jaccard = len(intersection) / len(union) if union else 0.0
|
| 413 |
+
score += 0.4 * jaccard
|
| 414 |
+
parts.append(
|
| 415 |
+
f"affected_fields: {len(intersection)}/{len(gt_fields)} correct, "
|
| 416 |
+
f"{len(agent_fields - gt_fields)} extra"
|
| 417 |
+
)
|
| 418 |
+
elif not agent_fields:
|
| 419 |
+
parts.append("affected_fields: MISSING (none provided)")
|
| 420 |
+
else:
|
| 421 |
+
parts.append("affected_fields: INCORRECT (0 matches)")
|
| 422 |
+
|
| 423 |
+
return round(score, 4), "; ".join(parts)
|
| 424 |
+
|
| 425 |
def _grade_medium(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 426 |
"""Grade request fix. Fully deterministic per-field validation.
|
| 427 |
|
|
|
|
| 465 |
|
| 466 |
return round(total_score, 4), feedback
|
| 467 |
|
| 468 |
+
def _grade_headers(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 469 |
+
"""Grade header fix. Fully deterministic.
|
| 470 |
+
|
| 471 |
+
The agent must provide corrected headers that match the spec.
|
| 472 |
+
Also awards partial credit for identifying the error type.
|
| 473 |
+
|
| 474 |
+
Scoring: 0.7 for correct headers + 0.3 for error type identification.
|
| 475 |
+
"""
|
| 476 |
+
score = 0.0
|
| 477 |
+
parts = []
|
| 478 |
+
|
| 479 |
+
# Error type identification (0.3 weight)
|
| 480 |
+
gt_types = {gt["error_type"] for gt in self.ground_truths}
|
| 481 |
+
if action.error_type and action.error_type in gt_types:
|
| 482 |
+
score += 0.3
|
| 483 |
+
parts.append("error_type: CORRECT")
|
| 484 |
+
else:
|
| 485 |
+
given = action.error_type or "(none)"
|
| 486 |
+
parts.append(f"error_type: INCORRECT (you said '{given}')")
|
| 487 |
+
|
| 488 |
+
# Header fix validation (0.7 weight)
|
| 489 |
+
if action.fixed_headers:
|
| 490 |
+
header_score, header_feedback = validate_headers_against_spec(
|
| 491 |
+
action.fixed_headers, self.spec
|
| 492 |
+
)
|
| 493 |
+
score += 0.7 * header_score
|
| 494 |
+
parts.append(header_feedback)
|
| 495 |
+
else:
|
| 496 |
+
parts.append("Headers: NOT PROVIDED (header fix needed)")
|
| 497 |
+
|
| 498 |
+
return round(score, 4), "; ".join(parts)
|
| 499 |
+
|
| 500 |
+
def _grade_response(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 501 |
+
"""Grade response validation. Fully deterministic.
|
| 502 |
+
|
| 503 |
+
Agent must identify issue types and, for wrong_status_code, provide
|
| 504 |
+
the correct status code.
|
| 505 |
+
|
| 506 |
+
Scoring: 0.5 for issue type identification (Jaccard) +
|
| 507 |
+
0.3 for affected field identification (Jaccard) +
|
| 508 |
+
0.2 for correct status code (if applicable).
|
| 509 |
+
"""
|
| 510 |
+
score = 0.0
|
| 511 |
+
parts = []
|
| 512 |
+
|
| 513 |
+
gt_issue_types = {gt["issue_type"] for gt in self.ground_truths}
|
| 514 |
+
gt_fields = {gt.get("affected_field", "") for gt in self.ground_truths} - {""}
|
| 515 |
+
|
| 516 |
+
# Issue type identification (0.5 weight)
|
| 517 |
+
predicted_issues = set(action.response_issues or [])
|
| 518 |
+
if predicted_issues and gt_issue_types:
|
| 519 |
+
intersection = predicted_issues & gt_issue_types
|
| 520 |
+
union = predicted_issues | gt_issue_types
|
| 521 |
+
jaccard = len(intersection) / len(union) if union else 0.0
|
| 522 |
+
score += 0.5 * jaccard
|
| 523 |
+
parts.append(f"Issue types: {len(intersection)}/{len(gt_issue_types)} correct (Jaccard={jaccard:.2f})")
|
| 524 |
+
else:
|
| 525 |
+
parts.append("Issue types: NOT PROVIDED" if not predicted_issues else "Issue types: NONE CORRECT")
|
| 526 |
+
|
| 527 |
+
# Affected field identification via error_type or affected_fields (0.3 weight)
|
| 528 |
+
predicted_fields = set(action.affected_fields or [])
|
| 529 |
+
if predicted_fields and gt_fields:
|
| 530 |
+
intersection = predicted_fields & gt_fields
|
| 531 |
+
union = predicted_fields | gt_fields
|
| 532 |
+
jaccard = len(intersection) / len(union) if union else 0.0
|
| 533 |
+
score += 0.3 * jaccard
|
| 534 |
+
parts.append(f"Affected fields: {len(intersection)}/{len(gt_fields)} correct")
|
| 535 |
+
else:
|
| 536 |
+
parts.append("Affected fields: NOT PROVIDED" if not predicted_fields else "Affected fields: NONE CORRECT")
|
| 537 |
+
|
| 538 |
+
# Status code check (0.2 weight) -- only if wrong_status_code is a ground truth
|
| 539 |
+
has_status_issue = any(gt["issue_type"] == "wrong_status_code" for gt in self.ground_truths)
|
| 540 |
+
if has_status_issue:
|
| 541 |
+
correct_status = None
|
| 542 |
+
for gt in self.ground_truths:
|
| 543 |
+
if gt["issue_type"] == "wrong_status_code":
|
| 544 |
+
correct_status = int(gt.get("correct_value", 0))
|
| 545 |
+
break
|
| 546 |
+
if action.expected_status_code and action.expected_status_code == correct_status:
|
| 547 |
+
score += 0.2
|
| 548 |
+
parts.append(f"Status code: CORRECT ({correct_status})")
|
| 549 |
+
else:
|
| 550 |
+
given = action.expected_status_code or "(none)"
|
| 551 |
+
parts.append(f"Status code: INCORRECT (you said {given}, expected {correct_status})")
|
| 552 |
+
else:
|
| 553 |
+
# No status code issue -- redistribute 0.2 to issue types
|
| 554 |
+
score += 0.2 * (len(predicted_issues & gt_issue_types) / len(gt_issue_types) if gt_issue_types else 0.0)
|
| 555 |
+
parts.append("Status code: N/A (no status code issue)")
|
| 556 |
+
|
| 557 |
+
return round(score, 4), "; ".join(parts)
|
| 558 |
+
|
| 559 |
def _grade_hard(self, action: APIDebugAction) -> Tuple[float, str]:
|
| 560 |
"""Grade fix + explanation. 70% deterministic fix, 30% explanation.
|
| 561 |
|
|
|
|
| 705 |
remaining = self.max_steps - self.current_step
|
| 706 |
msg = f"{remaining} step(s) remaining. Use the feedback to improve."
|
| 707 |
|
| 708 |
+
obs = APIDebugObservation(
|
| 709 |
task=self.task,
|
| 710 |
api_name=self.spec.get("api_name", ""),
|
| 711 |
http_method=self.shown_http_method,
|
|
|
|
| 721 |
done=done,
|
| 722 |
reward=reward,
|
| 723 |
)
|
| 724 |
+
# Include response data for response task
|
| 725 |
+
if self.task == "response":
|
| 726 |
+
obs.response_body = json.dumps(self.response_body, indent=2)
|
| 727 |
+
obs.response_status_code = self.response_status_code
|
| 728 |
+
return obs
|
server/error_injectors.py
CHANGED
|
@@ -324,10 +324,178 @@ def inject_datetime_format_error(
|
|
| 324 |
)
|
| 325 |
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
# =========================================================================
|
| 328 |
# Registry and helpers
|
| 329 |
# =========================================================================
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
ERROR_TYPES = [
|
| 332 |
"missing_required_field",
|
| 333 |
"wrong_field_type",
|
|
@@ -339,6 +507,11 @@ ERROR_TYPES = [
|
|
| 339 |
"malformed_json_value",
|
| 340 |
"invalid_enum_value",
|
| 341 |
"datetime_format_error",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
]
|
| 343 |
|
| 344 |
INJECTOR_MAP = {
|
|
@@ -352,6 +525,11 @@ INJECTOR_MAP = {
|
|
| 352 |
"malformed_json_value": inject_malformed_json_value,
|
| 353 |
"invalid_enum_value": inject_invalid_enum_value,
|
| 354 |
"datetime_format_error": inject_datetime_format_error,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
}
|
| 356 |
|
| 357 |
|
|
@@ -367,6 +545,89 @@ def inject_error(
|
|
| 367 |
return injector(request, headers, spec, rng)
|
| 368 |
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def inject_multiple_errors(
|
| 371 |
request: Dict[str, Any],
|
| 372 |
headers: Dict[str, str],
|
|
|
|
| 324 |
)
|
| 325 |
|
| 326 |
|
| 327 |
+
# =========================================================================
|
| 328 |
+
# 11. wrong_content_type
|
| 329 |
+
# =========================================================================
|
| 330 |
+
|
| 331 |
+
def inject_wrong_content_type(
|
| 332 |
+
request: Dict[str, Any],
|
| 333 |
+
headers: Dict[str, str],
|
| 334 |
+
spec: Dict[str, Any],
|
| 335 |
+
rng: random_module.Random,
|
| 336 |
+
) -> InjectorResult:
|
| 337 |
+
"""Change Content-Type to an incorrect value."""
|
| 338 |
+
broken_headers = copy.deepcopy(headers)
|
| 339 |
+
wrong_types = [
|
| 340 |
+
"text/plain",
|
| 341 |
+
"application/xml",
|
| 342 |
+
"multipart/form-data",
|
| 343 |
+
"text/html",
|
| 344 |
+
"application/x-www-form-urlencoded",
|
| 345 |
+
]
|
| 346 |
+
if "Content-Type" in broken_headers:
|
| 347 |
+
broken_headers["Content-Type"] = rng.choice(wrong_types)
|
| 348 |
+
else:
|
| 349 |
+
broken_headers["Content-Type"] = rng.choice(wrong_types)
|
| 350 |
+
return request, broken_headers, _ground_truth(
|
| 351 |
+
"wrong_content_type", ["Content-Type"], request, headers
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# =========================================================================
|
| 356 |
+
# 12. expired_auth_token
|
| 357 |
+
# =========================================================================
|
| 358 |
+
|
| 359 |
+
def inject_expired_token(
|
| 360 |
+
request: Dict[str, Any],
|
| 361 |
+
headers: Dict[str, str],
|
| 362 |
+
spec: Dict[str, Any],
|
| 363 |
+
rng: random_module.Random,
|
| 364 |
+
) -> InjectorResult:
|
| 365 |
+
"""Replace the Authorization token with an expired/malformed one."""
|
| 366 |
+
broken_headers = copy.deepcopy(headers)
|
| 367 |
+
bad_tokens = [
|
| 368 |
+
"Bearer expired_token_abc123",
|
| 369 |
+
"Bearer ",
|
| 370 |
+
"Basic dXNlcjpwYXNz",
|
| 371 |
+
"Token invalid",
|
| 372 |
+
"Bearer eyJhbGciOiJub25lIn0.e30.",
|
| 373 |
+
]
|
| 374 |
+
if "Authorization" in broken_headers:
|
| 375 |
+
broken_headers["Authorization"] = rng.choice(bad_tokens)
|
| 376 |
+
return request, broken_headers, _ground_truth(
|
| 377 |
+
"expired_auth_token", ["Authorization"], request, headers
|
| 378 |
+
)
|
| 379 |
+
# If no auth header in spec, inject wrong content type instead
|
| 380 |
+
return inject_wrong_content_type(request, headers, spec, rng)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# =========================================================================
|
| 384 |
+
# 13. wrong_status_code (for response validation / chained scenarios)
|
| 385 |
+
# =========================================================================
|
| 386 |
+
|
| 387 |
+
def inject_wrong_status_code(
|
| 388 |
+
request: Dict[str, Any],
|
| 389 |
+
headers: Dict[str, str],
|
| 390 |
+
spec: Dict[str, Any],
|
| 391 |
+
rng: random_module.Random,
|
| 392 |
+
) -> InjectorResult:
|
| 393 |
+
"""Record that the wrong HTTP status code would be returned.
|
| 394 |
+
|
| 395 |
+
Simulates a server returning an unexpected status code.
|
| 396 |
+
The ground truth stores the wrong code and the expected code.
|
| 397 |
+
"""
|
| 398 |
+
correct_status = 200 if spec["http_method"] == "GET" else 201
|
| 399 |
+
wrong_codes = [
|
| 400 |
+
(301, "Moved Permanently - resource redirected"),
|
| 401 |
+
(302, "Found - temporary redirect to different endpoint"),
|
| 402 |
+
(400, "Bad Request - but request is actually valid"),
|
| 403 |
+
(403, "Forbidden - incorrect permissions applied"),
|
| 404 |
+
(404, "Not Found - wrong endpoint routing"),
|
| 405 |
+
(429, "Too Many Requests - rate limit misconfigured"),
|
| 406 |
+
(500, "Internal Server Error - server-side issue"),
|
| 407 |
+
(502, "Bad Gateway - upstream service down"),
|
| 408 |
+
(503, "Service Unavailable - maintenance mode"),
|
| 409 |
+
]
|
| 410 |
+
wrong_status, description = rng.choice(wrong_codes)
|
| 411 |
+
gt = _ground_truth("wrong_status_code", ["status_code"], request, headers)
|
| 412 |
+
gt["wrong_status"] = wrong_status
|
| 413 |
+
gt["correct_status"] = correct_status
|
| 414 |
+
gt["description"] = description
|
| 415 |
+
return request, headers, gt
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# =========================================================================
|
| 419 |
+
# 14. redirect_loop
|
| 420 |
+
# =========================================================================
|
| 421 |
+
|
| 422 |
+
def inject_redirect_loop(
|
| 423 |
+
request: Dict[str, Any],
|
| 424 |
+
headers: Dict[str, str],
|
| 425 |
+
spec: Dict[str, Any],
|
| 426 |
+
rng: random_module.Random,
|
| 427 |
+
) -> InjectorResult:
|
| 428 |
+
"""Simulate a redirect chain issue.
|
| 429 |
+
|
| 430 |
+
The agent must identify that the endpoint redirects and provide
|
| 431 |
+
the correct target endpoint.
|
| 432 |
+
"""
|
| 433 |
+
redirect_scenarios = [
|
| 434 |
+
{
|
| 435 |
+
"from": spec["endpoint"],
|
| 436 |
+
"to": spec["endpoint"].rstrip("/") + "/v2",
|
| 437 |
+
"reason": "API version upgrade - v1 redirects to v2",
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"from": spec["endpoint"],
|
| 441 |
+
"to": spec["endpoint"].replace("/api/", "/api/v2/"),
|
| 442 |
+
"reason": "Base path migration",
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"from": spec["endpoint"],
|
| 446 |
+
"to": spec["endpoint"] + "?format=json",
|
| 447 |
+
"reason": "Content negotiation redirect",
|
| 448 |
+
},
|
| 449 |
+
]
|
| 450 |
+
scenario = rng.choice(redirect_scenarios)
|
| 451 |
+
gt = _ground_truth("redirect_loop", ["endpoint"], request, headers)
|
| 452 |
+
gt["redirect_from"] = scenario["from"]
|
| 453 |
+
gt["redirect_to"] = scenario["to"]
|
| 454 |
+
gt["reason"] = scenario["reason"]
|
| 455 |
+
return request, headers, gt
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# =========================================================================
|
| 459 |
+
# 15. rate_limit_headers
|
| 460 |
+
# =========================================================================
|
| 461 |
+
|
| 462 |
+
def inject_rate_limit_headers(
|
| 463 |
+
request: Dict[str, Any],
|
| 464 |
+
headers: Dict[str, str],
|
| 465 |
+
spec: Dict[str, Any],
|
| 466 |
+
rng: random_module.Random,
|
| 467 |
+
) -> InjectorResult:
|
| 468 |
+
"""Inject missing or wrong rate limit headers.
|
| 469 |
+
|
| 470 |
+
Real APIs require headers like X-RateLimit-Limit, Retry-After.
|
| 471 |
+
The agent must identify the rate limiting issue and provide correct headers.
|
| 472 |
+
"""
|
| 473 |
+
broken_headers = copy.deepcopy(headers)
|
| 474 |
+
# Add rate limit headers that indicate the client is being throttled
|
| 475 |
+
broken_headers["X-RateLimit-Remaining"] = "0"
|
| 476 |
+
broken_headers["X-RateLimit-Reset"] = "1712000000"
|
| 477 |
+
broken_headers["Retry-After"] = "60"
|
| 478 |
+
|
| 479 |
+
gt = _ground_truth(
|
| 480 |
+
"rate_limit_headers",
|
| 481 |
+
["X-RateLimit-Remaining", "Retry-After"],
|
| 482 |
+
request, headers,
|
| 483 |
+
)
|
| 484 |
+
gt["issue"] = "Client is rate-limited, must wait or reduce request frequency"
|
| 485 |
+
return request, broken_headers, gt
|
| 486 |
+
|
| 487 |
+
|
| 488 |
# =========================================================================
|
| 489 |
# Registry and helpers
|
| 490 |
# =========================================================================
|
| 491 |
|
| 492 |
+
# Header-only error types (used by the headers task)
|
| 493 |
+
HEADER_ERROR_TYPES = [
|
| 494 |
+
"missing_auth_header",
|
| 495 |
+
"wrong_content_type",
|
| 496 |
+
"expired_auth_token",
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
ERROR_TYPES = [
|
| 500 |
"missing_required_field",
|
| 501 |
"wrong_field_type",
|
|
|
|
| 507 |
"malformed_json_value",
|
| 508 |
"invalid_enum_value",
|
| 509 |
"datetime_format_error",
|
| 510 |
+
"wrong_content_type",
|
| 511 |
+
"expired_auth_token",
|
| 512 |
+
"wrong_status_code",
|
| 513 |
+
"redirect_loop",
|
| 514 |
+
"rate_limit_headers",
|
| 515 |
]
|
| 516 |
|
| 517 |
INJECTOR_MAP = {
|
|
|
|
| 525 |
"malformed_json_value": inject_malformed_json_value,
|
| 526 |
"invalid_enum_value": inject_invalid_enum_value,
|
| 527 |
"datetime_format_error": inject_datetime_format_error,
|
| 528 |
+
"wrong_content_type": inject_wrong_content_type,
|
| 529 |
+
"expired_auth_token": inject_expired_token,
|
| 530 |
+
"wrong_status_code": inject_wrong_status_code,
|
| 531 |
+
"redirect_loop": inject_redirect_loop,
|
| 532 |
+
"rate_limit_headers": inject_rate_limit_headers,
|
| 533 |
}
|
| 534 |
|
| 535 |
|
|
|
|
| 545 |
return injector(request, headers, spec, rng)
|
| 546 |
|
| 547 |
|
| 548 |
+
# Chain patterns for realistic multi-step debugging scenarios
|
| 549 |
+
CHAIN_PATTERNS = [
|
| 550 |
+
# Pattern 1: Auth gate -> body errors
|
| 551 |
+
# Real-world: API returns 401 first, body validation only runs after auth passes
|
| 552 |
+
{
|
| 553 |
+
"name": "auth_gate",
|
| 554 |
+
"gate_types": ["missing_auth_header", "expired_auth_token"],
|
| 555 |
+
"body_pool": None, # uses all body types
|
| 556 |
+
},
|
| 557 |
+
# Pattern 2: Content-type gate -> type mismatches
|
| 558 |
+
# Real-world: Wrong Content-Type causes parser to misinterpret the body
|
| 559 |
+
{
|
| 560 |
+
"name": "content_type_gate",
|
| 561 |
+
"gate_types": ["wrong_content_type"],
|
| 562 |
+
"body_pool": ["wrong_field_type", "malformed_json_value", "invalid_enum_value"],
|
| 563 |
+
},
|
| 564 |
+
# Pattern 3: Method + endpoint chain
|
| 565 |
+
# Real-world: Wrong method returns 405, then wrong fields for the correct method
|
| 566 |
+
{
|
| 567 |
+
"name": "method_chain",
|
| 568 |
+
"gate_types": ["wrong_http_method"],
|
| 569 |
+
"body_pool": ["missing_required_field", "extra_unknown_field", "null_value_in_required"],
|
| 570 |
+
},
|
| 571 |
+
# Pattern 4: Rate limit + auth
|
| 572 |
+
# Real-world: Rate limited, and when retrying the token has expired
|
| 573 |
+
{
|
| 574 |
+
"name": "rate_limit_chain",
|
| 575 |
+
"gate_types": ["rate_limit_headers"],
|
| 576 |
+
"body_pool": ["expired_auth_token", "missing_required_field"],
|
| 577 |
+
},
|
| 578 |
+
# Pattern 5: Redirect + body errors
|
| 579 |
+
# Real-world: Endpoint moved, client follows redirect but sends wrong body format
|
| 580 |
+
{
|
| 581 |
+
"name": "redirect_chain",
|
| 582 |
+
"gate_types": ["redirect_loop"],
|
| 583 |
+
"body_pool": ["wrong_field_type", "datetime_format_error", "invalid_email_format"],
|
| 584 |
+
},
|
| 585 |
+
]
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def inject_chained_errors(
|
| 589 |
+
request: Dict[str, Any],
|
| 590 |
+
headers: Dict[str, str],
|
| 591 |
+
spec: Dict[str, Any],
|
| 592 |
+
rng: random_module.Random,
|
| 593 |
+
count: int = 2,
|
| 594 |
+
) -> Tuple[Dict[str, Any], Dict[str, str], List[GroundTruth]]:
|
| 595 |
+
"""Inject errors in a realistic dependency chain.
|
| 596 |
+
|
| 597 |
+
Picks a chain pattern, injects the gate error first, then body errors.
|
| 598 |
+
Ground truths are ordered: gate errors first, body errors second.
|
| 599 |
+
This ordering lets the environment progressively reveal errors.
|
| 600 |
+
"""
|
| 601 |
+
broken_req = copy.deepcopy(request)
|
| 602 |
+
broken_hdrs = copy.deepcopy(headers)
|
| 603 |
+
chain: List[GroundTruth] = []
|
| 604 |
+
|
| 605 |
+
# Pick a random chain pattern
|
| 606 |
+
pattern = rng.choice(CHAIN_PATTERNS)
|
| 607 |
+
|
| 608 |
+
# Inject the gate error
|
| 609 |
+
gate_type = rng.choice(pattern["gate_types"])
|
| 610 |
+
injector = INJECTOR_MAP[gate_type]
|
| 611 |
+
broken_req, broken_hdrs, gt = injector(broken_req, broken_hdrs, spec, rng)
|
| 612 |
+
chain.append(gt)
|
| 613 |
+
|
| 614 |
+
# Inject body errors from the pattern's pool (or all body types)
|
| 615 |
+
body_pool = pattern["body_pool"]
|
| 616 |
+
if body_pool is None:
|
| 617 |
+
body_pool = [t for t in ERROR_TYPES if t not in HEADER_ERROR_TYPES
|
| 618 |
+
and t not in ("wrong_status_code", "redirect_loop", "rate_limit_headers")]
|
| 619 |
+
|
| 620 |
+
body_count = max(1, count - 1)
|
| 621 |
+
available = [t for t in body_pool if t in INJECTOR_MAP]
|
| 622 |
+
chosen = rng.sample(available, min(body_count, len(available)))
|
| 623 |
+
for err_type in chosen:
|
| 624 |
+
injector = INJECTOR_MAP[err_type]
|
| 625 |
+
broken_req, broken_hdrs, gt = injector(broken_req, broken_hdrs, spec, rng)
|
| 626 |
+
chain.append(gt)
|
| 627 |
+
|
| 628 |
+
return broken_req, broken_hdrs, chain
|
| 629 |
+
|
| 630 |
+
|
| 631 |
def inject_multiple_errors(
|
| 632 |
request: Dict[str, Any],
|
| 633 |
headers: Dict[str, str],
|
server/response_specs.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Response templates for the response validation task.
|
| 3 |
+
|
| 4 |
+
Each template defines what a correct API response looks like for a given
|
| 5 |
+
request type. The environment generates a broken response by injecting
|
| 6 |
+
issues (wrong status code, missing fields, wrong types, extra fields).
|
| 7 |
+
|
| 8 |
+
Response issue types:
|
| 9 |
+
- wrong_status_code: Response has incorrect HTTP status code
|
| 10 |
+
- missing_response_field: Required field missing from response body
|
| 11 |
+
- wrong_response_type: Field present but wrong data type
|
| 12 |
+
- extra_response_field: Unexpected field in response (data leak risk)
|
| 13 |
+
- inconsistent_error_format: Error response doesn't follow spec format
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
from typing import Any, Dict, List, Tuple
|
| 19 |
+
|
| 20 |
+
# Response issue types the agent must identify
|
| 21 |
+
RESPONSE_ISSUE_TYPES = [
|
| 22 |
+
"wrong_status_code",
|
| 23 |
+
"missing_response_field",
|
| 24 |
+
"wrong_response_type",
|
| 25 |
+
"extra_response_field",
|
| 26 |
+
"inconsistent_error_format",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Maps API operation type to expected success response
|
| 31 |
+
RESPONSE_TEMPLATES = [
|
| 32 |
+
{
|
| 33 |
+
"name": "Create Resource",
|
| 34 |
+
"success_status": 201,
|
| 35 |
+
"success_body": {
|
| 36 |
+
"id": "res_abc123",
|
| 37 |
+
"status": "created",
|
| 38 |
+
"created_at": "2025-01-15T10:30:00Z",
|
| 39 |
+
},
|
| 40 |
+
"required_response_fields": ["id", "status", "created_at"],
|
| 41 |
+
"field_types": {"id": "string", "status": "string", "created_at": "string"},
|
| 42 |
+
"error_status": 400,
|
| 43 |
+
"error_body": {
|
| 44 |
+
"error": {"code": "VALIDATION_ERROR", "message": "Invalid input", "details": []},
|
| 45 |
+
},
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"name": "List Resources",
|
| 49 |
+
"success_status": 200,
|
| 50 |
+
"success_body": {
|
| 51 |
+
"data": [{"id": "res_1", "name": "Item 1"}, {"id": "res_2", "name": "Item 2"}],
|
| 52 |
+
"total": 2,
|
| 53 |
+
"page": 1,
|
| 54 |
+
"per_page": 20,
|
| 55 |
+
},
|
| 56 |
+
"required_response_fields": ["data", "total", "page", "per_page"],
|
| 57 |
+
"field_types": {"data": "array", "total": "integer", "page": "integer", "per_page": "integer"},
|
| 58 |
+
"error_status": 401,
|
| 59 |
+
"error_body": {
|
| 60 |
+
"error": {"code": "UNAUTHORIZED", "message": "Invalid API key", "details": []},
|
| 61 |
+
},
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"name": "Update Resource",
|
| 65 |
+
"success_status": 200,
|
| 66 |
+
"success_body": {
|
| 67 |
+
"id": "res_abc123",
|
| 68 |
+
"status": "updated",
|
| 69 |
+
"updated_at": "2025-01-15T12:00:00Z",
|
| 70 |
+
},
|
| 71 |
+
"required_response_fields": ["id", "status", "updated_at"],
|
| 72 |
+
"field_types": {"id": "string", "status": "string", "updated_at": "string"},
|
| 73 |
+
"error_status": 404,
|
| 74 |
+
"error_body": {
|
| 75 |
+
"error": {"code": "NOT_FOUND", "message": "Resource not found", "details": []},
|
| 76 |
+
},
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"name": "Delete Resource",
|
| 80 |
+
"success_status": 204,
|
| 81 |
+
"success_body": {},
|
| 82 |
+
"required_response_fields": [],
|
| 83 |
+
"field_types": {},
|
| 84 |
+
"error_status": 403,
|
| 85 |
+
"error_body": {
|
| 86 |
+
"error": {"code": "FORBIDDEN", "message": "Insufficient permissions", "details": []},
|
| 87 |
+
},
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"name": "Batch Operation",
|
| 91 |
+
"success_status": 200,
|
| 92 |
+
"success_body": {
|
| 93 |
+
"processed": 5,
|
| 94 |
+
"failed": 0,
|
| 95 |
+
"results": [
|
| 96 |
+
{"id": "item_1", "status": "success"},
|
| 97 |
+
{"id": "item_2", "status": "success"},
|
| 98 |
+
],
|
| 99 |
+
},
|
| 100 |
+
"required_response_fields": ["processed", "failed", "results"],
|
| 101 |
+
"field_types": {"processed": "integer", "failed": "integer", "results": "array"},
|
| 102 |
+
"error_status": 422,
|
| 103 |
+
"error_body": {
|
| 104 |
+
"error": {"code": "UNPROCESSABLE", "message": "Batch validation failed", "details": []},
|
| 105 |
+
},
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"name": "Authentication",
|
| 109 |
+
"success_status": 200,
|
| 110 |
+
"success_body": {
|
| 111 |
+
"access_token": "eyJhbGciOiJIUzI1NiJ9.token",
|
| 112 |
+
"token_type": "Bearer",
|
| 113 |
+
"expires_in": 3600,
|
| 114 |
+
},
|
| 115 |
+
"required_response_fields": ["access_token", "token_type", "expires_in"],
|
| 116 |
+
"field_types": {"access_token": "string", "token_type": "string", "expires_in": "integer"},
|
| 117 |
+
"error_status": 401,
|
| 118 |
+
"error_body": {
|
| 119 |
+
"error": {"code": "INVALID_CREDENTIALS", "message": "Bad credentials", "details": []},
|
| 120 |
+
},
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "File Upload",
|
| 124 |
+
"success_status": 201,
|
| 125 |
+
"success_body": {
|
| 126 |
+
"file_id": "file_xyz789",
|
| 127 |
+
"filename": "report.pdf",
|
| 128 |
+
"size_bytes": 1048576,
|
| 129 |
+
"url": "https://cdn.example.com/files/file_xyz789",
|
| 130 |
+
},
|
| 131 |
+
"required_response_fields": ["file_id", "filename", "size_bytes", "url"],
|
| 132 |
+
"field_types": {"file_id": "string", "filename": "string", "size_bytes": "integer", "url": "string"},
|
| 133 |
+
"error_status": 413,
|
| 134 |
+
"error_body": {
|
| 135 |
+
"error": {"code": "PAYLOAD_TOO_LARGE", "message": "File exceeds 10MB limit", "details": []},
|
| 136 |
+
},
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"name": "Search Query",
|
| 140 |
+
"success_status": 200,
|
| 141 |
+
"success_body": {
|
| 142 |
+
"query": "test search",
|
| 143 |
+
"results": [{"id": "doc_1", "score": 0.95, "title": "Test Document"}],
|
| 144 |
+
"total_results": 1,
|
| 145 |
+
"search_time_ms": 42,
|
| 146 |
+
},
|
| 147 |
+
"required_response_fields": ["query", "results", "total_results", "search_time_ms"],
|
| 148 |
+
"field_types": {"query": "string", "results": "array", "total_results": "integer", "search_time_ms": "integer"},
|
| 149 |
+
"error_status": 400,
|
| 150 |
+
"error_body": {
|
| 151 |
+
"error": {"code": "INVALID_QUERY", "message": "Query syntax error", "details": []},
|
| 152 |
+
},
|
| 153 |
+
},
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_random_response_template(rng: random.Random) -> Dict[str, Any]:
|
| 158 |
+
"""Pick a random response template."""
|
| 159 |
+
return copy.deepcopy(rng.choice(RESPONSE_TEMPLATES))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def inject_response_issues(
|
| 163 |
+
template: Dict[str, Any],
|
| 164 |
+
rng: random.Random,
|
| 165 |
+
issue_count: int = 1,
|
| 166 |
+
) -> Tuple[Dict[str, Any], int, List[Dict[str, str]]]:
|
| 167 |
+
"""Inject issues into a response and return (broken_body, broken_status, ground_truths).
|
| 168 |
+
|
| 169 |
+
Each ground truth has: issue_type, description, affected_field (if applicable).
|
| 170 |
+
"""
|
| 171 |
+
# Decide if we break a success response or an error response
|
| 172 |
+
use_success = rng.random() < 0.6
|
| 173 |
+
if use_success:
|
| 174 |
+
body = copy.deepcopy(template["success_body"])
|
| 175 |
+
status = template["success_status"]
|
| 176 |
+
else:
|
| 177 |
+
body = copy.deepcopy(template["error_body"])
|
| 178 |
+
status = template["error_status"]
|
| 179 |
+
|
| 180 |
+
ground_truths: List[Dict[str, str]] = []
|
| 181 |
+
available_issues = list(RESPONSE_ISSUE_TYPES)
|
| 182 |
+
rng.shuffle(available_issues)
|
| 183 |
+
|
| 184 |
+
injected = 0
|
| 185 |
+
for issue_type in available_issues:
|
| 186 |
+
if injected >= issue_count:
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
if issue_type == "wrong_status_code":
|
| 190 |
+
wrong_codes = [200, 201, 204, 301, 400, 401, 403, 404, 422, 429, 500, 502, 503]
|
| 191 |
+
wrong_codes = [c for c in wrong_codes if c != status]
|
| 192 |
+
old_status = status
|
| 193 |
+
status = rng.choice(wrong_codes)
|
| 194 |
+
ground_truths.append({
|
| 195 |
+
"issue_type": "wrong_status_code",
|
| 196 |
+
"description": f"Expected status {old_status}, got {status}",
|
| 197 |
+
"affected_field": "status_code",
|
| 198 |
+
"correct_value": str(old_status),
|
| 199 |
+
})
|
| 200 |
+
injected += 1
|
| 201 |
+
|
| 202 |
+
elif issue_type == "missing_response_field" and template["required_response_fields"]:
|
| 203 |
+
fields = list(template["required_response_fields"])
|
| 204 |
+
rng.shuffle(fields)
|
| 205 |
+
field = fields[0]
|
| 206 |
+
if field in body:
|
| 207 |
+
del body[field]
|
| 208 |
+
ground_truths.append({
|
| 209 |
+
"issue_type": "missing_response_field",
|
| 210 |
+
"description": f"Required field '{field}' missing from response",
|
| 211 |
+
"affected_field": field,
|
| 212 |
+
})
|
| 213 |
+
injected += 1
|
| 214 |
+
|
| 215 |
+
elif issue_type == "wrong_response_type" and template["field_types"]:
|
| 216 |
+
typed_fields = [f for f in template["field_types"] if f in body]
|
| 217 |
+
if typed_fields:
|
| 218 |
+
field = rng.choice(typed_fields)
|
| 219 |
+
original_type = template["field_types"][field]
|
| 220 |
+
# Replace with wrong type
|
| 221 |
+
if original_type == "string":
|
| 222 |
+
body[field] = 12345
|
| 223 |
+
elif original_type == "integer":
|
| 224 |
+
body[field] = "not_a_number"
|
| 225 |
+
elif original_type == "array":
|
| 226 |
+
body[field] = "should_be_array"
|
| 227 |
+
else:
|
| 228 |
+
body[field] = [1, 2, 3]
|
| 229 |
+
ground_truths.append({
|
| 230 |
+
"issue_type": "wrong_response_type",
|
| 231 |
+
"description": f"Field '{field}' should be {original_type}",
|
| 232 |
+
"affected_field": field,
|
| 233 |
+
})
|
| 234 |
+
injected += 1
|
| 235 |
+
|
| 236 |
+
elif issue_type == "extra_response_field":
|
| 237 |
+
leak_fields = {
|
| 238 |
+
"internal_id": "int_9f8a2b",
|
| 239 |
+
"debug_trace": "stack trace at line 42",
|
| 240 |
+
"db_query": "SELECT * FROM users WHERE id=123",
|
| 241 |
+
"server_ip": "10.0.0.42",
|
| 242 |
+
"session_token": "sess_leaked_abc",
|
| 243 |
+
}
|
| 244 |
+
field, value = rng.choice(list(leak_fields.items()))
|
| 245 |
+
body[field] = value
|
| 246 |
+
ground_truths.append({
|
| 247 |
+
"issue_type": "extra_response_field",
|
| 248 |
+
"description": f"Unexpected field '{field}' in response (potential data leak)",
|
| 249 |
+
"affected_field": field,
|
| 250 |
+
})
|
| 251 |
+
injected += 1
|
| 252 |
+
|
| 253 |
+
elif issue_type == "inconsistent_error_format" and not use_success:
|
| 254 |
+
# Break the error format -- flatten it or use wrong keys
|
| 255 |
+
variants = [
|
| 256 |
+
{"msg": body.get("error", {}).get("message", "error"), "err_code": "UNKNOWN"},
|
| 257 |
+
{"error_message": "Something went wrong", "status": "error"},
|
| 258 |
+
{"errors": [{"msg": "bad request"}]},
|
| 259 |
+
]
|
| 260 |
+
body = rng.choice(variants)
|
| 261 |
+
ground_truths.append({
|
| 262 |
+
"issue_type": "inconsistent_error_format",
|
| 263 |
+
"description": "Error response doesn't follow standard format: {error: {code, message, details}}",
|
| 264 |
+
"affected_field": "error",
|
| 265 |
+
})
|
| 266 |
+
injected += 1
|
| 267 |
+
|
| 268 |
+
return body, status, ground_truths
|
tests/test_environment.py
CHANGED
|
@@ -182,6 +182,88 @@ class TestEasyGrader:
|
|
| 182 |
assert "CORRECT" in obs.feedback
|
| 183 |
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
# ---------------------------------------------------------------------------
|
| 186 |
# TestMediumGrader
|
| 187 |
# ---------------------------------------------------------------------------
|
|
@@ -271,6 +353,164 @@ class TestMediumGrader:
|
|
| 271 |
assert 0.0 <= obs.reward <= 1.0
|
| 272 |
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
# ---------------------------------------------------------------------------
|
| 275 |
# TestHardGrader
|
| 276 |
# ---------------------------------------------------------------------------
|
|
|
|
| 182 |
assert "CORRECT" in obs.feedback
|
| 183 |
|
| 184 |
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
# TestClassifyGrader
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
class TestClassifyGrader:
|
| 190 |
+
"""Tests for the classify task: multi-error classification."""
|
| 191 |
+
|
| 192 |
+
def test_all_types_all_fields_scores_high(self):
|
| 193 |
+
env = make_env("classify", seed=42)
|
| 194 |
+
gt_types = [gt["error_type"] for gt in env.ground_truths]
|
| 195 |
+
gt_fields = []
|
| 196 |
+
for gt in env.ground_truths:
|
| 197 |
+
gt_fields.extend(gt.get("affected_fields", []))
|
| 198 |
+
obs = env.step(APIDebugAction(
|
| 199 |
+
error_types=gt_types,
|
| 200 |
+
affected_fields=gt_fields,
|
| 201 |
+
))
|
| 202 |
+
assert obs.reward >= 0.9
|
| 203 |
+
|
| 204 |
+
def test_partial_types_gives_partial_score(self):
|
| 205 |
+
env = make_env("classify", seed=42)
|
| 206 |
+
first_type = env.ground_truths[0]["error_type"]
|
| 207 |
+
obs = env.step(APIDebugAction(error_types=[first_type]))
|
| 208 |
+
# Got 1 of 2-3 types correct, no fields -> partial score
|
| 209 |
+
assert 0.1 < obs.reward < 0.7
|
| 210 |
+
|
| 211 |
+
def test_empty_action_scores_near_0(self):
|
| 212 |
+
env = make_env("classify", seed=42)
|
| 213 |
+
obs = env.step(APIDebugAction())
|
| 214 |
+
assert obs.reward <= 0.01
|
| 215 |
+
|
| 216 |
+
def test_single_error_type_field_accepted(self):
|
| 217 |
+
"""Accepts error_type (singular) as fallback for classify."""
|
| 218 |
+
env = make_env("classify", seed=42)
|
| 219 |
+
first_type = env.ground_truths[0]["error_type"]
|
| 220 |
+
first_fields = env.ground_truths[0].get("affected_fields", [])
|
| 221 |
+
obs = env.step(APIDebugAction(
|
| 222 |
+
error_type=first_type,
|
| 223 |
+
affected_fields=first_fields,
|
| 224 |
+
))
|
| 225 |
+
assert obs.reward > 0.1
|
| 226 |
+
|
| 227 |
+
def test_wrong_types_score_low(self):
|
| 228 |
+
env = make_env("classify", seed=42)
|
| 229 |
+
obs = env.step(APIDebugAction(
|
| 230 |
+
error_types=["nonexistent_error"],
|
| 231 |
+
affected_fields=["fake_field"],
|
| 232 |
+
))
|
| 233 |
+
assert obs.reward <= 0.01
|
| 234 |
+
|
| 235 |
+
def test_extra_types_reduces_jaccard(self):
|
| 236 |
+
env = make_env("classify", seed=42)
|
| 237 |
+
gt_types = [gt["error_type"] for gt in env.ground_truths]
|
| 238 |
+
obs = env.step(APIDebugAction(
|
| 239 |
+
error_types=gt_types + ["extra_fake_error"],
|
| 240 |
+
))
|
| 241 |
+
# Extra type reduces Jaccard, but correct ones still score
|
| 242 |
+
assert obs.reward > 0.1
|
| 243 |
+
|
| 244 |
+
def test_max_steps_is_4(self):
|
| 245 |
+
env = make_env("classify", seed=42)
|
| 246 |
+
assert env.max_steps == 4
|
| 247 |
+
|
| 248 |
+
def test_multiple_ground_truths(self):
|
| 249 |
+
env = make_env("classify", seed=42)
|
| 250 |
+
assert len(env.ground_truths) >= 2
|
| 251 |
+
|
| 252 |
+
def test_feedback_contains_type_info(self):
|
| 253 |
+
env = make_env("classify", seed=42)
|
| 254 |
+
gt_types = [gt["error_type"] for gt in env.ground_truths]
|
| 255 |
+
obs = env.step(APIDebugAction(error_types=gt_types))
|
| 256 |
+
assert "error_types" in obs.feedback
|
| 257 |
+
|
| 258 |
+
def test_reward_in_valid_range(self):
|
| 259 |
+
env = make_env("classify", seed=42)
|
| 260 |
+
obs = env.step(APIDebugAction(
|
| 261 |
+
error_types=["missing_required_field"],
|
| 262 |
+
affected_fields=["email"],
|
| 263 |
+
))
|
| 264 |
+
assert 0.001 <= obs.reward <= 0.999
|
| 265 |
+
|
| 266 |
+
|
| 267 |
# ---------------------------------------------------------------------------
|
| 268 |
# TestMediumGrader
|
| 269 |
# ---------------------------------------------------------------------------
|
|
|
|
| 353 |
assert 0.0 <= obs.reward <= 1.0
|
| 354 |
|
| 355 |
|
| 356 |
+
# ---------------------------------------------------------------------------
|
| 357 |
+
# TestHeadersGrader
|
| 358 |
+
# ---------------------------------------------------------------------------
|
| 359 |
+
|
| 360 |
+
class TestHeadersGrader:
|
| 361 |
+
"""Tests for the headers task: header-focused debugging."""
|
| 362 |
+
|
| 363 |
+
def test_correct_headers_and_type_scores_high(self):
|
| 364 |
+
env = make_env("headers", seed=42)
|
| 365 |
+
gt = env.ground_truths[0]
|
| 366 |
+
obs = env.step(APIDebugAction(
|
| 367 |
+
error_type=gt["error_type"],
|
| 368 |
+
fixed_headers=gt["valid_headers"],
|
| 369 |
+
))
|
| 370 |
+
assert obs.reward >= 0.9
|
| 371 |
+
|
| 372 |
+
def test_correct_headers_no_type_scores_partial(self):
|
| 373 |
+
env = make_env("headers", seed=42)
|
| 374 |
+
gt = env.ground_truths[0]
|
| 375 |
+
obs = env.step(APIDebugAction(
|
| 376 |
+
fixed_headers=gt["valid_headers"],
|
| 377 |
+
))
|
| 378 |
+
# 0.7 for headers, 0 for type = ~0.7
|
| 379 |
+
assert 0.5 < obs.reward < 0.9
|
| 380 |
+
|
| 381 |
+
def test_correct_type_no_headers_scores_low(self):
|
| 382 |
+
env = make_env("headers", seed=42)
|
| 383 |
+
gt = env.ground_truths[0]
|
| 384 |
+
obs = env.step(APIDebugAction(
|
| 385 |
+
error_type=gt["error_type"],
|
| 386 |
+
))
|
| 387 |
+
# 0.3 for type, 0 for headers
|
| 388 |
+
assert obs.reward < 0.5
|
| 389 |
+
|
| 390 |
+
def test_empty_action_scores_near_0(self):
|
| 391 |
+
env = make_env("headers", seed=42)
|
| 392 |
+
obs = env.step(APIDebugAction())
|
| 393 |
+
assert obs.reward <= 0.01
|
| 394 |
+
|
| 395 |
+
def test_max_steps_is_4(self):
|
| 396 |
+
env = make_env("headers", seed=42)
|
| 397 |
+
assert env.max_steps == 4
|
| 398 |
+
|
| 399 |
+
def test_header_error_type_is_header_related(self):
|
| 400 |
+
env = make_env("headers", seed=42)
|
| 401 |
+
header_types = {"missing_auth_header", "wrong_content_type", "expired_auth_token"}
|
| 402 |
+
for gt in env.ground_truths:
|
| 403 |
+
assert gt["error_type"] in header_types
|
| 404 |
+
|
| 405 |
+
def test_feedback_contains_header_info(self):
|
| 406 |
+
env = make_env("headers", seed=42)
|
| 407 |
+
gt = env.ground_truths[0]
|
| 408 |
+
obs = env.step(APIDebugAction(fixed_headers=gt["valid_headers"]))
|
| 409 |
+
assert "Header" in obs.feedback or "error_type" in obs.feedback
|
| 410 |
+
|
| 411 |
+
def test_wrong_headers_score_low(self):
|
| 412 |
+
env = make_env("headers", seed=42)
|
| 413 |
+
obs = env.step(APIDebugAction(
|
| 414 |
+
fixed_headers={"X-Wrong": "value"},
|
| 415 |
+
))
|
| 416 |
+
assert obs.reward < 0.5
|
| 417 |
+
|
| 418 |
+
def test_reward_in_valid_range(self):
|
| 419 |
+
env = make_env("headers", seed=42)
|
| 420 |
+
obs = env.step(APIDebugAction(
|
| 421 |
+
fixed_headers={"Authorization": "Bearer test"},
|
| 422 |
+
))
|
| 423 |
+
assert 0.001 <= obs.reward <= 0.999
|
| 424 |
+
|
| 425 |
+
def test_single_ground_truth(self):
|
| 426 |
+
env = make_env("headers", seed=42)
|
| 427 |
+
assert len(env.ground_truths) == 1
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ---------------------------------------------------------------------------
|
| 431 |
+
# TestResponseGrader
|
| 432 |
+
# ---------------------------------------------------------------------------
|
| 433 |
+
|
| 434 |
+
class TestResponseGrader:
|
| 435 |
+
"""Tests for the response validation task."""
|
| 436 |
+
|
| 437 |
+
def test_correct_issues_scores_high(self):
|
| 438 |
+
env = make_env("response", seed=42)
|
| 439 |
+
gt_issues = [gt["issue_type"] for gt in env.ground_truths]
|
| 440 |
+
gt_fields = [gt.get("affected_field", "") for gt in env.ground_truths if gt.get("affected_field")]
|
| 441 |
+
# Find correct status if applicable
|
| 442 |
+
expected_status = None
|
| 443 |
+
for gt in env.ground_truths:
|
| 444 |
+
if gt["issue_type"] == "wrong_status_code":
|
| 445 |
+
expected_status = int(gt.get("correct_value", 0))
|
| 446 |
+
obs = env.step(APIDebugAction(
|
| 447 |
+
response_issues=gt_issues,
|
| 448 |
+
affected_fields=gt_fields,
|
| 449 |
+
expected_status_code=expected_status,
|
| 450 |
+
))
|
| 451 |
+
assert obs.reward >= 0.7
|
| 452 |
+
|
| 453 |
+
def test_partial_issues_scores_partial(self):
|
| 454 |
+
env = make_env("response", seed=10)
|
| 455 |
+
# Only provide one issue type even if there are more
|
| 456 |
+
gt = env.ground_truths[0]
|
| 457 |
+
obs = env.step(APIDebugAction(
|
| 458 |
+
response_issues=[gt["issue_type"]],
|
| 459 |
+
))
|
| 460 |
+
assert 0.001 <= obs.reward <= 0.999
|
| 461 |
+
|
| 462 |
+
def test_empty_action_scores_near_0(self):
|
| 463 |
+
env = make_env("response", seed=42)
|
| 464 |
+
obs = env.step(APIDebugAction())
|
| 465 |
+
assert obs.reward <= 0.01
|
| 466 |
+
|
| 467 |
+
def test_max_steps_is_4(self):
|
| 468 |
+
env = make_env("response", seed=42)
|
| 469 |
+
assert env.max_steps == 4
|
| 470 |
+
|
| 471 |
+
def test_observation_has_response_body(self):
|
| 472 |
+
env = APIDebugEnvironment()
|
| 473 |
+
obs = env.reset(task="response", seed=42)
|
| 474 |
+
assert obs.response_body != ""
|
| 475 |
+
assert obs.response_status_code > 0
|
| 476 |
+
|
| 477 |
+
def test_observation_has_correct_task(self):
|
| 478 |
+
env = APIDebugEnvironment()
|
| 479 |
+
obs = env.reset(task="response", seed=42)
|
| 480 |
+
assert obs.task == "response"
|
| 481 |
+
|
| 482 |
+
def test_ground_truth_has_issue_type(self):
|
| 483 |
+
env = make_env("response", seed=42)
|
| 484 |
+
valid_types = {
|
| 485 |
+
"wrong_status_code", "missing_response_field",
|
| 486 |
+
"wrong_response_type", "extra_response_field",
|
| 487 |
+
"inconsistent_error_format",
|
| 488 |
+
}
|
| 489 |
+
for gt in env.ground_truths:
|
| 490 |
+
assert gt["issue_type"] in valid_types
|
| 491 |
+
|
| 492 |
+
def test_wrong_issues_score_low(self):
|
| 493 |
+
env = make_env("response", seed=42)
|
| 494 |
+
obs = env.step(APIDebugAction(
|
| 495 |
+
response_issues=["nonexistent_issue"],
|
| 496 |
+
))
|
| 497 |
+
assert obs.reward < 0.3
|
| 498 |
+
|
| 499 |
+
def test_reward_in_valid_range(self):
|
| 500 |
+
env = make_env("response", seed=42)
|
| 501 |
+
obs = env.step(APIDebugAction(
|
| 502 |
+
response_issues=["wrong_status_code"],
|
| 503 |
+
expected_status_code=200,
|
| 504 |
+
))
|
| 505 |
+
assert 0.001 <= obs.reward <= 0.999
|
| 506 |
+
|
| 507 |
+
def test_response_body_is_valid_json(self):
|
| 508 |
+
env = APIDebugEnvironment()
|
| 509 |
+
obs = env.reset(task="response", seed=42)
|
| 510 |
+
parsed = json.loads(obs.response_body)
|
| 511 |
+
assert isinstance(parsed, dict)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
# ---------------------------------------------------------------------------
|
| 515 |
# TestHardGrader
|
| 516 |
# ---------------------------------------------------------------------------
|
training/README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training with GRPO on API Debug Environment
|
| 2 |
+
|
| 3 |
+
Trains a small LLM using **GRPO** (Group Relative Policy Optimization)
|
| 4 |
+
on the live API Debug Environment with **curriculum learning**.
|
| 5 |
+
|
| 6 |
+
## What is GRPO?
|
| 7 |
+
|
| 8 |
+
For each prompt, GRPO:
|
| 9 |
+
1. Generates multiple completions (debug attempts)
|
| 10 |
+
2. Scores each with the environment's grader (reward signal)
|
| 11 |
+
3. Updates the model to prefer higher-scoring responses
|
| 12 |
+
|
| 13 |
+
Over thousands of episodes, the LLM learns to debug API requests
|
| 14 |
+
purely from reward signals -- no labelled data needed.
|
| 15 |
+
|
| 16 |
+
## Curriculum Learning
|
| 17 |
+
|
| 18 |
+
The training auto-promotes through difficulty levels:
|
| 19 |
+
|
| 20 |
+
| Level | Task | Threshold | Max Turns | Skill |
|
| 21 |
+
|-------|------|-----------|-----------|-------|
|
| 22 |
+
| 1 | easy | 0.7 avg reward | 3 | Identify single error type + fields |
|
| 23 |
+
| 2 | classify | 0.6 avg reward | 4 | Identify ALL error types + fields |
|
| 24 |
+
| 3 | medium | 0.6 avg reward | 5 | Fix the broken request body |
|
| 25 |
+
| 4 | headers | 0.5 avg reward | 4 | Fix header-level errors |
|
| 26 |
+
| 5 | response | 0.5 avg reward | 4 | Validate API response issues |
|
| 27 |
+
| 6 | hard | -- | 7 | Fix mixed errors + explain reasoning |
|
| 28 |
+
|
| 29 |
+
Promotion happens when the rolling average reward (window=10) exceeds
|
| 30 |
+
the threshold for the current level.
|
| 31 |
+
|
| 32 |
+
## Architecture
|
| 33 |
+
```
|
| 34 |
+
Dataset prompt ("Debug this broken API request.")
|
| 35 |
+
|
|
| 36 |
+
GRPOTrainer calls rollout_func()
|
| 37 |
+
|
|
| 38 |
+
rollout_func() connects to live HF Space via WebSocket
|
| 39 |
+
|
|
| 40 |
+
env.reset(task=current_task) -> broken API request
|
| 41 |
+
|
|
| 42 |
+
LLM generates JSON response -> env.step(action) -> reward
|
| 43 |
+
| (repeat up to max_turns)
|
| 44 |
+
Returns: prompt_ids, completion_ids, logprobs, env_reward
|
| 45 |
+
|
|
| 46 |
+
reward_from_env() extracts env_reward
|
| 47 |
+
|
|
| 48 |
+
GRPO updates model weights
|
| 49 |
+
|
|
| 50 |
+
maybe_promote() checks if agent should advance to next task
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Run on Google Colab (free T4 GPU)
|
| 54 |
+
```python
|
| 55 |
+
# Cell 1 -- Install
|
| 56 |
+
!pip install trl>=0.26.0 transformers torch datasets openenv-core openai
|
| 57 |
+
|
| 58 |
+
# Cell 2 -- Clone repo
|
| 59 |
+
!git clone https://github.com/Avi-chauhan/api-debug-env.git
|
| 60 |
+
%cd api-debug-env
|
| 61 |
+
|
| 62 |
+
# Cell 3 -- Train
|
| 63 |
+
!python training/train.py
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Requirements
|
| 67 |
+
|
| 68 |
+
- GPU: T4 or better (free Colab works)
|
| 69 |
+
- RAM: 8GB+
|
| 70 |
+
- The live HF Space must be running:
|
| 71 |
+
https://huggingface.co/spaces/avichauhan/api-debug-env
|
training/__init__.py
ADDED
|
File without changes
|
training/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trl>=0.26.0
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
datasets
|
| 5 |
+
openenv-core
|
| 6 |
+
openai
|
training/train.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
GRPO Training on API Debug Environment
|
| 6 |
+
=======================================
|
| 7 |
+
Trains a small LLM (Qwen 0.5B) to debug malformed API requests using
|
| 8 |
+
reward signals from the live HuggingFace Space environment.
|
| 9 |
+
|
| 10 |
+
Supports curriculum learning: starts on easy task, promotes to classify
|
| 11 |
+
and medium as the agent improves.
|
| 12 |
+
|
| 13 |
+
Run on Colab (free T4 GPU):
|
| 14 |
+
pip install -r training/requirements.txt
|
| 15 |
+
python training/train.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset
|
| 27 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 28 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 29 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 30 |
+
|
| 31 |
+
from client import APIDebugEnv
|
| 32 |
+
from models import APIDebugAction
|
| 33 |
+
|
| 34 |
+
# -- GPU check ----------------------------------------------------------------
|
| 35 |
+
print(f"GPU available : {torch.cuda.is_available()}")
|
| 36 |
+
print(f"GPU name : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None (CPU)'}")
|
| 37 |
+
|
| 38 |
+
has_gpu = torch.cuda.is_available()
|
| 39 |
+
supports_bf16 = has_gpu and torch.cuda.is_bf16_supported()
|
| 40 |
+
|
| 41 |
+
# -- Config -------------------------------------------------------------------
|
| 42 |
+
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 43 |
+
ENV_URL = "https://avichauhan-api-debug-env.hf.space"
|
| 44 |
+
MAX_TURNS = 3 # easy task: 3 steps max
|
| 45 |
+
NUM_SAMPLES = 64
|
| 46 |
+
|
| 47 |
+
# -- Curriculum state ---------------------------------------------------------
|
| 48 |
+
# Tracks which task the agent is currently training on.
|
| 49 |
+
# Promotes when rolling average reward exceeds threshold.
|
| 50 |
+
CURRICULUM = {
|
| 51 |
+
"easy": {"next": "classify", "threshold": 0.7, "max_turns": 3},
|
| 52 |
+
"classify": {"next": "medium", "threshold": 0.6, "max_turns": 4},
|
| 53 |
+
"medium": {"next": "headers", "threshold": 0.6, "max_turns": 5},
|
| 54 |
+
"headers": {"next": "response", "threshold": 0.5, "max_turns": 4},
|
| 55 |
+
"response": {"next": "hard", "threshold": 0.5, "max_turns": 4},
|
| 56 |
+
"hard": {"next": None, "threshold": None, "max_turns": 7},
|
| 57 |
+
}
|
| 58 |
+
current_task = "easy"
|
| 59 |
+
recent_rewards: list[float] = []
|
| 60 |
+
WINDOW_SIZE = 10
|
| 61 |
+
|
| 62 |
+
SYSTEM_PROMPT = """You are an API debugging expert. You receive a broken API request and its specification.
|
| 63 |
+
Your job: identify the error type and the affected fields.
|
| 64 |
+
|
| 65 |
+
Respond with ONLY a JSON object in this format:
|
| 66 |
+
{"error_type": "<type>", "affected_fields": ["field1", "field2"]}
|
| 67 |
+
|
| 68 |
+
Valid error types:
|
| 69 |
+
missing_required_field, wrong_field_type, invalid_email_format,
|
| 70 |
+
missing_auth_header, extra_unknown_field, null_value_in_required,
|
| 71 |
+
wrong_http_method, malformed_json_value, invalid_enum_value,
|
| 72 |
+
datetime_format_error, wrong_content_type, expired_auth_token"""
|
| 73 |
+
|
| 74 |
+
CLASSIFY_PROMPT = """You are an API debugging expert. This request has MULTIPLE errors.
|
| 75 |
+
Identify ALL error types and ALL affected fields.
|
| 76 |
+
|
| 77 |
+
Respond with ONLY a JSON object:
|
| 78 |
+
{"error_types": ["type1", "type2"], "affected_fields": ["field1", "field2"]}"""
|
| 79 |
+
|
| 80 |
+
MEDIUM_PROMPT = """You are an API debugging expert. Fix the broken request to match the API spec.
|
| 81 |
+
|
| 82 |
+
Respond with ONLY a JSON object:
|
| 83 |
+
{"fixed_request": {"field": "value"}, "fixed_headers": {"Header": "value"}}"""
|
| 84 |
+
|
| 85 |
+
HEADERS_PROMPT = """You are an API debugging expert. This request has ONLY header-level errors.
|
| 86 |
+
Identify the error type and fix the headers to match the API spec.
|
| 87 |
+
|
| 88 |
+
Respond with ONLY a JSON object:
|
| 89 |
+
{"error_type": "<type>", "fixed_headers": {"Header-Name": "correct_value"}}
|
| 90 |
+
|
| 91 |
+
Common header error types: wrong_content_type, expired_auth_token, missing_auth_header"""
|
| 92 |
+
|
| 93 |
+
RESPONSE_PROMPT = """You are an API response validation expert. You receive an API request, its spec, and the server response.
|
| 94 |
+
Identify issues in the response: wrong status codes, missing fields, wrong types, extra fields, inconsistent error format.
|
| 95 |
+
|
| 96 |
+
Respond with ONLY a JSON object:
|
| 97 |
+
{"response_issues": ["issue_type1"], "affected_fields": ["field1"], "expected_status_code": 200}
|
| 98 |
+
|
| 99 |
+
Valid issue types: wrong_status_code, missing_response_field, wrong_response_type, extra_response_field, inconsistent_error_format"""
|
| 100 |
+
|
| 101 |
+
HARD_PROMPT = """You are an API debugging expert. This request has MULTIPLE errors across headers and body.
|
| 102 |
+
Some errors are chained -- fixing one may reveal others. Fix everything and explain your reasoning.
|
| 103 |
+
|
| 104 |
+
Respond with ONLY a JSON object:
|
| 105 |
+
{"fixed_request": {"field": "value"}, "fixed_headers": {"Header": "value"}, "explanation": "why each fix was needed"}"""
|
| 106 |
+
|
| 107 |
+
TASK_PROMPTS = {
|
| 108 |
+
"easy": SYSTEM_PROMPT,
|
| 109 |
+
"classify": CLASSIFY_PROMPT,
|
| 110 |
+
"medium": MEDIUM_PROMPT,
|
| 111 |
+
"headers": HEADERS_PROMPT,
|
| 112 |
+
"response": RESPONSE_PROMPT,
|
| 113 |
+
"hard": HARD_PROMPT,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# -- Environment client -------------------------------------------------------
|
| 117 |
+
print(f"Connecting to environment: {ENV_URL}")
|
| 118 |
+
env_client = APIDebugEnv(base_url=ENV_URL)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# -- JSON parser (reused from inference.py) -----------------------------------
|
| 122 |
+
def parse_llm_response(text: str) -> dict:
|
| 123 |
+
if not text:
|
| 124 |
+
return {}
|
| 125 |
+
try:
|
| 126 |
+
return json.loads(text)
|
| 127 |
+
except json.JSONDecodeError:
|
| 128 |
+
pass
|
| 129 |
+
code_block = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
|
| 130 |
+
if code_block:
|
| 131 |
+
try:
|
| 132 |
+
return json.loads(code_block.group(1))
|
| 133 |
+
except json.JSONDecodeError:
|
| 134 |
+
pass
|
| 135 |
+
brace_match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
| 136 |
+
if brace_match:
|
| 137 |
+
try:
|
| 138 |
+
return json.loads(brace_match.group(0))
|
| 139 |
+
except json.JSONDecodeError:
|
| 140 |
+
pass
|
| 141 |
+
return {}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def build_action(data: dict) -> APIDebugAction:
|
| 145 |
+
fixed_req = data.get("fixed_request")
|
| 146 |
+
if isinstance(fixed_req, dict):
|
| 147 |
+
fixed_req = json.dumps(fixed_req)
|
| 148 |
+
return APIDebugAction(
|
| 149 |
+
error_type=data.get("error_type"),
|
| 150 |
+
error_types=data.get("error_types"),
|
| 151 |
+
affected_fields=data.get("affected_fields"),
|
| 152 |
+
fixed_request=fixed_req,
|
| 153 |
+
fixed_headers=data.get("fixed_headers"),
|
| 154 |
+
response_issues=data.get("response_issues"),
|
| 155 |
+
expected_status_code=data.get("expected_status_code"),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# -- Curriculum learning ------------------------------------------------------
|
| 160 |
+
def maybe_promote():
|
| 161 |
+
"""Check if agent should be promoted to next difficulty."""
|
| 162 |
+
global current_task
|
| 163 |
+
config = CURRICULUM[current_task]
|
| 164 |
+
if config["next"] is None or config["threshold"] is None:
|
| 165 |
+
return
|
| 166 |
+
if len(recent_rewards) < WINDOW_SIZE:
|
| 167 |
+
return
|
| 168 |
+
avg = sum(recent_rewards[-WINDOW_SIZE:]) / WINDOW_SIZE
|
| 169 |
+
if avg >= config["threshold"]:
|
| 170 |
+
old_task = current_task
|
| 171 |
+
current_task = config["next"]
|
| 172 |
+
recent_rewards.clear()
|
| 173 |
+
print(f"[CURRICULUM] Promoted: {old_task} -> {current_task} (avg_reward={avg:.3f})")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# -- Rollout function ---------------------------------------------------------
|
| 177 |
+
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict:
|
| 178 |
+
tokenizer = trainer.processing_class
|
| 179 |
+
all_prompt_ids = []
|
| 180 |
+
all_completion_ids = []
|
| 181 |
+
all_logprobs = []
|
| 182 |
+
all_rewards = []
|
| 183 |
+
|
| 184 |
+
task = current_task
|
| 185 |
+
max_turns = CURRICULUM[task]["max_turns"]
|
| 186 |
+
system_prompt = TASK_PROMPTS[task]
|
| 187 |
+
|
| 188 |
+
for base_prompt in prompts:
|
| 189 |
+
with env_client.sync() as env:
|
| 190 |
+
obs = env.reset(task=task)
|
| 191 |
+
episode_reward = 0.0
|
| 192 |
+
episode_prompt_ids = []
|
| 193 |
+
episode_comp_ids = []
|
| 194 |
+
episode_logprobs = []
|
| 195 |
+
|
| 196 |
+
for turn in range(max_turns):
|
| 197 |
+
messages = [
|
| 198 |
+
{"role": "system", "content": system_prompt},
|
| 199 |
+
{"role": "user", "content": (
|
| 200 |
+
f"{base_prompt}\n\n"
|
| 201 |
+
f"API: {obs.observation.http_method} {obs.observation.endpoint} "
|
| 202 |
+
f"({obs.observation.api_name})\n"
|
| 203 |
+
f"Error count: {obs.observation.error_count}\n"
|
| 204 |
+
f"Step {turn + 1}/{max_turns}\n\n"
|
| 205 |
+
f"Broken request:\n{obs.observation.broken_request}\n\n"
|
| 206 |
+
f"Headers: {json.dumps(obs.observation.broken_headers)}\n\n"
|
| 207 |
+
f"API Spec:\n{obs.observation.api_spec}\n"
|
| 208 |
+
+ (f"\nFeedback:\n{obs.observation.feedback}" if obs.observation.feedback else "")
|
| 209 |
+
)},
|
| 210 |
+
]
|
| 211 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 212 |
+
messages,
|
| 213 |
+
add_generation_prompt=True,
|
| 214 |
+
tokenize=False,
|
| 215 |
+
enable_thinking=False,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
outputs = generate_rollout_completions(trainer, [prompt_text])[0]
|
| 219 |
+
completion_text = tokenizer.decode(
|
| 220 |
+
outputs["completion_ids"], skip_special_tokens=True
|
| 221 |
+
).strip()
|
| 222 |
+
|
| 223 |
+
episode_prompt_ids.extend(outputs["prompt_ids"])
|
| 224 |
+
episode_comp_ids.extend(outputs["completion_ids"])
|
| 225 |
+
episode_logprobs.extend(outputs["logprobs"])
|
| 226 |
+
|
| 227 |
+
# Parse LLM output into action
|
| 228 |
+
parsed = parse_llm_response(completion_text)
|
| 229 |
+
action = build_action(parsed)
|
| 230 |
+
obs = env.step(action)
|
| 231 |
+
episode_reward = float(obs.reward or 0.0)
|
| 232 |
+
|
| 233 |
+
if obs.done:
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
all_prompt_ids.append(episode_prompt_ids)
|
| 237 |
+
all_completion_ids.append(episode_comp_ids)
|
| 238 |
+
all_logprobs.append(episode_logprobs)
|
| 239 |
+
all_rewards.append(episode_reward)
|
| 240 |
+
|
| 241 |
+
# Track for curriculum
|
| 242 |
+
recent_rewards.append(episode_reward)
|
| 243 |
+
|
| 244 |
+
# Check if agent should be promoted
|
| 245 |
+
maybe_promote()
|
| 246 |
+
|
| 247 |
+
return {
|
| 248 |
+
"prompt_ids": all_prompt_ids,
|
| 249 |
+
"completion_ids": all_completion_ids,
|
| 250 |
+
"logprobs": all_logprobs,
|
| 251 |
+
"env_reward": all_rewards,
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# -- Reward function ----------------------------------------------------------
|
| 256 |
+
def reward_from_env(completions, **kwargs):
|
| 257 |
+
env_rewards = kwargs.get("env_reward", [])
|
| 258 |
+
return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# -- Dataset ------------------------------------------------------------------
|
| 262 |
+
dataset = Dataset.from_dict({
|
| 263 |
+
"prompt": ["Debug this broken API request."] * NUM_SAMPLES
|
| 264 |
+
})
|
| 265 |
+
|
| 266 |
+
# -- Trainer ------------------------------------------------------------------
|
| 267 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 268 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, attn_implementation="eager")
|
| 269 |
+
|
| 270 |
+
grpo_args = GRPOConfig(
|
| 271 |
+
use_vllm=True,
|
| 272 |
+
vllm_mode="colocate",
|
| 273 |
+
num_train_epochs=1,
|
| 274 |
+
num_generations=2,
|
| 275 |
+
max_completion_length=128,
|
| 276 |
+
per_device_train_batch_size=1,
|
| 277 |
+
gradient_accumulation_steps=16,
|
| 278 |
+
learning_rate=5e-6,
|
| 279 |
+
output_dir="./outputs/api-debug-grpo",
|
| 280 |
+
logging_steps=1,
|
| 281 |
+
report_to="none",
|
| 282 |
+
bf16=supports_bf16,
|
| 283 |
+
fp16=has_gpu and not supports_bf16,
|
| 284 |
+
no_cuda=not has_gpu,
|
| 285 |
+
gradient_checkpointing=True,
|
| 286 |
+
vllm_gpu_memory_utilization=0.3,
|
| 287 |
+
dataloader_pin_memory=False,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
trainer = GRPOTrainer(
|
| 291 |
+
model=model,
|
| 292 |
+
processing_class=tokenizer,
|
| 293 |
+
reward_funcs=reward_from_env,
|
| 294 |
+
train_dataset=dataset,
|
| 295 |
+
rollout_func=rollout_func,
|
| 296 |
+
args=grpo_args,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
print("Starting GRPO training on API Debug Environment...")
|
| 301 |
+
print(f"Model : {MODEL_ID}")
|
| 302 |
+
print(f"Environment: {ENV_URL}")
|
| 303 |
+
print(f"Episodes : {NUM_SAMPLES}")
|
| 304 |
+
print(f"Task : {current_task} (with curriculum learning)")
|
| 305 |
+
print(f"bf16 : {supports_bf16}")
|
| 306 |
+
print(f"fp16 : {has_gpu and not supports_bf16}")
|
| 307 |
+
trainer.train()
|
| 308 |
+
print(f"Training complete! Final task: {current_task}")
|
| 309 |
+
print("Model saved to ./outputs/api-debug-grpo")
|