Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +36 -0
- README.md +345 -4
- __init__.py +4 -0
- client.py +5 -0
- dataqa_env/__init__.py +19 -0
- dataqa_env/client.py +37 -0
- dataqa_env/models.py +77 -0
- dataqa_env/server/Dockerfile +33 -0
- dataqa_env/server/__init__.py +0 -0
- dataqa_env/server/app.py +39 -0
- dataqa_env/server/environment.py +623 -0
- dataqa_env/server/gradio_ui.py +568 -0
- dataqa_env/server/tasks.py +1159 -0
- inference.py +376 -0
- models.py +4 -0
- openenv.yaml +6 -0
- openenv_dataqa_env.egg-info/PKG-INFO +13 -0
- openenv_dataqa_env.egg-info/SOURCES.txt +15 -0
- openenv_dataqa_env.egg-info/dependency_links.txt +1 -0
- openenv_dataqa_env.egg-info/entry_points.txt +2 -0
- openenv_dataqa_env.egg-info/requires.txt +9 -0
- openenv_dataqa_env.egg-info/top_level.txt +1 -0
- pyproject.toml +32 -0
- scripts/prevalidation_script.sh +185 -0
- scripts/sample_inference_script.py +188 -0
- server/__init__.py +1 -0
- server/app.py +13 -0
- tests/__init__.py +0 -0
- tests/test_environment.py +455 -0
- tests/test_extensibility.py +215 -0
- tests/test_inference.py +191 -0
- tests/test_tasks.py +212 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system deps
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
git curl \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Install uv for fast dependency management
|
| 11 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 12 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 13 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx
|
| 14 |
+
|
| 15 |
+
# Copy project files
|
| 16 |
+
COPY pyproject.toml /app/
|
| 17 |
+
COPY openenv.yaml /app/
|
| 18 |
+
COPY dataqa_env/ /app/dataqa_env/
|
| 19 |
+
COPY inference.py /app/
|
| 20 |
+
COPY README.md /app/
|
| 21 |
+
|
| 22 |
+
# Install dependencies
|
| 23 |
+
RUN uv sync --no-editable 2>/dev/null || pip install -e .
|
| 24 |
+
|
| 25 |
+
# Set environment
|
| 26 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 27 |
+
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 28 |
+
|
| 29 |
+
# Health check — HF Spaces uses port 8000
|
| 30 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 31 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 32 |
+
|
| 33 |
+
EXPOSE 8000
|
| 34 |
+
|
| 35 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 36 |
+
CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,351 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DataQA Environment Server
|
| 3 |
+
emoji: "\U0001F50D"
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
base_path: /web
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# DataQA Environment
|
| 15 |
+
|
| 16 |
+
A two-phase OpenEnv RL environment for **Data Quality Assurance** — an LLM agent inspects corrupted datasets, identifies all planted quality issues, and proposes data repairs.
|
| 17 |
+
|
| 18 |
+
### Demo: Agent Trajectory Replay
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
EASY TASK (Step 2) — All 6 issues found + 5 fixes proposed
|
| 22 |
+
Reward: 0.87 | Identify: 1.00 | Fix: 0.67
|
| 23 |
+
✓ row:4 name: empty → "David Kim"
|
| 24 |
+
✓ row:7 salary: "seventy-five thousand" → "75000"
|
| 25 |
+
✓ row:9 salary: "5000" → "73000"
|
| 26 |
+
✓ row:15 email: mismatch → "oscar.rivera@company.com"
|
| 27 |
+
✓ row:18 start_date: "2027-06-15" → "2022-01-19"
|
| 28 |
+
✓ row:21 duplicate row detected
|
| 29 |
+
|
| 30 |
+
HARD TASK — ML experiment metadata
|
| 31 |
+
Step 1: Found 5/10, missed hard issues → Reward: 0.69
|
| 32 |
+
Step 2: Found 10/10 + 5 fixes proposed → Reward: 0.77
|
| 33 |
+
Issues requiring ML knowledge:
|
| 34 |
+
• val_loss < train_loss (data leakage signal)
|
| 35 |
+
• resnet18 using 42.5GB GPU (impossible)
|
| 36 |
+
• 350 epochs on ImageNet in 30 min (impossible)
|
| 37 |
+
• wav2vec2 at 98.5% accuracy (exceeds SOTA)
|
| 38 |
+
|
| 39 |
+
ALIGNMENT TASK — NVIDIA HelpSteer data (hardest)
|
| 40 |
+
Step 1: Found 7/12, missed subtle issues → Reward: 0.58
|
| 41 |
+
Step 2: Found 12/12 + 3 fixes proposed → Reward: 0.72
|
| 42 |
+
Issues requiring deep reasoning:
|
| 43 |
+
• Cerasus vs Prunus serrulata (wrong taxonomic name)
|
| 44 |
+
• $400.3M at Sotheby's vs $450.3M at Christie's (close but wrong)
|
| 45 |
+
• "does NOT learn via backprop" then describes backprop (self-contradiction)
|
| 46 |
+
• Fake Nature paper by "Dr. Sarah Chen" (hallucinated citation)
|
| 47 |
+
• "use bare except everywhere" rated helpfulness=3 (harmful advice)
|
| 48 |
+
• [SYSTEM] prompt leaked in response (pipeline contamination)
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
> The interactive replay UI with color-coded dataset visualization is available on the HF Space.
|
| 52 |
+
|
| 53 |
+
## Motivation
|
| 54 |
+
|
| 55 |
+
Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
|
| 56 |
+
|
| 57 |
+
DataQA turns this into a **two-phase RL challenge**:
|
| 58 |
+
1. **Identify** — systematically inspect corrupted data and pinpoint every planted issue
|
| 59 |
+
2. **Fix** — propose corrected values by reasoning about schema, constraints, and context
|
| 60 |
+
|
| 61 |
+
This creates a rich multi-step decision problem where agents must explore datasets strategically, distinguish subtle anomalies from noise, and reason about what the correct data should be.
|
| 62 |
+
|
| 63 |
+
## Environment API
|
| 64 |
+
|
| 65 |
+
| Endpoint | Method | Description |
|
| 66 |
+
|----------|--------|-------------|
|
| 67 |
+
| `/reset` | POST | Start a new episode with a corrupted dataset |
|
| 68 |
+
| `/step` | POST | Submit identified issues + proposed fixes |
|
| 69 |
+
| `/state` | GET | Get current episode state |
|
| 70 |
+
| `/health` | GET | Health check |
|
| 71 |
+
|
| 72 |
+
## Tasks
|
| 73 |
+
|
| 74 |
+
| Task | Issues | Difficulty | Domain | Description |
|
| 75 |
+
|------|--------|-----------|--------|-------------|
|
| 76 |
+
| `easy` | 6 | Beginner | HR/Employee data (21 rows) | Nulls, wrong types, duplicates, out-of-range, email-name mismatch, future dates |
|
| 77 |
+
| `medium` | 8 | Intermediate | E-commerce orders (31 rows) | Inconsistent totals, invalid categories, duplicate keys, wrong date formats, invalid country codes, future-date deliveries |
|
| 78 |
+
| `hard` | 10 | Advanced | ML experiment metadata (31 rows) | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
|
| 79 |
+
| `alignment` | 12 | Expert | LLM alignment data (30 rows, NVIDIA HelpSteer) | See below |
|
| 80 |
+
|
| 81 |
+
**Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
|
| 82 |
+
|
| 83 |
+
### Alignment Task: LLM Training Data Quality (Expert)
|
| 84 |
+
|
| 85 |
+
Built on **real data from [NVIDIA HelpSteer](https://huggingface.co/datasets/nvidia/HelpSteer)** — 30 human-annotated prompt-response pairs with quality scores (helpfulness, correctness, coherence, complexity, verbosity on 0-4 scale).
|
| 86 |
+
|
| 87 |
+
This task targets a critical real-world problem: **catching quality issues in LLM fine-tuning data before it corrupts model training**. The 12 planted issues represent failure modes actually seen in production data pipelines:
|
| 88 |
+
|
| 89 |
+
| Issue | Difficulty | Why It's Hard |
|
| 90 |
+
|---|---|---|
|
| 91 |
+
| Subtle factual error (*Cerasus* vs *Prunus serrulata*) | 3.0 | Old taxonomic synonym — sounds plausible, requires domain knowledge |
|
| 92 |
+
| Plausible wrong numbers ($400.3M at Sotheby's vs $450.3M at Christie's) | 3.0 | Right painting, wrong price by $50M and wrong auction house |
|
| 93 |
+
| Self-contradictory reasoning ("does NOT learn via backprop" then describes backprop) | 3.0 | Response negates its own conclusion — trains confused models |
|
| 94 |
+
| Hallucinated citation (fake Nature paper by fake Dr. Sarah Chen) | 3.0 | Fabricated study with specific fake statistics — most dangerous for training |
|
| 95 |
+
| Harmful coding advice ("use bare except everywhere") with high quality scores | 3.0 | Teaches dangerous practices if used for fine-tuning |
|
| 96 |
+
| Leaked system prompt (`[SYSTEM] You are a helpful AI...`) in response | 2.5 | Data pipeline failed to strip prompt template |
|
| 97 |
+
| Semantic near-duplicate prompt (rephrased, not exact copy) | 2.5 | Requires semantic similarity detection, not just string matching |
|
| 98 |
+
| Score inflation (helpfulness=4 for a 4-word answer) | 2.5 | Score-content mismatch requires understanding rating criteria |
|
| 99 |
+
| Truncated response (cut mid-sentence) | 2.5 | `max_length` truncation without sentence boundary detection |
|
| 100 |
+
| Response in French for English prompt | 2.0 | Language contamination from multilingual training data |
|
| 101 |
+
| Response plagiarized from another row | 2.0 | Data pipeline shuffling/dedup failure |
|
| 102 |
+
| Whitespace-only prompt | 2.0 | Empty training example from pipeline artifact |
|
| 103 |
+
|
| 104 |
+
These issues are designed to challenge frontier models — they require factual recall, semantic reasoning, cross-row comparison, and understanding of what makes training data harmful.
|
| 105 |
+
|
| 106 |
+
## Two-Phase Action Space
|
| 107 |
+
|
| 108 |
+
### Phase 1: Identify Issues
|
| 109 |
+
|
| 110 |
+
Submit issues in format: `row:<row_number>,col:<column_name>,issue:<issue_type>`
|
| 111 |
+
|
| 112 |
+
- `row_number`: 1-indexed data row position (after header)
|
| 113 |
+
- `column_name`: Exact column header name, lowercase
|
| 114 |
+
- `issue_type`: One of the supported types below
|
| 115 |
+
|
| 116 |
+
### Phase 2: Propose Fixes
|
| 117 |
+
|
| 118 |
+
Submit fixes in format: `row:<row_number>,col:<column_name>,fix:<corrected_value>`
|
| 119 |
+
|
| 120 |
+
The agent proposes the **correct value** that should replace the corrupted data. Fixes are graded against the original clean dataset.
|
| 121 |
+
|
| 122 |
+
Both phases can be submitted in the same step or across multiple steps.
|
| 123 |
+
|
| 124 |
+
**Supported Issue Types:**
|
| 125 |
+
|
| 126 |
+
| Type | Description | Example |
|
| 127 |
+
|------|-------------|---------|
|
| 128 |
+
| `missing_value` | Null, empty, or whitespace-only | Empty name field |
|
| 129 |
+
| `wrong_type` | Value doesn't match expected type | Salary as "seventy-five thousand" |
|
| 130 |
+
| `duplicate_row` | Exact duplicate or duplicate key | Two rows with same employee_id |
|
| 131 |
+
| `out_of_range` | Value outside valid range | Salary of 5000 when min is 50000 |
|
| 132 |
+
| `format_violation` | Wrong format or invalid enum | Date as DD/MM/YYYY instead of YYYY-MM-DD |
|
| 133 |
+
| `inconsistent_value` | Computed field mismatch, logical inconsistency | total != qty * price |
|
| 134 |
+
| `statistical_outlier` | Unreasonable value given context | resnet18 using 42.5GB GPU |
|
| 135 |
+
| `referential_integrity` | Foreign key violation | (available for custom tasks) |
|
| 136 |
+
|
| 137 |
+
## Observation Space
|
| 138 |
+
|
| 139 |
+
| Field | Type | Description |
|
| 140 |
+
|-------|------|-------------|
|
| 141 |
+
| `dataset_csv` | str | The corrupted dataset in CSV format |
|
| 142 |
+
| `schema_description` | str | Column types, ranges, and constraints |
|
| 143 |
+
| `validation_rules` | str | Business rules the data must satisfy |
|
| 144 |
+
| `task_description` | str | Task context and instructions |
|
| 145 |
+
| `feedback` | str | Per-step results: TP/FP/FN, precision/recall, fix scores |
|
| 146 |
+
| `num_issues_hint` | int | Exact count of planted issues |
|
| 147 |
+
| `max_steps` | int | Maximum attempts allowed |
|
| 148 |
+
| `done` | bool | Whether episode has terminated |
|
| 149 |
+
| `reward` | float | Best combined reward so far (0.0-1.0) |
|
| 150 |
+
|
| 151 |
+
**Observation Metadata** (per step):
|
| 152 |
+
- Identify: `identify_f1`, `identify_score`, `precision`, `recall`, `tp`, `fp`, `fn`
|
| 153 |
+
- Fix: `fix_score`, `fixes_correct`, `fixes_partial`, `fixes_wrong`, `fixes_attempted`
|
| 154 |
+
- Combined: `combined_reward`, `difficulty_found`, `difficulty_missed`
|
| 155 |
+
|
| 156 |
+
## Reward Function
|
| 157 |
+
|
| 158 |
+
### Combined Reward
|
| 159 |
+
|
| 160 |
+
```
|
| 161 |
+
combined_reward = 0.6 * identify_score + 0.4 * fix_score
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
If no fixes are submitted, `combined_reward = identify_score` (no penalty — backward compatible).
|
| 165 |
+
|
| 166 |
+
### Identify Score (Difficulty-Weighted F1)
|
| 167 |
+
|
| 168 |
+
Each planted issue has a **difficulty weight** (1.0-3.0):
|
| 169 |
+
|
| 170 |
+
| Weight | Category | Examples |
|
| 171 |
+
|--------|----------|----------|
|
| 172 |
+
| 1.0 | Easy | Missing values, obvious out-of-range, wrong type |
|
| 173 |
+
| 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
|
| 174 |
+
| 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
|
| 175 |
+
|
| 176 |
+
- **Weighted Recall** = (difficulty of found issues) / (total difficulty)
|
| 177 |
+
- **Weighted Precision** = penalizes false positives proportional to average difficulty
|
| 178 |
+
- **Weighted F1** = harmonic mean
|
| 179 |
+
|
| 180 |
+
### Fix Score (Difficulty-Weighted Quality)
|
| 181 |
+
|
| 182 |
+
Each proposed fix is compared against the original clean value:
|
| 183 |
+
|
| 184 |
+
| Fix Quality | Score | Description |
|
| 185 |
+
|-------------|-------|-------------|
|
| 186 |
+
| Exact match | 1.0 | Case-insensitive, whitespace-stripped match |
|
| 187 |
+
| Numeric close | 0.8 | Within 1% of correct numeric value |
|
| 188 |
+
| Correct cell | 0.1 | Right location, wrong value |
|
| 189 |
+
| Non-issue cell | 0.0 | Fix targets a cell with no issue |
|
| 190 |
+
|
| 191 |
+
Fix score = (sum of best fix score per issue × difficulty weight) / (total difficulty weight)
|
| 192 |
+
|
| 193 |
+
### Reward Properties
|
| 194 |
+
|
| 195 |
+
- **Per-step partial progress**: reward increases as more issues are found/fixed
|
| 196 |
+
- **Difficulty-aware**: finding subtle issues earns more than obvious ones
|
| 197 |
+
- **Penalizes bad behavior**: false positives reduce score, fixing non-issues earns nothing
|
| 198 |
+
- **Monotonically non-decreasing**: best score across all steps is the final reward
|
| 199 |
+
- **Always in [0.0, 1.0]**: meets hackathon requirement
|
| 200 |
+
|
| 201 |
+
### Episode Boundaries
|
| 202 |
+
|
| 203 |
+
- Each task allows up to 3 steps (attempts)
|
| 204 |
+
- Episode ends when F1 >= 0.999 (perfect identification) or max steps reached
|
| 205 |
+
- Agent receives detailed feedback after each step to improve on next attempt
|
| 206 |
+
|
| 207 |
+
## Baseline Scores
|
| 208 |
+
|
| 209 |
+
Baseline agent uses Qwen2.5-72B-Instruct via HuggingFace Router:
|
| 210 |
+
|
| 211 |
+
| Task | Identify Score | Fix Score | Combined | Notes |
|
| 212 |
+
|------|---------------|-----------|----------|-------|
|
| 213 |
+
| `easy` | 0.7-1.0 | 0.5-0.9 | 0.6-1.0 | Most LLMs find obvious issues reliably |
|
| 214 |
+
| `medium` | 0.5-0.8 | 0.3-0.6 | 0.4-0.7 | Cross-column reasoning challenges models |
|
| 215 |
+
| `hard` | 0.3-0.6 | 0.2-0.4 | 0.3-0.5 | ML domain knowledge and subtle patterns |
|
| 216 |
+
|
| 217 |
+
Scores vary by model. The hard task is designed to challenge frontier models.
|
| 218 |
+
|
| 219 |
+
## Extensibility
|
| 220 |
+
|
| 221 |
+
### Custom Contamination Rules
|
| 222 |
+
|
| 223 |
+
```python
|
| 224 |
+
from dataqa_env import register_contamination_rule
|
| 225 |
+
from dataqa_env.server.tasks import PlantedIssue
|
| 226 |
+
|
| 227 |
+
def swap_digits(rows, header, col_idx, row_idx, rng):
|
| 228 |
+
val = rows[row_idx][col_idx]
|
| 229 |
+
corrupted = val[::-1]
|
| 230 |
+
issue = PlantedIssue(
|
| 231 |
+
row=row_idx + 1, col=header[col_idx],
|
| 232 |
+
issue_type="format_violation",
|
| 233 |
+
description=f"Digits swapped in {header[col_idx]}",
|
| 234 |
+
difficulty=2.0,
|
| 235 |
+
)
|
| 236 |
+
return corrupted, issue
|
| 237 |
+
|
| 238 |
+
register_contamination_rule("swap_digits", swap_digits)
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
### Custom Tasks from Config
|
| 242 |
+
|
| 243 |
+
```python
|
| 244 |
+
from dataqa_env import create_task_from_config, register_task
|
| 245 |
+
|
| 246 |
+
task = create_task_from_config(
|
| 247 |
+
task_id="custom",
|
| 248 |
+
name="Custom Validation",
|
| 249 |
+
description="Find quality issues in this dataset.",
|
| 250 |
+
schema_description="id: int, name: str, score: int (0-100)",
|
| 251 |
+
validation_rules="No missing values. Scores must be 0-100.",
|
| 252 |
+
clean_csv="id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92",
|
| 253 |
+
contaminations=[
|
| 254 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 255 |
+
{"rule": "negative_value", "row": 2, "col": 2, "difficulty": 1.5},
|
| 256 |
+
],
|
| 257 |
+
)
|
| 258 |
+
register_task("custom", lambda seed: task)
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
### Built-in Contamination Rules
|
| 262 |
+
|
| 263 |
+
| Rule | Effect | Default Difficulty |
|
| 264 |
+
|------|--------|--------------------|
|
| 265 |
+
| `missing_value` | Sets field to empty string | 1.0 |
|
| 266 |
+
| `whitespace_value` | Sets field to single space | 2.5 |
|
| 267 |
+
| `wrong_type_text` | Replaces with random text | 1.0 |
|
| 268 |
+
| `negative_value` | Negates numeric value | 1.0 |
|
| 269 |
+
|
| 270 |
+
## Setup & Quick Start
|
| 271 |
+
|
| 272 |
+
```bash
|
| 273 |
+
# Install
|
| 274 |
+
pip install -e .
|
| 275 |
+
|
| 276 |
+
# Run server locally
|
| 277 |
+
uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
|
| 278 |
+
|
| 279 |
+
# Run inference (set your API credentials)
|
| 280 |
+
API_BASE_URL=https://router.huggingface.co/v1 \
|
| 281 |
+
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
|
| 282 |
+
HF_TOKEN=your-token \
|
| 283 |
+
python inference.py
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
## Docker
|
| 287 |
+
|
| 288 |
+
```bash
|
| 289 |
+
docker build -t dataqa-env .
|
| 290 |
+
docker run -p 8000:8000 dataqa-env
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
## Testing
|
| 294 |
+
|
| 295 |
+
```bash
|
| 296 |
+
pip install -e ".[dev]"
|
| 297 |
+
pytest tests/ -v
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
118 tests covering:
|
| 301 |
+
- Task creation, corruption, and difficulty weights
|
| 302 |
+
- Issue key and fix parsing (standard, lenient, edge cases)
|
| 303 |
+
- F1, weighted reward, and fix quality computation
|
| 304 |
+
- Full environment lifecycle (identify-only and identify+fix)
|
| 305 |
+
- Combined reward calculation and weight verification
|
| 306 |
+
- Inference script parsing and prompt building
|
| 307 |
+
- Structured log format ([START], [STEP], [END])
|
| 308 |
+
- Score bounds (0.0-1.0), best-score monotonicity
|
| 309 |
+
- Extensibility API (custom rules, custom tasks)
|
| 310 |
+
|
| 311 |
+
## Validation
|
| 312 |
+
|
| 313 |
+
```bash
|
| 314 |
+
# OpenEnv spec validation
|
| 315 |
+
openenv validate .
|
| 316 |
+
|
| 317 |
+
# Pre-submission validation (requires HF Space URL)
|
| 318 |
+
./prevalidation_script.sh https://your-space.hf.space
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
## Environment Variables
|
| 322 |
+
|
| 323 |
+
| Variable | Description | Default |
|
| 324 |
+
|----------|-------------|---------|
|
| 325 |
+
| `API_BASE_URL` | LLM API endpoint | `https://router.huggingface.co/v1` |
|
| 326 |
+
| `MODEL_NAME` | Model identifier | `Qwen/Qwen2.5-72B-Instruct` |
|
| 327 |
+
| `HF_TOKEN` | HuggingFace token / API key | - |
|
| 328 |
+
| `ENV_URL` | Environment server URL | `http://localhost:8000` |
|
| 329 |
+
|
| 330 |
+
## Architecture
|
| 331 |
+
|
| 332 |
+
```
|
| 333 |
+
dataqa_env/
|
| 334 |
+
├── __init__.py # Public API + extensibility exports
|
| 335 |
+
├── models.py # Pydantic: DataQAAction (issues + fixes), DataQAObservation, DataQAState
|
| 336 |
+
├── client.py # EnvClient for WebSocket connections
|
| 337 |
+
├── server/
|
| 338 |
+
│ ├── environment.py # Two-phase DataQAEnvironment (identify + fix + combined reward)
|
| 339 |
+
│ ├── tasks.py # Task definitions + contamination rules + extensibility API
|
| 340 |
+
│ ├── app.py # FastAPI server (via openenv-core create_app)
|
| 341 |
+
│ └── Dockerfile
|
| 342 |
+
tests/
|
| 343 |
+
├── test_tasks.py # Task creation, corruption, difficulty weights
|
| 344 |
+
├── test_environment.py # Identify scoring, fix grading, combined reward, lifecycle
|
| 345 |
+
├── test_inference.py # LLM response parsing, fix parsing, prompt building, log format
|
| 346 |
+
└── test_extensibility.py # Custom rules, custom tasks, registration API
|
| 347 |
+
inference.py # Two-phase baseline agent (identify → fix)
|
| 348 |
+
openenv.yaml # OpenEnv/HF Spaces spec
|
| 349 |
+
pyproject.toml # Package metadata and dependencies
|
| 350 |
+
Dockerfile # Production container
|
| 351 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root-level package for OpenEnv compatibility."""
|
| 2 |
+
from dataqa_env import DataQAEnv, DataQAAction, DataQAObservation, DataQAState
|
| 3 |
+
|
| 4 |
+
__all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
|
client.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root-level client for OpenEnv compatibility."""
|
| 2 |
+
from dataqa_env.client import DataQAEnv
|
| 3 |
+
from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
|
| 4 |
+
|
| 5 |
+
__all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
|
dataqa_env/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .client import DataQAEnv
|
| 2 |
+
from .models import DataQAAction, DataQAObservation, DataQAState
|
| 3 |
+
from .server.tasks import (
|
| 4 |
+
create_task_from_config,
|
| 5 |
+
register_task,
|
| 6 |
+
register_contamination_rule,
|
| 7 |
+
CONTAMINATION_RULES,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DataQAEnv",
|
| 12 |
+
"DataQAAction",
|
| 13 |
+
"DataQAObservation",
|
| 14 |
+
"DataQAState",
|
| 15 |
+
"create_task_from_config",
|
| 16 |
+
"register_task",
|
| 17 |
+
"register_contamination_rule",
|
| 18 |
+
"CONTAMINATION_RULES",
|
| 19 |
+
]
|
dataqa_env/client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataQAEnv Client
|
| 3 |
+
----------------
|
| 4 |
+
Client-side wrapper for the DataQA environment server.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from openenv.core.client_types import StepResult
|
| 10 |
+
from openenv.core.env_client import EnvClient
|
| 11 |
+
|
| 12 |
+
from .models import DataQAAction, DataQAObservation, DataQAState
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DataQAEnv(EnvClient[DataQAAction, DataQAObservation, DataQAState]):
|
| 16 |
+
|
| 17 |
+
def _step_payload(self, action: DataQAAction) -> dict:
|
| 18 |
+
return {"issues": action.issues, "task_id": action.task_id}
|
| 19 |
+
|
| 20 |
+
def _parse_result(self, payload: dict) -> StepResult[DataQAObservation]:
|
| 21 |
+
obs = DataQAObservation(**payload["observation"])
|
| 22 |
+
return StepResult(
|
| 23 |
+
observation=obs,
|
| 24 |
+
reward=payload.get("reward"),
|
| 25 |
+
done=bool(payload.get("done", False)),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def _parse_state(self, payload: dict) -> DataQAState:
|
| 29 |
+
return DataQAState(
|
| 30 |
+
episode_id=payload.get("episode_id"),
|
| 31 |
+
step_count=payload.get("step_count", 0),
|
| 32 |
+
task_id=payload.get("task_id", ""),
|
| 33 |
+
current_step=payload.get("current_step", 0),
|
| 34 |
+
max_steps=payload.get("max_steps", 3),
|
| 35 |
+
best_score=payload.get("best_score", 0.0),
|
| 36 |
+
total_planted_issues=payload.get("total_planted_issues", 0),
|
| 37 |
+
)
|
dataqa_env/models.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataQA Environment Models
|
| 3 |
+
-------------------------
|
| 4 |
+
Action/Observation/State types for the Data Quality Assurance environment.
|
| 5 |
+
|
| 6 |
+
The agent receives a dataset with planted quality issues and must identify them.
|
| 7 |
+
Grading is based on F1 score (precision × recall) of correctly identified issues.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
from openenv.core.env_server.interfaces import Action, Observation, State
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DataQAAction(Action):
|
| 18 |
+
"""
|
| 19 |
+
Agent submits identified issues AND optional proposed fixes.
|
| 20 |
+
|
| 21 |
+
Two-phase action space:
|
| 22 |
+
Phase 1 (Identify): List issues in format "row:<N>,col:<name>,issue:<type>"
|
| 23 |
+
Phase 2 (Fix): List fixes in format "row:<N>,col:<name>,fix:<proposed_value>"
|
| 24 |
+
|
| 25 |
+
The agent can submit both in the same step or across multiple steps.
|
| 26 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 27 |
+
|
| 28 |
+
Supported issue types:
|
| 29 |
+
missing_value, wrong_type, duplicate_row, out_of_range,
|
| 30 |
+
format_violation, inconsistent_value, statistical_outlier,
|
| 31 |
+
referential_integrity
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
issues: List[str]
|
| 35 |
+
fixes: List[str] = []
|
| 36 |
+
# Include task_id so step() can reconstruct context in stateless HTTP mode
|
| 37 |
+
task_id: str = "easy"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DataQAObservation(Observation):
|
| 41 |
+
"""
|
| 42 |
+
What the agent sees: a dataset, its schema/rules, and feedback.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# The dataset as CSV text
|
| 46 |
+
dataset_csv: str = ""
|
| 47 |
+
|
| 48 |
+
# Schema description (column names, expected types, constraints)
|
| 49 |
+
schema_description: str = ""
|
| 50 |
+
|
| 51 |
+
# Validation rules in plain text
|
| 52 |
+
validation_rules: str = ""
|
| 53 |
+
|
| 54 |
+
# Task description
|
| 55 |
+
task_description: str = ""
|
| 56 |
+
|
| 57 |
+
# Feedback from previous step (empty on reset)
|
| 58 |
+
feedback: str = ""
|
| 59 |
+
|
| 60 |
+
# Current task ID
|
| 61 |
+
task_id: str = ""
|
| 62 |
+
|
| 63 |
+
# Number of planted issues (hint for the agent)
|
| 64 |
+
num_issues_hint: int = 0
|
| 65 |
+
|
| 66 |
+
# Max allowed steps for this task
|
| 67 |
+
max_steps: int = 3
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class DataQAState(State):
|
| 71 |
+
"""Tracks episode progress."""
|
| 72 |
+
|
| 73 |
+
task_id: str = ""
|
| 74 |
+
current_step: int = 0
|
| 75 |
+
max_steps: int = 3
|
| 76 |
+
best_score: float = 0.0
|
| 77 |
+
total_planted_issues: int = 0
|
dataqa_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system deps
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
git curl \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Install uv for fast dependency management
|
| 11 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 12 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 13 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx
|
| 14 |
+
|
| 15 |
+
# Copy project files
|
| 16 |
+
COPY . /app/env
|
| 17 |
+
|
| 18 |
+
WORKDIR /app/env
|
| 19 |
+
|
| 20 |
+
# Install dependencies
|
| 21 |
+
RUN uv sync --frozen --no-editable 2>/dev/null || uv sync --no-editable
|
| 22 |
+
|
| 23 |
+
# Set environment
|
| 24 |
+
ENV PATH="/app/env/.venv/bin:$PATH"
|
| 25 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 26 |
+
|
| 27 |
+
# Health check
|
| 28 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 29 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 30 |
+
|
| 31 |
+
EXPOSE 8000
|
| 32 |
+
|
| 33 |
+
CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
dataqa_env/server/__init__.py
ADDED
|
File without changes
|
dataqa_env/server/app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the DataQA Environment.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uvicorn dataqa_env.server.app:app --reload --host 0.0.0.0 --port 8000
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from openenv.core.env_server.http_server import create_app
|
| 10 |
+
from .environment import DataQAEnvironment
|
| 11 |
+
from ..models import DataQAAction, DataQAObservation
|
| 12 |
+
except ImportError:
|
| 13 |
+
from openenv.core.env_server.http_server import create_app
|
| 14 |
+
from dataqa_env.server.environment import DataQAEnvironment
|
| 15 |
+
from dataqa_env.models import DataQAAction, DataQAObservation
|
| 16 |
+
|
| 17 |
+
app = create_app(
|
| 18 |
+
DataQAEnvironment, DataQAAction, DataQAObservation, env_name="dataqa_env"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@app.get("/")
|
| 23 |
+
def root():
|
| 24 |
+
"""Root endpoint — environment info."""
|
| 25 |
+
return {
|
| 26 |
+
"name": "DataQA Environment",
|
| 27 |
+
"description": "Two-phase data quality assurance environment: identify issues + propose fixes",
|
| 28 |
+
"tasks": ["easy", "medium", "hard", "alignment", "coding", "toolcalling"],
|
| 29 |
+
"endpoints": ["/health", "/reset", "/step", "/state"],
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
import uvicorn
|
| 35 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
dataqa_env/server/environment.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataQA Environment
|
| 3 |
+
------------------
|
| 4 |
+
Server-side environment for data quality assurance tasks.
|
| 5 |
+
|
| 6 |
+
Two-phase RL environment:
|
| 7 |
+
Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
|
| 8 |
+
Phase 2 (Fix): Agent proposes corrections for identified issues.
|
| 9 |
+
|
| 10 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 11 |
+
Both phases scored with difficulty-weighted metrics for rich per-step signal.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
import uuid
|
| 18 |
+
from typing import Any, Optional, Set
|
| 19 |
+
|
| 20 |
+
from openenv.core.env_server.interfaces import Action, Environment, Observation
|
| 21 |
+
|
| 22 |
+
from ..models import DataQAAction, DataQAObservation, DataQAState
|
| 23 |
+
from .tasks import PlantedIssue, Task, get_task, list_tasks
|
| 24 |
+
|
| 25 |
+
# Reward weights for the two phases
|
| 26 |
+
IDENTIFY_WEIGHT = 0.6
|
| 27 |
+
FIX_WEIGHT = 0.4
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_issue_key(raw: str) -> Optional[str]:
|
| 31 |
+
"""
|
| 32 |
+
Parse an agent-reported issue string into a normalized key.
|
| 33 |
+
Expected format: row:<N>,col:<name>,issue:<type>
|
| 34 |
+
Returns normalized key or None if unparseable.
|
| 35 |
+
"""
|
| 36 |
+
raw = raw.strip().lower()
|
| 37 |
+
row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
|
| 38 |
+
col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
|
| 39 |
+
issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
|
| 40 |
+
|
| 41 |
+
if row_match and col_match and issue_match:
|
| 42 |
+
return f"row:{row_match.group(1)},col:{col_match.group(1)},issue:{issue_match.group(1)}"
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def parse_fix(raw: str) -> Optional[tuple[int, str, str]]:
|
| 47 |
+
"""
|
| 48 |
+
Parse an agent-proposed fix into (row, col, proposed_value).
|
| 49 |
+
Expected format: row:<N>,col:<name>,fix:<value>
|
| 50 |
+
Returns (row, col, value) or None if unparseable.
|
| 51 |
+
"""
|
| 52 |
+
raw = raw.strip()
|
| 53 |
+
row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE)
|
| 54 |
+
col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE)
|
| 55 |
+
fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE)
|
| 56 |
+
|
| 57 |
+
if row_match and col_match and fix_match:
|
| 58 |
+
return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip())
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
|
| 63 |
+
"""Compute precision, recall, and F1 score."""
|
| 64 |
+
if not reported_keys and not planted_keys:
|
| 65 |
+
return {"precision": 1.0, "recall": 1.0, "f1": 1.0, "tp": 0, "fp": 0, "fn": 0}
|
| 66 |
+
|
| 67 |
+
if not reported_keys:
|
| 68 |
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": 0, "fn": len(planted_keys)}
|
| 69 |
+
|
| 70 |
+
if not planted_keys:
|
| 71 |
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": len(reported_keys), "fn": 0}
|
| 72 |
+
|
| 73 |
+
tp = len(reported_keys & planted_keys)
|
| 74 |
+
fp = len(reported_keys - planted_keys)
|
| 75 |
+
fn = len(planted_keys - reported_keys)
|
| 76 |
+
|
| 77 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 78 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 79 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 80 |
+
|
| 81 |
+
return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_weighted_reward(
|
| 85 |
+
reported_keys: Set[str],
|
| 86 |
+
planted_issues: list,
|
| 87 |
+
) -> dict:
|
| 88 |
+
"""
|
| 89 |
+
Compute difficulty-weighted reward for richer per-step signal.
|
| 90 |
+
|
| 91 |
+
Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues
|
| 92 |
+
earns more reward. False positives incur a penalty scaled by average difficulty.
|
| 93 |
+
|
| 94 |
+
Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown.
|
| 95 |
+
"""
|
| 96 |
+
if not planted_issues and not reported_keys:
|
| 97 |
+
return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 98 |
+
|
| 99 |
+
planted_by_key = {issue.to_key(): issue for issue in planted_issues}
|
| 100 |
+
planted_keys = set(planted_by_key.keys())
|
| 101 |
+
|
| 102 |
+
if not reported_keys:
|
| 103 |
+
total_weight = sum(i.difficulty for i in planted_issues)
|
| 104 |
+
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight}
|
| 105 |
+
|
| 106 |
+
if not planted_keys:
|
| 107 |
+
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 108 |
+
|
| 109 |
+
found_keys = reported_keys & planted_keys
|
| 110 |
+
missed_keys = planted_keys - reported_keys
|
| 111 |
+
false_positive_count = len(reported_keys - planted_keys)
|
| 112 |
+
|
| 113 |
+
difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys)
|
| 114 |
+
difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
|
| 115 |
+
total_weight = sum(i.difficulty for i in planted_issues)
|
| 116 |
+
|
| 117 |
+
weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
|
| 118 |
+
|
| 119 |
+
avg_difficulty = total_weight / len(planted_issues)
|
| 120 |
+
fp_penalty_weight = false_positive_count * avg_difficulty
|
| 121 |
+
weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
|
| 122 |
+
|
| 123 |
+
if (weighted_precision + weighted_recall) > 0:
|
| 124 |
+
weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
|
| 125 |
+
else:
|
| 126 |
+
weighted_reward = 0.0
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
"weighted_reward": round(weighted_reward, 4),
|
| 130 |
+
"difficulty_found": round(difficulty_found, 2),
|
| 131 |
+
"difficulty_missed": round(difficulty_missed, 2),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def grade_fixes(
|
| 136 |
+
fixes: list[tuple[int, str, str]],
|
| 137 |
+
task: Task,
|
| 138 |
+
) -> dict:
|
| 139 |
+
"""
|
| 140 |
+
Grade proposed fixes against the clean dataset.
|
| 141 |
+
|
| 142 |
+
For each fix (row, col, proposed_value), compare to the original clean value.
|
| 143 |
+
Scoring per fix:
|
| 144 |
+
- Exact match (case-insensitive, whitespace-stripped): 1.0
|
| 145 |
+
- Numeric close match (within 1%): 0.8
|
| 146 |
+
- Correct column but wrong value: 0.1
|
| 147 |
+
- Targets a non-issue cell: 0.0 (penalty)
|
| 148 |
+
|
| 149 |
+
Returns dict with fix_score (0.0-1.0), details per fix, and counts.
|
| 150 |
+
"""
|
| 151 |
+
if not fixes and not task.planted_issues:
|
| 152 |
+
return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0,
|
| 153 |
+
"fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
|
| 154 |
+
|
| 155 |
+
if not fixes:
|
| 156 |
+
return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0,
|
| 157 |
+
"fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
|
| 158 |
+
|
| 159 |
+
issue_map = task.get_planted_issue_map()
|
| 160 |
+
# Build set of (row, col) that are actual issues
|
| 161 |
+
issue_cells = {(issue.row, issue.col) for issue in task.planted_issues}
|
| 162 |
+
|
| 163 |
+
total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0
|
| 164 |
+
earned_weight = 0.0
|
| 165 |
+
fixes_correct = 0
|
| 166 |
+
fixes_partial = 0
|
| 167 |
+
fixes_wrong = 0
|
| 168 |
+
fix_details = []
|
| 169 |
+
|
| 170 |
+
# Track which issues have been fixed (best fix wins)
|
| 171 |
+
fixed_issues: dict[tuple[int, str], float] = {}
|
| 172 |
+
|
| 173 |
+
for row, col, proposed in fixes:
|
| 174 |
+
clean_value = task.get_clean_value(row, col)
|
| 175 |
+
cell_key = (row, col)
|
| 176 |
+
|
| 177 |
+
if cell_key not in issue_cells:
|
| 178 |
+
# Fix targets a non-issue cell — no credit
|
| 179 |
+
fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"})
|
| 180 |
+
fixes_wrong += 1
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if clean_value is None:
|
| 184 |
+
fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"})
|
| 185 |
+
fixes_wrong += 1
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Find the planted issue for this cell to get its difficulty weight
|
| 189 |
+
matching_issue = None
|
| 190 |
+
for issue in task.planted_issues:
|
| 191 |
+
if issue.row == row and issue.col == col:
|
| 192 |
+
matching_issue = issue
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
difficulty = matching_issue.difficulty if matching_issue else 1.0
|
| 196 |
+
|
| 197 |
+
# Score the fix using tiered grading:
|
| 198 |
+
# 1.0 = exact match with clean value
|
| 199 |
+
# 0.8 = valid fix (right type, in range, addresses the issue) but not exact
|
| 200 |
+
# 0.4 = partially valid (reasonable attempt, right direction)
|
| 201 |
+
# 0.1 = targets correct cell but fix doesn't address the issue
|
| 202 |
+
# 0.0 = makes things worse or targets non-issue cell
|
| 203 |
+
score = 0.0
|
| 204 |
+
reason = "wrong value"
|
| 205 |
+
issue_type = matching_issue.issue_type if matching_issue else ""
|
| 206 |
+
|
| 207 |
+
# Exact match (case-insensitive, whitespace-stripped)
|
| 208 |
+
if proposed.strip().lower() == clean_value.lower():
|
| 209 |
+
score = 1.0
|
| 210 |
+
reason = "exact match"
|
| 211 |
+
fixes_correct += 1
|
| 212 |
+
else:
|
| 213 |
+
# Grade by issue type — check if the fix is VALID even if not exact
|
| 214 |
+
proposed_stripped = proposed.strip()
|
| 215 |
+
|
| 216 |
+
if issue_type == "missing_value":
|
| 217 |
+
# Any non-empty value is a reasonable fix for a missing value
|
| 218 |
+
if proposed_stripped and proposed_stripped != " ":
|
| 219 |
+
score = 0.8
|
| 220 |
+
reason = "valid fix (non-empty value for missing field)"
|
| 221 |
+
fixes_partial += 1
|
| 222 |
+
else:
|
| 223 |
+
score = 0.0
|
| 224 |
+
reason = "fix is still empty"
|
| 225 |
+
fixes_wrong += 1
|
| 226 |
+
|
| 227 |
+
elif issue_type == "wrong_type":
|
| 228 |
+
# Check if the proposed value is the correct type
|
| 229 |
+
try:
|
| 230 |
+
float(proposed_stripped)
|
| 231 |
+
# Original was text, proposed is numeric — correct type fix
|
| 232 |
+
score = 0.8
|
| 233 |
+
reason = "valid fix (correct type)"
|
| 234 |
+
fixes_partial += 1
|
| 235 |
+
except ValueError:
|
| 236 |
+
score = 0.1
|
| 237 |
+
reason = "fix is still wrong type"
|
| 238 |
+
fixes_partial += 1
|
| 239 |
+
|
| 240 |
+
elif issue_type == "out_of_range":
|
| 241 |
+
# Check if proposed value is within a reasonable range
|
| 242 |
+
try:
|
| 243 |
+
proposed_num = float(proposed_stripped)
|
| 244 |
+
clean_num = float(clean_value)
|
| 245 |
+
# Within 50% of clean value = good estimate
|
| 246 |
+
if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.5:
|
| 247 |
+
score = 0.8
|
| 248 |
+
reason = "valid fix (in reasonable range)"
|
| 249 |
+
fixes_partial += 1
|
| 250 |
+
elif proposed_num > 0 and (clean_num > 0) == (proposed_num > 0):
|
| 251 |
+
# At least right sign/direction
|
| 252 |
+
score = 0.4
|
| 253 |
+
reason = "partially valid (right direction)"
|
| 254 |
+
fixes_partial += 1
|
| 255 |
+
else:
|
| 256 |
+
score = 0.1
|
| 257 |
+
reason = "fix still out of reasonable range"
|
| 258 |
+
fixes_partial += 1
|
| 259 |
+
except ValueError:
|
| 260 |
+
score = 0.1
|
| 261 |
+
reason = "correct cell, wrong value"
|
| 262 |
+
fixes_partial += 1
|
| 263 |
+
|
| 264 |
+
elif issue_type == "format_violation":
|
| 265 |
+
# Check if proposed value matches expected format
|
| 266 |
+
# For dates: YYYY-MM-DD pattern
|
| 267 |
+
if re.match(r"\d{4}-\d{2}-\d{2}", proposed_stripped):
|
| 268 |
+
score = 0.8
|
| 269 |
+
reason = "valid fix (correct format)"
|
| 270 |
+
fixes_partial += 1
|
| 271 |
+
elif proposed_stripped and proposed_stripped != clean_value:
|
| 272 |
+
score = 0.4
|
| 273 |
+
reason = "fix attempted but format unclear"
|
| 274 |
+
fixes_partial += 1
|
| 275 |
+
else:
|
| 276 |
+
score = 0.1
|
| 277 |
+
reason = "correct cell, wrong value"
|
| 278 |
+
fixes_partial += 1
|
| 279 |
+
|
| 280 |
+
elif issue_type in ("inconsistent_value", "statistical_outlier"):
|
| 281 |
+
# These require domain knowledge — any reasonable attempt gets partial credit
|
| 282 |
+
try:
|
| 283 |
+
proposed_num = float(proposed_stripped)
|
| 284 |
+
clean_num = float(clean_value)
|
| 285 |
+
# Within 20% = strong fix, within 50% = reasonable
|
| 286 |
+
if clean_num != 0:
|
| 287 |
+
pct_diff = abs(proposed_num - clean_num) / abs(clean_num)
|
| 288 |
+
if pct_diff <= 0.01:
|
| 289 |
+
score = 1.0
|
| 290 |
+
reason = "exact numeric match"
|
| 291 |
+
fixes_correct += 1
|
| 292 |
+
elif pct_diff <= 0.2:
|
| 293 |
+
score = 0.8
|
| 294 |
+
reason = "valid fix (within 20% of correct value)"
|
| 295 |
+
fixes_partial += 1
|
| 296 |
+
elif pct_diff <= 0.5:
|
| 297 |
+
score = 0.4
|
| 298 |
+
reason = "partially valid (right ballpark)"
|
| 299 |
+
fixes_partial += 1
|
| 300 |
+
else:
|
| 301 |
+
score = 0.1
|
| 302 |
+
reason = "correct cell, value not close"
|
| 303 |
+
fixes_partial += 1
|
| 304 |
+
else:
|
| 305 |
+
score = 0.4
|
| 306 |
+
reason = "numeric fix attempted"
|
| 307 |
+
fixes_partial += 1
|
| 308 |
+
except ValueError:
|
| 309 |
+
# Non-numeric fix for text fields — check similarity
|
| 310 |
+
if len(proposed_stripped) > 10 and proposed_stripped != clean_value:
|
| 311 |
+
score = 0.4
|
| 312 |
+
reason = "text fix attempted (cannot verify automatically)"
|
| 313 |
+
fixes_partial += 1
|
| 314 |
+
else:
|
| 315 |
+
score = 0.1
|
| 316 |
+
reason = "correct cell, wrong value"
|
| 317 |
+
fixes_partial += 1
|
| 318 |
+
|
| 319 |
+
else:
|
| 320 |
+
# Fallback: numeric close match or partial credit
|
| 321 |
+
try:
|
| 322 |
+
proposed_num = float(proposed_stripped)
|
| 323 |
+
clean_num = float(clean_value)
|
| 324 |
+
if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01:
|
| 325 |
+
score = 0.8
|
| 326 |
+
reason = "numeric close match"
|
| 327 |
+
fixes_partial += 1
|
| 328 |
+
else:
|
| 329 |
+
score = 0.1
|
| 330 |
+
reason = "correct cell, wrong value"
|
| 331 |
+
fixes_partial += 1
|
| 332 |
+
except (ValueError, ZeroDivisionError):
|
| 333 |
+
score = 0.1
|
| 334 |
+
reason = "correct cell, wrong value"
|
| 335 |
+
fixes_partial += 1
|
| 336 |
+
|
| 337 |
+
# Keep best fix per cell
|
| 338 |
+
if cell_key not in fixed_issues or score > fixed_issues[cell_key]:
|
| 339 |
+
fixed_issues[cell_key] = score
|
| 340 |
+
|
| 341 |
+
fix_details.append({"row": row, "col": col, "score": score, "reason": reason})
|
| 342 |
+
|
| 343 |
+
# Compute fix score: weighted sum of best fix per issue / total weight
|
| 344 |
+
for issue in task.planted_issues:
|
| 345 |
+
cell_key = (issue.row, issue.col)
|
| 346 |
+
if cell_key in fixed_issues:
|
| 347 |
+
earned_weight += issue.difficulty * fixed_issues[cell_key]
|
| 348 |
+
|
| 349 |
+
fix_score = earned_weight / total_weight if total_weight > 0 else 0.0
|
| 350 |
+
fix_score = min(max(fix_score, 0.0), 1.0)
|
| 351 |
+
|
| 352 |
+
return {
|
| 353 |
+
"fix_score": round(fix_score, 4),
|
| 354 |
+
"fixes_correct": fixes_correct,
|
| 355 |
+
"fixes_partial": fixes_partial,
|
| 356 |
+
"fixes_wrong": fixes_wrong,
|
| 357 |
+
"fixes_attempted": len(fixes),
|
| 358 |
+
"fix_details": fix_details,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class DataQAEnvironment(Environment):
|
| 363 |
+
"""
|
| 364 |
+
Data Quality Assurance environment — two-phase identify + fix.
|
| 365 |
+
|
| 366 |
+
Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
|
| 367 |
+
Phase 2 (Fix): Agent proposes corrections for identified issues.
|
| 368 |
+
|
| 369 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 370 |
+
Both phases use difficulty-weighted scoring for rich per-step reward signals.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 374 |
+
|
| 375 |
+
def __init__(self):
|
| 376 |
+
self._state = DataQAState()
|
| 377 |
+
self._current_task: Optional[Task] = None
|
| 378 |
+
self._planted_keys: Set[str] = set()
|
| 379 |
+
self._best_score: float = 0.0
|
| 380 |
+
|
| 381 |
+
def reset(
|
| 382 |
+
self,
|
| 383 |
+
seed: Optional[int] = None,
|
| 384 |
+
episode_id: Optional[str] = None,
|
| 385 |
+
**kwargs: Any,
|
| 386 |
+
) -> Observation:
|
| 387 |
+
task_id = kwargs.get("task_id", "easy")
|
| 388 |
+
task_seed = seed if seed is not None else 42
|
| 389 |
+
|
| 390 |
+
self._current_task = get_task(task_id, seed=task_seed)
|
| 391 |
+
self._planted_keys = {issue.to_key() for issue in self._current_task.planted_issues}
|
| 392 |
+
self._best_score = 0.0
|
| 393 |
+
|
| 394 |
+
ep_id = episode_id or str(uuid.uuid4())
|
| 395 |
+
self._state = DataQAState(
|
| 396 |
+
episode_id=ep_id,
|
| 397 |
+
step_count=0,
|
| 398 |
+
task_id=task_id,
|
| 399 |
+
current_step=0,
|
| 400 |
+
max_steps=self._current_task.max_steps,
|
| 401 |
+
best_score=0.0,
|
| 402 |
+
total_planted_issues=len(self._current_task.planted_issues),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return DataQAObservation(
|
| 406 |
+
dataset_csv=self._current_task.corrupted_csv,
|
| 407 |
+
schema_description=self._current_task.schema_description,
|
| 408 |
+
validation_rules=self._current_task.validation_rules,
|
| 409 |
+
task_description=self._current_task.description,
|
| 410 |
+
feedback=(
|
| 411 |
+
"Environment reset. Inspect the dataset and report all quality issues.\n"
|
| 412 |
+
"You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n"
|
| 413 |
+
"Combined reward = 0.6 * identify_score + 0.4 * fix_score"
|
| 414 |
+
),
|
| 415 |
+
task_id=task_id,
|
| 416 |
+
num_issues_hint=len(self._current_task.planted_issues),
|
| 417 |
+
max_steps=self._current_task.max_steps,
|
| 418 |
+
done=False,
|
| 419 |
+
reward=0.0,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def step(
|
| 423 |
+
self,
|
| 424 |
+
action: Action,
|
| 425 |
+
timeout_s: Optional[float] = None,
|
| 426 |
+
**kwargs: Any,
|
| 427 |
+
) -> Observation:
|
| 428 |
+
if not isinstance(action, DataQAAction):
|
| 429 |
+
raise ValueError(f"Expected DataQAAction, got {type(action)}")
|
| 430 |
+
|
| 431 |
+
# Auto-reset in stateless HTTP mode
|
| 432 |
+
if self._current_task is None:
|
| 433 |
+
self.reset(task_id=action.task_id)
|
| 434 |
+
|
| 435 |
+
self._state.step_count += 1
|
| 436 |
+
self._state.current_step += 1
|
| 437 |
+
|
| 438 |
+
# ── Phase 1: Parse and score issue identification ──
|
| 439 |
+
reported_keys: Set[str] = set()
|
| 440 |
+
parse_errors: list[str] = []
|
| 441 |
+
for raw_issue in action.issues:
|
| 442 |
+
key = parse_issue_key(raw_issue)
|
| 443 |
+
if key:
|
| 444 |
+
reported_keys.add(key)
|
| 445 |
+
else:
|
| 446 |
+
parse_errors.append(f"Could not parse issue: '{raw_issue}'")
|
| 447 |
+
|
| 448 |
+
metrics = compute_f1(reported_keys, self._planted_keys)
|
| 449 |
+
identify_f1 = metrics["f1"]
|
| 450 |
+
|
| 451 |
+
weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
|
| 452 |
+
identify_score = weighted["weighted_reward"]
|
| 453 |
+
|
| 454 |
+
# ── Phase 2: Parse and score proposed fixes ──
|
| 455 |
+
parsed_fixes: list[tuple[int, str, str]] = []
|
| 456 |
+
for raw_fix in action.fixes:
|
| 457 |
+
fix = parse_fix(raw_fix)
|
| 458 |
+
if fix:
|
| 459 |
+
parsed_fixes.append(fix)
|
| 460 |
+
else:
|
| 461 |
+
parse_errors.append(f"Could not parse fix: '{raw_fix}'")
|
| 462 |
+
|
| 463 |
+
fix_result = grade_fixes(parsed_fixes, self._current_task)
|
| 464 |
+
fix_score = fix_result["fix_score"]
|
| 465 |
+
|
| 466 |
+
# ── Combined reward ──
|
| 467 |
+
# If no fixes submitted, score is identify-only (no penalty for not fixing)
|
| 468 |
+
if action.fixes:
|
| 469 |
+
combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score
|
| 470 |
+
else:
|
| 471 |
+
combined_reward = identify_score # backward compatible
|
| 472 |
+
|
| 473 |
+
self._best_score = max(self._best_score, combined_reward)
|
| 474 |
+
self._state.best_score = self._best_score
|
| 475 |
+
|
| 476 |
+
# ── Check if done ──
|
| 477 |
+
is_done = (
|
| 478 |
+
identify_f1 >= 0.999 # Perfect identification
|
| 479 |
+
or self._state.current_step >= self._state.max_steps
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# ── Build feedback with actionable diagnostics ──
|
| 483 |
+
# Show the agent exactly which reported issues were correct (TP) and which were wrong (FP)
|
| 484 |
+
tp_keys = reported_keys & self._planted_keys
|
| 485 |
+
fp_keys = reported_keys - self._planted_keys
|
| 486 |
+
|
| 487 |
+
feedback_lines = [
|
| 488 |
+
f"Step {self._state.current_step}/{self._state.max_steps}",
|
| 489 |
+
"",
|
| 490 |
+
"--- Identification ---",
|
| 491 |
+
f"Issues reported: {len(reported_keys)}",
|
| 492 |
+
f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
|
| 493 |
+
f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}",
|
| 494 |
+
f"Identify score (weighted): {identify_score:.3f}",
|
| 495 |
+
]
|
| 496 |
+
|
| 497 |
+
# Show which reported issues were correct vs wrong (helps agent self-correct)
|
| 498 |
+
if tp_keys:
|
| 499 |
+
feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}")
|
| 500 |
+
if fp_keys:
|
| 501 |
+
feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}")
|
| 502 |
+
|
| 503 |
+
if action.fixes:
|
| 504 |
+
feedback_lines += [
|
| 505 |
+
"",
|
| 506 |
+
"--- Fix Proposals ---",
|
| 507 |
+
f"Fixes attempted: {fix_result['fixes_attempted']}",
|
| 508 |
+
f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}",
|
| 509 |
+
f"Fix score: {fix_score:.3f}",
|
| 510 |
+
]
|
| 511 |
+
# Show per-fix feedback so agent knows which fixes worked
|
| 512 |
+
for detail in fix_result["fix_details"]:
|
| 513 |
+
status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong")
|
| 514 |
+
feedback_lines.append(
|
| 515 |
+
f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})"
|
| 516 |
+
)
|
| 517 |
+
feedback_lines.append(
|
| 518 |
+
f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---"
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
feedback_lines += [
|
| 522 |
+
"",
|
| 523 |
+
"Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.",
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
if parse_errors:
|
| 527 |
+
feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}")
|
| 528 |
+
|
| 529 |
+
if not is_done:
|
| 530 |
+
if metrics["fn"] > 0:
|
| 531 |
+
feedback_lines.append(
|
| 532 |
+
f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully."
|
| 533 |
+
)
|
| 534 |
+
if metrics["fp"] > 0:
|
| 535 |
+
feedback_lines.append(
|
| 536 |
+
f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues."
|
| 537 |
+
)
|
| 538 |
+
feedback_lines.append("You can submit again with updated issues and/or fixes.")
|
| 539 |
+
else:
|
| 540 |
+
feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}")
|
| 541 |
+
|
| 542 |
+
# ── Flag items for human review ──
|
| 543 |
+
# In a production data QA pipeline, these would go to a human reviewer.
|
| 544 |
+
# The grader flags cases where automated scoring has low confidence.
|
| 545 |
+
human_review_flags: list[dict] = []
|
| 546 |
+
|
| 547 |
+
# 1. False positives that target real columns — could be legitimate issues
|
| 548 |
+
# the task designer didn't plant (agent may be smarter than the grader)
|
| 549 |
+
issue_map = self._current_task.get_planted_issue_map()
|
| 550 |
+
valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range",
|
| 551 |
+
"format_violation", "inconsistent_value", "statistical_outlier",
|
| 552 |
+
"referential_integrity"}
|
| 553 |
+
for fp_key in fp_keys:
|
| 554 |
+
parts = fp_key.split(",")
|
| 555 |
+
itype = parts[2].split(":")[1] if len(parts) >= 3 else ""
|
| 556 |
+
if itype in valid_issue_types:
|
| 557 |
+
human_review_flags.append({
|
| 558 |
+
"item": fp_key,
|
| 559 |
+
"reason": "Agent reported this issue but it's not in ground truth — may be a real issue the grader missed",
|
| 560 |
+
"type": "possible_unplanted_issue",
|
| 561 |
+
})
|
| 562 |
+
|
| 563 |
+
# 2. Partial fix matches — fix was close but not exact, human should verify
|
| 564 |
+
for detail in fix_result["fix_details"]:
|
| 565 |
+
if 0 < detail["score"] < 0.99:
|
| 566 |
+
human_review_flags.append({
|
| 567 |
+
"item": f"row:{detail['row']},col:{detail['col']}",
|
| 568 |
+
"reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) — human should verify if acceptable",
|
| 569 |
+
"type": "partial_fix",
|
| 570 |
+
})
|
| 571 |
+
|
| 572 |
+
# 3. High-difficulty issues that were missed — flag for training data review
|
| 573 |
+
planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues}
|
| 574 |
+
fn_keys = self._planted_keys - reported_keys
|
| 575 |
+
for fn_key in fn_keys:
|
| 576 |
+
issue = planted_by_key.get(fn_key)
|
| 577 |
+
if issue and issue.difficulty >= 2.5:
|
| 578 |
+
human_review_flags.append({
|
| 579 |
+
"item": fn_key,
|
| 580 |
+
"reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed — {issue.description}",
|
| 581 |
+
"type": "missed_hard_issue",
|
| 582 |
+
})
|
| 583 |
+
|
| 584 |
+
if human_review_flags:
|
| 585 |
+
feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---")
|
| 586 |
+
for flag in human_review_flags:
|
| 587 |
+
feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}")
|
| 588 |
+
|
| 589 |
+
return DataQAObservation(
|
| 590 |
+
dataset_csv=self._current_task.corrupted_csv,
|
| 591 |
+
schema_description=self._current_task.schema_description,
|
| 592 |
+
validation_rules=self._current_task.validation_rules,
|
| 593 |
+
task_description=self._current_task.description,
|
| 594 |
+
feedback="\n".join(feedback_lines),
|
| 595 |
+
task_id=self._current_task.task_id,
|
| 596 |
+
num_issues_hint=len(self._current_task.planted_issues),
|
| 597 |
+
max_steps=self._state.max_steps,
|
| 598 |
+
done=is_done,
|
| 599 |
+
reward=self._best_score,
|
| 600 |
+
metadata={
|
| 601 |
+
"identify_f1": identify_f1,
|
| 602 |
+
"identify_score": identify_score,
|
| 603 |
+
"fix_score": fix_score,
|
| 604 |
+
"combined_reward": combined_reward,
|
| 605 |
+
"precision": metrics["precision"],
|
| 606 |
+
"recall": metrics["recall"],
|
| 607 |
+
"tp": metrics["tp"],
|
| 608 |
+
"fp": metrics["fp"],
|
| 609 |
+
"fn": metrics["fn"],
|
| 610 |
+
"difficulty_found": weighted["difficulty_found"],
|
| 611 |
+
"difficulty_missed": weighted["difficulty_missed"],
|
| 612 |
+
"fixes_correct": fix_result["fixes_correct"],
|
| 613 |
+
"fixes_partial": fix_result["fixes_partial"],
|
| 614 |
+
"fixes_wrong": fix_result["fixes_wrong"],
|
| 615 |
+
"fixes_attempted": fix_result["fixes_attempted"],
|
| 616 |
+
"fix_details": fix_result["fix_details"],
|
| 617 |
+
"human_review_flags": human_review_flags,
|
| 618 |
+
},
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
@property
|
| 622 |
+
def state(self) -> DataQAState:
|
| 623 |
+
return self._state
|
dataqa_env/server/gradio_ui.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI — Agent Trajectory Replay Viewer for DataQA.
|
| 3 |
+
|
| 4 |
+
Designed for judges: zero clicks needed, auto-plays on load.
|
| 5 |
+
Tab per task, step slider, prominent metric cards, color-coded dataset.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import csv
|
| 11 |
+
import io
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from .environment import DataQAEnvironment, parse_issue_key
|
| 16 |
+
from .tasks import list_tasks, PlantedIssue
|
| 17 |
+
from ..models import DataQAAction
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ── Pre-built agent trajectories (simulates baseline agent) ──
|
| 21 |
+
|
| 22 |
+
AGENT_TRAJECTORIES = {
|
| 23 |
+
# Demo trajectories: fixes are ONLY proposed where the correct value
|
| 24 |
+
# is logically inferrable (computable, format conversion, or deducible from context).
|
| 25 |
+
# Ambiguous fixes (any valid salary, any past date) are NOT proposed.
|
| 26 |
+
"easy": [
|
| 27 |
+
{
|
| 28 |
+
"issues": [
|
| 29 |
+
"row:4,col:name,issue:missing_value",
|
| 30 |
+
"row:7,col:salary,issue:wrong_type",
|
| 31 |
+
"row:9,col:salary,issue:out_of_range",
|
| 32 |
+
"row:18,col:start_date,issue:out_of_range",
|
| 33 |
+
"row:3,col:email,issue:format_violation", # FP
|
| 34 |
+
],
|
| 35 |
+
"fixes": [],
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"issues": [
|
| 39 |
+
"row:4,col:name,issue:missing_value",
|
| 40 |
+
"row:7,col:salary,issue:wrong_type",
|
| 41 |
+
"row:9,col:salary,issue:out_of_range",
|
| 42 |
+
"row:21,col:employee_id,issue:duplicate_row",
|
| 43 |
+
"row:15,col:email,issue:inconsistent_value",
|
| 44 |
+
"row:18,col:start_date,issue:out_of_range",
|
| 45 |
+
],
|
| 46 |
+
"fixes": [
|
| 47 |
+
# Inferrable: name "David Kim" deduced from email david.kim@company.com
|
| 48 |
+
"row:4,col:name,fix:David Kim",
|
| 49 |
+
# Inferrable: "seventy-five thousand" is clearly 75000
|
| 50 |
+
"row:7,col:salary,fix:75000",
|
| 51 |
+
# Inferrable: email must match name pattern oscar.rivera@company.com
|
| 52 |
+
"row:15,col:email,fix:oscar.rivera@company.com",
|
| 53 |
+
# NOT proposed: row:9 salary (any valid salary 50000-150000 works)
|
| 54 |
+
# NOT proposed: row:18 start_date (any past date works)
|
| 55 |
+
# NOT proposed: row:21 duplicate (remove or reassign — ambiguous)
|
| 56 |
+
],
|
| 57 |
+
},
|
| 58 |
+
],
|
| 59 |
+
"medium": [
|
| 60 |
+
{
|
| 61 |
+
"issues": [
|
| 62 |
+
"row:5,col:total,issue:inconsistent_value",
|
| 63 |
+
"row:10,col:category,issue:format_violation",
|
| 64 |
+
"row:14,col:product_name,issue:missing_value",
|
| 65 |
+
"row:17,col:quantity,issue:out_of_range",
|
| 66 |
+
"row:19,col:order_id,issue:duplicate_row",
|
| 67 |
+
"row:12,col:order_date,issue:format_violation",
|
| 68 |
+
"row:24,col:shipping_country,issue:format_violation",
|
| 69 |
+
],
|
| 70 |
+
"fixes": [],
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"issues": [
|
| 74 |
+
"row:5,col:total,issue:inconsistent_value",
|
| 75 |
+
"row:10,col:category,issue:format_violation",
|
| 76 |
+
"row:14,col:product_name,issue:missing_value",
|
| 77 |
+
"row:17,col:quantity,issue:out_of_range",
|
| 78 |
+
"row:19,col:order_id,issue:duplicate_row",
|
| 79 |
+
"row:12,col:order_date,issue:format_violation",
|
| 80 |
+
"row:24,col:shipping_country,issue:format_violation",
|
| 81 |
+
"row:29,col:order_date,issue:inconsistent_value",
|
| 82 |
+
],
|
| 83 |
+
"fixes": [
|
| 84 |
+
# Inferrable: total = qty(1) * price(42.00) = 42.00
|
| 85 |
+
"row:5,col:total,fix:42.00",
|
| 86 |
+
# Inferrable: "Fitness" is closest to "Sports" in allowed categories
|
| 87 |
+
"row:10,col:category,fix:Sports",
|
| 88 |
+
# Inferrable: 26/01/2024 reformatted to YYYY-MM-DD
|
| 89 |
+
"row:12,col:order_date,fix:2024-01-26",
|
| 90 |
+
# NOT proposed: row:14 product_name (any product name works)
|
| 91 |
+
# NOT proposed: row:17 quantity (any positive int)
|
| 92 |
+
# NOT proposed: row:19 duplicate order_id (reassign — ambiguous)
|
| 93 |
+
# NOT proposed: row:24 country (could be any valid ISO code)
|
| 94 |
+
# NOT proposed: row:29 future date (any past date works)
|
| 95 |
+
],
|
| 96 |
+
},
|
| 97 |
+
],
|
| 98 |
+
"hard": [
|
| 99 |
+
{
|
| 100 |
+
"issues": [
|
| 101 |
+
"row:14,col:training_time_hours,issue:out_of_range",
|
| 102 |
+
"row:13,col:learning_rate,issue:out_of_range",
|
| 103 |
+
"row:15,col:model_name,issue:missing_value",
|
| 104 |
+
"row:9,col:batch_size,issue:format_violation",
|
| 105 |
+
"row:10,col:train_size,issue:inconsistent_value",
|
| 106 |
+
],
|
| 107 |
+
"fixes": [],
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"issues": [
|
| 111 |
+
"row:14,col:training_time_hours,issue:out_of_range",
|
| 112 |
+
"row:13,col:learning_rate,issue:out_of_range",
|
| 113 |
+
"row:15,col:model_name,issue:missing_value",
|
| 114 |
+
"row:9,col:batch_size,issue:format_violation",
|
| 115 |
+
"row:10,col:train_size,issue:inconsistent_value",
|
| 116 |
+
"row:5,col:val_loss,issue:inconsistent_value",
|
| 117 |
+
"row:7,col:gpu_memory_gb,issue:statistical_outlier",
|
| 118 |
+
"row:11,col:timestamp,issue:inconsistent_value",
|
| 119 |
+
"row:9,col:training_time_hours,issue:statistical_outlier",
|
| 120 |
+
"row:12,col:test_accuracy,issue:statistical_outlier",
|
| 121 |
+
],
|
| 122 |
+
"fixes": [
|
| 123 |
+
# Inferrable: batch_size 250 → nearest power of 2 = 256
|
| 124 |
+
"row:9,col:batch_size,fix:256",
|
| 125 |
+
# Inferrable: negative time -72.0 → absolute value 72.0
|
| 126 |
+
"row:14,col:training_time_hours,fix:72.0",
|
| 127 |
+
# NOT proposed: row:13 LR (any valid LR 1e-7 to 1.0)
|
| 128 |
+
# NOT proposed: row:15 model_name (could be any model)
|
| 129 |
+
# NOT proposed: row:5 val_loss (any val >= train_loss)
|
| 130 |
+
# NOT proposed: row:7 GPU memory (any reasonable value)
|
| 131 |
+
# NOT proposed: row:10 train_size (any value > test_size)
|
| 132 |
+
# NOT proposed: row:11 timestamp (any date after prev)
|
| 133 |
+
# NOT proposed: row:9 training_time (any reasonable hours)
|
| 134 |
+
# NOT proposed: row:12 test_accuracy (any < SOTA)
|
| 135 |
+
],
|
| 136 |
+
},
|
| 137 |
+
],
|
| 138 |
+
"alignment": [
|
| 139 |
+
{
|
| 140 |
+
"issues": [
|
| 141 |
+
"row:6,col:response,issue:inconsistent_value",
|
| 142 |
+
"row:15,col:response,issue:inconsistent_value",
|
| 143 |
+
"row:28,col:prompt,issue:missing_value",
|
| 144 |
+
"row:20,col:response,issue:inconsistent_value",
|
| 145 |
+
"row:7,col:prompt,issue:duplicate_row",
|
| 146 |
+
"row:25,col:response,issue:missing_value",
|
| 147 |
+
"row:3,col:response,issue:inconsistent_value",
|
| 148 |
+
],
|
| 149 |
+
"fixes": [],
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"issues": [
|
| 153 |
+
"row:3,col:response,issue:inconsistent_value",
|
| 154 |
+
"row:4,col:response,issue:inconsistent_value",
|
| 155 |
+
"row:6,col:response,issue:inconsistent_value",
|
| 156 |
+
"row:7,col:prompt,issue:duplicate_row",
|
| 157 |
+
"row:8,col:response,issue:inconsistent_value",
|
| 158 |
+
"row:11,col:response,issue:inconsistent_value",
|
| 159 |
+
"row:15,col:response,issue:inconsistent_value",
|
| 160 |
+
"row:17,col:helpfulness,issue:inconsistent_value",
|
| 161 |
+
"row:20,col:response,issue:inconsistent_value",
|
| 162 |
+
"row:25,col:response,issue:missing_value",
|
| 163 |
+
"row:28,col:prompt,issue:missing_value",
|
| 164 |
+
"row:29,col:response,issue:inconsistent_value",
|
| 165 |
+
],
|
| 166 |
+
"fixes": [
|
| 167 |
+
# Inferrable: Salvator Mundi facts are well-known ($450.3M at Christie's)
|
| 168 |
+
"row:4,col:response,fix:The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $450.3 million at Christie's in New York City in 2017.",
|
| 169 |
+
# Inferrable: strip leaked [SYSTEM] prompt prefix
|
| 170 |
+
"row:3,col:response,fix:Kitsch is art or design that is overly sentimental or ornate while camp is a style that is over-the-top and exaggerated often used in satire or irony.",
|
| 171 |
+
# NOT proposed: row:6 wrong scientific name (need taxonomy knowledge)
|
| 172 |
+
# NOT proposed: row:8 harmful advice (need to write safe version)
|
| 173 |
+
# NOT proposed: row:11 self-contradiction (need to rewrite coherently)
|
| 174 |
+
# NOT proposed: row:15 French response (need English translation)
|
| 175 |
+
# NOT proposed: row:29 hallucinated citation (need factual replacement)
|
| 176 |
+
],
|
| 177 |
+
},
|
| 178 |
+
],
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ── HTML rendering ──
|
| 183 |
+
|
| 184 |
+
def _metric_card(label: str, value: str, color: str = "#333") -> str:
|
| 185 |
+
return (
|
| 186 |
+
f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
|
| 187 |
+
f'border-radius:8px;min-width:100px;">'
|
| 188 |
+
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
|
| 189 |
+
f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
|
| 190 |
+
f'</div>'
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _csv_to_html(
|
| 195 |
+
csv_text: str,
|
| 196 |
+
planted: list[PlantedIssue],
|
| 197 |
+
correct: set[tuple[int, str]],
|
| 198 |
+
fp: set[tuple[int, str]],
|
| 199 |
+
missed: set[tuple[int, str]],
|
| 200 |
+
fixed: dict[tuple[int, str], str],
|
| 201 |
+
fix_values: dict[tuple[int, str], str] | None = None,
|
| 202 |
+
) -> str:
|
| 203 |
+
"""Render CSV as HTML with color-coded cells and inline fix proposals."""
|
| 204 |
+
fix_values = fix_values or {}
|
| 205 |
+
desc_map = {(i.row, i.col): i for i in planted}
|
| 206 |
+
reader = csv.reader(io.StringIO(csv_text.strip()))
|
| 207 |
+
rows = list(reader)
|
| 208 |
+
if not rows:
|
| 209 |
+
return ""
|
| 210 |
+
|
| 211 |
+
header = rows[0]
|
| 212 |
+
header_lower = [h.strip().lower() for h in header]
|
| 213 |
+
data = rows[1:]
|
| 214 |
+
|
| 215 |
+
t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
|
| 216 |
+
t.append('<tr>')
|
| 217 |
+
t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
|
| 218 |
+
for h in header:
|
| 219 |
+
t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
|
| 220 |
+
t.append('</tr>')
|
| 221 |
+
|
| 222 |
+
for i, row in enumerate(data):
|
| 223 |
+
rn = i + 1
|
| 224 |
+
bg = "#fff" if i % 2 == 0 else "#f8f9fa"
|
| 225 |
+
t.append(f'<tr style="background:{bg};">')
|
| 226 |
+
t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
|
| 227 |
+
for j, val in enumerate(row):
|
| 228 |
+
col = header_lower[j] if j < len(header_lower) else ""
|
| 229 |
+
ck = (rn, col)
|
| 230 |
+
s = "border:1px solid #dee2e6;padding:4px 8px;"
|
| 231 |
+
tip = ""
|
| 232 |
+
badge = ""
|
| 233 |
+
|
| 234 |
+
issue = desc_map.get(ck)
|
| 235 |
+
|
| 236 |
+
if ck in correct:
|
| 237 |
+
s += "background:#d4edda;"
|
| 238 |
+
tip = f"FOUND: {issue.description}" if issue else ""
|
| 239 |
+
badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
|
| 240 |
+
elif ck in fp:
|
| 241 |
+
s += "background:#f8d7da;"
|
| 242 |
+
badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
|
| 243 |
+
elif ck in missed:
|
| 244 |
+
s += "background:#fff3cd;"
|
| 245 |
+
tip = f"MISSED: {issue.description}" if issue else ""
|
| 246 |
+
badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
|
| 247 |
+
|
| 248 |
+
fx = fixed.get(ck)
|
| 249 |
+
proposed = fix_values.get(ck)
|
| 250 |
+
if fx == "correct":
|
| 251 |
+
s += "box-shadow:inset 0 0 0 2px #28a745;"
|
| 252 |
+
badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
|
| 253 |
+
elif fx == "partial":
|
| 254 |
+
s += "box-shadow:inset 0 0 0 2px #ffc107;"
|
| 255 |
+
badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
|
| 256 |
+
|
| 257 |
+
dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
|
| 258 |
+
|
| 259 |
+
# Show proposed fix value below the corrupted value
|
| 260 |
+
fix_line = ""
|
| 261 |
+
if proposed is not None:
|
| 262 |
+
fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
|
| 263 |
+
fix_line = (
|
| 264 |
+
f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
|
| 265 |
+
f'border-top:1px dashed {fix_color};padding-top:2px;">'
|
| 266 |
+
f'\u2192 {proposed}</div>'
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
|
| 270 |
+
t.append('</tr>')
|
| 271 |
+
t.append('</table>')
|
| 272 |
+
return "".join(t)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
LEGEND_HTML = (
|
| 276 |
+
'<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
|
| 277 |
+
'<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
|
| 278 |
+
'<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
|
| 279 |
+
'<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
|
| 280 |
+
'<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
|
| 281 |
+
'<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
|
| 282 |
+
'</div>'
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# ── Core replay logic ──
|
| 287 |
+
|
| 288 |
+
def _replay_task(task_id: str) -> list[dict]:
|
| 289 |
+
"""Run the agent trajectory and collect per-step data."""
|
| 290 |
+
env = DataQAEnvironment()
|
| 291 |
+
obs = env.reset(task_id=task_id)
|
| 292 |
+
task = env._current_task
|
| 293 |
+
planted_keys = {i.to_key() for i in task.planted_issues}
|
| 294 |
+
steps_data = []
|
| 295 |
+
|
| 296 |
+
# Step 0: initial state
|
| 297 |
+
steps_data.append({
|
| 298 |
+
"label": "Initial — corrupted dataset",
|
| 299 |
+
"html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
|
| 300 |
+
"metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
|
| 301 |
+
"identify": 0.0, "fix": 0.0, "fixes_correct": 0},
|
| 302 |
+
"feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
trajectory = AGENT_TRAJECTORIES.get(task_id, [])
|
| 306 |
+
for i, step_data in enumerate(trajectory):
|
| 307 |
+
action = DataQAAction(
|
| 308 |
+
issues=step_data["issues"],
|
| 309 |
+
fixes=step_data.get("fixes", []),
|
| 310 |
+
task_id=task_id,
|
| 311 |
+
)
|
| 312 |
+
obs = env.step(action)
|
| 313 |
+
|
| 314 |
+
reported_keys = set()
|
| 315 |
+
for iss in step_data["issues"]:
|
| 316 |
+
key = parse_issue_key(iss)
|
| 317 |
+
if key:
|
| 318 |
+
reported_keys.add(key)
|
| 319 |
+
|
| 320 |
+
tp_keys = reported_keys & planted_keys
|
| 321 |
+
fp_keys = reported_keys - planted_keys
|
| 322 |
+
fn_keys = planted_keys - reported_keys
|
| 323 |
+
|
| 324 |
+
correct = {_kc(k) for k in tp_keys}
|
| 325 |
+
fp = {_kc(k) for k in fp_keys}
|
| 326 |
+
missed = {_kc(k) for k in fn_keys} if obs.done else set()
|
| 327 |
+
|
| 328 |
+
fixed: dict[tuple[int, str], str] = {}
|
| 329 |
+
for d in obs.metadata.get("fix_details", []):
|
| 330 |
+
c = (d["row"], d["col"])
|
| 331 |
+
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
|
| 332 |
+
|
| 333 |
+
# Extract proposed fix values from the raw fix strings
|
| 334 |
+
fix_values: dict[tuple[int, str], str] = {}
|
| 335 |
+
from .environment import parse_fix
|
| 336 |
+
for raw_fix in step_data.get("fixes", []):
|
| 337 |
+
parsed = parse_fix(raw_fix)
|
| 338 |
+
if parsed:
|
| 339 |
+
row, col, val = parsed
|
| 340 |
+
fix_values[(row, col)] = val
|
| 341 |
+
|
| 342 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
|
| 343 |
+
|
| 344 |
+
has_fixes = bool(step_data.get("fixes"))
|
| 345 |
+
if has_fixes:
|
| 346 |
+
label = f"Step {i+1} — identify + fix"
|
| 347 |
+
else:
|
| 348 |
+
label = f"Step {i+1} — identify only"
|
| 349 |
+
|
| 350 |
+
steps_data.append({
|
| 351 |
+
"label": label,
|
| 352 |
+
"html": html,
|
| 353 |
+
"metrics": {
|
| 354 |
+
"reward": obs.reward,
|
| 355 |
+
"tp": obs.metadata["tp"],
|
| 356 |
+
"fp": obs.metadata["fp"],
|
| 357 |
+
"fn": obs.metadata["fn"],
|
| 358 |
+
"identify": obs.metadata["identify_score"],
|
| 359 |
+
"fix": obs.metadata["fix_score"],
|
| 360 |
+
"fixes_correct": obs.metadata["fixes_correct"],
|
| 361 |
+
},
|
| 362 |
+
"feedback": obs.feedback,
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
return steps_data
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _kc(key: str) -> tuple[int, str]:
|
| 369 |
+
parts = key.split(",")
|
| 370 |
+
return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# ── Gradio app ──
|
| 374 |
+
|
| 375 |
+
def build_gradio_ui():
|
| 376 |
+
# Pre-compute all replays at startup
|
| 377 |
+
all_replays: dict[str, list[dict]] = {}
|
| 378 |
+
for tid in list_tasks():
|
| 379 |
+
all_replays[tid] = _replay_task(tid)
|
| 380 |
+
|
| 381 |
+
def show_step(task_id: str, step_idx: int):
|
| 382 |
+
replay = all_replays.get(task_id, [])
|
| 383 |
+
step_idx = int(step_idx)
|
| 384 |
+
if step_idx >= len(replay):
|
| 385 |
+
step_idx = len(replay) - 1
|
| 386 |
+
sd = replay[step_idx]
|
| 387 |
+
m = sd["metrics"]
|
| 388 |
+
|
| 389 |
+
# Reward color
|
| 390 |
+
r = m["reward"]
|
| 391 |
+
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
|
| 392 |
+
|
| 393 |
+
cards = (
|
| 394 |
+
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
|
| 395 |
+
+ _metric_card("Reward", f"{r:.2f}", rc)
|
| 396 |
+
+ _metric_card("Found", str(m["tp"]), "#28a745")
|
| 397 |
+
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
|
| 398 |
+
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
|
| 399 |
+
+ _metric_card("Identify", f"{m['identify']:.2f}", "#333")
|
| 400 |
+
+ _metric_card("Fix", f"{m['fix']:.2f}", "#333")
|
| 401 |
+
+ '</div>'
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
full_html = (
|
| 405 |
+
f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
|
| 406 |
+
f'{sd["label"]}</div>'
|
| 407 |
+
+ cards + sd["html"] + LEGEND_HTML
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return full_html, sd["feedback"]
|
| 411 |
+
|
| 412 |
+
def on_task_change(task_id):
|
| 413 |
+
replay = all_replays.get(task_id, [])
|
| 414 |
+
max_step = len(replay) - 1
|
| 415 |
+
html, fb = show_step(task_id, 0)
|
| 416 |
+
return (
|
| 417 |
+
gr.update(maximum=max_step, value=0),
|
| 418 |
+
html,
|
| 419 |
+
fb,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def on_step_change(task_id, step_idx):
|
| 423 |
+
html, fb = show_step(task_id, step_idx)
|
| 424 |
+
return html, fb
|
| 425 |
+
|
| 426 |
+
# ── Live agent runner (connects to the env server) ──
|
| 427 |
+
|
| 428 |
+
live_env = DataQAEnvironment()
|
| 429 |
+
live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
|
| 430 |
+
|
| 431 |
+
def live_reset(task_id):
|
| 432 |
+
obs = live_env.reset(task_id=task_id)
|
| 433 |
+
task = live_env._current_task
|
| 434 |
+
live_state["obs"] = obs
|
| 435 |
+
live_state["task_id"] = task_id
|
| 436 |
+
live_state["steps"] = []
|
| 437 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
|
| 438 |
+
info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
|
| 439 |
+
return html, info, "", "0.000"
|
| 440 |
+
|
| 441 |
+
def live_step(issues_text, fixes_text):
|
| 442 |
+
if live_state["obs"] is None:
|
| 443 |
+
return "Reset first.", "", "", ""
|
| 444 |
+
obs = live_state["obs"]
|
| 445 |
+
task = live_env._current_task
|
| 446 |
+
planted_keys = {i.to_key() for i in task.planted_issues}
|
| 447 |
+
|
| 448 |
+
issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
|
| 449 |
+
fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
|
| 450 |
+
|
| 451 |
+
action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
|
| 452 |
+
obs = live_env.step(action)
|
| 453 |
+
live_state["obs"] = obs
|
| 454 |
+
|
| 455 |
+
reported_keys = set()
|
| 456 |
+
for iss in issues:
|
| 457 |
+
key = parse_issue_key(iss)
|
| 458 |
+
if key:
|
| 459 |
+
reported_keys.add(key)
|
| 460 |
+
|
| 461 |
+
tp_keys = reported_keys & planted_keys
|
| 462 |
+
fp_keys = reported_keys - planted_keys
|
| 463 |
+
fn_keys = planted_keys - reported_keys
|
| 464 |
+
|
| 465 |
+
correct = {_kc(k) for k in tp_keys}
|
| 466 |
+
fp_set = {_kc(k) for k in fp_keys}
|
| 467 |
+
missed = {_kc(k) for k in fn_keys} if obs.done else set()
|
| 468 |
+
|
| 469 |
+
fixed: dict[tuple[int, str], str] = {}
|
| 470 |
+
for d in obs.metadata.get("fix_details", []):
|
| 471 |
+
c = (d["row"], d["col"])
|
| 472 |
+
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
|
| 473 |
+
|
| 474 |
+
from .environment import parse_fix
|
| 475 |
+
fix_values: dict[tuple[int, str], str] = {}
|
| 476 |
+
for raw in fixes:
|
| 477 |
+
parsed = parse_fix(raw)
|
| 478 |
+
if parsed:
|
| 479 |
+
fix_values[(parsed[0], parsed[1])] = parsed[2]
|
| 480 |
+
|
| 481 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
|
| 482 |
+
|
| 483 |
+
m = obs.metadata
|
| 484 |
+
r = obs.reward
|
| 485 |
+
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
|
| 486 |
+
cards = (
|
| 487 |
+
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
|
| 488 |
+
+ _metric_card("Reward", f"{r:.2f}", rc)
|
| 489 |
+
+ _metric_card("Found", str(m["tp"]), "#28a745")
|
| 490 |
+
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
|
| 491 |
+
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
|
| 492 |
+
+ '</div>'
|
| 493 |
+
)
|
| 494 |
+
full_html = cards + html + LEGEND_HTML
|
| 495 |
+
return full_html, obs.feedback, f"{r:.3f}", ""
|
| 496 |
+
|
| 497 |
+
# ── Build the UI ──
|
| 498 |
+
|
| 499 |
+
with gr.Blocks(title="DataQA Environment") as demo:
|
| 500 |
+
gr.Markdown(
|
| 501 |
+
"# DataQA — Data Quality Assurance Environment\n"
|
| 502 |
+
"Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
with gr.Tabs():
|
| 506 |
+
# ── Tab 1: Demo replay ──
|
| 507 |
+
with gr.Tab("Demo (Baseline Agent)"):
|
| 508 |
+
gr.Markdown(
|
| 509 |
+
"*Replay of the baseline Qwen-72B agent. "
|
| 510 |
+
"Use the slider to step through the agent's trajectory.*"
|
| 511 |
+
)
|
| 512 |
+
with gr.Row():
|
| 513 |
+
task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
|
| 514 |
+
step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
|
| 515 |
+
|
| 516 |
+
viz_html = gr.HTML()
|
| 517 |
+
feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
|
| 518 |
+
|
| 519 |
+
task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
|
| 520 |
+
step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
|
| 521 |
+
demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
|
| 522 |
+
|
| 523 |
+
# ── Tab 2: Try your own agent ──
|
| 524 |
+
with gr.Tab("Try Your Own Agent"):
|
| 525 |
+
gr.Markdown(
|
| 526 |
+
"*Submit your own issues and fixes to see how the environment scores them. "
|
| 527 |
+
"This is the same environment the baseline agent talks to.*"
|
| 528 |
+
)
|
| 529 |
+
with gr.Row():
|
| 530 |
+
live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
|
| 531 |
+
live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
|
| 532 |
+
|
| 533 |
+
with gr.Row():
|
| 534 |
+
live_info = gr.Markdown()
|
| 535 |
+
live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
|
| 536 |
+
|
| 537 |
+
live_viz = gr.HTML()
|
| 538 |
+
|
| 539 |
+
with gr.Row():
|
| 540 |
+
live_issues = gr.Textbox(
|
| 541 |
+
label="Issues (one per line)",
|
| 542 |
+
placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
|
| 543 |
+
lines=5,
|
| 544 |
+
)
|
| 545 |
+
live_fixes = gr.Textbox(
|
| 546 |
+
label="Fixes (one per line, optional)",
|
| 547 |
+
placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
|
| 548 |
+
lines=5,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
live_step_btn = gr.Button("Submit Step", variant="primary")
|
| 552 |
+
live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
|
| 553 |
+
|
| 554 |
+
live_reset_btn.click(
|
| 555 |
+
live_reset, inputs=[live_task_dd],
|
| 556 |
+
outputs=[live_viz, live_info, live_feedback, live_reward],
|
| 557 |
+
)
|
| 558 |
+
live_step_btn.click(
|
| 559 |
+
live_step, inputs=[live_issues, live_fixes],
|
| 560 |
+
outputs=[live_viz, live_feedback, live_reward, live_issues],
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
return demo
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if __name__ == "__main__":
|
| 567 |
+
demo = build_gradio_ui()
|
| 568 |
+
demo.launch()
|
dataqa_env/server/tasks.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions for the DataQA environment.
|
| 3 |
+
|
| 4 |
+
Each task provides:
|
| 5 |
+
- A clean dataset (CSV)
|
| 6 |
+
- A schema + validation rules
|
| 7 |
+
- A set of planted issues (ground truth)
|
| 8 |
+
- A function to inject those issues into the clean data
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import csv
|
| 14 |
+
import io
|
| 15 |
+
import random
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import List, Set
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class PlantedIssue:
|
| 22 |
+
"""A single planted data quality issue."""
|
| 23 |
+
|
| 24 |
+
row: int
|
| 25 |
+
col: str
|
| 26 |
+
issue_type: str
|
| 27 |
+
description: str
|
| 28 |
+
difficulty: float = 1.0 # 1.0=easy, 2.0=medium, 3.0=hard (for weighted reward)
|
| 29 |
+
|
| 30 |
+
def to_key(self) -> str:
|
| 31 |
+
return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Task:
|
| 36 |
+
task_id: str
|
| 37 |
+
name: str
|
| 38 |
+
description: str
|
| 39 |
+
schema_description: str
|
| 40 |
+
validation_rules: str
|
| 41 |
+
clean_csv: str
|
| 42 |
+
planted_issues: List[PlantedIssue] = field(default_factory=list)
|
| 43 |
+
corrupted_csv: str = ""
|
| 44 |
+
max_steps: int = 3
|
| 45 |
+
|
| 46 |
+
def get_clean_value(self, row: int, col: str) -> str | None:
|
| 47 |
+
"""
|
| 48 |
+
Look up the original clean value for a given (row, col).
|
| 49 |
+
Row is 1-indexed (data row after header).
|
| 50 |
+
Returns None if row/col is out of bounds or column not found.
|
| 51 |
+
"""
|
| 52 |
+
rows = _csv_to_rows(self.clean_csv)
|
| 53 |
+
if len(rows) < 2:
|
| 54 |
+
return None
|
| 55 |
+
header = [h.strip().lower() for h in rows[0]]
|
| 56 |
+
if col.lower() not in header:
|
| 57 |
+
return None
|
| 58 |
+
col_idx = header.index(col.lower())
|
| 59 |
+
data_row_idx = row # row is 1-indexed, rows[0] is header, so rows[row] is the data row
|
| 60 |
+
if data_row_idx < 1 or data_row_idx >= len(rows):
|
| 61 |
+
return None
|
| 62 |
+
return rows[data_row_idx][col_idx].strip()
|
| 63 |
+
|
| 64 |
+
def get_planted_issue_map(self) -> dict:
|
| 65 |
+
"""Return dict mapping issue key -> PlantedIssue for quick lookups."""
|
| 66 |
+
return {issue.to_key(): issue for issue in self.planted_issues}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _csv_to_rows(csv_text: str) -> List[List[str]]:
|
| 70 |
+
reader = csv.reader(io.StringIO(csv_text.strip()))
|
| 71 |
+
return [row for row in reader]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _rows_to_csv(rows: List[List[str]]) -> str:
|
| 75 |
+
output = io.StringIO()
|
| 76 |
+
writer = csv.writer(output)
|
| 77 |
+
writer.writerows(rows)
|
| 78 |
+
return output.getvalue()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# TASK 1: Easy — Employee directory with obvious issues
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def create_task_easy(seed: int = 42) -> Task:
|
| 86 |
+
rng = random.Random(seed)
|
| 87 |
+
|
| 88 |
+
clean_csv = """employee_id,name,email,department,salary,start_date
|
| 89 |
+
101,Alice Chen,alice.chen@company.com,Engineering,95000,2022-03-15
|
| 90 |
+
102,Bob Martinez,bob.martinez@company.com,Marketing,72000,2021-07-01
|
| 91 |
+
103,Carol Davis,carol.davis@company.com,Engineering,98000,2020-11-20
|
| 92 |
+
104,David Kim,david.kim@company.com,Sales,68000,2023-01-10
|
| 93 |
+
105,Eve Johnson,eve.johnson@company.com,HR,71000,2022-06-05
|
| 94 |
+
106,Frank Wilson,frank.wilson@company.com,Engineering,102000,2019-08-12
|
| 95 |
+
107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
|
| 96 |
+
108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
|
| 97 |
+
109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
|
| 98 |
+
110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14
|
| 99 |
+
111,Kevin Zhang,kevin.zhang@company.com,Engineering,91000,2021-05-22
|
| 100 |
+
112,Laura Adams,laura.adams@company.com,Sales,69000,2022-11-03
|
| 101 |
+
113,Mike Torres,mike.torres@company.com,Marketing,74000,2020-08-17
|
| 102 |
+
114,Nina Sharma,nina.sharma@company.com,HR,76000,2019-04-30
|
| 103 |
+
115,Oscar Rivera,oscar.rivera@company.com,Engineering,105000,2018-12-10
|
| 104 |
+
116,Paula Green,paula.green@company.com,Sales,67000,2023-06-25
|
| 105 |
+
117,Quinn Murphy,quinn.murphy@company.com,Marketing,78000,2021-03-08
|
| 106 |
+
118,Rosa Diaz,rosa.diaz@company.com,Engineering,99000,2022-01-19
|
| 107 |
+
119,Sam Cooper,sam.cooper@company.com,HR,70000,2020-10-05
|
| 108 |
+
120,Tara Singh,tara.singh@company.com,Sales,66000,2023-02-14"""
|
| 109 |
+
|
| 110 |
+
schema_desc = """Columns:
|
| 111 |
+
- employee_id: integer, unique, range 100-999
|
| 112 |
+
- name: string, non-empty, format "FirstName LastName"
|
| 113 |
+
- email: string, valid email format, must match pattern firstname.lastname@company.com
|
| 114 |
+
- department: string, one of [Engineering, Marketing, Sales, HR]
|
| 115 |
+
- salary: integer, range 50000-150000
|
| 116 |
+
- start_date: string, format YYYY-MM-DD, must be between 2015-01-01 and 2025-12-31"""
|
| 117 |
+
|
| 118 |
+
rules = """1. No missing values in any column
|
| 119 |
+
2. employee_id must be unique
|
| 120 |
+
3. email must follow the pattern: lowercase(firstname).lowercase(lastname)@company.com
|
| 121 |
+
4. salary must be within the valid range
|
| 122 |
+
5. No duplicate rows"""
|
| 123 |
+
|
| 124 |
+
rows = _csv_to_rows(clean_csv)
|
| 125 |
+
header = rows[0]
|
| 126 |
+
data = rows[1:]
|
| 127 |
+
issues: List[PlantedIssue] = []
|
| 128 |
+
|
| 129 |
+
# Issue 1: Missing value - null out a name (easy to spot)
|
| 130 |
+
r = 3 # row index in data (0-based), displayed as row 4 in CSV
|
| 131 |
+
data[r][1] = ""
|
| 132 |
+
issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
|
| 133 |
+
description="Empty name field", difficulty=1.0))
|
| 134 |
+
|
| 135 |
+
# Issue 2: Wrong type - salary as text (easy to spot)
|
| 136 |
+
r = 6
|
| 137 |
+
data[r][4] = "seventy-five thousand"
|
| 138 |
+
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
|
| 139 |
+
description="Salary is text instead of integer", difficulty=1.0))
|
| 140 |
+
|
| 141 |
+
# Issue 3: Duplicate row (moderate — requires cross-row comparison)
|
| 142 |
+
dup_source = 1
|
| 143 |
+
data.append(list(data[dup_source]))
|
| 144 |
+
issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
|
| 145 |
+
description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
|
| 146 |
+
|
| 147 |
+
# Issue 4: Out of range salary (easy to spot)
|
| 148 |
+
r = 8
|
| 149 |
+
data[r][4] = "5000"
|
| 150 |
+
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
|
| 151 |
+
description="Salary 5000 is below minimum 50000", difficulty=1.0))
|
| 152 |
+
|
| 153 |
+
# Issue 5: Email doesn't match name pattern (moderate — cross-column check)
|
| 154 |
+
r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
|
| 155 |
+
data[r][2] = "john.doe@company.com"
|
| 156 |
+
issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
|
| 157 |
+
description="Email john.doe@company.com doesn't match name Oscar Rivera",
|
| 158 |
+
difficulty=1.5))
|
| 159 |
+
|
| 160 |
+
# Issue 6: Future start date (requires knowing current date context)
|
| 161 |
+
r = 17 # Rosa Diaz
|
| 162 |
+
data[r][5] = "2027-06-15"
|
| 163 |
+
issues.append(PlantedIssue(row=r + 1, col="start_date", issue_type="out_of_range",
|
| 164 |
+
description="Start date 2027-06-15 is in the future (beyond 2025-12-31)",
|
| 165 |
+
difficulty=1.5))
|
| 166 |
+
|
| 167 |
+
corrupted = _rows_to_csv([header] + data)
|
| 168 |
+
|
| 169 |
+
return Task(
|
| 170 |
+
task_id="easy",
|
| 171 |
+
name="Employee Directory Validation",
|
| 172 |
+
description=(
|
| 173 |
+
"You are given an employee directory dataset. "
|
| 174 |
+
"Find all data quality issues based on the schema and validation rules. "
|
| 175 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 176 |
+
),
|
| 177 |
+
schema_description=schema_desc,
|
| 178 |
+
validation_rules=rules,
|
| 179 |
+
clean_csv=clean_csv,
|
| 180 |
+
planted_issues=issues,
|
| 181 |
+
corrupted_csv=corrupted,
|
| 182 |
+
max_steps=3,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# TASK 2: Medium — E-commerce orders with moderate issues
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
def create_task_medium(seed: int = 42) -> Task:
|
| 191 |
+
rng = random.Random(seed)
|
| 192 |
+
|
| 193 |
+
clean_csv = """order_id,customer_id,product_name,category,quantity,unit_price,order_date,shipping_country,status,total
|
| 194 |
+
ORD-001,CUST-100,Wireless Mouse,Electronics,2,29.99,2024-01-15,US,delivered,59.98
|
| 195 |
+
ORD-002,CUST-101,Python Cookbook,Books,1,45.50,2024-01-16,UK,delivered,45.50
|
| 196 |
+
ORD-003,CUST-102,USB-C Hub,Electronics,1,35.00,2024-01-17,US,shipped,35.00
|
| 197 |
+
ORD-004,CUST-103,Yoga Mat,Sports,1,25.99,2024-01-18,CA,delivered,25.99
|
| 198 |
+
ORD-005,CUST-104,Desk Lamp,Home,1,42.00,2024-01-19,US,processing,42.00
|
| 199 |
+
ORD-006,CUST-105,Running Shoes,Sports,1,89.99,2024-01-20,DE,delivered,89.99
|
| 200 |
+
ORD-007,CUST-106,Mechanical Keyboard,Electronics,1,129.99,2024-01-21,US,shipped,129.99
|
| 201 |
+
ORD-008,CUST-100,Monitor Stand,Home,1,55.00,2024-01-22,US,delivered,55.00
|
| 202 |
+
ORD-009,CUST-107,Data Science Handbook,Books,2,39.99,2024-01-23,UK,delivered,79.98
|
| 203 |
+
ORD-010,CUST-108,Resistance Bands,Sports,3,12.99,2024-01-24,CA,shipped,38.97
|
| 204 |
+
ORD-011,CUST-109,Webcam HD,Electronics,1,65.00,2024-01-25,US,delivered,65.00
|
| 205 |
+
ORD-012,CUST-110,Standing Desk,Home,1,299.99,2024-01-26,US,processing,299.99
|
| 206 |
+
ORD-013,CUST-111,Tennis Racket,Sports,1,75.00,2024-01-27,AU,delivered,75.00
|
| 207 |
+
ORD-014,CUST-112,LED Strip Lights,Home,2,18.50,2024-01-28,US,shipped,37.00
|
| 208 |
+
ORD-015,CUST-113,AI Textbook,Books,1,59.99,2024-01-29,DE,delivered,59.99
|
| 209 |
+
ORD-016,CUST-114,Bluetooth Speaker,Electronics,1,49.99,2024-01-30,UK,delivered,49.99
|
| 210 |
+
ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
|
| 211 |
+
ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
|
| 212 |
+
ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
|
| 213 |
+
ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99
|
| 214 |
+
ORD-021,CUST-119,Laptop Sleeve,Electronics,1,24.99,2024-02-04,US,delivered,24.99
|
| 215 |
+
ORD-022,CUST-120,Hiking Backpack,Sports,1,65.00,2024-02-05,CA,shipped,65.00
|
| 216 |
+
ORD-023,CUST-121,Machine Learning Book,Books,1,54.99,2024-02-06,UK,delivered,54.99
|
| 217 |
+
ORD-024,CUST-122,Plant Pot Set,Home,3,15.00,2024-02-07,US,delivered,45.00
|
| 218 |
+
ORD-025,CUST-123,Noise Cancelling Headphones,Electronics,1,199.99,2024-02-08,DE,shipped,199.99
|
| 219 |
+
ORD-026,CUST-124,Basketball,Sports,1,29.99,2024-02-09,US,delivered,29.99
|
| 220 |
+
ORD-027,CUST-125,Cookbook Collection,Books,2,22.50,2024-02-10,AU,delivered,45.00
|
| 221 |
+
ORD-028,CUST-126,Smart Plug,Home,4,12.99,2024-02-11,US,processing,51.96
|
| 222 |
+
ORD-029,CUST-127,Wireless Charger,Electronics,1,34.99,2024-02-12,UK,delivered,34.99
|
| 223 |
+
ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
|
| 224 |
+
|
| 225 |
+
schema_desc = """Columns:
|
| 226 |
+
- order_id: string, unique, format ORD-NNN
|
| 227 |
+
- customer_id: string, format CUST-NNN
|
| 228 |
+
- product_name: string, non-empty
|
| 229 |
+
- category: string, one of [Electronics, Books, Sports, Home]
|
| 230 |
+
- quantity: integer, range 1-100
|
| 231 |
+
- unit_price: float, range 0.01-10000.00
|
| 232 |
+
- order_date: string, format YYYY-MM-DD
|
| 233 |
+
- shipping_country: string, ISO 2-letter country code
|
| 234 |
+
- status: string, one of [processing, shipped, delivered, cancelled, returned]
|
| 235 |
+
- total: float, must equal quantity * unit_price"""
|
| 236 |
+
|
| 237 |
+
rules = """1. No missing values in any column
|
| 238 |
+
2. order_id must be unique
|
| 239 |
+
3. total must equal quantity * unit_price (tolerance: 0.01)
|
| 240 |
+
4. order_date must be in valid chronological order for sequential order_ids
|
| 241 |
+
5. category must be from the allowed set
|
| 242 |
+
6. All monetary values must have at most 2 decimal places
|
| 243 |
+
7. shipping_country must be a valid ISO 2-letter code"""
|
| 244 |
+
|
| 245 |
+
rows = _csv_to_rows(clean_csv)
|
| 246 |
+
header = rows[0]
|
| 247 |
+
data = rows[1:]
|
| 248 |
+
issues: List[PlantedIssue] = []
|
| 249 |
+
|
| 250 |
+
# Issue 1: total doesn't match quantity * unit_price (requires cross-column check)
|
| 251 |
+
r = 4 # ORD-005
|
| 252 |
+
data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
|
| 253 |
+
issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
|
| 254 |
+
description="total (84.00) != quantity (1) * unit_price (42.00)", difficulty=2.0))
|
| 255 |
+
|
| 256 |
+
# Issue 2: Invalid category (requires knowing the allowed set)
|
| 257 |
+
r = 9 # ORD-010
|
| 258 |
+
data[r][3] = "Fitness" # should be Sports
|
| 259 |
+
issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
|
| 260 |
+
description="'Fitness' is not in allowed categories", difficulty=1.5))
|
| 261 |
+
|
| 262 |
+
# Issue 3: Missing value in product_name (easy to spot)
|
| 263 |
+
r = 13 # ORD-014
|
| 264 |
+
data[r][2] = ""
|
| 265 |
+
issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
|
| 266 |
+
description="Empty product_name", difficulty=1.0))
|
| 267 |
+
|
| 268 |
+
# Issue 4: Out of range quantity (easy to spot)
|
| 269 |
+
r = 16 # ORD-017
|
| 270 |
+
data[r][4] = "-1"
|
| 271 |
+
issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
|
| 272 |
+
description="Negative quantity", difficulty=1.0))
|
| 273 |
+
|
| 274 |
+
# Issue 5: Duplicate order_id (requires cross-row comparison)
|
| 275 |
+
r = 18 # ORD-019
|
| 276 |
+
data[r][0] = "ORD-003"
|
| 277 |
+
issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
|
| 278 |
+
description="Duplicate order_id ORD-003", difficulty=1.5))
|
| 279 |
+
|
| 280 |
+
# Issue 6: Wrong date format (moderate — format mismatch)
|
| 281 |
+
r = 11 # ORD-012
|
| 282 |
+
data[r][6] = "26/01/2024"
|
| 283 |
+
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
|
| 284 |
+
description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
|
| 285 |
+
|
| 286 |
+
# Issue 7: Invalid country code (requires ISO knowledge)
|
| 287 |
+
r = 23 # ORD-024
|
| 288 |
+
data[r][7] = "XX" # not a valid ISO country code
|
| 289 |
+
issues.append(PlantedIssue(row=r + 1, col="shipping_country", issue_type="format_violation",
|
| 290 |
+
description="'XX' is not a valid ISO 2-letter country code", difficulty=1.5))
|
| 291 |
+
|
| 292 |
+
# Issue 8: Status-date inconsistency — order from Feb 13 still "processing" is suspicious
|
| 293 |
+
# but more importantly: delivered order with a future date
|
| 294 |
+
r = 28 # ORD-029
|
| 295 |
+
data[r][6] = "2025-12-25" # future date but status is "delivered"
|
| 296 |
+
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="inconsistent_value",
|
| 297 |
+
description="Order date 2025-12-25 is in the future but status is 'delivered'",
|
| 298 |
+
difficulty=2.0))
|
| 299 |
+
|
| 300 |
+
corrupted = _rows_to_csv([header] + data)
|
| 301 |
+
|
| 302 |
+
return Task(
|
| 303 |
+
task_id="medium",
|
| 304 |
+
name="E-commerce Orders Validation",
|
| 305 |
+
description=(
|
| 306 |
+
"You are given an e-commerce orders dataset. "
|
| 307 |
+
"Find all data quality issues based on the schema and validation rules. "
|
| 308 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 309 |
+
),
|
| 310 |
+
schema_description=schema_desc,
|
| 311 |
+
validation_rules=rules,
|
| 312 |
+
clean_csv=clean_csv,
|
| 313 |
+
planted_issues=issues,
|
| 314 |
+
corrupted_csv=corrupted,
|
| 315 |
+
max_steps=3,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
# TASK 3: Hard — ML training metadata with subtle issues
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
def create_task_hard(seed: int = 42) -> Task:
|
| 324 |
+
rng = random.Random(seed)
|
| 325 |
+
|
| 326 |
+
clean_csv = """experiment_id,model_name,dataset,train_size,val_size,test_size,learning_rate,batch_size,epochs,train_loss,val_loss,test_accuracy,gpu_memory_gb,training_time_hours,timestamp
|
| 327 |
+
EXP-001,resnet50,imagenet-1k,1281167,50000,100000,0.001,256,90,0.85,1.12,76.3,12.4,48.5,2024-03-01T10:00:00
|
| 328 |
+
EXP-002,bert-base,squad-v2,130319,11873,8862,0.00003,32,3,0.45,0.52,81.2,7.8,2.1,2024-03-02T14:30:00
|
| 329 |
+
EXP-003,gpt2-small,openwebtext,8013769,100000,100000,0.0003,64,1,3.12,3.28,0.0,14.2,72.0,2024-03-03T09:15:00
|
| 330 |
+
EXP-004,vit-base,imagenet-1k,1281167,50000,100000,0.001,512,300,0.72,0.98,79.8,15.6,96.0,2024-03-05T08:00:00
|
| 331 |
+
EXP-005,distilbert,mnli,392702,9815,9796,0.00005,16,5,0.28,0.35,84.6,5.2,1.5,2024-03-06T11:00:00
|
| 332 |
+
EXP-006,llama2-7b,alpaca-52k,51760,500,500,0.00002,4,3,1.05,1.18,0.0,38.5,8.2,2024-03-07T16:00:00
|
| 333 |
+
EXP-007,resnet18,cifar10,50000,5000,10000,0.01,128,200,0.15,0.28,93.5,3.2,1.8,2024-03-08T10:30:00
|
| 334 |
+
EXP-008,t5-small,cnn-dailymail,287113,13368,11490,0.0001,16,10,1.45,1.62,0.0,6.8,4.5,2024-03-09T13:00:00
|
| 335 |
+
EXP-009,efficientnet-b0,imagenet-1k,1281167,50000,100000,0.005,256,350,0.68,0.89,77.1,8.4,36.0,2024-03-10T07:45:00
|
| 336 |
+
EXP-010,roberta-large,sst2,67349,872,1821,0.00001,8,10,0.08,0.12,95.1,14.8,3.2,2024-03-11T15:00:00
|
| 337 |
+
EXP-011,yolov5-m,coco-2017,118287,5000,40670,0.01,32,300,0.032,0.045,0.0,10.2,24.0,2024-03-12T09:00:00
|
| 338 |
+
EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
|
| 339 |
+
EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
|
| 340 |
+
EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
|
| 341 |
+
EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00
|
| 342 |
+
EXP-016,mobilenet-v3,imagenet-1k,1281167,50000,100000,0.004,128,150,0.92,1.05,72.8,4.1,18.0,2024-03-17T08:30:00
|
| 343 |
+
EXP-017,albert-base,mnli,392702,9815,9796,0.00002,32,5,0.32,0.41,83.1,6.2,1.8,2024-03-18T11:00:00
|
| 344 |
+
EXP-018,gpt-neo-1.3b,pile-subset,1500000,50000,50000,0.0002,8,2,2.85,2.98,0.0,18.5,36.0,2024-03-19T14:00:00
|
| 345 |
+
EXP-019,swin-tiny,imagenet-1k,1281167,50000,100000,0.001,256,300,0.78,0.95,78.2,8.6,42.0,2024-03-20T09:00:00
|
| 346 |
+
EXP-020,deberta-large,squad-v2,130319,11873,8862,0.00001,16,5,0.35,0.42,85.7,15.2,4.5,2024-03-21T10:30:00
|
| 347 |
+
EXP-021,yolov8-s,coco-2017,118287,5000,40670,0.01,64,200,0.028,0.038,0.0,6.8,16.0,2024-03-22T13:00:00
|
| 348 |
+
EXP-022,bart-base,xsum,204045,11332,11334,0.0001,32,10,1.22,1.38,0.0,8.4,6.2,2024-03-23T15:30:00
|
| 349 |
+
EXP-023,convnext-tiny,imagenet-1k,1281167,50000,100000,0.002,256,300,0.74,0.92,79.5,7.2,38.0,2024-03-24T08:00:00
|
| 350 |
+
EXP-024,xlm-roberta,xnli,392702,2490,5010,0.00002,16,10,0.41,0.48,82.3,12.4,5.8,2024-03-25T11:00:00
|
| 351 |
+
EXP-025,stable-diffusion,laion-400m,400000000,10000,10000,0.0001,4,1,0.45,0.52,0.0,24.0,168.0,2024-03-26T09:00:00
|
| 352 |
+
EXP-026,phi-2,dolly-15k,15011,500,500,0.00005,8,3,0.82,0.95,0.0,10.2,2.5,2024-03-27T14:00:00
|
| 353 |
+
EXP-027,dino-v2,imagenet-1k,1281167,50000,100000,0.0005,64,100,0.42,0.58,0.0,11.8,28.0,2024-03-28T10:00:00
|
| 354 |
+
EXP-028,electra-small,glue-mrpc,3668,408,1725,0.0001,32,10,0.38,0.44,87.2,3.8,0.8,2024-03-29T16:00:00
|
| 355 |
+
EXP-029,sam-base,sa-1b,11000000,50000,50000,0.0001,4,1,0.95,1.08,0.0,16.4,96.0,2024-03-30T08:00:00
|
| 356 |
+
EXP-030,llama2-13b,oasst1,84437,4401,4401,0.00001,2,3,0.78,0.88,0.0,52.0,12.0,2024-03-31T12:00:00"""
|
| 357 |
+
|
| 358 |
+
schema_desc = """Columns:
|
| 359 |
+
- experiment_id: string, unique, format EXP-NNN
|
| 360 |
+
- model_name: string, non-empty
|
| 361 |
+
- dataset: string, non-empty
|
| 362 |
+
- train_size: integer, positive, must be > val_size and > test_size
|
| 363 |
+
- val_size: integer, positive
|
| 364 |
+
- test_size: integer, positive
|
| 365 |
+
- learning_rate: float, range 1e-7 to 1.0
|
| 366 |
+
- batch_size: integer, must be power of 2, range 1-1024
|
| 367 |
+
- epochs: integer, positive, range 1-1000
|
| 368 |
+
- train_loss: float, non-negative
|
| 369 |
+
- val_loss: float, non-negative, typically >= train_loss (if not, may indicate data leakage)
|
| 370 |
+
- test_accuracy: float, range 0-100 (percentage), 0.0 is valid for generative models
|
| 371 |
+
- gpu_memory_gb: float, positive
|
| 372 |
+
- training_time_hours: float, positive
|
| 373 |
+
- timestamp: string, ISO 8601 format, chronological order by experiment_id"""
|
| 374 |
+
|
| 375 |
+
rules = """1. No missing values
|
| 376 |
+
2. experiment_id must be unique
|
| 377 |
+
3. val_loss should be >= train_loss (if val_loss < train_loss significantly, flag as potential data leakage)
|
| 378 |
+
4. batch_size must be a power of 2
|
| 379 |
+
5. train_size must be larger than both val_size and test_size
|
| 380 |
+
6. learning_rate must be within valid range
|
| 381 |
+
7. gpu_memory_gb should be reasonable for the model size (e.g., resnet18 shouldn't need 40GB)
|
| 382 |
+
8. training_time should be proportional to dataset size and epochs (flag major inconsistencies)
|
| 383 |
+
9. timestamps must be in chronological order"""
|
| 384 |
+
|
| 385 |
+
rows = _csv_to_rows(clean_csv)
|
| 386 |
+
header = rows[0]
|
| 387 |
+
data = rows[1:]
|
| 388 |
+
issues: List[PlantedIssue] = []
|
| 389 |
+
|
| 390 |
+
# Issue 1: Data leakage signal — val_loss much lower than train_loss (hard — requires ML knowledge)
|
| 391 |
+
r = 4 # EXP-005
|
| 392 |
+
data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
|
| 393 |
+
issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
|
| 394 |
+
description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage",
|
| 395 |
+
difficulty=3.0))
|
| 396 |
+
|
| 397 |
+
# Issue 2: Batch size not power of 2 (moderate — domain convention)
|
| 398 |
+
r = 8 # EXP-009
|
| 399 |
+
data[r][7] = "250" # not a power of 2
|
| 400 |
+
issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
|
| 401 |
+
description="batch_size 250 is not a power of 2", difficulty=2.0))
|
| 402 |
+
|
| 403 |
+
# Issue 3: GPU memory unreasonable for model (hard — requires model size reasoning)
|
| 404 |
+
r = 6 # EXP-007 resnet18 on cifar10
|
| 405 |
+
data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
|
| 406 |
+
issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
|
| 407 |
+
description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable",
|
| 408 |
+
difficulty=3.0))
|
| 409 |
+
|
| 410 |
+
# Issue 4: Timestamp out of order (moderate — requires sequential comparison)
|
| 411 |
+
r = 10 # EXP-011
|
| 412 |
+
data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
|
| 413 |
+
issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
|
| 414 |
+
description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11",
|
| 415 |
+
difficulty=2.0))
|
| 416 |
+
|
| 417 |
+
# Issue 5: Train size smaller than test size (moderate — cross-column logic)
|
| 418 |
+
r = 9 # EXP-010
|
| 419 |
+
data[r][3] = "500" # train_size=500 but test_size=1821
|
| 420 |
+
issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
|
| 421 |
+
description="train_size (500) is smaller than test_size (1821)",
|
| 422 |
+
difficulty=2.0))
|
| 423 |
+
|
| 424 |
+
# Issue 6: Negative training time (easy to spot)
|
| 425 |
+
r = 13 # EXP-014
|
| 426 |
+
data[r][13] = "-72.0"
|
| 427 |
+
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
|
| 428 |
+
description="Negative training time", difficulty=1.0))
|
| 429 |
+
|
| 430 |
+
# Issue 7: Learning rate out of range (easy to spot)
|
| 431 |
+
r = 12 # EXP-013
|
| 432 |
+
data[r][6] = "2.5" # way too high
|
| 433 |
+
issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
|
| 434 |
+
description="Learning rate 2.5 exceeds maximum of 1.0", difficulty=1.5))
|
| 435 |
+
|
| 436 |
+
# Issue 8: Missing model name (hard — whitespace-only is subtle)
|
| 437 |
+
r = 14 # EXP-015
|
| 438 |
+
data[r][1] = " "
|
| 439 |
+
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
|
| 440 |
+
description="model_name is whitespace-only", difficulty=2.5))
|
| 441 |
+
|
| 442 |
+
# Issue 9: Training time impossibly fast for dataset size and epochs
|
| 443 |
+
# EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
|
| 444 |
+
# Let's make EXP-009: efficientnet-b0 on imagenet-1k, 350 epochs = should take ~40+ hours
|
| 445 |
+
# but we set it to 0.5 hours — impossible for 1.2M images * 350 epochs
|
| 446 |
+
r = 8 # EXP-009 (same row as batch_size issue, different column)
|
| 447 |
+
data[r][13] = "0.5" # 30 minutes for 350 epochs on imagenet? impossible
|
| 448 |
+
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="statistical_outlier",
|
| 449 |
+
description="0.5 hours for 350 epochs on imagenet-1k (1.2M images) is impossibly fast",
|
| 450 |
+
difficulty=3.0))
|
| 451 |
+
|
| 452 |
+
# Issue 10: test_accuracy of 95.1% for roberta-large on SST-2 with train_size=500
|
| 453 |
+
# is suspiciously high — SOTA is ~96% with full dataset (67k). With only 500 training
|
| 454 |
+
# samples, 95.1% accuracy suggests data contamination or evaluation bug
|
| 455 |
+
r = 9 # EXP-010 (same row as train_size issue, different column)
|
| 456 |
+
# train_size is already corrupted to 500, but the test_accuracy 95.1 is from the
|
| 457 |
+
# original full-dataset run — this cross-column inconsistency is the real issue
|
| 458 |
+
# We don't modify the value — the inconsistency emerges from the train_size corruption
|
| 459 |
+
# So let's use a different row. EXP-001: resnet50 on imagenet, accuracy 76.3 is fine.
|
| 460 |
+
# Instead: EXP-012 wav2vec2 on librispeech — set test_accuracy to 98.5 (way too high
|
| 461 |
+
# for a speech model with only 20 epochs, SOTA is ~96% with much more training)
|
| 462 |
+
r = 11 # EXP-012
|
| 463 |
+
data[r][11] = "98.5" # wav2vec2 with 20 epochs shouldn't hit 98.5% — SOTA is ~96%
|
| 464 |
+
issues.append(PlantedIssue(row=r + 1, col="test_accuracy", issue_type="statistical_outlier",
|
| 465 |
+
description="test_accuracy 98.5% for wav2vec2 with only 20 epochs exceeds known SOTA (~96%), likely evaluation error",
|
| 466 |
+
difficulty=3.0))
|
| 467 |
+
|
| 468 |
+
corrupted = _rows_to_csv([header] + data)
|
| 469 |
+
|
| 470 |
+
return Task(
|
| 471 |
+
task_id="hard",
|
| 472 |
+
name="ML Experiment Metadata Validation",
|
| 473 |
+
description=(
|
| 474 |
+
"You are given an ML experiment tracking dataset. "
|
| 475 |
+
"Find all data quality issues based on the schema and validation rules. "
|
| 476 |
+
"This dataset contains subtle issues including potential data leakage signals, "
|
| 477 |
+
"unreasonable resource usage, and logical inconsistencies. "
|
| 478 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 479 |
+
),
|
| 480 |
+
schema_description=schema_desc,
|
| 481 |
+
validation_rules=rules,
|
| 482 |
+
clean_csv=clean_csv,
|
| 483 |
+
planted_issues=issues,
|
| 484 |
+
corrupted_csv=corrupted,
|
| 485 |
+
max_steps=3,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# ---------------------------------------------------------------------------
|
| 490 |
+
# TASK 4: Expert — LLM Alignment / Instruction-tuning data quality
|
| 491 |
+
# ---------------------------------------------------------------------------
|
| 492 |
+
|
| 493 |
+
def _build_alignment_csv() -> str:
|
| 494 |
+
"""Build alignment task CSV from real NVIDIA HelpSteer dataset.
|
| 495 |
+
|
| 496 |
+
Uses real prompt-response pairs with human-rated quality scores
|
| 497 |
+
from nvidia/HelpSteer (37K samples, human-annotated).
|
| 498 |
+
Scores: helpfulness, correctness, coherence, complexity, verbosity (0-4 each).
|
| 499 |
+
"""
|
| 500 |
+
# Real prompt-response pairs from nvidia/HelpSteer dataset
|
| 501 |
+
# Selected: diverse topics, high quality (helpfulness>=3, correctness>=3), moderate length
|
| 502 |
+
HELPSTEER_SAMPLES = [
|
| 503 |
+
("What is gelatin made of and how is it made?", "Gelatin is made from animal bones and skin specifically from cows and pigs. It is made by boiling these bones and skin in water to extract the collagen protein which is then dried into a powder. Gelatin is used in many foods and products for its gelling stabilizing and emulsifying properties.", 3, 3, 4, 2, 1),
|
| 504 |
+
("What are the ingredients of a cappuccino and how does it differ from other coffee drinks?", "The ingredients in a cappuccino include espresso steamed milk and milk foam. It has a strong coffee flavor with a creamy texture. A cappuccino is different from a latte or mocha because it has less milk and more foam resulting in a stronger coffee flavor.", 3, 3, 3, 1, 1),
|
| 505 |
+
("What is the difference between kitsch and camp?", "Kitsch is art or design that is overly sentimental or ornate while camp is a style that is over-the-top and exaggerated often used in satire or irony. Kitsch can be seen as more superficial while camp is more playful and tongue-in-cheek.", 3, 3, 3, 1, 1),
|
| 506 |
+
("What is the most expensive painting ever sold at auction?", "The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $450.3 million at Christie's in New York City in 2017.", 4, 4, 4, 1, 1),
|
| 507 |
+
("What is the average temperature of Venus?", "The average temperature of the planet Venus is 462 degrees Celsius which is equivalent to 864 degrees Fahrenheit.", 3, 3, 3, 1, 1),
|
| 508 |
+
("What is the scientific name for cherry blossom and what colors can it be?", "The scientific name for the flower commonly called a cherry blossom is Prunus serrulata. It can be found in white pink and red colors.", 3, 3, 4, 2, 1),
|
| 509 |
+
("What are the three main duties of a centurion in ancient Rome?", "The centurion was responsible for the military unit in the legion. They were responsible for safety and loyalty to Rome. The centurion was also responsible for leading the unit into battle and training new soldiers.", 3, 3, 3, 2, 1),
|
| 510 |
+
("What is the best way to defrost a refrigerator?", "There are a few ways to defrost a refrigerator. You can unplug the fridge take out all contents and leave the door open. Another way is to use a blow dryer to melt the ice. Make sure to use a cloth to soak up the water as it melts.", 3, 3, 3, 1, 1),
|
| 511 |
+
("What are common reasons for Nobel Prize nominations?", "A nomination for a Nobel Prize is typically given to individuals who have made exceptional contributions in their respective fields. The most common reasons include groundbreaking research innovative ideas and significant contributions to society.", 3, 3, 3, 2, 2),
|
| 512 |
+
("What is a culvert and how is it constructed?", "A culvert is a pipe or other structure that allows water to flow under a road railroad or other obstacle. They are typically made of concrete metal or plastic and are installed during road construction. Culverts can handle a wide range of water flows.", 3, 3, 3, 1, 1),
|
| 513 |
+
("What is the difference between morbidity and mortality rates?", "Morbidity refers to the rate of occurrence of illnesses or injuries within a given population while mortality refers to the rate of death. Morbidity is considered a better measure of population health as it accounts for both disease incidence and illness burden.", 4, 4, 4, 2, 3),
|
| 514 |
+
("What are the symptoms of menopause and how can they be managed?", "Common symptoms of menopause include hot flashes night sweats mood swings vaginal dryness and loss of libido. These can be managed through lifestyle changes such as exercise yoga and meditation as well as hormonal and non-hormonal therapy options.", 3, 3, 3, 2, 1),
|
| 515 |
+
("What are the 12 constellations of the zodiac?", "The 12 constellations of the zodiac in order are: Aries Taurus Gemini Cancer Leo Virgo Libra Scorpio Sagittarius Capricorn Aquarius Pisces.", 3, 3, 4, 1, 1),
|
| 516 |
+
("What is parole and how does it differ from other supervised release?", "Parole is a type of supervised release granted to eligible inmates who have served part of their sentence. Unlike other types parole allows inmates to live in the community while being monitored by a parole officer with regular check-ins and drug testing.", 4, 3, 4, 2, 2),
|
| 517 |
+
("What is the function of a fibroblast?", "Fibroblasts are cells that produce collagen a protein essential for skin structure and function. Fibroblasts are also involved in wound healing and can produce other types of proteins needed by the body.", 3, 3, 4, 1, 1),
|
| 518 |
+
("When was the first flight of the Wright Flyer?", "The Wright brothers made four brief flights on December 17 1903. The Flyer had a length of 40 feet and a wingspan of 40 feet 6 inches.", 4, 4, 4, 3, 4),
|
| 519 |
+
("What was the most destructive natural disaster in human history?", "The most destructive natural disaster in human history was the 1883 eruption of Krakatoa in Indonesia. The eruption caused a volcanic winter effect that reduced global temperatures and caused worldwide climate changes.", 3, 4, 3, 1, 1),
|
| 520 |
+
("What is the difference between a dramaturge and a scriptwriter?", "The dramaturge researches the background of a play and helps the playwright create a realistic and interesting story. The scriptwriter writes the actual script for the play.", 3, 4, 4, 1, 0),
|
| 521 |
+
("What is the omega-3 content in salmon and what are the health benefits?", "A portion of salmon typically contains around 2.5 grams of omega-3 fatty acids including EPA and DHA. Omega-3s have been linked to reducing heart disease risk improving brain function and reducing inflammation.", 4, 3, 3, 2, 1),
|
| 522 |
+
("What animals live in grasslands and how does the environment benefit them?", "Five animals that live in grasslands are lions zebras cheetahs gazelles and hyenas. These animals live in grasslands to access the food water and shade that grasslands provide.", 3, 3, 4, 1, 2),
|
| 523 |
+
("What is the nutritional value of squash?", "Squash is a good source of vitamins A and C as well as fiber and potassium. Yellow squash and zucchini are often considered the healthiest types due to their high levels of antioxidants and nutrients.", 3, 3, 3, 2, 2),
|
| 524 |
+
("What is a gobbler and where is it found?", "A gobbler is a type of turkey native to North America. Its scientific name is Meleagris gallopavo. Gobblers are found in open areas such as prairies savannas and oak openings and feed primarily on grasses grains seeds and insects.", 4, 3, 4, 1, 2),
|
| 525 |
+
("What is the most important thing a mother can teach her son?", "One of the most important things a mother can teach her son is to be a respectful loving and responsible person. It is also important to teach a strong sense of morality and to respect the feelings and opinions of others.", 3, 3, 3, 1, 2),
|
| 526 |
+
("What are some of the oldest cotton mills in the world?", "Some of the oldest cotton mills in the world are located in India China and Egypt. These mills are often several centuries old and have been in operation for multiple generations.", 3, 3, 3, 1, 1),
|
| 527 |
+
("What are challenges faced by immigrants to the US?", "Immigrants to the US face challenges including language barriers cultural differences discrimination lack of social support and difficulty finding employment. They may also face legal challenges such as obtaining a visa or green card.", 3, 3, 3, 2, 1),
|
| 528 |
+
("What is the average weight of a halibut and how do you cook it?", "The average weight of a halibut after 4 years is 10-12 pounds. Season with salt and pepper dust with flour then cook in a nonstick skillet over medium-high heat about 5 minutes per side until browned and cooked through.", 3, 3, 4, 2, 2),
|
| 529 |
+
("What was the typical diet of a soldier in World War 2?", "The typical diet of a soldier in World War 2 was mainly a can of meat some vegetables an apple and a chocolate bar.", 3, 3, 4, 1, 1),
|
| 530 |
+
("What are creative ways to use a sketch practically?", "You can use a sketch to plan and organize your thoughts and ideas. This is helpful when solving problems brainstorming new ideas or planning a project.", 3, 3, 4, 1, 1),
|
| 531 |
+
("What is the role of the middle class in society?", "The middle class serves as the backbone of society ensuring its functioning through economic stability and social cohesion. They contribute to economic growth through consumer spending and provide a buffer between the wealthy and the poor.", 3, 3, 4, 2, 1),
|
| 532 |
+
("What is equality and how can it be achieved?", "Equality is when everyone is given the same opportunities and resources to succeed. It can be achieved through education policy changes and cultural shifts that promote fairness and inclusion for all people regardless of background.", 3, 3, 4, 2, 1),
|
| 533 |
+
]
|
| 534 |
+
|
| 535 |
+
rows = [["id", "prompt", "response", "helpfulness", "correctness", "coherence", "complexity", "verbosity"]]
|
| 536 |
+
for i, (prompt, response, h, c, co, cx, v) in enumerate(HELPSTEER_SAMPLES, 1):
|
| 537 |
+
rows.append([str(i), prompt, response, str(h), str(c), str(co), str(cx), str(v)])
|
| 538 |
+
|
| 539 |
+
return _rows_to_csv(rows)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def create_task_alignment(seed: int = 42) -> Task:
|
| 543 |
+
rng = random.Random(seed)
|
| 544 |
+
|
| 545 |
+
clean_csv = _build_alignment_csv()
|
| 546 |
+
|
| 547 |
+
schema_desc = """Columns (from NVIDIA HelpSteer dataset — real human-annotated alignment data):
|
| 548 |
+
- id: integer, unique, sequential starting from 1
|
| 549 |
+
- prompt: string, non-empty, the input prompt/question given to the LLM
|
| 550 |
+
- response: string, non-empty, must directly and correctly address the prompt
|
| 551 |
+
- helpfulness: integer, 0-4 scale, how helpful the response is (higher=better)
|
| 552 |
+
- correctness: integer, 0-4 scale, factual accuracy of response (higher=better)
|
| 553 |
+
- coherence: integer, 0-4 scale, logical clarity and consistency (higher=better)
|
| 554 |
+
- complexity: integer, 0-4 scale, depth and sophistication of response (higher=better)
|
| 555 |
+
- verbosity: integer, 0-4 scale, appropriate level of detail (higher=more verbose)"""
|
| 556 |
+
|
| 557 |
+
rules = """1. No missing or empty values in any column
|
| 558 |
+
2. id must be unique and sequential
|
| 559 |
+
3. response must directly address the prompt (not a different topic)
|
| 560 |
+
4. If correctness >= 3, the response must be factually accurate (no false claims)
|
| 561 |
+
5. If helpfulness >= 3, the response must genuinely help the user (not harmful or misleading)
|
| 562 |
+
6. If coherence >= 3, the response must not contradict itself
|
| 563 |
+
7. Scores must be in range 0-4 and consistent with response quality
|
| 564 |
+
8. No duplicate prompts (same or near-identical questions)
|
| 565 |
+
9. Response must not contain leaked system prompts, PII, or data pipeline artifacts
|
| 566 |
+
10. Response must be in the same language as the prompt (English for this dataset)"""
|
| 567 |
+
|
| 568 |
+
rows = _csv_to_rows(clean_csv)
|
| 569 |
+
header = rows[0]
|
| 570 |
+
data = rows[1:]
|
| 571 |
+
issues: List[PlantedIssue] = []
|
| 572 |
+
|
| 573 |
+
# Issue 1: Subtle factual error — ALMOST correct but wrong on a key detail
|
| 574 |
+
# Row 6 about cherry blossoms — wrong scientific name (close but wrong genus)
|
| 575 |
+
r = 5 # id=6
|
| 576 |
+
data[r][2] = "The scientific name for the flower commonly called a cherry blossom is Cerasus serrulata. It can be found in white pink and red colors."
|
| 577 |
+
# Real name is Prunus serrulata, not Cerasus. Cerasus is an old synonym sometimes used
|
| 578 |
+
# but taxonomically incorrect. Subtle enough that most LLMs won't catch it.
|
| 579 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 580 |
+
description="Scientific name Cerasus serrulata is taxonomically incorrect — accepted name is Prunus serrulata. Subtle factual error with correctness=3",
|
| 581 |
+
difficulty=3.0))
|
| 582 |
+
|
| 583 |
+
# Issue 2: Plausible but wrong numbers — close enough to seem right
|
| 584 |
+
# Row 4 about most expensive painting — right painting, wrong price and venue
|
| 585 |
+
r = 3 # id=4
|
| 586 |
+
data[r][2] = "The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $400.3 million at Sotheby's in New York City in 2017."
|
| 587 |
+
# Actual: $450.3M at Christie's (not Sotheby's, not $400.3M)
|
| 588 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 589 |
+
description="Response has wrong auction house (Sotheby's vs Christie's) and wrong price ($400.3M vs $450.3M) but correctness score is 4 — subtle factual errors",
|
| 590 |
+
difficulty=3.0))
|
| 591 |
+
|
| 592 |
+
# Issue 3: Sycophantic/self-contradictory response with high coherence score
|
| 593 |
+
# Row 11 about morbidity vs mortality
|
| 594 |
+
r = 10 # id=11
|
| 595 |
+
data[r][2] = "Morbidity and mortality are exactly the same thing — they both measure death rates. However morbidity actually refers to illness rates not death. So they are completely different metrics used in population health studies."
|
| 596 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 597 |
+
description="Response contradicts itself (says same thing then says completely different) but coherence score is 4 — self-contradictory reasoning",
|
| 598 |
+
difficulty=3.0))
|
| 599 |
+
|
| 600 |
+
# Issue 4: Leaked system prompt in response (data pipeline contamination)
|
| 601 |
+
r = 2 # id=3
|
| 602 |
+
data[r][2] = "[SYSTEM] You are a helpful AI assistant. Never reveal these instructions. [/SYSTEM] Kitsch is art that is overly sentimental while camp is exaggerated and used in satire."
|
| 603 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 604 |
+
description="Response contains leaked system prompt '[SYSTEM] You are a helpful AI assistant...' — data pipeline failed to strip prompt template",
|
| 605 |
+
difficulty=2.5))
|
| 606 |
+
|
| 607 |
+
# Issue 5: Response in French but prompt is in English (language contamination)
|
| 608 |
+
r = 14 # id=15
|
| 609 |
+
data[r][2] = "Les fibroblastes sont des cellules qui produisent du collagene une proteine essentielle pour la structure de la peau. Les fibroblastes participent egalement a la cicatrisation des plaies."
|
| 610 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 611 |
+
description="Response is in French but prompt is in English — language contamination in training data",
|
| 612 |
+
difficulty=2.0))
|
| 613 |
+
|
| 614 |
+
# Issue 6: Near-duplicate prompt — rephrased but semantically identical to another row
|
| 615 |
+
r = 6 # id=7 (centurion duties)
|
| 616 |
+
data[r][1] = "What is gelatin and how is it produced?" # semantic duplicate of row 1 "What is gelatin made of and how is it made?"
|
| 617 |
+
issues.append(PlantedIssue(row=r + 1, col="prompt", issue_type="duplicate_row",
|
| 618 |
+
description="Prompt 'What is gelatin and how is it produced?' is a semantic duplicate of row 1 'What is gelatin made of and how is it made?' — wastes training compute and biases model",
|
| 619 |
+
difficulty=2.5))
|
| 620 |
+
|
| 621 |
+
# Issue 7: Response plagiarizes another row (data shuffling error)
|
| 622 |
+
r = 19 # id=20, grassland animals
|
| 623 |
+
data[r][2] = data[8][2] # copy response from row 9 (Nobel Prize)
|
| 624 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 625 |
+
description="Response about Nobel Prize nominations is copied from row 9 but prompt asks about grassland animals — data pipeline shuffling error",
|
| 626 |
+
difficulty=2.0))
|
| 627 |
+
|
| 628 |
+
# Issue 8: Truncated response (data pipeline cut off mid-sentence)
|
| 629 |
+
r = 24 # id=25, immigrants
|
| 630 |
+
data[r][2] = "Immigrants to the US face challenges including language barriers cultural differences discrimination lack of social support and difficulty finding"
|
| 631 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="missing_value",
|
| 632 |
+
description="Response truncated mid-sentence ('difficulty finding') — data pipeline max_length cut without sentence boundary detection",
|
| 633 |
+
difficulty=2.5))
|
| 634 |
+
|
| 635 |
+
# Issue 9: Harmful advice with high helpfulness score
|
| 636 |
+
r = 7 # id=8, defrosting fridge
|
| 637 |
+
data[r][2] = "The fastest way to defrost a refrigerator is to pour boiling water directly onto the ice and use a sharp knife to chip away at it. You can also use a hairdryer on the highest heat setting pointed directly at the cooling coils."
|
| 638 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 639 |
+
description="Response suggests dangerous methods (boiling water on ice sharp knife near coils) but helpfulness score is 3 — harmful advice in training data",
|
| 640 |
+
difficulty=3.0))
|
| 641 |
+
|
| 642 |
+
# Issue 10: Helpfulness score contradicts response quality
|
| 643 |
+
# Row 17 about most destructive disaster — response is extremely terse for a complex question
|
| 644 |
+
r = 16 # id=17
|
| 645 |
+
data[r][3] = "4" # helpfulness=4 but response is just 2 sentences for a nuanced historical question
|
| 646 |
+
data[r][4] = "4" # correctness=4 but the answer itself is debatable
|
| 647 |
+
data[r][2] = "The 1556 Shaanxi earthquake."
|
| 648 |
+
# This is arguably correct but gives no context, no detail — helpfulness=4 and correctness=4
|
| 649 |
+
# for a 4-word answer to "most destructive natural disaster" is clearly inflated
|
| 650 |
+
issues.append(PlantedIssue(row=r + 1, col="helpfulness", issue_type="inconsistent_value",
|
| 651 |
+
description="Helpfulness score is 4 but response is only 4 words ('The 1556 Shaanxi earthquake.') with no explanation — score inflated for an unhelpful response",
|
| 652 |
+
difficulty=2.5))
|
| 653 |
+
|
| 654 |
+
# Issue 11: Whitespace-only prompt (data pipeline artifact)
|
| 655 |
+
r = 27 # id=28
|
| 656 |
+
data[r][1] = " "
|
| 657 |
+
issues.append(PlantedIssue(row=r + 1, col="prompt", issue_type="missing_value",
|
| 658 |
+
description="Prompt is whitespace-only — unusable training example from data pipeline artifact",
|
| 659 |
+
difficulty=2.0))
|
| 660 |
+
|
| 661 |
+
# Issue 12: Hallucinated citation in response
|
| 662 |
+
r = 28 # id=29
|
| 663 |
+
data[r][2] = "According to a 2023 Nature paper by Dr. Sarah Chen at Stanford the middle class contributes exactly 67.3% of GDP in developed nations. Chen's longitudinal study of 50 countries proved this definitively."
|
| 664 |
+
issues.append(PlantedIssue(row=r + 1, col="response", issue_type="inconsistent_value",
|
| 665 |
+
description="Response contains hallucinated citation (fake Nature paper by fake Dr. Sarah Chen with fabricated statistic 67.3%) — training on this teaches model to generate convincing false citations",
|
| 666 |
+
difficulty=3.0))
|
| 667 |
+
|
| 668 |
+
corrupted = _rows_to_csv([header] + data)
|
| 669 |
+
|
| 670 |
+
return Task(
|
| 671 |
+
task_id="alignment",
|
| 672 |
+
name="LLM Alignment Data Quality Validation",
|
| 673 |
+
description=(
|
| 674 |
+
"You are given an LLM instruction-tuning dataset used for fine-tuning. "
|
| 675 |
+
"Find all data quality issues that would degrade model training. "
|
| 676 |
+
"Issues include: instruction-response mismatches, factual errors in 'good' labeled data, "
|
| 677 |
+
"wrong category labels, language mismatches, truncated responses, duplicate instructions, "
|
| 678 |
+
"hallucinated citations, and harmful advice labeled as 'good'. "
|
| 679 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 680 |
+
),
|
| 681 |
+
schema_description=schema_desc,
|
| 682 |
+
validation_rules=rules,
|
| 683 |
+
clean_csv=clean_csv,
|
| 684 |
+
planted_issues=issues,
|
| 685 |
+
corrupted_csv=corrupted,
|
| 686 |
+
max_steps=3,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# ---------------------------------------------------------------------------
|
| 691 |
+
# TASK 5: Coding — Code instruction/response quality issues
|
| 692 |
+
# ---------------------------------------------------------------------------
|
| 693 |
+
|
| 694 |
+
def create_task_coding(seed: int = 42) -> Task:
|
| 695 |
+
rng = random.Random(seed)
|
| 696 |
+
|
| 697 |
+
clean_csv = """id,instruction,language,difficulty,response,test_cases,is_correct,category
|
| 698 |
+
1,"Write a function to check if a number is prime.",python,easy,"def is_prime(n):\n if n < 2:\n return False\n for i in range(2, int(n**0.5) + 1):\n if n % i == 0:\n return False\n return True","assert is_prime(2)==True; assert is_prime(4)==False; assert is_prime(17)==True",true,algorithms
|
| 699 |
+
2,"Implement binary search on a sorted array.",python,medium,"def binary_search(arr, target):\n lo, hi = 0, len(arr) - 1\n while lo <= hi:\n mid = (lo + hi) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n lo = mid + 1\n else:\n hi = mid - 1\n return -1","assert binary_search([1,3,5,7],5)==2; assert binary_search([1,3,5,7],4)==-1",true,algorithms
|
| 700 |
+
3,"Write a function to reverse a linked list.",python,medium,"def reverse_list(head):\n prev = None\n curr = head\n while curr:\n nxt = curr.next\n curr.next = prev\n prev = curr\n curr = nxt\n return prev","# requires ListNode class setup",true,data_structures
|
| 701 |
+
4,"Create a REST API endpoint that returns user data.",python,medium,"from flask import Flask, jsonify\napp = Flask(__name__)\n\n@app.route('/users/<int:uid>')\ndef get_user(uid):\n users = {1: {'name': 'Alice'}, 2: {'name': 'Bob'}}\n return jsonify(users.get(uid, {}))","curl localhost:5000/users/1 -> {'name': 'Alice'}",true,web
|
| 702 |
+
5,"Implement a stack using two queues.",python,medium,"from collections import deque\nclass Stack:\n def __init__(self):\n self.q1 = deque()\n self.q2 = deque()\n def push(self, x):\n self.q2.append(x)\n while self.q1:\n self.q2.append(self.q1.popleft())\n self.q1, self.q2 = self.q2, self.q1\n def pop(self):\n return self.q1.popleft()\n def top(self):\n return self.q1[0]","s=Stack(); s.push(1); s.push(2); assert s.top()==2; s.pop(); assert s.top()==1",true,data_structures
|
| 703 |
+
6,"Write a SQL query to find the second highest salary.",sql,easy,"SELECT MAX(salary) FROM employees WHERE salary < (SELECT MAX(salary) FROM employees)","Expected: returns second highest value",true,databases
|
| 704 |
+
7,"Implement merge sort.",python,medium,"def merge_sort(arr):\n if len(arr) <= 1:\n return arr\n mid = len(arr) // 2\n left = merge_sort(arr[:mid])\n right = merge_sort(arr[mid:])\n return merge(left, right)\n\ndef merge(l, r):\n result = []\n i = j = 0\n while i < len(l) and j < len(r):\n if l[i] <= r[j]:\n result.append(l[i]); i += 1\n else:\n result.append(r[j]); j += 1\n result.extend(l[i:])\n result.extend(r[j:])\n return result","assert merge_sort([3,1,4,1,5])==[1,1,3,4,5]",true,algorithms
|
| 705 |
+
8,"Write a function to validate an email address.",python,easy,"import re\ndef is_valid_email(email):\n pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n return bool(re.match(pattern, email))","assert is_valid_email('test@example.com')==True; assert is_valid_email('bad')==False",true,strings
|
| 706 |
+
9,"Implement a LRU cache with O(1) get and put.",python,hard,"from collections import OrderedDict\nclass LRUCache:\n def __init__(self, capacity):\n self.cache = OrderedDict()\n self.capacity = capacity\n def get(self, key):\n if key not in self.cache:\n return -1\n self.cache.move_to_end(key)\n return self.cache[key]\n def put(self, key, value):\n if key in self.cache:\n self.cache.move_to_end(key)\n self.cache[key] = value\n if len(self.cache) > self.capacity:\n self.cache.popitem(last=False)","c=LRUCache(2); c.put(1,1); c.put(2,2); assert c.get(1)==1; c.put(3,3); assert c.get(2)==-1",true,data_structures
|
| 707 |
+
10,"Write a function to find all permutations of a string.",python,medium,"def permutations(s):\n if len(s) <= 1:\n return [s]\n result = []\n for i, c in enumerate(s):\n for perm in permutations(s[:i] + s[i+1:]):\n result.append(c + perm)\n return result","assert sorted(permutations('ab'))==['ab','ba']",true,algorithms
|
| 708 |
+
11,"Implement depth-first search on a graph.",python,medium,"def dfs(graph, start, visited=None):\n if visited is None:\n visited = set()\n visited.add(start)\n for neighbor in graph.get(start, []):\n if neighbor not in visited:\n dfs(graph, neighbor, visited)\n return visited","g={'A':['B','C'],'B':['D'],'C':[],'D':[]}; assert dfs(g,'A')=={'A','B','C','D'}",true,algorithms
|
| 709 |
+
12,"Write a function to check balanced parentheses.",python,easy,"def is_balanced(s):\n stack = []\n mapping = {')':'(', '}':'{', ']':'['}\n for c in s:\n if c in mapping.values():\n stack.append(c)\n elif c in mapping:\n if not stack or stack[-1] != mapping[c]:\n return False\n stack.pop()\n return len(stack) == 0","assert is_balanced('([]){}')==True; assert is_balanced('([)]')==False",true,strings
|
| 710 |
+
13,"Create a decorator that caches function results.",python,medium,"from functools import wraps\ndef memoize(func):\n cache = {}\n @wraps(func)\n def wrapper(*args):\n if args not in cache:\n cache[args] = func(*args)\n return cache[args]\n return wrapper","@memoize\ndef fib(n): return n if n<2 else fib(n-1)+fib(n-2)\nassert fib(10)==55",true,design_patterns
|
| 711 |
+
14,"Implement quicksort.",python,medium,"def quicksort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr)//2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quicksort(left) + middle + quicksort(right)","assert quicksort([3,6,8,10,1,2,1])==[1,1,2,3,6,8,10]",true,algorithms
|
| 712 |
+
15,"Write a function to detect a cycle in a linked list.",python,medium,"def has_cycle(head):\n slow = fast = head\n while fast and fast.next:\n slow = slow.next\n fast = fast.next.next\n if slow == fast:\n return True\n return False","# requires ListNode class with cycle setup",true,data_structures
|
| 713 |
+
16,"Implement a trie (prefix tree).",python,hard,"class TrieNode:\n def __init__(self):\n self.children = {}\n self.is_end = False\n\nclass Trie:\n def __init__(self):\n self.root = TrieNode()\n def insert(self, word):\n node = self.root\n for c in word:\n if c not in node.children:\n node.children[c] = TrieNode()\n node = node.children[c]\n node.is_end = True\n def search(self, word):\n node = self.root\n for c in word:\n if c not in node.children:\n return False\n node = node.children[c]\n return node.is_end","t=Trie(); t.insert('apple'); assert t.search('apple')==True; assert t.search('app')==False",true,data_structures
|
| 714 |
+
17,"Write a function that flattens a nested list.",python,easy,"def flatten(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.extend(flatten(item))\n else:\n result.append(item)\n return result","assert flatten([1,[2,[3,4],5]])==[1,2,3,4,5]",true,algorithms
|
| 715 |
+
18,"Implement a basic calculator that evaluates +,-,*,/ with parentheses.",python,hard,"def calculate(s):\n def helper(tokens):\n stack = []\n num = 0\n sign = '+'\n while tokens:\n t = tokens.pop(0)\n if t.isdigit():\n num = num * 10 + int(t)\n if t == '(':\n num = helper(tokens)\n if t in '+-*/)' or not tokens:\n if sign == '+': stack.append(num)\n elif sign == '-': stack.append(-num)\n elif sign == '*': stack.append(stack.pop() * num)\n elif sign == '/': stack.append(int(stack.pop() / num))\n num = 0\n sign = t\n if t == ')':\n break\n return sum(stack)\n return helper(list(s.replace(' ', '')))","assert calculate('3+2*2')==7; assert calculate('(1+2)*3')==9",true,algorithms
|
| 716 |
+
19,"Write a thread-safe singleton pattern in Python.",python,hard,"import threading\nclass Singleton:\n _instance = None\n _lock = threading.Lock()\n def __new__(cls):\n if cls._instance is None:\n with cls._lock:\n if cls._instance is None:\n cls._instance = super().__new__(cls)\n return cls._instance","s1=Singleton(); s2=Singleton(); assert s1 is s2",true,design_patterns
|
| 717 |
+
20,"Implement Dijkstra's shortest path algorithm.",python,hard,"import heapq\ndef dijkstra(graph, start):\n dist = {node: float('inf') for node in graph}\n dist[start] = 0\n pq = [(0, start)]\n while pq:\n d, u = heapq.heappop(pq)\n if d > dist[u]:\n continue\n for v, w in graph[u]:\n if dist[u] + w < dist[v]:\n dist[v] = dist[u] + w\n heapq.heappush(pq, (dist[v], v))\n return dist","g={'A':[('B',1),('C',4)],'B':[('C',2)],'C':[]}; assert dijkstra(g,'A')=={'A':0,'B':1,'C':3}",true,algorithms"""
|
| 718 |
+
|
| 719 |
+
schema_desc = """Columns:
|
| 720 |
+
- id: integer, unique, sequential starting from 1
|
| 721 |
+
- instruction: string, non-empty, describes a coding task
|
| 722 |
+
- language: string, one of [python, javascript, sql, java, cpp, rust, go]
|
| 723 |
+
- difficulty: string, one of [easy, medium, hard]
|
| 724 |
+
- response: string, non-empty, contains code that solves the instruction
|
| 725 |
+
- test_cases: string, non-empty, contains assertions or test descriptions
|
| 726 |
+
- is_correct: boolean (true/false), whether the response correctly solves the instruction
|
| 727 |
+
- category: string, one of [algorithms, data_structures, strings, web, databases, design_patterns]"""
|
| 728 |
+
|
| 729 |
+
rules = """1. No missing values in any column
|
| 730 |
+
2. id must be unique and sequential
|
| 731 |
+
3. language must be a valid programming language from the allowed set
|
| 732 |
+
4. response code must be in the language specified by the language column
|
| 733 |
+
5. is_correct must be 'true' if and only if the code actually solves the problem correctly
|
| 734 |
+
6. difficulty must reflect the actual complexity of the task
|
| 735 |
+
7. response must be syntactically valid code (no truncation or syntax errors)
|
| 736 |
+
8. test_cases must be relevant to the instruction
|
| 737 |
+
9. No duplicate instructions (same problem stated differently counts as duplicate)
|
| 738 |
+
10. category must match the actual nature of the problem"""
|
| 739 |
+
|
| 740 |
+
rows = _csv_to_rows(clean_csv)
|
| 741 |
+
header = rows[0]
|
| 742 |
+
data = rows[1:]
|
| 743 |
+
issues: List[PlantedIssue] = []
|
| 744 |
+
|
| 745 |
+
# Issue 1: Response has syntax error but is_correct=true (difficulty 2.0)
|
| 746 |
+
# Row 3 (reverse linked list) — introduce unbalanced parenthesis
|
| 747 |
+
r = 2 # 0-indexed -> row 3
|
| 748 |
+
data[r][4] = "def reverse_list(head):\n prev = None\n curr = head\n while curr:\n nxt = curr.next\n curr.next = prev\n prev = curr\n curr = nxt\n return prev)" # extra closing paren
|
| 749 |
+
issues.append(PlantedIssue(
|
| 750 |
+
row=r + 1, col="response", issue_type="format_violation",
|
| 751 |
+
description="Syntax error: unbalanced parenthesis in response but is_correct=true",
|
| 752 |
+
difficulty=2.0))
|
| 753 |
+
|
| 754 |
+
# Issue 2: Wrong language — response is JavaScript but language says python (difficulty 2.5)
|
| 755 |
+
# Row 8 (email validation)
|
| 756 |
+
r = 7
|
| 757 |
+
data[r][4] = "function isValidEmail(email) {\n const pattern = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/;\n return pattern.test(email);\n}"
|
| 758 |
+
issues.append(PlantedIssue(
|
| 759 |
+
row=r + 1, col="response", issue_type="inconsistent_value",
|
| 760 |
+
description="Response is JavaScript but language column says python",
|
| 761 |
+
difficulty=2.5))
|
| 762 |
+
|
| 763 |
+
# Issue 3: Truncated response — code cut off mid-function (difficulty 2.0)
|
| 764 |
+
# Row 18 (basic calculator)
|
| 765 |
+
r = 17
|
| 766 |
+
data[r][4] = "def calculate(s):\n def helper(tokens):\n stack = []\n num = 0\n sign = '+'\n while tokens:\n t = tokens.pop(0)\n if t.isdigit():\n num = num" # truncated
|
| 767 |
+
issues.append(PlantedIssue(
|
| 768 |
+
row=r + 1, col="response", issue_type="format_violation",
|
| 769 |
+
description="Response truncated mid-expression — incomplete code",
|
| 770 |
+
difficulty=2.0))
|
| 771 |
+
|
| 772 |
+
# Issue 4: is_correct=true but code has logic bug (difficulty 3.0)
|
| 773 |
+
# Row 2 (binary search) — off-by-one: lo = mid instead of mid + 1
|
| 774 |
+
r = 1
|
| 775 |
+
data[r][4] = "def binary_search(arr, target):\n lo, hi = 0, len(arr) - 1\n while lo <= hi:\n mid = (lo + hi) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n lo = mid\n else:\n hi = mid - 1\n return -1"
|
| 776 |
+
data[r][6] = "true" # claims correct but has infinite loop bug
|
| 777 |
+
issues.append(PlantedIssue(
|
| 778 |
+
row=r + 1, col="is_correct", issue_type="inconsistent_value",
|
| 779 |
+
description="is_correct=true but binary search has off-by-one bug (lo=mid causes infinite loop)",
|
| 780 |
+
difficulty=3.0))
|
| 781 |
+
|
| 782 |
+
# Issue 5: SQL response for a python-labeled task (difficulty 2.0)
|
| 783 |
+
# Row 6 is SQL task — change language to python but keep SQL response
|
| 784 |
+
r = 5
|
| 785 |
+
data[r][2] = "python"
|
| 786 |
+
issues.append(PlantedIssue(
|
| 787 |
+
row=r + 1, col="language", issue_type="inconsistent_value",
|
| 788 |
+
description="Language says python but response is SQL query",
|
| 789 |
+
difficulty=2.0))
|
| 790 |
+
|
| 791 |
+
# Issue 6: Difficulty mismatch — trivial problem labeled hard (difficulty 2.5)
|
| 792 |
+
# Row 17 (flatten nested list) is easy, change to hard
|
| 793 |
+
r = 16
|
| 794 |
+
data[r][3] = "hard"
|
| 795 |
+
issues.append(PlantedIssue(
|
| 796 |
+
row=r + 1, col="difficulty", issue_type="inconsistent_value",
|
| 797 |
+
description="Flatten nested list is a simple recursion but labeled as hard",
|
| 798 |
+
difficulty=2.5))
|
| 799 |
+
|
| 800 |
+
# Issue 7: Missing test cases — empty string (difficulty 1.0)
|
| 801 |
+
r = 12
|
| 802 |
+
data[r][5] = ""
|
| 803 |
+
issues.append(PlantedIssue(
|
| 804 |
+
row=r + 1, col="test_cases", issue_type="missing_value",
|
| 805 |
+
description="Empty test_cases field for memoize decorator",
|
| 806 |
+
difficulty=1.0))
|
| 807 |
+
|
| 808 |
+
# Issue 8: Security vulnerability in response rated is_correct=true (difficulty 3.0)
|
| 809 |
+
# Row 4 (REST API) — add eval() of user input
|
| 810 |
+
r = 3
|
| 811 |
+
data[r][4] = "from flask import Flask, jsonify, request\napp = Flask(__name__)\n\n@app.route('/users/<uid>')\ndef get_user(uid):\n users = {1: {'name': 'Alice'}, 2: {'name': 'Bob'}}\n user_id = eval(uid)\n return jsonify(users.get(user_id, {}))"
|
| 812 |
+
issues.append(PlantedIssue(
|
| 813 |
+
row=r + 1, col="response", issue_type="inconsistent_value",
|
| 814 |
+
description="Response uses eval() on user input — critical security vulnerability (code injection) but is_correct=true",
|
| 815 |
+
difficulty=3.0))
|
| 816 |
+
|
| 817 |
+
# Issue 9: Duplicate instruction — row 14 (quicksort) is semantically same as row 7 (merge sort)
|
| 818 |
+
# Change instruction to match merge sort
|
| 819 |
+
r = 13
|
| 820 |
+
data[r][1] = "Implement merge sort algorithm."
|
| 821 |
+
issues.append(PlantedIssue(
|
| 822 |
+
row=r + 1, col="instruction", issue_type="duplicate_row",
|
| 823 |
+
description="Instruction 'Implement merge sort algorithm' duplicates row 7 'Implement merge sort' (semantic duplicate)",
|
| 824 |
+
difficulty=2.5))
|
| 825 |
+
|
| 826 |
+
# Issue 10: Wrong category — Dijkstra labeled as design_patterns (difficulty 1.5)
|
| 827 |
+
r = 19
|
| 828 |
+
data[r][7] = "design_patterns"
|
| 829 |
+
issues.append(PlantedIssue(
|
| 830 |
+
row=r + 1, col="category", issue_type="inconsistent_value",
|
| 831 |
+
description="Dijkstra's algorithm categorized as design_patterns instead of algorithms",
|
| 832 |
+
difficulty=1.5))
|
| 833 |
+
|
| 834 |
+
corrupted = _rows_to_csv([header] + data)
|
| 835 |
+
|
| 836 |
+
return Task(
|
| 837 |
+
task_id="coding",
|
| 838 |
+
name="Code Quality Dataset Validation",
|
| 839 |
+
description=(
|
| 840 |
+
"You are given a coding instruction-response dataset used for LLM fine-tuning. "
|
| 841 |
+
"Find all data quality issues: incorrect labels, language mismatches, logic bugs, "
|
| 842 |
+
"syntax errors, security vulnerabilities, duplicate instructions, and missing fields. "
|
| 843 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 844 |
+
),
|
| 845 |
+
schema_description=schema_desc,
|
| 846 |
+
validation_rules=rules,
|
| 847 |
+
clean_csv=clean_csv,
|
| 848 |
+
planted_issues=issues,
|
| 849 |
+
corrupted_csv=corrupted,
|
| 850 |
+
max_steps=3,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
# ---------------------------------------------------------------------------
|
| 855 |
+
# TASK 6: Tool-calling — Function definition and call quality issues
|
| 856 |
+
# ---------------------------------------------------------------------------
|
| 857 |
+
|
| 858 |
+
def create_task_toolcalling(seed: int = 42) -> Task:
|
| 859 |
+
rng = random.Random(seed)
|
| 860 |
+
|
| 861 |
+
clean_csv = """id,function_name,description,parameters_json,required_params,return_type,example_call,example_output,category
|
| 862 |
+
1,get_weather,"Get current weather for a location.","{""location"": ""string"", ""units"": ""string (celsius|fahrenheit)""}","location",object,"{""function"": ""get_weather"", ""arguments"": {""location"": ""San Francisco"", ""units"": ""celsius""}}","{""temp"": 18, ""condition"": ""cloudy""}",information
|
| 863 |
+
2,send_email,"Send an email to a recipient.","{""to"": ""string"", ""subject"": ""string"", ""body"": ""string"", ""cc"": ""string (optional)""}","to,subject,body",object,"{""function"": ""send_email"", ""arguments"": {""to"": ""alice@example.com"", ""subject"": ""Meeting"", ""body"": ""See you at 3pm""}}","{""status"": ""sent"", ""message_id"": ""msg_123""}",communication
|
| 864 |
+
3,search_database,"Query a database with filters.","{""query"": ""string"", ""table"": ""string"", ""limit"": ""integer (default 10)""}","query,table",array,"{""function"": ""search_database"", ""arguments"": {""query"": ""age > 25"", ""table"": ""users"", ""limit"": 5}}","[{""name"": ""Alice"", ""age"": 30}]",data
|
| 865 |
+
4,create_calendar_event,"Create a new calendar event.","{""title"": ""string"", ""start_time"": ""string (ISO 8601)"", ""end_time"": ""string (ISO 8601)"", ""attendees"": ""array of strings (optional)""}","title,start_time,end_time",object,"{""function"": ""create_calendar_event"", ""arguments"": {""title"": ""Team Sync"", ""start_time"": ""2024-03-15T10:00:00Z"", ""end_time"": ""2024-03-15T11:00:00Z""}}","{""event_id"": ""evt_456"", ""status"": ""created""}",scheduling
|
| 866 |
+
5,translate_text,"Translate text between languages.","{""text"": ""string"", ""source_lang"": ""string (ISO 639-1)"", ""target_lang"": ""string (ISO 639-1)""}","text,target_lang",object,"{""function"": ""translate_text"", ""arguments"": {""text"": ""Hello world"", ""source_lang"": ""en"", ""target_lang"": ""es""}}","{""translated"": ""Hola mundo"", ""confidence"": 0.95}",language
|
| 867 |
+
6,get_stock_price,"Get real-time stock price.","{""symbol"": ""string"", ""exchange"": ""string (optional, default NYSE)""}","symbol",object,"{""function"": ""get_stock_price"", ""arguments"": {""symbol"": ""AAPL""}}","{""price"": 178.52, ""currency"": ""USD"", ""change"": 2.3}",finance
|
| 868 |
+
7,upload_file,"Upload a file to cloud storage.","{""file_path"": ""string"", ""bucket"": ""string"", ""public"": ""boolean (default false)""}","file_path,bucket",object,"{""function"": ""upload_file"", ""arguments"": {""file_path"": ""/data/report.pdf"", ""bucket"": ""my-bucket""}}","{""url"": ""https://storage.example.com/my-bucket/report.pdf"", ""size_bytes"": 1048576}",storage
|
| 869 |
+
8,run_code,"Execute code in a sandboxed environment.","{""code"": ""string"", ""language"": ""string (python|javascript|ruby)"", ""timeout"": ""integer (seconds, default 30)""}","code,language",object,"{""function"": ""run_code"", ""arguments"": {""code"": ""print(2+2)"", ""language"": ""python""}}","{""stdout"": ""4\n"", ""exit_code"": 0}",execution
|
| 870 |
+
9,get_directions,"Get driving/walking directions.","{""origin"": ""string"", ""destination"": ""string"", ""mode"": ""string (driving|walking|transit)""}","origin,destination",object,"{""function"": ""get_directions"", ""arguments"": {""origin"": ""NYC"", ""destination"": ""Boston"", ""mode"": ""driving""}}","{""distance_km"": 346, ""duration_min"": 230, ""steps"": [""Take I-95 N...""]}",navigation
|
| 871 |
+
10,analyze_sentiment,"Analyze sentiment of text.","{""text"": ""string"", ""language"": ""string (optional, default en)""}","text",object,"{""function"": ""analyze_sentiment"", ""arguments"": {""text"": ""I love this product!""}}","{""sentiment"": ""positive"", ""score"": 0.92}",analysis
|
| 872 |
+
11,create_user,"Create a new user account.","{""username"": ""string"", ""email"": ""string"", ""role"": ""string (admin|user|viewer)""}","username,email,role",object,"{""function"": ""create_user"", ""arguments"": {""username"": ""jdoe"", ""email"": ""jdoe@example.com"", ""role"": ""user""}}","{""user_id"": ""usr_789"", ""created"": true}",account
|
| 873 |
+
12,generate_image,"Generate an image from a text prompt.","{""prompt"": ""string"", ""size"": ""string (256x256|512x512|1024x1024)"", ""style"": ""string (optional)""}","prompt",object,"{""function"": ""generate_image"", ""arguments"": {""prompt"": ""sunset over mountains"", ""size"": ""512x512""}}","{""image_url"": ""https://img.example.com/gen_001.png""}",creative
|
| 874 |
+
13,list_files,"List files in a directory.","{""path"": ""string"", ""recursive"": ""boolean (default false)"", ""pattern"": ""string (glob, optional)""}","path",array,"{""function"": ""list_files"", ""arguments"": {""path"": ""/home/user/docs""}}","[""report.pdf"", ""notes.txt""]",filesystem
|
| 875 |
+
14,set_reminder,"Set a timed reminder.","{""message"": ""string"", ""time"": ""string (ISO 8601)"", ""repeat"": ""string (none|daily|weekly, optional)""}","message,time",object,"{""function"": ""set_reminder"", ""arguments"": {""message"": ""Stand up and stretch"", ""time"": ""2024-03-15T15:00:00Z""}}","{""reminder_id"": ""rem_101"", ""status"": ""set""}",scheduling
|
| 876 |
+
15,convert_currency,"Convert between currencies.","{""amount"": ""number"", ""from_currency"": ""string (ISO 4217)"", ""to_currency"": ""string (ISO 4217)""}","amount,from_currency,to_currency",object,"{""function"": ""convert_currency"", ""arguments"": {""amount"": 100, ""from_currency"": ""USD"", ""to_currency"": ""EUR""}}","{""converted"": 91.5, ""rate"": 0.915}",finance
|
| 877 |
+
16,summarize_text,"Summarize a long text.","{""text"": ""string"", ""max_length"": ""integer (optional, default 100)""}","text",object,"{""function"": ""summarize_text"", ""arguments"": {""text"": ""Long article about climate change..."", ""max_length"": 50}}","{""summary"": ""Climate change poses significant challenges...""}",analysis
|
| 878 |
+
17,get_user_info,"Retrieve user profile information.","{""user_id"": ""string""}","user_id",object,"{""function"": ""get_user_info"", ""arguments"": {""user_id"": ""usr_789""}}","{""username"": ""jdoe"", ""email"": ""jdoe@example.com"", ""role"": ""user""}",account
|
| 879 |
+
18,compress_image,"Compress an image to reduce file size.","{""image_url"": ""string"", ""quality"": ""integer (1-100)"", ""format"": ""string (jpeg|png|webp)""}","image_url,quality",object,"{""function"": ""compress_image"", ""arguments"": {""image_url"": ""https://img.example.com/photo.png"", ""quality"": 80}}","{""compressed_url"": ""https://img.example.com/photo_compressed.png"", ""reduction"": ""65%""}",media
|
| 880 |
+
19,execute_trade,"Execute a stock trade.","{""symbol"": ""string"", ""action"": ""string (buy|sell)"", ""quantity"": ""integer"", ""order_type"": ""string (market|limit)"", ""limit_price"": ""number (required if order_type=limit)""}","symbol,action,quantity,order_type",object,"{""function"": ""execute_trade"", ""arguments"": {""symbol"": ""AAPL"", ""action"": ""buy"", ""quantity"": 10, ""order_type"": ""market""}}","{""trade_id"": ""trd_202"", ""status"": ""executed"", ""filled_price"": 178.52}",finance
|
| 881 |
+
20,parse_pdf,"Extract text content from a PDF.","{""url"": ""string"", ""pages"": ""string (optional, e.g. 1-5)""}","url",object,"{""function"": ""parse_pdf"", ""arguments"": {""url"": ""https://docs.example.com/report.pdf""}}","{""text"": ""Annual Report 2024..."", ""page_count"": 12}",data"""
|
| 882 |
+
|
| 883 |
+
schema_desc = """Columns:
|
| 884 |
+
- id: integer, unique, sequential starting from 1
|
| 885 |
+
- function_name: string, valid identifier (snake_case), unique
|
| 886 |
+
- description: string, non-empty, describes what the function does
|
| 887 |
+
- parameters_json: string, valid JSON-like parameter schema with types
|
| 888 |
+
- required_params: string, comma-separated parameter names that must be present in example_call
|
| 889 |
+
- return_type: string, one of [object, array, string, number, boolean]
|
| 890 |
+
- example_call: string, valid JSON with "function" matching function_name and "arguments" containing required params
|
| 891 |
+
- example_output: string, valid JSON matching return_type
|
| 892 |
+
- category: string, one of [information, communication, data, scheduling, language, finance, storage, execution, navigation, analysis, account, creative, filesystem, media]"""
|
| 893 |
+
|
| 894 |
+
rules = """1. No missing values in any column
|
| 895 |
+
2. id must be unique and sequential
|
| 896 |
+
3. function_name must be unique and match the "function" field in example_call
|
| 897 |
+
4. All required_params must appear as keys in the example_call arguments
|
| 898 |
+
5. Parameter types in parameters_json must match the actual values in example_call
|
| 899 |
+
6. return_type must match the type of example_output
|
| 900 |
+
7. example_call must be valid JSON
|
| 901 |
+
8. example_output must be valid JSON
|
| 902 |
+
9. description must accurately describe what the function does
|
| 903 |
+
10. No hallucinated parameters in example_call that are not defined in parameters_json"""
|
| 904 |
+
|
| 905 |
+
rows = _csv_to_rows(clean_csv)
|
| 906 |
+
header = rows[0]
|
| 907 |
+
data = rows[1:]
|
| 908 |
+
issues: List[PlantedIssue] = []
|
| 909 |
+
|
| 910 |
+
# Issue 1: Function name mismatch — example_call uses wrong function name (difficulty 2.0)
|
| 911 |
+
# Row 3 (search_database) — call says "query_database" instead
|
| 912 |
+
r = 2
|
| 913 |
+
data[r][6] = '{"function": "query_database", "arguments": {"query": "age > 25", "table": "users", "limit": 5}}'
|
| 914 |
+
issues.append(PlantedIssue(
|
| 915 |
+
row=r + 1, col="example_call", issue_type="inconsistent_value",
|
| 916 |
+
description="example_call function name 'query_database' doesn't match function_name 'search_database'",
|
| 917 |
+
difficulty=2.0))
|
| 918 |
+
|
| 919 |
+
# Issue 2: Missing required parameter in example_call (difficulty 2.5)
|
| 920 |
+
# Row 4 (create_calendar_event) — missing end_time which is required
|
| 921 |
+
r = 3
|
| 922 |
+
data[r][6] = '{"function": "create_calendar_event", "arguments": {"title": "Team Sync", "start_time": "2024-03-15T10:00:00Z"}}'
|
| 923 |
+
issues.append(PlantedIssue(
|
| 924 |
+
row=r + 1, col="example_call", issue_type="inconsistent_value",
|
| 925 |
+
description="Required parameter 'end_time' missing from example_call arguments",
|
| 926 |
+
difficulty=2.5))
|
| 927 |
+
|
| 928 |
+
# Issue 3: Hallucinated parameter — example_call has param not in schema (difficulty 3.0)
|
| 929 |
+
# Row 10 (analyze_sentiment) — add "model" param not in parameters_json
|
| 930 |
+
r = 9
|
| 931 |
+
data[r][6] = '{"function": "analyze_sentiment", "arguments": {"text": "I love this product!", "model": "gpt-4", "confidence_threshold": 0.8}}'
|
| 932 |
+
issues.append(PlantedIssue(
|
| 933 |
+
row=r + 1, col="example_call", issue_type="inconsistent_value",
|
| 934 |
+
description="Hallucinated parameters 'model' and 'confidence_threshold' not defined in parameters_json",
|
| 935 |
+
difficulty=3.0))
|
| 936 |
+
|
| 937 |
+
# Issue 4: Wrong return_type — returns object but labeled as array (difficulty 1.5)
|
| 938 |
+
# Row 6 (get_stock_price)
|
| 939 |
+
r = 5
|
| 940 |
+
data[r][5] = "array"
|
| 941 |
+
issues.append(PlantedIssue(
|
| 942 |
+
row=r + 1, col="return_type", issue_type="inconsistent_value",
|
| 943 |
+
description="return_type says 'array' but example_output is an object",
|
| 944 |
+
difficulty=1.5))
|
| 945 |
+
|
| 946 |
+
# Issue 5: Invalid JSON in example_call (difficulty 2.0)
|
| 947 |
+
# Row 12 (generate_image) — malformed JSON
|
| 948 |
+
r = 11
|
| 949 |
+
data[r][6] = '{"function": "generate_image", "arguments": {"prompt": "sunset over mountains", "size": "512x512"' # missing closing braces
|
| 950 |
+
issues.append(PlantedIssue(
|
| 951 |
+
row=r + 1, col="example_call", issue_type="format_violation",
|
| 952 |
+
description="Invalid JSON in example_call — missing closing braces",
|
| 953 |
+
difficulty=2.0))
|
| 954 |
+
|
| 955 |
+
# Issue 6: Parameter type mismatch — schema says integer but call passes string (difficulty 2.5)
|
| 956 |
+
# Row 18 (compress_image) — quality should be integer but passed as string "high"
|
| 957 |
+
r = 17
|
| 958 |
+
data[r][6] = '{"function": "compress_image", "arguments": {"image_url": "https://img.example.com/photo.png", "quality": "high"}}'
|
| 959 |
+
issues.append(PlantedIssue(
|
| 960 |
+
row=r + 1, col="example_call", issue_type="wrong_type",
|
| 961 |
+
description="Parameter 'quality' should be integer (1-100) but example_call passes string 'high'",
|
| 962 |
+
difficulty=2.5))
|
| 963 |
+
|
| 964 |
+
# Issue 7: Empty description (difficulty 1.0)
|
| 965 |
+
r = 15
|
| 966 |
+
data[r][2] = ""
|
| 967 |
+
issues.append(PlantedIssue(
|
| 968 |
+
row=r + 1, col="description", issue_type="missing_value",
|
| 969 |
+
description="Empty description field for summarize_text function",
|
| 970 |
+
difficulty=1.0))
|
| 971 |
+
|
| 972 |
+
# Issue 8: Duplicate function — row 17 (get_user_info) duplicates row 11 (create_user) in purpose
|
| 973 |
+
# Change function_name to create_user (duplicate)
|
| 974 |
+
r = 16
|
| 975 |
+
data[r][1] = "create_user"
|
| 976 |
+
issues.append(PlantedIssue(
|
| 977 |
+
row=r + 1, col="function_name", issue_type="duplicate_row",
|
| 978 |
+
description="Duplicate function_name 'create_user' — already defined in row 11",
|
| 979 |
+
difficulty=1.5))
|
| 980 |
+
|
| 981 |
+
# Issue 9: Dangerous function without safety description (difficulty 3.0)
|
| 982 |
+
# Row 19 (execute_trade) — description doesn't mention it's irreversible/dangerous
|
| 983 |
+
# Change description to misleadingly safe
|
| 984 |
+
r = 18
|
| 985 |
+
data[r][2] = "Preview a potential stock trade."
|
| 986 |
+
issues.append(PlantedIssue(
|
| 987 |
+
row=r + 1, col="description", issue_type="inconsistent_value",
|
| 988 |
+
description="Description says 'Preview a potential stock trade' but function actually executes trades (irreversible action mislabeled as preview)",
|
| 989 |
+
difficulty=3.0))
|
| 990 |
+
|
| 991 |
+
# Issue 10: Wrong category (difficulty 1.5)
|
| 992 |
+
# Row 8 (run_code) labeled as "scheduling" instead of "execution"
|
| 993 |
+
r = 7
|
| 994 |
+
data[r][8] = "scheduling"
|
| 995 |
+
issues.append(PlantedIssue(
|
| 996 |
+
row=r + 1, col="category", issue_type="inconsistent_value",
|
| 997 |
+
description="run_code categorized as 'scheduling' instead of 'execution'",
|
| 998 |
+
difficulty=1.5))
|
| 999 |
+
|
| 1000 |
+
corrupted = _rows_to_csv([header] + data)
|
| 1001 |
+
|
| 1002 |
+
return Task(
|
| 1003 |
+
task_id="toolcalling",
|
| 1004 |
+
name="Tool-Calling Dataset Validation",
|
| 1005 |
+
description=(
|
| 1006 |
+
"You are given a tool-calling/function-calling dataset used for LLM fine-tuning. "
|
| 1007 |
+
"Find all data quality issues: function name mismatches between definition and call, "
|
| 1008 |
+
"missing required parameters, hallucinated parameters, type mismatches, invalid JSON, "
|
| 1009 |
+
"duplicate functions, and misleading descriptions. "
|
| 1010 |
+
"Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
|
| 1011 |
+
),
|
| 1012 |
+
schema_description=schema_desc,
|
| 1013 |
+
validation_rules=rules,
|
| 1014 |
+
clean_csv=clean_csv,
|
| 1015 |
+
planted_issues=issues,
|
| 1016 |
+
corrupted_csv=corrupted,
|
| 1017 |
+
max_steps=3,
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# ---------------------------------------------------------------------------
|
| 1022 |
+
# Contamination rules for extensible task creation
|
| 1023 |
+
# ---------------------------------------------------------------------------
|
| 1024 |
+
|
| 1025 |
+
# Each contamination rule is a callable: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
|
| 1026 |
+
# Users can define their own and register them.
|
| 1027 |
+
|
| 1028 |
+
CONTAMINATION_RULES = {
|
| 1029 |
+
"missing_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 1030 |
+
"",
|
| 1031 |
+
PlantedIssue(
|
| 1032 |
+
row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
|
| 1033 |
+
description=f"Empty {header[col_idx]} field", difficulty=1.0,
|
| 1034 |
+
),
|
| 1035 |
+
),
|
| 1036 |
+
"whitespace_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 1037 |
+
" ",
|
| 1038 |
+
PlantedIssue(
|
| 1039 |
+
row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
|
| 1040 |
+
description=f"Whitespace-only {header[col_idx]} field", difficulty=2.5,
|
| 1041 |
+
),
|
| 1042 |
+
),
|
| 1043 |
+
"wrong_type_text": lambda rows, header, col_idx, row_idx, rng: (
|
| 1044 |
+
rng.choice(["not-a-number", "N/A", "null", "undefined"]),
|
| 1045 |
+
PlantedIssue(
|
| 1046 |
+
row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
|
| 1047 |
+
description=f"{header[col_idx]} is text instead of expected type", difficulty=1.0,
|
| 1048 |
+
),
|
| 1049 |
+
),
|
| 1050 |
+
"negative_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 1051 |
+
str(-abs(float(rows[row_idx][col_idx]) if rows[row_idx][col_idx] else 1)),
|
| 1052 |
+
PlantedIssue(
|
| 1053 |
+
row=row_idx + 1, col=header[col_idx], issue_type="out_of_range",
|
| 1054 |
+
description=f"Negative {header[col_idx]}", difficulty=1.0,
|
| 1055 |
+
),
|
| 1056 |
+
),
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
def create_task_from_config(
|
| 1061 |
+
task_id: str,
|
| 1062 |
+
name: str,
|
| 1063 |
+
description: str,
|
| 1064 |
+
schema_description: str,
|
| 1065 |
+
validation_rules: str,
|
| 1066 |
+
clean_csv: str,
|
| 1067 |
+
contaminations: List[dict],
|
| 1068 |
+
max_steps: int = 3,
|
| 1069 |
+
seed: int = 42,
|
| 1070 |
+
) -> Task:
|
| 1071 |
+
"""
|
| 1072 |
+
Create a custom task from a configuration dict.
|
| 1073 |
+
|
| 1074 |
+
Each contamination entry should have:
|
| 1075 |
+
- rule: str (key in CONTAMINATION_RULES) or callable
|
| 1076 |
+
- row: int (0-based row index in data)
|
| 1077 |
+
- col: int (column index in header)
|
| 1078 |
+
- difficulty: float (optional, overrides rule default)
|
| 1079 |
+
|
| 1080 |
+
Example:
|
| 1081 |
+
contaminations = [
|
| 1082 |
+
{"rule": "missing_value", "row": 2, "col": 1, "difficulty": 1.5},
|
| 1083 |
+
{"rule": "negative_value", "row": 5, "col": 4},
|
| 1084 |
+
]
|
| 1085 |
+
"""
|
| 1086 |
+
rng = random.Random(seed)
|
| 1087 |
+
rows = _csv_to_rows(clean_csv)
|
| 1088 |
+
header = rows[0]
|
| 1089 |
+
data = rows[1:]
|
| 1090 |
+
issues: List[PlantedIssue] = []
|
| 1091 |
+
|
| 1092 |
+
for spec in contaminations:
|
| 1093 |
+
rule = spec["rule"]
|
| 1094 |
+
row_idx = spec["row"]
|
| 1095 |
+
col_idx = spec["col"]
|
| 1096 |
+
|
| 1097 |
+
if callable(rule):
|
| 1098 |
+
new_val, issue = rule(data, header, col_idx, row_idx, rng)
|
| 1099 |
+
elif rule in CONTAMINATION_RULES:
|
| 1100 |
+
new_val, issue = CONTAMINATION_RULES[rule](data, header, col_idx, row_idx, rng)
|
| 1101 |
+
else:
|
| 1102 |
+
raise ValueError(f"Unknown contamination rule: {rule}. Available: {list(CONTAMINATION_RULES.keys())}")
|
| 1103 |
+
|
| 1104 |
+
data[row_idx][col_idx] = new_val
|
| 1105 |
+
if "difficulty" in spec:
|
| 1106 |
+
issue.difficulty = spec["difficulty"]
|
| 1107 |
+
issues.append(issue)
|
| 1108 |
+
|
| 1109 |
+
corrupted = _rows_to_csv([header] + data)
|
| 1110 |
+
|
| 1111 |
+
return Task(
|
| 1112 |
+
task_id=task_id,
|
| 1113 |
+
name=name,
|
| 1114 |
+
description=description,
|
| 1115 |
+
schema_description=schema_description,
|
| 1116 |
+
validation_rules=validation_rules,
|
| 1117 |
+
clean_csv=clean_csv,
|
| 1118 |
+
planted_issues=issues,
|
| 1119 |
+
corrupted_csv=corrupted,
|
| 1120 |
+
max_steps=max_steps,
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def register_task(task_id: str, factory_fn):
|
| 1125 |
+
"""Register a custom task factory. Factory must accept (seed: int) -> Task."""
|
| 1126 |
+
TASK_REGISTRY[task_id] = factory_fn
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
def register_contamination_rule(name: str, rule_fn):
|
| 1130 |
+
"""
|
| 1131 |
+
Register a custom contamination rule.
|
| 1132 |
+
|
| 1133 |
+
rule_fn signature: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
|
| 1134 |
+
"""
|
| 1135 |
+
CONTAMINATION_RULES[name] = rule_fn
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
# ---------------------------------------------------------------------------
|
| 1139 |
+
# Task registry
|
| 1140 |
+
# ---------------------------------------------------------------------------
|
| 1141 |
+
|
| 1142 |
+
TASK_REGISTRY = {
|
| 1143 |
+
"easy": create_task_easy,
|
| 1144 |
+
"medium": create_task_medium,
|
| 1145 |
+
"hard": create_task_hard,
|
| 1146 |
+
"alignment": create_task_alignment,
|
| 1147 |
+
"coding": create_task_coding,
|
| 1148 |
+
"toolcalling": create_task_toolcalling,
|
| 1149 |
+
}
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
def get_task(task_id: str, seed: int = 42) -> Task:
|
| 1153 |
+
if task_id not in TASK_REGISTRY:
|
| 1154 |
+
raise ValueError(f"Unknown task: {task_id}. Available: {list(TASK_REGISTRY.keys())}")
|
| 1155 |
+
return TASK_REGISTRY[task_id](seed=seed)
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def list_tasks() -> List[str]:
|
| 1159 |
+
return list(TASK_REGISTRY.keys())
|
inference.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DataQA Inference Script — Two-Phase Agent
|
| 4 |
+
------------------------------------------
|
| 5 |
+
LLM agent that plays the DataQA environment in two phases:
|
| 6 |
+
Phase 1: Identify all data quality issues
|
| 7 |
+
Phase 2: Propose fixes for identified issues
|
| 8 |
+
|
| 9 |
+
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
|
| 10 |
+
|
| 11 |
+
Required environment variables:
|
| 12 |
+
API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
|
| 13 |
+
MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
|
| 14 |
+
HF_TOKEN - HuggingFace token / API key
|
| 15 |
+
|
| 16 |
+
STDOUT FORMAT (mandatory for evaluation):
|
| 17 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 18 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 19 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from typing import List, Optional
|
| 29 |
+
|
| 30 |
+
import requests
|
| 31 |
+
from openai import OpenAI
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Configuration
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 37 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 38 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 39 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
|
| 40 |
+
|
| 41 |
+
BENCHMARK = "dataqa_env"
|
| 42 |
+
TASKS = ["easy", "medium", "hard", "alignment", "coding", "toolcalling"]
|
| 43 |
+
MAX_STEPS_PER_TASK = 3
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Logging helpers (structured stdout — exact format required by evaluation)
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 51 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 55 |
+
error_val = error if error else "null"
|
| 56 |
+
done_val = str(done).lower()
|
| 57 |
+
print(
|
| 58 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 59 |
+
flush=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 64 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 65 |
+
print(
|
| 66 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 67 |
+
flush=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Environment HTTP client
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
class EnvHTTPClient:
|
| 76 |
+
"""Minimal HTTP client for the DataQA environment."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, base_url: str):
|
| 79 |
+
self.base_url = base_url.rstrip("/")
|
| 80 |
+
self.session = requests.Session()
|
| 81 |
+
|
| 82 |
+
def health(self) -> bool:
|
| 83 |
+
try:
|
| 84 |
+
r = self.session.get(f"{self.base_url}/health", timeout=10)
|
| 85 |
+
return r.status_code == 200
|
| 86 |
+
except Exception:
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def reset(self, task_id: str = "easy") -> dict:
|
| 90 |
+
r = self.session.post(
|
| 91 |
+
f"{self.base_url}/reset",
|
| 92 |
+
json={"task_id": task_id},
|
| 93 |
+
timeout=30,
|
| 94 |
+
)
|
| 95 |
+
r.raise_for_status()
|
| 96 |
+
return r.json()
|
| 97 |
+
|
| 98 |
+
def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
|
| 99 |
+
r = self.session.post(
|
| 100 |
+
f"{self.base_url}/step",
|
| 101 |
+
json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
|
| 102 |
+
timeout=30,
|
| 103 |
+
)
|
| 104 |
+
r.raise_for_status()
|
| 105 |
+
return r.json()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# LLM Prompts
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
|
| 113 |
+
|
| 114 |
+
You will be given:
|
| 115 |
+
1. A dataset in CSV format
|
| 116 |
+
2. A schema describing expected column types and constraints
|
| 117 |
+
3. Validation rules that the data should satisfy
|
| 118 |
+
|
| 119 |
+
You must identify ALL data quality issues and report each one in EXACTLY this format:
|
| 120 |
+
row:<row_number>,col:<column_name>,issue:<issue_type>
|
| 121 |
+
|
| 122 |
+
Supported issue types:
|
| 123 |
+
- missing_value (null, empty, or whitespace-only)
|
| 124 |
+
- wrong_type (value doesn't match expected type)
|
| 125 |
+
- duplicate_row (exact duplicate or duplicate key)
|
| 126 |
+
- out_of_range (value outside valid range)
|
| 127 |
+
- format_violation (wrong format, invalid enum value)
|
| 128 |
+
- inconsistent_value (computed field doesn't match, logical inconsistency)
|
| 129 |
+
- statistical_outlier (value is unreasonable given context)
|
| 130 |
+
- referential_integrity (foreign key violation)
|
| 131 |
+
|
| 132 |
+
CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
|
| 133 |
+
- Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
|
| 134 |
+
- Row 1 = the FIRST data row after the header
|
| 135 |
+
- Row 2 = the SECOND data row after the header
|
| 136 |
+
- DO NOT use the employee_id, order_id, or experiment_id as the row number
|
| 137 |
+
- Column names must match exactly (use the CSV header names, lowercase)
|
| 138 |
+
- Check EVERY row and EVERY column systematically
|
| 139 |
+
- Consider cross-column consistency (e.g., total = quantity * price)
|
| 140 |
+
- Look for subtle issues like whitespace-only values, near-duplicates
|
| 141 |
+
- Report ALL issues you find, even if uncertain
|
| 142 |
+
|
| 143 |
+
Respond with ONLY the list of issues, one per line. No other text.
|
| 144 |
+
Example: row:3,col:salary,issue:missing_value"""
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
|
| 148 |
+
|
| 149 |
+
For each issue you identified, propose a fix in EXACTLY this format:
|
| 150 |
+
row:<row_number>,col:<column_name>,fix:<corrected_value>
|
| 151 |
+
|
| 152 |
+
Guidelines for proposing fixes:
|
| 153 |
+
- For missing_value: infer the correct value from context, schema, and other rows
|
| 154 |
+
- For wrong_type: convert to the correct type (e.g., "seventy-five thousand" → "75000")
|
| 155 |
+
- For out_of_range: propose a value within the valid range that makes sense in context
|
| 156 |
+
- For format_violation: correct the format (e.g., "26/01/2024" → "2024-01-26")
|
| 157 |
+
- For inconsistent_value: compute the correct value from related fields
|
| 158 |
+
- For duplicate_row: propose a corrected unique key or indicate removal
|
| 159 |
+
- For statistical_outlier: propose a reasonable value given the model/context
|
| 160 |
+
|
| 161 |
+
Use the schema, validation rules, and surrounding data to determine the correct fix.
|
| 162 |
+
Respond with ONLY the list of fixes, one per line. No other text.
|
| 163 |
+
Example: row:3,col:salary,fix:75000"""
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
|
| 167 |
+
obs = observation if isinstance(observation, dict) else observation
|
| 168 |
+
parts = []
|
| 169 |
+
|
| 170 |
+
if obs.get("task_description"):
|
| 171 |
+
parts.append(f"TASK: {obs['task_description']}")
|
| 172 |
+
|
| 173 |
+
parts.append(f"SCHEMA:\n{obs.get('schema_description', '')}")
|
| 174 |
+
parts.append(f"VALIDATION RULES:\n{obs.get('validation_rules', '')}")
|
| 175 |
+
parts.append(f"DATASET:\n{obs.get('dataset_csv', '')}")
|
| 176 |
+
|
| 177 |
+
hint = obs.get("num_issues_hint", 0)
|
| 178 |
+
if hint:
|
| 179 |
+
parts.append(f"HINT: There are exactly {hint} issues to find.")
|
| 180 |
+
|
| 181 |
+
feedback = obs.get("feedback", "")
|
| 182 |
+
if feedback and "reset" not in feedback.lower():
|
| 183 |
+
parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
|
| 184 |
+
|
| 185 |
+
if include_fixes:
|
| 186 |
+
parts.append(
|
| 187 |
+
"Now propose fixes for ALL issues. "
|
| 188 |
+
"Use format: row:<N>,col:<name>,fix:<corrected_value>"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return "\n\n".join(parts)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def parse_llm_response(response: str) -> list[str]:
|
| 195 |
+
"""Extract issue lines from LLM response."""
|
| 196 |
+
issues = []
|
| 197 |
+
for line in response.strip().split("\n"):
|
| 198 |
+
line = line.strip()
|
| 199 |
+
if not line:
|
| 200 |
+
continue
|
| 201 |
+
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
|
| 202 |
+
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 203 |
+
line = line.strip()
|
| 204 |
+
if "row" in line.lower() and "col" in line.lower():
|
| 205 |
+
match = re.search(
|
| 206 |
+
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
|
| 207 |
+
line,
|
| 208 |
+
re.IGNORECASE,
|
| 209 |
+
)
|
| 210 |
+
if match:
|
| 211 |
+
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
|
| 212 |
+
issues.append(normalized)
|
| 213 |
+
return issues
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse_fix_response(response: str) -> list[str]:
|
| 217 |
+
"""Extract fix lines from LLM response."""
|
| 218 |
+
fixes = []
|
| 219 |
+
for line in response.strip().split("\n"):
|
| 220 |
+
line = line.strip()
|
| 221 |
+
if not line:
|
| 222 |
+
continue
|
| 223 |
+
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
|
| 224 |
+
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 225 |
+
line = line.strip()
|
| 226 |
+
if "row" in line.lower() and "fix" in line.lower():
|
| 227 |
+
match = re.search(
|
| 228 |
+
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
|
| 229 |
+
line,
|
| 230 |
+
re.IGNORECASE,
|
| 231 |
+
)
|
| 232 |
+
if match:
|
| 233 |
+
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
|
| 234 |
+
fixes.append(normalized)
|
| 235 |
+
return fixes
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
|
| 239 |
+
"""Call the LLM with retry on rate limit."""
|
| 240 |
+
for attempt in range(3):
|
| 241 |
+
try:
|
| 242 |
+
response = client.chat.completions.create(
|
| 243 |
+
model=MODEL_NAME,
|
| 244 |
+
messages=[
|
| 245 |
+
{"role": "system", "content": system_prompt},
|
| 246 |
+
{"role": "user", "content": user_prompt},
|
| 247 |
+
],
|
| 248 |
+
temperature=0.1,
|
| 249 |
+
max_tokens=2048,
|
| 250 |
+
)
|
| 251 |
+
return response.choices[0].message.content or ""
|
| 252 |
+
except Exception as e:
|
| 253 |
+
if "rate_limit" in str(e).lower() or "429" in str(e):
|
| 254 |
+
wait = 10 * (attempt + 1)
|
| 255 |
+
print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
|
| 256 |
+
time.sleep(wait)
|
| 257 |
+
else:
|
| 258 |
+
print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
|
| 259 |
+
return ""
|
| 260 |
+
return ""
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
| 264 |
+
"""
|
| 265 |
+
Run a single task with two-phase strategy:
|
| 266 |
+
Step 1: Identify issues only
|
| 267 |
+
Step 2: Identify + Fix (using feedback from step 1)
|
| 268 |
+
Step 3: Refined identify + fix (if needed)
|
| 269 |
+
"""
|
| 270 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 271 |
+
|
| 272 |
+
rewards: List[float] = []
|
| 273 |
+
steps_taken = 0
|
| 274 |
+
best_score = 0.0
|
| 275 |
+
success = False
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
reset_response = env.reset(task_id=task_id)
|
| 279 |
+
observation = reset_response.get("observation", reset_response)
|
| 280 |
+
|
| 281 |
+
last_issues: list[str] = []
|
| 282 |
+
last_llm_output = ""
|
| 283 |
+
|
| 284 |
+
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
|
| 285 |
+
error_msg = None
|
| 286 |
+
|
| 287 |
+
# ── Phase 1: Identify issues ──
|
| 288 |
+
user_prompt = build_user_prompt(observation)
|
| 289 |
+
identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
|
| 290 |
+
issues = parse_llm_response(identify_output)
|
| 291 |
+
|
| 292 |
+
if not issues and not error_msg:
|
| 293 |
+
error_msg = "no issues parsed from LLM response"
|
| 294 |
+
|
| 295 |
+
# ── Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ──
|
| 296 |
+
fixes: list[str] = []
|
| 297 |
+
if issues and step_num >= 2:
|
| 298 |
+
# Build a fix prompt that includes the identified issues
|
| 299 |
+
fix_prompt = build_user_prompt(observation, include_fixes=True)
|
| 300 |
+
fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
|
| 301 |
+
fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
|
| 302 |
+
fixes = parse_fix_response(fix_output)
|
| 303 |
+
|
| 304 |
+
# ── Submit to environment ──
|
| 305 |
+
action_str = ";".join(issues[:5]) if issues else "none"
|
| 306 |
+
if fixes:
|
| 307 |
+
action_str += "|fixes:" + ";".join(fixes[:3])
|
| 308 |
+
|
| 309 |
+
step_response = env.step(issues, fixes, task_id=task_id)
|
| 310 |
+
observation = step_response.get("observation", step_response)
|
| 311 |
+
|
| 312 |
+
reward = float(step_response.get("reward", 0.0) or 0.0)
|
| 313 |
+
done = bool(step_response.get("done", False))
|
| 314 |
+
best_score = max(best_score, reward)
|
| 315 |
+
rewards.append(reward)
|
| 316 |
+
steps_taken = step_num
|
| 317 |
+
|
| 318 |
+
log_step(
|
| 319 |
+
step=step_num,
|
| 320 |
+
action=action_str,
|
| 321 |
+
reward=reward,
|
| 322 |
+
done=done,
|
| 323 |
+
error=error_msg,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if done:
|
| 327 |
+
break
|
| 328 |
+
|
| 329 |
+
last_issues = issues
|
| 330 |
+
last_llm_output = identify_output
|
| 331 |
+
|
| 332 |
+
success = best_score >= 0.5
|
| 333 |
+
|
| 334 |
+
finally:
|
| 335 |
+
log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
|
| 336 |
+
|
| 337 |
+
return best_score
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
# Main
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
|
| 344 |
+
def main():
|
| 345 |
+
print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
|
| 346 |
+
print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
|
| 347 |
+
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
|
| 348 |
+
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
|
| 349 |
+
|
| 350 |
+
env = EnvHTTPClient(ENV_URL)
|
| 351 |
+
llm_client = OpenAI(
|
| 352 |
+
base_url=API_BASE_URL,
|
| 353 |
+
api_key=API_KEY or "no-key",
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if not env.health():
|
| 357 |
+
print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
|
| 358 |
+
sys.exit(1)
|
| 359 |
+
|
| 360 |
+
print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
|
| 361 |
+
|
| 362 |
+
scores = {}
|
| 363 |
+
for task_id in TASKS:
|
| 364 |
+
try:
|
| 365 |
+
score = run_task(llm_client, env, task_id)
|
| 366 |
+
scores[task_id] = score
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
|
| 369 |
+
scores[task_id] = 0.0
|
| 370 |
+
|
| 371 |
+
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
|
| 372 |
+
print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root-level models for OpenEnv compatibility."""
|
| 2 |
+
from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
|
| 3 |
+
|
| 4 |
+
__all__ = ["DataQAAction", "DataQAObservation", "DataQAState"]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: dataqa_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: dataqa_env.server.app:app
|
| 6 |
+
port: 8000
|
openenv_dataqa_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-dataqa-env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Data Quality Assurance Environment for OpenEnv - An LLM agent inspects datasets to find planted quality issues
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
+
Requires-Dist: fastapi>=0.115.0
|
| 8 |
+
Requires-Dist: pydantic>=2.0.0
|
| 9 |
+
Requires-Dist: uvicorn[standard]>=0.24.0
|
| 10 |
+
Requires-Dist: requests>=2.31.0
|
| 11 |
+
Provides-Extra: dev
|
| 12 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 13 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_dataqa_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
dataqa_env/__init__.py
|
| 4 |
+
dataqa_env/client.py
|
| 5 |
+
dataqa_env/models.py
|
| 6 |
+
dataqa_env/server/__init__.py
|
| 7 |
+
dataqa_env/server/app.py
|
| 8 |
+
dataqa_env/server/environment.py
|
| 9 |
+
dataqa_env/server/tasks.py
|
| 10 |
+
openenv_dataqa_env.egg-info/PKG-INFO
|
| 11 |
+
openenv_dataqa_env.egg-info/SOURCES.txt
|
| 12 |
+
openenv_dataqa_env.egg-info/dependency_links.txt
|
| 13 |
+
openenv_dataqa_env.egg-info/entry_points.txt
|
| 14 |
+
openenv_dataqa_env.egg-info/requires.txt
|
| 15 |
+
openenv_dataqa_env.egg-info/top_level.txt
|
openenv_dataqa_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_dataqa_env.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = dataqa_env.server.app:main
|
openenv_dataqa_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
uvicorn[standard]>=0.24.0
|
| 5 |
+
requests>=2.31.0
|
| 6 |
+
|
| 7 |
+
[dev]
|
| 8 |
+
pytest>=8.0.0
|
| 9 |
+
pytest-cov>=4.0.0
|
openenv_dataqa_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
dataqa_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-dataqa-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Data Quality Assurance Environment for OpenEnv - An LLM agent inspects datasets to find planted quality issues"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.2",
|
| 12 |
+
"fastapi>=0.115.0",
|
| 13 |
+
"pydantic>=2.0.0",
|
| 14 |
+
"uvicorn[standard]>=0.24.0",
|
| 15 |
+
"requests>=2.31.0",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
dev = [
|
| 20 |
+
"pytest>=8.0.0",
|
| 21 |
+
"pytest-cov>=4.0.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.scripts]
|
| 25 |
+
server = "dataqa_env.server.app:main"
|
| 26 |
+
|
| 27 |
+
[tool.setuptools]
|
| 28 |
+
packages = ["dataqa_env", "dataqa_env.server"]
|
| 29 |
+
package-dir = { "dataqa_env" = "dataqa_env", "dataqa_env.server" = "dataqa_env/server" }
|
| 30 |
+
|
| 31 |
+
[tool.setuptools.package-data]
|
| 32 |
+
dataqa_env = ["**/*.yaml", "**/*.yml"]
|
scripts/prevalidation_script.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 25 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 113 |
+
|
| 114 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 115 |
+
pass "HF Space is live and responds to /reset"
|
| 116 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 117 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 118 |
+
hint "Check your network connection and that the Space is running."
|
| 119 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 120 |
+
stop_at "Step 1"
|
| 121 |
+
else
|
| 122 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 123 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 124 |
+
hint "Try opening $PING_URL in your browser first."
|
| 125 |
+
stop_at "Step 1"
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 129 |
+
|
| 130 |
+
if ! command -v docker &>/dev/null; then
|
| 131 |
+
fail "docker command not found"
|
| 132 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 133 |
+
stop_at "Step 2"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 137 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 138 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 139 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 140 |
+
else
|
| 141 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 142 |
+
stop_at "Step 2"
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 146 |
+
|
| 147 |
+
BUILD_OK=false
|
| 148 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 149 |
+
|
| 150 |
+
if [ "$BUILD_OK" = true ]; then
|
| 151 |
+
pass "Docker build succeeded"
|
| 152 |
+
else
|
| 153 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 154 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 155 |
+
stop_at "Step 2"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 159 |
+
|
| 160 |
+
if ! command -v openenv &>/dev/null; then
|
| 161 |
+
fail "openenv command not found"
|
| 162 |
+
hint "Install it: pip install openenv-core"
|
| 163 |
+
stop_at "Step 3"
|
| 164 |
+
fi
|
| 165 |
+
|
| 166 |
+
VALIDATE_OK=false
|
| 167 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 168 |
+
|
| 169 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 170 |
+
pass "openenv validate passed"
|
| 171 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 172 |
+
else
|
| 173 |
+
fail "openenv validate failed"
|
| 174 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 175 |
+
stop_at "Step 3"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
printf "\n"
|
| 179 |
+
printf "${BOLD}========================================${NC}\n"
|
| 180 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 182 |
+
printf "${BOLD}========================================${NC}\n"
|
| 183 |
+
printf "\n"
|
| 184 |
+
|
| 185 |
+
exit 0
|
scripts/sample_inference_script.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script Example
|
| 3 |
+
===================================
|
| 4 |
+
MANDATORY
|
| 5 |
+
- Before submitting, ensure the following variables are defined in your environment configuration:
|
| 6 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 7 |
+
MODEL_NAME The model identifier to use for inference.
|
| 8 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 9 |
+
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
|
| 10 |
+
method
|
| 11 |
+
|
| 12 |
+
- Defaults are set only for API_BASE_URL and MODEL_NAME
|
| 13 |
+
(and should reflect your active inference setup):
|
| 14 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
|
| 15 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
|
| 16 |
+
|
| 17 |
+
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 18 |
+
- Participants must use OpenAI Client for all LLM calls using above variables
|
| 19 |
+
|
| 20 |
+
STDOUT FORMAT
|
| 21 |
+
- The script must emit exactly three line types to stdout, in this order:
|
| 22 |
+
|
| 23 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 24 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 25 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 26 |
+
|
| 27 |
+
Rules:
|
| 28 |
+
- One [START] line at episode begin.
|
| 29 |
+
- One [STEP] line per step, immediately after env.step() returns.
|
| 30 |
+
- One [END] line after env.close(), always emitted (even on exception).
|
| 31 |
+
- reward and rewards are formatted to 2 decimal places.
|
| 32 |
+
- done and success are lowercase booleans: true or false.
|
| 33 |
+
- error is the raw last_action_error string, or null if none.
|
| 34 |
+
- All fields on a single line with no newlines within a line.
|
| 35 |
+
- Each tasks should return score in [0, 1]
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
[START] task=click-test env=miniwob model=Qwen3-VL-30B
|
| 39 |
+
[STEP] step=1 action=click('123') reward=0.00 done=false error=null
|
| 40 |
+
[STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
|
| 41 |
+
[STEP] step=3 action=click('789') reward=1.00 done=true error=null
|
| 42 |
+
[END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import asyncio
|
| 46 |
+
import os
|
| 47 |
+
import textwrap
|
| 48 |
+
from typing import List, Optional
|
| 49 |
+
|
| 50 |
+
from openai import OpenAI
|
| 51 |
+
|
| 52 |
+
from my_env_v4 import MyEnvV4Action, MyEnvV4Env
|
| 53 |
+
IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
|
| 54 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 55 |
+
|
| 56 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 57 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 58 |
+
TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
|
| 59 |
+
BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
|
| 60 |
+
MAX_STEPS = 8
|
| 61 |
+
TEMPERATURE = 0.7
|
| 62 |
+
MAX_TOKENS = 150
|
| 63 |
+
SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
|
| 64 |
+
|
| 65 |
+
# Max possible reward: each token contributes 0.1, across all steps
|
| 66 |
+
_MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
|
| 67 |
+
MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
|
| 68 |
+
|
| 69 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 70 |
+
"""
|
| 71 |
+
You are interacting with a simple echo environment.
|
| 72 |
+
Each turn you must send a message. The environment will echo it back.
|
| 73 |
+
Reward is proportional to message length: reward = len(message) * 0.1
|
| 74 |
+
Your goal is to maximize total reward by sending meaningful, substantive messages.
|
| 75 |
+
Reply with exactly one message string — no quotes, no prefixes, just the message text.
|
| 76 |
+
"""
|
| 77 |
+
).strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 81 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 85 |
+
error_val = error if error else "null"
|
| 86 |
+
done_val = str(done).lower()
|
| 87 |
+
print(
|
| 88 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 89 |
+
flush=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 94 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 95 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 99 |
+
history_block = "\n".join(history[-4:]) if history else "None"
|
| 100 |
+
return textwrap.dedent(
|
| 101 |
+
f"""
|
| 102 |
+
Step: {step}
|
| 103 |
+
Last echoed message: {last_echoed!r}
|
| 104 |
+
Last reward: {last_reward:.2f}
|
| 105 |
+
Previous steps:
|
| 106 |
+
{history_block}
|
| 107 |
+
Send your next message.
|
| 108 |
+
"""
|
| 109 |
+
).strip()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 113 |
+
user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
|
| 114 |
+
try:
|
| 115 |
+
completion = client.chat.completions.create(
|
| 116 |
+
model=MODEL_NAME,
|
| 117 |
+
messages=[
|
| 118 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 119 |
+
{"role": "user", "content": user_prompt},
|
| 120 |
+
],
|
| 121 |
+
temperature=TEMPERATURE,
|
| 122 |
+
max_tokens=MAX_TOKENS,
|
| 123 |
+
stream=False,
|
| 124 |
+
)
|
| 125 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 126 |
+
return text if text else "hello"
|
| 127 |
+
except Exception as exc:
|
| 128 |
+
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 129 |
+
return "hello"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def main() -> None:
|
| 133 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 134 |
+
|
| 135 |
+
env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
|
| 136 |
+
|
| 137 |
+
history: List[str] = []
|
| 138 |
+
rewards: List[float] = []
|
| 139 |
+
steps_taken = 0
|
| 140 |
+
score = 0.0
|
| 141 |
+
success = False
|
| 142 |
+
|
| 143 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
result = await env.reset() # OpenENV.reset()
|
| 147 |
+
last_echoed = result.observation.echoed_message
|
| 148 |
+
last_reward = 0.0
|
| 149 |
+
|
| 150 |
+
for step in range(1, MAX_STEPS + 1):
|
| 151 |
+
if result.done:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
message = get_model_message(client, step, last_echoed, last_reward, history)
|
| 155 |
+
|
| 156 |
+
result = await env.step(MyEnvV4Action(message=message))
|
| 157 |
+
obs = result.observation
|
| 158 |
+
|
| 159 |
+
reward = result.reward or 0.0
|
| 160 |
+
done = result.done
|
| 161 |
+
error = None
|
| 162 |
+
|
| 163 |
+
rewards.append(reward)
|
| 164 |
+
steps_taken = step
|
| 165 |
+
last_echoed = obs.echoed_message
|
| 166 |
+
last_reward = reward
|
| 167 |
+
|
| 168 |
+
log_step(step=step, action=message, reward=reward, done=done, error=error)
|
| 169 |
+
|
| 170 |
+
history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
|
| 171 |
+
|
| 172 |
+
if done:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
|
| 176 |
+
score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
|
| 177 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 178 |
+
|
| 179 |
+
finally:
|
| 180 |
+
try:
|
| 181 |
+
await env.close()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
|
| 184 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
asyncio.run(main())
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Root-level server package — delegates to dataqa_env.server."""
|
server/app.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entrypoint for openenv-core deployment. Delegates to dataqa_env.server.app."""
|
| 2 |
+
|
| 3 |
+
from dataqa_env.server.app import app # noqa: F401
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
"""Start the environment server."""
|
| 8 |
+
import uvicorn
|
| 9 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
main()
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.environment import (
|
| 5 |
+
DataQAEnvironment,
|
| 6 |
+
parse_issue_key,
|
| 7 |
+
parse_fix,
|
| 8 |
+
compute_f1,
|
| 9 |
+
compute_weighted_reward,
|
| 10 |
+
grade_fixes,
|
| 11 |
+
IDENTIFY_WEIGHT,
|
| 12 |
+
FIX_WEIGHT,
|
| 13 |
+
)
|
| 14 |
+
from dataqa_env.models import DataQAAction
|
| 15 |
+
from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ──────────────────────────────────────────────────────
|
| 19 |
+
# Issue parsing
|
| 20 |
+
# ──────────────────────────────────────────────────────
|
| 21 |
+
|
| 22 |
+
class TestParseIssueKey:
|
| 23 |
+
def test_standard_format(self):
|
| 24 |
+
assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
|
| 25 |
+
|
| 26 |
+
def test_with_equals(self):
|
| 27 |
+
assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
|
| 28 |
+
|
| 29 |
+
def test_case_insensitive(self):
|
| 30 |
+
assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
|
| 31 |
+
|
| 32 |
+
def test_with_spaces(self):
|
| 33 |
+
assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
|
| 34 |
+
|
| 35 |
+
def test_unparseable(self):
|
| 36 |
+
assert parse_issue_key("this is garbage") is None
|
| 37 |
+
|
| 38 |
+
def test_partial_match(self):
|
| 39 |
+
assert parse_issue_key("row:3,col:salary") is None
|
| 40 |
+
|
| 41 |
+
def test_empty_string(self):
|
| 42 |
+
assert parse_issue_key("") is None
|
| 43 |
+
|
| 44 |
+
def test_semicolon_separator(self):
|
| 45 |
+
result = parse_issue_key("row:3;col:salary;issue:missing_value")
|
| 46 |
+
assert result == "row:3,col:salary,issue:missing_value"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ──────────────────────────────────────────────────────
|
| 50 |
+
# Fix parsing
|
| 51 |
+
# ──────────────────────────────────────────────────────
|
| 52 |
+
|
| 53 |
+
class TestParseFix:
|
| 54 |
+
def test_standard_format(self):
|
| 55 |
+
result = parse_fix("row:4,col:name,fix:Alice Chen")
|
| 56 |
+
assert result == (4, "name", "Alice Chen")
|
| 57 |
+
|
| 58 |
+
def test_with_equals(self):
|
| 59 |
+
result = parse_fix("row=4,col=name,fix=Alice Chen")
|
| 60 |
+
assert result == (4, "name", "Alice Chen")
|
| 61 |
+
|
| 62 |
+
def test_numeric_fix(self):
|
| 63 |
+
result = parse_fix("row:7,col:salary,fix:75000")
|
| 64 |
+
assert result == (7, "salary", "75000")
|
| 65 |
+
|
| 66 |
+
def test_date_fix(self):
|
| 67 |
+
result = parse_fix("row:12,col:order_date,fix:2024-01-26")
|
| 68 |
+
assert result == (12, "order_date", "2024-01-26")
|
| 69 |
+
|
| 70 |
+
def test_case_insensitive(self):
|
| 71 |
+
result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
|
| 72 |
+
assert result == (4, "name", "Alice Chen")
|
| 73 |
+
|
| 74 |
+
def test_unparseable(self):
|
| 75 |
+
assert parse_fix("garbage") is None
|
| 76 |
+
assert parse_fix("row:4,col:name") is None
|
| 77 |
+
|
| 78 |
+
def test_fix_with_special_chars(self):
|
| 79 |
+
result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
|
| 80 |
+
assert result == (1, "email", "alice.chen@company.com")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ──────────────────────────────────────────────────────
|
| 84 |
+
# F1 scoring
|
| 85 |
+
# ──────────────────────────────────────────────────────
|
| 86 |
+
|
| 87 |
+
class TestComputeF1:
|
| 88 |
+
def test_perfect_match(self):
|
| 89 |
+
keys = {"row:1,col:a,issue:missing_value"}
|
| 90 |
+
result = compute_f1(keys, keys)
|
| 91 |
+
assert result["f1"] == 1.0
|
| 92 |
+
|
| 93 |
+
def test_no_reported_no_planted(self):
|
| 94 |
+
result = compute_f1(set(), set())
|
| 95 |
+
assert result["f1"] == 1.0
|
| 96 |
+
|
| 97 |
+
def test_no_reported_some_planted(self):
|
| 98 |
+
planted = {"row:1,col:a,issue:missing_value"}
|
| 99 |
+
result = compute_f1(set(), planted)
|
| 100 |
+
assert result["f1"] == 0.0
|
| 101 |
+
assert result["fn"] == 1
|
| 102 |
+
|
| 103 |
+
def test_all_false_positives(self):
|
| 104 |
+
reported = {"row:99,col:x,issue:wrong_type"}
|
| 105 |
+
planted = {"row:1,col:a,issue:missing_value"}
|
| 106 |
+
result = compute_f1(reported, planted)
|
| 107 |
+
assert result["f1"] == 0.0
|
| 108 |
+
|
| 109 |
+
def test_partial_match(self):
|
| 110 |
+
reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
|
| 111 |
+
planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
|
| 112 |
+
result = compute_f1(reported, planted)
|
| 113 |
+
assert result["tp"] == 1
|
| 114 |
+
assert result["fp"] == 1
|
| 115 |
+
assert result["fn"] == 1
|
| 116 |
+
assert 0 < result["f1"] < 1
|
| 117 |
+
|
| 118 |
+
def test_precision_recall_calculation(self):
|
| 119 |
+
reported = {"a", "b", "c"}
|
| 120 |
+
planted = {"a", "b", "d"}
|
| 121 |
+
result = compute_f1(reported, planted)
|
| 122 |
+
assert result["precision"] == pytest.approx(2 / 3)
|
| 123 |
+
assert result["recall"] == pytest.approx(2 / 3)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ──────────────────────────────────────────────────────
|
| 127 |
+
# Weighted reward
|
| 128 |
+
# ──────────────────────────────────────────────────────
|
| 129 |
+
|
| 130 |
+
class TestComputeWeightedReward:
|
| 131 |
+
def test_perfect_match(self):
|
| 132 |
+
issues = [
|
| 133 |
+
PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
|
| 134 |
+
PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
|
| 135 |
+
]
|
| 136 |
+
reported = {i.to_key() for i in issues}
|
| 137 |
+
result = compute_weighted_reward(reported, issues)
|
| 138 |
+
assert result["weighted_reward"] == 1.0
|
| 139 |
+
|
| 140 |
+
def test_empty_both(self):
|
| 141 |
+
result = compute_weighted_reward(set(), [])
|
| 142 |
+
assert result["weighted_reward"] == 1.0
|
| 143 |
+
|
| 144 |
+
def test_no_reported(self):
|
| 145 |
+
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
|
| 146 |
+
result = compute_weighted_reward(set(), issues)
|
| 147 |
+
assert result["weighted_reward"] == 0.0
|
| 148 |
+
|
| 149 |
+
def test_hard_issue_worth_more(self):
|
| 150 |
+
easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
|
| 151 |
+
hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
|
| 152 |
+
issues = [easy, hard]
|
| 153 |
+
hard_found = compute_weighted_reward({hard.to_key()}, issues)
|
| 154 |
+
easy_found = compute_weighted_reward({easy.to_key()}, issues)
|
| 155 |
+
assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
|
| 156 |
+
|
| 157 |
+
def test_false_positives_reduce_reward(self):
|
| 158 |
+
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
|
| 159 |
+
correct = {issues[0].to_key()}
|
| 160 |
+
with_fp = correct | {"row:99,col:x,issue:wrong_type"}
|
| 161 |
+
r_correct = compute_weighted_reward(correct, issues)
|
| 162 |
+
r_with_fp = compute_weighted_reward(with_fp, issues)
|
| 163 |
+
assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ──────────────────────────────────────────────────────
|
| 167 |
+
# Fix grading
|
| 168 |
+
# ──────────────────────────────────────────────────────
|
| 169 |
+
|
| 170 |
+
class TestGradeFixes:
|
| 171 |
+
@pytest.fixture
|
| 172 |
+
def easy_task(self):
|
| 173 |
+
return create_task_easy()
|
| 174 |
+
|
| 175 |
+
def test_no_fixes_no_issues(self):
|
| 176 |
+
from dataqa_env.server.tasks import Task
|
| 177 |
+
task = Task(task_id="empty", name="", description="", schema_description="",
|
| 178 |
+
validation_rules="", clean_csv="a\n1")
|
| 179 |
+
result = grade_fixes([], task)
|
| 180 |
+
assert result["fix_score"] == 1.0
|
| 181 |
+
|
| 182 |
+
def test_no_fixes_submitted(self, easy_task):
|
| 183 |
+
result = grade_fixes([], easy_task)
|
| 184 |
+
assert result["fix_score"] == 0.0
|
| 185 |
+
assert result["fixes_attempted"] == 0
|
| 186 |
+
|
| 187 |
+
def test_exact_fix_for_missing_name(self, easy_task):
|
| 188 |
+
# Row 4 has empty name — clean value is "David Kim"
|
| 189 |
+
fixes = [(4, "name", "David Kim")]
|
| 190 |
+
result = grade_fixes(fixes, easy_task)
|
| 191 |
+
assert result["fix_score"] > 0.0
|
| 192 |
+
assert result["fixes_correct"] == 1
|
| 193 |
+
|
| 194 |
+
def test_exact_fix_for_wrong_type_salary(self, easy_task):
|
| 195 |
+
# Row 7 has "seventy-five thousand" — clean value is "75000"
|
| 196 |
+
fixes = [(7, "salary", "75000")]
|
| 197 |
+
result = grade_fixes(fixes, easy_task)
|
| 198 |
+
assert result["fixes_correct"] == 1
|
| 199 |
+
|
| 200 |
+
def test_numeric_close_match(self, easy_task):
|
| 201 |
+
# Row 9 has salary "5000" — clean value is "73000"
|
| 202 |
+
# Propose 73100 (within 1% of 73000)
|
| 203 |
+
fixes = [(9, "salary", "73100")]
|
| 204 |
+
result = grade_fixes(fixes, easy_task)
|
| 205 |
+
assert result["fixes_partial"] == 1
|
| 206 |
+
|
| 207 |
+
def test_wrong_value_for_issue_cell(self, easy_task):
|
| 208 |
+
# Row 4 name is empty — propose wrong name
|
| 209 |
+
fixes = [(4, "name", "Wrong Person")]
|
| 210 |
+
result = grade_fixes(fixes, easy_task)
|
| 211 |
+
assert result["fixes_partial"] == 1 # correct cell, wrong value
|
| 212 |
+
assert result["fix_score"] > 0.0 # gets partial credit
|
| 213 |
+
|
| 214 |
+
def test_fix_for_non_issue_cell(self, easy_task):
|
| 215 |
+
# Row 1 col name is fine — no issue there
|
| 216 |
+
fixes = [(1, "name", "Some Name")]
|
| 217 |
+
result = grade_fixes(fixes, easy_task)
|
| 218 |
+
assert result["fixes_wrong"] == 1
|
| 219 |
+
assert result["fix_score"] == 0.0
|
| 220 |
+
|
| 221 |
+
def test_multiple_fixes_best_wins(self, easy_task):
|
| 222 |
+
# Submit two fixes for same cell — best one should count
|
| 223 |
+
fixes = [
|
| 224 |
+
(4, "name", "Wrong Person"), # partial credit
|
| 225 |
+
(4, "name", "David Kim"), # exact match
|
| 226 |
+
]
|
| 227 |
+
result = grade_fixes(fixes, easy_task)
|
| 228 |
+
assert result["fixes_correct"] >= 1
|
| 229 |
+
|
| 230 |
+
def test_all_fixes_correct(self, easy_task):
|
| 231 |
+
# Fix most issues with exact values
|
| 232 |
+
fixes = [
|
| 233 |
+
(4, "name", "David Kim"),
|
| 234 |
+
(7, "salary", "75000"),
|
| 235 |
+
(9, "salary", "73000"),
|
| 236 |
+
(15, "email", "oscar.rivera@company.com"),
|
| 237 |
+
(18, "start_date", "2022-01-19"),
|
| 238 |
+
]
|
| 239 |
+
result = grade_fixes(fixes, easy_task)
|
| 240 |
+
assert result["fix_score"] > 0.7 # 5 out of 6 issues fixed (duplicate can't be fixed)
|
| 241 |
+
|
| 242 |
+
def test_fix_score_bounded(self, easy_task):
|
| 243 |
+
fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
|
| 244 |
+
result = grade_fixes(fixes, easy_task)
|
| 245 |
+
assert 0.0 <= result["fix_score"] <= 1.0
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ──────────────────────────────────────────────────────
|
| 249 |
+
# Full environment lifecycle
|
| 250 |
+
# ──────────────────────────────────────────────────────
|
| 251 |
+
|
| 252 |
+
class TestDataQAEnvironment:
|
| 253 |
+
@pytest.fixture
|
| 254 |
+
def env(self):
|
| 255 |
+
return DataQAEnvironment()
|
| 256 |
+
|
| 257 |
+
def test_reset_returns_observation(self, env):
|
| 258 |
+
obs = env.reset(task_id="easy")
|
| 259 |
+
assert obs.dataset_csv
|
| 260 |
+
assert obs.schema_description
|
| 261 |
+
assert obs.validation_rules
|
| 262 |
+
assert obs.task_description
|
| 263 |
+
assert obs.num_issues_hint == 6
|
| 264 |
+
assert obs.max_steps == 3
|
| 265 |
+
assert obs.done is False
|
| 266 |
+
assert obs.reward == 0.0
|
| 267 |
+
assert "fix" in obs.feedback.lower() # mentions fix phase
|
| 268 |
+
|
| 269 |
+
def test_reset_medium(self, env):
|
| 270 |
+
obs = env.reset(task_id="medium")
|
| 271 |
+
assert obs.num_issues_hint == 8
|
| 272 |
+
|
| 273 |
+
def test_reset_hard(self, env):
|
| 274 |
+
obs = env.reset(task_id="hard")
|
| 275 |
+
assert obs.num_issues_hint == 10
|
| 276 |
+
|
| 277 |
+
def test_step_identify_only(self, env):
|
| 278 |
+
"""Backward compatible: only issues, no fixes."""
|
| 279 |
+
env.reset(task_id="easy")
|
| 280 |
+
# Submit all 6 correct issues for easy task
|
| 281 |
+
action = DataQAAction(
|
| 282 |
+
issues=[
|
| 283 |
+
"row:4,col:name,issue:missing_value",
|
| 284 |
+
"row:7,col:salary,issue:wrong_type",
|
| 285 |
+
"row:21,col:employee_id,issue:duplicate_row",
|
| 286 |
+
"row:9,col:salary,issue:out_of_range",
|
| 287 |
+
"row:15,col:email,issue:inconsistent_value",
|
| 288 |
+
"row:18,col:start_date,issue:out_of_range",
|
| 289 |
+
],
|
| 290 |
+
task_id="easy",
|
| 291 |
+
)
|
| 292 |
+
obs = env.step(action)
|
| 293 |
+
assert obs.done is True
|
| 294 |
+
assert obs.reward >= 0.999 # identify-only uses identify_score directly
|
| 295 |
+
|
| 296 |
+
def test_step_with_fixes_increases_reward(self, env):
|
| 297 |
+
"""Submitting correct fixes should produce high combined reward."""
|
| 298 |
+
env.reset(task_id="easy")
|
| 299 |
+
# All 6 issues + 3 fixes
|
| 300 |
+
action = DataQAAction(
|
| 301 |
+
issues=[
|
| 302 |
+
"row:4,col:name,issue:missing_value",
|
| 303 |
+
"row:7,col:salary,issue:wrong_type",
|
| 304 |
+
"row:21,col:employee_id,issue:duplicate_row",
|
| 305 |
+
"row:9,col:salary,issue:out_of_range",
|
| 306 |
+
"row:15,col:email,issue:inconsistent_value",
|
| 307 |
+
"row:18,col:start_date,issue:out_of_range",
|
| 308 |
+
],
|
| 309 |
+
fixes=[
|
| 310 |
+
"row:4,col:name,fix:David Kim",
|
| 311 |
+
"row:7,col:salary,fix:75000",
|
| 312 |
+
"row:9,col:salary,fix:73000",
|
| 313 |
+
],
|
| 314 |
+
task_id="easy",
|
| 315 |
+
)
|
| 316 |
+
obs = env.step(action)
|
| 317 |
+
# Perfect identify + partial fixes -> high combined reward
|
| 318 |
+
assert obs.metadata["combined_reward"] > 0.7
|
| 319 |
+
|
| 320 |
+
def test_step_with_partial_issues(self, env):
|
| 321 |
+
env.reset(task_id="easy")
|
| 322 |
+
action = DataQAAction(
|
| 323 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 324 |
+
task_id="easy",
|
| 325 |
+
)
|
| 326 |
+
obs = env.step(action)
|
| 327 |
+
assert 0 < obs.reward < 1.0
|
| 328 |
+
assert obs.done is False
|
| 329 |
+
|
| 330 |
+
def test_step_with_no_issues(self, env):
|
| 331 |
+
env.reset(task_id="easy")
|
| 332 |
+
action = DataQAAction(issues=[], task_id="easy")
|
| 333 |
+
obs = env.step(action)
|
| 334 |
+
assert obs.reward == 0.0
|
| 335 |
+
|
| 336 |
+
def test_step_exhausts_max_steps(self, env):
|
| 337 |
+
env.reset(task_id="easy")
|
| 338 |
+
for _ in range(3):
|
| 339 |
+
action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 340 |
+
obs = env.step(action)
|
| 341 |
+
assert obs.done is True
|
| 342 |
+
|
| 343 |
+
def test_auto_reset_on_step(self, env):
|
| 344 |
+
action = DataQAAction(
|
| 345 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 346 |
+
task_id="easy",
|
| 347 |
+
)
|
| 348 |
+
obs = env.step(action)
|
| 349 |
+
assert obs.task_id == "easy"
|
| 350 |
+
|
| 351 |
+
def test_state_tracking(self, env):
|
| 352 |
+
env.reset(task_id="easy")
|
| 353 |
+
assert env.state.task_id == "easy"
|
| 354 |
+
assert env.state.current_step == 0
|
| 355 |
+
assert env.state.best_score == 0.0
|
| 356 |
+
|
| 357 |
+
action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
|
| 358 |
+
env.step(action)
|
| 359 |
+
assert env.state.current_step == 1
|
| 360 |
+
assert env.state.best_score > 0.0
|
| 361 |
+
|
| 362 |
+
def test_best_score_monotonic(self, env):
|
| 363 |
+
env.reset(task_id="easy")
|
| 364 |
+
action1 = DataQAAction(
|
| 365 |
+
issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
|
| 366 |
+
task_id="easy",
|
| 367 |
+
)
|
| 368 |
+
env.step(action1)
|
| 369 |
+
score_after_1 = env.state.best_score
|
| 370 |
+
|
| 371 |
+
action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 372 |
+
env.step(action2)
|
| 373 |
+
assert env.state.best_score >= score_after_1
|
| 374 |
+
|
| 375 |
+
def test_metadata_includes_both_phases(self, env):
|
| 376 |
+
env.reset(task_id="easy")
|
| 377 |
+
action = DataQAAction(
|
| 378 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 379 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 380 |
+
task_id="easy",
|
| 381 |
+
)
|
| 382 |
+
obs = env.step(action)
|
| 383 |
+
m = obs.metadata
|
| 384 |
+
assert "identify_f1" in m
|
| 385 |
+
assert "identify_score" in m
|
| 386 |
+
assert "fix_score" in m
|
| 387 |
+
assert "combined_reward" in m
|
| 388 |
+
assert "tp" in m
|
| 389 |
+
assert "fixes_correct" in m
|
| 390 |
+
assert "fixes_attempted" in m
|
| 391 |
+
|
| 392 |
+
def test_parse_error_in_feedback(self, env):
|
| 393 |
+
env.reset(task_id="easy")
|
| 394 |
+
action = DataQAAction(issues=["garbage input"], task_id="easy")
|
| 395 |
+
obs = env.step(action)
|
| 396 |
+
assert "Parse error" in obs.feedback
|
| 397 |
+
|
| 398 |
+
def test_concurrent_sessions_flag(self):
|
| 399 |
+
assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
|
| 400 |
+
|
| 401 |
+
def test_reward_between_0_and_1(self, env):
|
| 402 |
+
"""Hackathon requirement: scores must be 0.0-1.0."""
|
| 403 |
+
env.reset(task_id="hard")
|
| 404 |
+
for _ in range(3):
|
| 405 |
+
action = DataQAAction(
|
| 406 |
+
issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
|
| 407 |
+
fixes=["row:1,col:x,fix:wrong"],
|
| 408 |
+
task_id="hard",
|
| 409 |
+
)
|
| 410 |
+
obs = env.step(action)
|
| 411 |
+
assert 0.0 <= obs.reward <= 1.0
|
| 412 |
+
|
| 413 |
+
def test_combined_reward_weights(self, env):
|
| 414 |
+
"""Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
|
| 415 |
+
env.reset(task_id="easy")
|
| 416 |
+
action = DataQAAction(
|
| 417 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 418 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 419 |
+
task_id="easy",
|
| 420 |
+
)
|
| 421 |
+
obs = env.step(action)
|
| 422 |
+
m = obs.metadata
|
| 423 |
+
expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
|
| 424 |
+
assert abs(m["combined_reward"] - expected) < 0.01
|
| 425 |
+
|
| 426 |
+
def test_fix_feedback_shown_when_fixes_submitted(self, env):
|
| 427 |
+
env.reset(task_id="easy")
|
| 428 |
+
action = DataQAAction(
|
| 429 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 430 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 431 |
+
task_id="easy",
|
| 432 |
+
)
|
| 433 |
+
obs = env.step(action)
|
| 434 |
+
assert "Fix Proposals" in obs.feedback
|
| 435 |
+
assert "Combined Reward" in obs.feedback
|
| 436 |
+
|
| 437 |
+
def test_no_fix_penalty_when_no_fixes_submitted(self, env):
|
| 438 |
+
"""If agent submits no fixes, reward = identify_score (no penalty)."""
|
| 439 |
+
env.reset(task_id="easy")
|
| 440 |
+
action = DataQAAction(
|
| 441 |
+
issues=[
|
| 442 |
+
"row:4,col:name,issue:missing_value",
|
| 443 |
+
"row:7,col:salary,issue:wrong_type",
|
| 444 |
+
"row:21,col:employee_id,issue:duplicate_row",
|
| 445 |
+
"row:9,col:salary,issue:out_of_range",
|
| 446 |
+
"row:15,col:email,issue:inconsistent_value",
|
| 447 |
+
"row:18,col:start_date,issue:out_of_range",
|
| 448 |
+
],
|
| 449 |
+
task_id="easy",
|
| 450 |
+
)
|
| 451 |
+
obs = env.step(action)
|
| 452 |
+
# identify_score should be ~1.0 since all 6 issues found
|
| 453 |
+
assert obs.reward >= 0.99
|
| 454 |
+
# combined_reward equals identify_score when no fixes
|
| 455 |
+
assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
|
tests/test_extensibility.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the extensibility API — custom tasks and contamination rules."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.tasks import (
|
| 5 |
+
PlantedIssue,
|
| 6 |
+
create_task_from_config,
|
| 7 |
+
register_task,
|
| 8 |
+
register_contamination_rule,
|
| 9 |
+
CONTAMINATION_RULES,
|
| 10 |
+
get_task,
|
| 11 |
+
list_tasks,
|
| 12 |
+
)
|
| 13 |
+
from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
|
| 14 |
+
from dataqa_env.models import DataQAAction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestCreateTaskFromConfig:
|
| 21 |
+
def test_basic_creation(self):
|
| 22 |
+
task = create_task_from_config(
|
| 23 |
+
task_id="test_custom",
|
| 24 |
+
name="Test Task",
|
| 25 |
+
description="Test",
|
| 26 |
+
schema_description="id: int, name: str, score: int",
|
| 27 |
+
validation_rules="No missing values",
|
| 28 |
+
clean_csv=SIMPLE_CSV,
|
| 29 |
+
contaminations=[
|
| 30 |
+
{"rule": "missing_value", "row": 0, "col": 1},
|
| 31 |
+
],
|
| 32 |
+
)
|
| 33 |
+
assert task.task_id == "test_custom"
|
| 34 |
+
assert len(task.planted_issues) == 1
|
| 35 |
+
assert task.planted_issues[0].issue_type == "missing_value"
|
| 36 |
+
assert task.planted_issues[0].col == "name"
|
| 37 |
+
|
| 38 |
+
def test_multiple_contaminations(self):
|
| 39 |
+
task = create_task_from_config(
|
| 40 |
+
task_id="multi",
|
| 41 |
+
name="Multi",
|
| 42 |
+
description="Test",
|
| 43 |
+
schema_description="",
|
| 44 |
+
validation_rules="",
|
| 45 |
+
clean_csv=SIMPLE_CSV,
|
| 46 |
+
contaminations=[
|
| 47 |
+
{"rule": "missing_value", "row": 0, "col": 1},
|
| 48 |
+
{"rule": "missing_value", "row": 2, "col": 1},
|
| 49 |
+
],
|
| 50 |
+
)
|
| 51 |
+
assert len(task.planted_issues) == 2
|
| 52 |
+
|
| 53 |
+
def test_custom_difficulty_override(self):
|
| 54 |
+
task = create_task_from_config(
|
| 55 |
+
task_id="custom_diff",
|
| 56 |
+
name="Custom Difficulty",
|
| 57 |
+
description="Test",
|
| 58 |
+
schema_description="",
|
| 59 |
+
validation_rules="",
|
| 60 |
+
clean_csv=SIMPLE_CSV,
|
| 61 |
+
contaminations=[
|
| 62 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
|
| 63 |
+
],
|
| 64 |
+
)
|
| 65 |
+
assert task.planted_issues[0].difficulty == 2.5
|
| 66 |
+
|
| 67 |
+
def test_callable_rule(self):
|
| 68 |
+
def custom_rule(rows, header, col_idx, row_idx, rng):
|
| 69 |
+
return "CORRUPTED", PlantedIssue(
|
| 70 |
+
row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
|
| 71 |
+
description="Custom corruption", difficulty=1.5,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
task = create_task_from_config(
|
| 75 |
+
task_id="callable",
|
| 76 |
+
name="Callable Rule",
|
| 77 |
+
description="Test",
|
| 78 |
+
schema_description="",
|
| 79 |
+
validation_rules="",
|
| 80 |
+
clean_csv=SIMPLE_CSV,
|
| 81 |
+
contaminations=[
|
| 82 |
+
{"rule": custom_rule, "row": 1, "col": 2},
|
| 83 |
+
],
|
| 84 |
+
)
|
| 85 |
+
assert task.planted_issues[0].issue_type == "wrong_type"
|
| 86 |
+
assert "CORRUPTED" in task.corrupted_csv
|
| 87 |
+
|
| 88 |
+
def test_unknown_rule_raises(self):
|
| 89 |
+
with pytest.raises(ValueError, match="Unknown contamination rule"):
|
| 90 |
+
create_task_from_config(
|
| 91 |
+
task_id="bad",
|
| 92 |
+
name="Bad",
|
| 93 |
+
description="",
|
| 94 |
+
schema_description="",
|
| 95 |
+
validation_rules="",
|
| 96 |
+
clean_csv=SIMPLE_CSV,
|
| 97 |
+
contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TestRegisterContaminationRule:
|
| 102 |
+
def test_register_and_use(self):
|
| 103 |
+
def reverse_value(rows, header, col_idx, row_idx, rng):
|
| 104 |
+
val = rows[row_idx][col_idx]
|
| 105 |
+
return val[::-1], PlantedIssue(
|
| 106 |
+
row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
|
| 107 |
+
description="Reversed value", difficulty=1.5,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
register_contamination_rule("reverse", reverse_value)
|
| 111 |
+
assert "reverse" in CONTAMINATION_RULES
|
| 112 |
+
|
| 113 |
+
task = create_task_from_config(
|
| 114 |
+
task_id="rev_test",
|
| 115 |
+
name="Reverse Test",
|
| 116 |
+
description="",
|
| 117 |
+
schema_description="",
|
| 118 |
+
validation_rules="",
|
| 119 |
+
clean_csv=SIMPLE_CSV,
|
| 120 |
+
contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
|
| 121 |
+
)
|
| 122 |
+
assert task.planted_issues[0].issue_type == "format_violation"
|
| 123 |
+
# "Alice" reversed is "ecilA"
|
| 124 |
+
assert "ecilA" in task.corrupted_csv
|
| 125 |
+
|
| 126 |
+
# Cleanup
|
| 127 |
+
del CONTAMINATION_RULES["reverse"]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class TestRegisterTask:
|
| 131 |
+
def test_register_and_get(self):
|
| 132 |
+
task = create_task_from_config(
|
| 133 |
+
task_id="registered",
|
| 134 |
+
name="Registered Task",
|
| 135 |
+
description="Test registered task",
|
| 136 |
+
schema_description="id: int, name: str",
|
| 137 |
+
validation_rules="No missing values",
|
| 138 |
+
clean_csv=SIMPLE_CSV,
|
| 139 |
+
contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
|
| 140 |
+
)
|
| 141 |
+
register_task("registered", lambda seed: task)
|
| 142 |
+
assert "registered" in list_tasks()
|
| 143 |
+
|
| 144 |
+
fetched = get_task("registered")
|
| 145 |
+
assert fetched.task_id == "registered"
|
| 146 |
+
assert len(fetched.planted_issues) == 1
|
| 147 |
+
|
| 148 |
+
# Cleanup
|
| 149 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 150 |
+
del TASK_REGISTRY["registered"]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TestCustomTaskInEnvironment:
|
| 154 |
+
def test_full_lifecycle_identify_only(self):
|
| 155 |
+
"""Custom task works end-to-end with identify-only."""
|
| 156 |
+
task = create_task_from_config(
|
| 157 |
+
task_id="e2e_custom",
|
| 158 |
+
name="E2E Custom",
|
| 159 |
+
description="End-to-end test",
|
| 160 |
+
schema_description="id: int, name: str, score: int",
|
| 161 |
+
validation_rules="No missing values",
|
| 162 |
+
clean_csv=SIMPLE_CSV,
|
| 163 |
+
contaminations=[
|
| 164 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 165 |
+
{"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
|
| 166 |
+
],
|
| 167 |
+
)
|
| 168 |
+
register_task("e2e_custom", lambda seed: task)
|
| 169 |
+
|
| 170 |
+
env = DataQAEnvironment()
|
| 171 |
+
obs = env.reset(task_id="e2e_custom")
|
| 172 |
+
assert obs.num_issues_hint == 2
|
| 173 |
+
|
| 174 |
+
action = DataQAAction(
|
| 175 |
+
issues=[i.to_key() for i in task.planted_issues],
|
| 176 |
+
task_id="e2e_custom",
|
| 177 |
+
)
|
| 178 |
+
obs = env.step(action)
|
| 179 |
+
assert obs.done is True
|
| 180 |
+
assert obs.reward >= 0.999
|
| 181 |
+
|
| 182 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 183 |
+
del TASK_REGISTRY["e2e_custom"]
|
| 184 |
+
|
| 185 |
+
def test_full_lifecycle_identify_and_fix(self):
|
| 186 |
+
"""Custom task works end-to-end with both identify and fix."""
|
| 187 |
+
task = create_task_from_config(
|
| 188 |
+
task_id="e2e_fix",
|
| 189 |
+
name="E2E Fix",
|
| 190 |
+
description="End-to-end test with fixes",
|
| 191 |
+
schema_description="id: int, name: str, score: int",
|
| 192 |
+
validation_rules="No missing values",
|
| 193 |
+
clean_csv=SIMPLE_CSV,
|
| 194 |
+
contaminations=[
|
| 195 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 196 |
+
],
|
| 197 |
+
)
|
| 198 |
+
register_task("e2e_fix", lambda seed: task)
|
| 199 |
+
|
| 200 |
+
env = DataQAEnvironment()
|
| 201 |
+
env.reset(task_id="e2e_fix")
|
| 202 |
+
|
| 203 |
+
# Submit issues + fix
|
| 204 |
+
action = DataQAAction(
|
| 205 |
+
issues=[task.planted_issues[0].to_key()],
|
| 206 |
+
fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
|
| 207 |
+
task_id="e2e_fix",
|
| 208 |
+
)
|
| 209 |
+
obs = env.step(action)
|
| 210 |
+
assert obs.done is True
|
| 211 |
+
assert obs.metadata["fix_score"] > 0.0
|
| 212 |
+
assert obs.metadata["combined_reward"] > 0.0
|
| 213 |
+
|
| 214 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 215 |
+
del TASK_REGISTRY["e2e_fix"]
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the inference script's parsing, prompt building, and log format."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 8 |
+
from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestParseLLMResponse:
|
| 12 |
+
def test_standard_format(self):
|
| 13 |
+
response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
|
| 14 |
+
issues = parse_llm_response(response)
|
| 15 |
+
assert len(issues) == 2
|
| 16 |
+
assert "row:1,col:name,issue:missing_value" in issues
|
| 17 |
+
|
| 18 |
+
def test_numbered_list(self):
|
| 19 |
+
response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
|
| 20 |
+
issues = parse_llm_response(response)
|
| 21 |
+
assert len(issues) == 2
|
| 22 |
+
|
| 23 |
+
def test_bullet_list(self):
|
| 24 |
+
response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
|
| 25 |
+
issues = parse_llm_response(response)
|
| 26 |
+
assert len(issues) == 2
|
| 27 |
+
|
| 28 |
+
def test_equals_delimiter(self):
|
| 29 |
+
response = "row=1,col=name,issue=missing_value"
|
| 30 |
+
issues = parse_llm_response(response)
|
| 31 |
+
assert len(issues) == 1
|
| 32 |
+
assert issues[0] == "row:1,col:name,issue:missing_value"
|
| 33 |
+
|
| 34 |
+
def test_mixed_case(self):
|
| 35 |
+
response = "Row:1,Col:Name,Issue:Missing_Value"
|
| 36 |
+
issues = parse_llm_response(response)
|
| 37 |
+
assert len(issues) == 1
|
| 38 |
+
assert issues[0] == "row:1,col:name,issue:missing_value"
|
| 39 |
+
|
| 40 |
+
def test_empty_response(self):
|
| 41 |
+
assert parse_llm_response("") == []
|
| 42 |
+
assert parse_llm_response(" ") == []
|
| 43 |
+
|
| 44 |
+
def test_garbage_lines_skipped(self):
|
| 45 |
+
response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
|
| 46 |
+
issues = parse_llm_response(response)
|
| 47 |
+
assert len(issues) == 1
|
| 48 |
+
|
| 49 |
+
def test_deduplication_not_applied(self):
|
| 50 |
+
response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
|
| 51 |
+
issues = parse_llm_response(response)
|
| 52 |
+
assert len(issues) == 2
|
| 53 |
+
|
| 54 |
+
def test_with_column_variant(self):
|
| 55 |
+
response = "row:1,column:name,issue:missing_value"
|
| 56 |
+
issues = parse_llm_response(response)
|
| 57 |
+
assert len(issues) == 1
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TestParseFixResponse:
|
| 61 |
+
def test_standard_format(self):
|
| 62 |
+
response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
|
| 63 |
+
fixes = parse_fix_response(response)
|
| 64 |
+
assert len(fixes) == 2
|
| 65 |
+
assert "row:4,col:name,fix:David Kim" in fixes
|
| 66 |
+
|
| 67 |
+
def test_numbered_list(self):
|
| 68 |
+
response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
|
| 69 |
+
fixes = parse_fix_response(response)
|
| 70 |
+
assert len(fixes) == 2
|
| 71 |
+
|
| 72 |
+
def test_with_special_chars(self):
|
| 73 |
+
response = "row:1,col:email,fix:alice.chen@company.com"
|
| 74 |
+
fixes = parse_fix_response(response)
|
| 75 |
+
assert len(fixes) == 1
|
| 76 |
+
assert "alice.chen@company.com" in fixes[0]
|
| 77 |
+
|
| 78 |
+
def test_empty_response(self):
|
| 79 |
+
assert parse_fix_response("") == []
|
| 80 |
+
|
| 81 |
+
def test_date_fix(self):
|
| 82 |
+
response = "row:12,col:order_date,fix:2024-01-26"
|
| 83 |
+
fixes = parse_fix_response(response)
|
| 84 |
+
assert len(fixes) == 1
|
| 85 |
+
|
| 86 |
+
def test_ignores_issue_lines(self):
|
| 87 |
+
response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
|
| 88 |
+
fixes = parse_fix_response(response)
|
| 89 |
+
assert len(fixes) == 1 # only the fix line
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestBuildUserPrompt:
|
| 93 |
+
def test_includes_all_fields(self):
|
| 94 |
+
obs = {
|
| 95 |
+
"task_description": "Find issues",
|
| 96 |
+
"schema_description": "col: int",
|
| 97 |
+
"validation_rules": "no nulls",
|
| 98 |
+
"dataset_csv": "a,b\n1,2",
|
| 99 |
+
"num_issues_hint": 3,
|
| 100 |
+
"feedback": "",
|
| 101 |
+
}
|
| 102 |
+
prompt = build_user_prompt(obs)
|
| 103 |
+
assert "Find issues" in prompt
|
| 104 |
+
assert "col: int" in prompt
|
| 105 |
+
assert "no nulls" in prompt
|
| 106 |
+
assert "a,b" in prompt
|
| 107 |
+
assert "3 issues" in prompt
|
| 108 |
+
|
| 109 |
+
def test_includes_feedback_on_retry(self):
|
| 110 |
+
obs = {
|
| 111 |
+
"task_description": "Find issues",
|
| 112 |
+
"schema_description": "",
|
| 113 |
+
"validation_rules": "",
|
| 114 |
+
"dataset_csv": "a\n1",
|
| 115 |
+
"num_issues_hint": 0,
|
| 116 |
+
"feedback": "Step 1/3: You missed 2 issues",
|
| 117 |
+
}
|
| 118 |
+
prompt = build_user_prompt(obs)
|
| 119 |
+
assert "FEEDBACK" in prompt
|
| 120 |
+
assert "missed 2" in prompt
|
| 121 |
+
|
| 122 |
+
def test_excludes_reset_feedback(self):
|
| 123 |
+
obs = {
|
| 124 |
+
"task_description": "",
|
| 125 |
+
"schema_description": "",
|
| 126 |
+
"validation_rules": "",
|
| 127 |
+
"dataset_csv": "",
|
| 128 |
+
"num_issues_hint": 0,
|
| 129 |
+
"feedback": "Environment reset. Start inspecting.",
|
| 130 |
+
}
|
| 131 |
+
prompt = build_user_prompt(obs)
|
| 132 |
+
assert "FEEDBACK" not in prompt
|
| 133 |
+
|
| 134 |
+
def test_include_fixes_flag(self):
|
| 135 |
+
obs = {
|
| 136 |
+
"task_description": "Find issues",
|
| 137 |
+
"schema_description": "",
|
| 138 |
+
"validation_rules": "",
|
| 139 |
+
"dataset_csv": "a\n1",
|
| 140 |
+
"num_issues_hint": 0,
|
| 141 |
+
"feedback": "",
|
| 142 |
+
}
|
| 143 |
+
prompt = build_user_prompt(obs, include_fixes=True)
|
| 144 |
+
assert "fix" in prompt.lower()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TestLogFormat:
|
| 148 |
+
"""Verify stdout log format matches hackathon evaluation requirements."""
|
| 149 |
+
|
| 150 |
+
def test_log_start_format(self, capsys):
|
| 151 |
+
log_start(task="easy", env="dataqa_env", model="test-model")
|
| 152 |
+
out = capsys.readouterr().out.strip()
|
| 153 |
+
assert out == "[START] task=easy env=dataqa_env model=test-model"
|
| 154 |
+
|
| 155 |
+
def test_log_step_format(self, capsys):
|
| 156 |
+
log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
|
| 157 |
+
out = capsys.readouterr().out.strip()
|
| 158 |
+
assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
|
| 159 |
+
|
| 160 |
+
def test_log_step_with_error(self, capsys):
|
| 161 |
+
log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
|
| 162 |
+
out = capsys.readouterr().out.strip()
|
| 163 |
+
assert "error=timeout" in out
|
| 164 |
+
assert "done=true" in out
|
| 165 |
+
|
| 166 |
+
def test_log_end_format(self, capsys):
|
| 167 |
+
log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
|
| 168 |
+
out = capsys.readouterr().out.strip()
|
| 169 |
+
assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
|
| 170 |
+
|
| 171 |
+
def test_log_end_failure(self, capsys):
|
| 172 |
+
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
|
| 173 |
+
out = capsys.readouterr().out.strip()
|
| 174 |
+
assert "success=false" in out
|
| 175 |
+
assert "score=0.000" in out
|
| 176 |
+
|
| 177 |
+
def test_reward_format_2_decimal(self, capsys):
|
| 178 |
+
log_step(step=1, action="test", reward=0.123456, done=False, error=None)
|
| 179 |
+
out = capsys.readouterr().out.strip()
|
| 180 |
+
assert "reward=0.12" in out
|
| 181 |
+
|
| 182 |
+
def test_no_newlines_within_line(self, capsys):
|
| 183 |
+
log_start(task="easy", env="dataqa_env", model="model")
|
| 184 |
+
log_step(step=1, action="act", reward=0.0, done=False, error=None)
|
| 185 |
+
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
|
| 186 |
+
out = capsys.readouterr().out
|
| 187 |
+
lines = [l for l in out.split("\n") if l.strip()]
|
| 188 |
+
assert len(lines) == 3
|
| 189 |
+
assert lines[0].startswith("[START]")
|
| 190 |
+
assert lines[1].startswith("[STEP]")
|
| 191 |
+
assert lines[2].startswith("[END]")
|
tests/test_tasks.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for task definitions, data corruption, and issue planting."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.tasks import (
|
| 5 |
+
PlantedIssue,
|
| 6 |
+
Task,
|
| 7 |
+
create_task_easy,
|
| 8 |
+
create_task_medium,
|
| 9 |
+
create_task_hard,
|
| 10 |
+
get_task,
|
| 11 |
+
list_tasks,
|
| 12 |
+
_csv_to_rows,
|
| 13 |
+
_rows_to_csv,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestPlantedIssue:
|
| 18 |
+
def test_to_key(self):
|
| 19 |
+
issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
|
| 20 |
+
assert issue.to_key() == "row:3,col:salary,issue:missing_value"
|
| 21 |
+
|
| 22 |
+
def test_difficulty_default(self):
|
| 23 |
+
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
|
| 24 |
+
assert issue.difficulty == 1.0
|
| 25 |
+
|
| 26 |
+
def test_difficulty_custom(self):
|
| 27 |
+
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
|
| 28 |
+
assert issue.difficulty == 3.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestCSVHelpers:
|
| 32 |
+
def test_roundtrip(self):
|
| 33 |
+
csv_text = "a,b,c\n1,2,3\n4,5,6"
|
| 34 |
+
rows = _csv_to_rows(csv_text)
|
| 35 |
+
assert len(rows) == 3
|
| 36 |
+
result = _rows_to_csv(rows)
|
| 37 |
+
assert "1,2,3" in result
|
| 38 |
+
|
| 39 |
+
def test_empty_csv(self):
|
| 40 |
+
rows = _csv_to_rows("a,b\n")
|
| 41 |
+
assert len(rows) == 1 # header only
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestTaskEasy:
|
| 45 |
+
@pytest.fixture
|
| 46 |
+
def task(self):
|
| 47 |
+
return create_task_easy()
|
| 48 |
+
|
| 49 |
+
def test_task_id(self, task):
|
| 50 |
+
assert task.task_id == "easy"
|
| 51 |
+
|
| 52 |
+
def test_has_6_issues(self, task):
|
| 53 |
+
assert len(task.planted_issues) == 6
|
| 54 |
+
|
| 55 |
+
def test_issue_types(self, task):
|
| 56 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 57 |
+
assert "missing_value" in types
|
| 58 |
+
assert "wrong_type" in types
|
| 59 |
+
assert "duplicate_row" in types
|
| 60 |
+
assert "out_of_range" in types
|
| 61 |
+
assert "inconsistent_value" in types
|
| 62 |
+
|
| 63 |
+
def test_corrupted_csv_differs_from_clean(self, task):
|
| 64 |
+
assert task.corrupted_csv != task.clean_csv
|
| 65 |
+
|
| 66 |
+
def test_issue_keys_unique(self, task):
|
| 67 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 68 |
+
assert len(keys) == len(set(keys))
|
| 69 |
+
|
| 70 |
+
def test_max_steps(self, task):
|
| 71 |
+
assert task.max_steps == 3
|
| 72 |
+
|
| 73 |
+
def test_corrupted_csv_has_more_rows(self, task):
|
| 74 |
+
clean_rows = _csv_to_rows(task.clean_csv)
|
| 75 |
+
corrupt_rows = _csv_to_rows(task.corrupted_csv)
|
| 76 |
+
assert len(corrupt_rows) > len(clean_rows) # duplicate row added
|
| 77 |
+
|
| 78 |
+
def test_difficulty_weights(self, task):
|
| 79 |
+
for issue in task.planted_issues:
|
| 80 |
+
assert 1.0 <= issue.difficulty <= 3.0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TestTaskMedium:
|
| 84 |
+
@pytest.fixture
|
| 85 |
+
def task(self):
|
| 86 |
+
return create_task_medium()
|
| 87 |
+
|
| 88 |
+
def test_task_id(self, task):
|
| 89 |
+
assert task.task_id == "medium"
|
| 90 |
+
|
| 91 |
+
def test_has_8_issues(self, task):
|
| 92 |
+
assert len(task.planted_issues) == 8
|
| 93 |
+
|
| 94 |
+
def test_issue_types(self, task):
|
| 95 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 96 |
+
assert "inconsistent_value" in types
|
| 97 |
+
assert "format_violation" in types
|
| 98 |
+
assert "missing_value" in types
|
| 99 |
+
|
| 100 |
+
def test_issue_keys_unique(self, task):
|
| 101 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 102 |
+
assert len(keys) == len(set(keys))
|
| 103 |
+
|
| 104 |
+
def test_difficulty_weights(self, task):
|
| 105 |
+
for issue in task.planted_issues:
|
| 106 |
+
assert 1.0 <= issue.difficulty <= 3.0
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TestTaskHard:
|
| 110 |
+
@pytest.fixture
|
| 111 |
+
def task(self):
|
| 112 |
+
return create_task_hard()
|
| 113 |
+
|
| 114 |
+
def test_task_id(self, task):
|
| 115 |
+
assert task.task_id == "hard"
|
| 116 |
+
|
| 117 |
+
def test_has_10_issues(self, task):
|
| 118 |
+
assert len(task.planted_issues) == 10
|
| 119 |
+
|
| 120 |
+
def test_issue_types(self, task):
|
| 121 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 122 |
+
assert "inconsistent_value" in types
|
| 123 |
+
assert "format_violation" in types
|
| 124 |
+
assert "statistical_outlier" in types
|
| 125 |
+
assert "out_of_range" in types
|
| 126 |
+
assert "missing_value" in types
|
| 127 |
+
|
| 128 |
+
def test_has_high_difficulty_issues(self, task):
|
| 129 |
+
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
|
| 130 |
+
assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
|
| 131 |
+
|
| 132 |
+
def test_issue_keys_unique(self, task):
|
| 133 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 134 |
+
assert len(keys) == len(set(keys))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TestTaskAlignment:
|
| 138 |
+
@pytest.fixture
|
| 139 |
+
def task(self):
|
| 140 |
+
return create_task_hard() # reuse import, we'll import alignment below
|
| 141 |
+
|
| 142 |
+
def test_alignment_task(self):
|
| 143 |
+
from dataqa_env.server.tasks import get_task
|
| 144 |
+
task = get_task("alignment")
|
| 145 |
+
assert task.task_id == "alignment"
|
| 146 |
+
assert len(task.planted_issues) == 12
|
| 147 |
+
|
| 148 |
+
def test_alignment_issue_types(self):
|
| 149 |
+
from dataqa_env.server.tasks import get_task
|
| 150 |
+
task = get_task("alignment")
|
| 151 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 152 |
+
assert "inconsistent_value" in types # factual errors, mismatches, hallucinations
|
| 153 |
+
assert "missing_value" in types # truncated, whitespace-only
|
| 154 |
+
assert "duplicate_row" in types # duplicate instruction
|
| 155 |
+
|
| 156 |
+
def test_alignment_has_high_difficulty(self):
|
| 157 |
+
from dataqa_env.server.tasks import get_task
|
| 158 |
+
task = get_task("alignment")
|
| 159 |
+
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
|
| 160 |
+
assert len(hard_issues) >= 3 # hallucinated citation, harmful advice, factual error
|
| 161 |
+
|
| 162 |
+
def test_alignment_issue_keys_unique(self):
|
| 163 |
+
from dataqa_env.server.tasks import get_task
|
| 164 |
+
task = get_task("alignment")
|
| 165 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 166 |
+
assert len(keys) == len(set(keys))
|
| 167 |
+
|
| 168 |
+
def test_alignment_corrupted_differs(self):
|
| 169 |
+
from dataqa_env.server.tasks import get_task
|
| 170 |
+
task = get_task("alignment")
|
| 171 |
+
assert task.corrupted_csv != task.clean_csv
|
| 172 |
+
|
| 173 |
+
def test_alignment_in_env(self):
|
| 174 |
+
from dataqa_env.server.environment import DataQAEnvironment
|
| 175 |
+
from dataqa_env.models import DataQAAction
|
| 176 |
+
env = DataQAEnvironment()
|
| 177 |
+
obs = env.reset(task_id="alignment")
|
| 178 |
+
assert obs.num_issues_hint == 12
|
| 179 |
+
# Perfect submission
|
| 180 |
+
from dataqa_env.server.tasks import get_task
|
| 181 |
+
task = get_task("alignment")
|
| 182 |
+
action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="alignment")
|
| 183 |
+
obs = env.step(action)
|
| 184 |
+
assert obs.reward >= 0.99
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TestTaskRegistry:
|
| 188 |
+
def test_list_tasks(self):
|
| 189 |
+
tasks = list_tasks()
|
| 190 |
+
assert set(tasks) == {"easy", "medium", "hard", "alignment", "coding", "toolcalling"}
|
| 191 |
+
|
| 192 |
+
def test_get_task_easy(self):
|
| 193 |
+
task = get_task("easy")
|
| 194 |
+
assert task.task_id == "easy"
|
| 195 |
+
|
| 196 |
+
def test_get_task_medium(self):
|
| 197 |
+
task = get_task("medium")
|
| 198 |
+
assert task.task_id == "medium"
|
| 199 |
+
|
| 200 |
+
def test_get_task_hard(self):
|
| 201 |
+
task = get_task("hard")
|
| 202 |
+
assert task.task_id == "hard"
|
| 203 |
+
|
| 204 |
+
def test_get_task_unknown_raises(self):
|
| 205 |
+
with pytest.raises(ValueError, match="Unknown task"):
|
| 206 |
+
get_task("nonexistent")
|
| 207 |
+
|
| 208 |
+
def test_seed_determinism(self):
|
| 209 |
+
t1 = get_task("easy", seed=42)
|
| 210 |
+
t2 = get_task("easy", seed=42)
|
| 211 |
+
assert t1.corrupted_csv == t2.corrupted_csv
|
| 212 |
+
assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|