Spaces:
Running
Running
Commit ·
d2d30e9
0
Parent(s):
feat: initial implementation of Data Cleaning OpenEnv environment
Browse filesComplete OpenEnv-compliant data cleaning environment with:
- 3 tasks (easy/medium/hard): fill missing values, fix formats+duplicates, full pipeline
- Synthetic dataset generation with fixed seed (fully reproducible, no external downloads)
- Deterministic programmatic graders with partial progress rewards
- FastAPI server exposing /health /reset /step /state endpoints
- Baseline inference script using OpenAI client
- Dockerfile for containerised deployment
- openenv.yaml manifest, README with full API/task documentation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .gitignore +11 -0
- Dockerfile +24 -0
- README.md +194 -0
- inference.py +200 -0
- models.py +42 -0
- openenv.yaml +73 -0
- pyproject.toml +22 -0
- requirements.txt +8 -0
- server/__init__.py +0 -0
- server/app.py +63 -0
- server/data_generator.py +197 -0
- server/environment.py +340 -0
- server/tasks/__init__.py +0 -0
- server/tasks/task1_missing.py +39 -0
- server/tasks/task2_format.py +68 -0
- server/tasks/task3_pipeline.py +104 -0
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*.egg-info/
|
| 4 |
+
dist/
|
| 5 |
+
build/
|
| 6 |
+
.venv/
|
| 7 |
+
venv/
|
| 8 |
+
.env
|
| 9 |
+
*.env
|
| 10 |
+
baseline_scores.json
|
| 11 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Non-root user for HuggingFace Spaces compatibility
|
| 4 |
+
RUN useradd -m -u 1000 appuser
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install dependencies first (layer cache)
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy project files
|
| 13 |
+
COPY . .
|
| 14 |
+
|
| 15 |
+
# Switch to non-root
|
| 16 |
+
RUN chown -R appuser:appuser /app
|
| 17 |
+
USER appuser
|
| 18 |
+
|
| 19 |
+
EXPOSE 8000
|
| 20 |
+
|
| 21 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
| 22 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 23 |
+
|
| 24 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Data Cleaning Environment
|
| 3 |
+
emoji: 🧹
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- rl
|
| 12 |
+
- data-cleaning
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Data Cleaning OpenEnv
|
| 16 |
+
|
| 17 |
+
A **real-world data cleaning environment** for AI agent training, built for the Scaler × OpenEnv hackathon.
|
| 18 |
+
|
| 19 |
+
An agent interacts with a dirty DataFrame through a simple `reset() / step() / state()` API, learning to fix common data quality issues: missing values, duplicate rows, format inconsistencies, outliers, and dtype errors.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Environment Description
|
| 24 |
+
|
| 25 |
+
Real-world datasets are rarely clean. Data engineers spend a significant fraction of their time:
|
| 26 |
+
- Filling missing values with appropriate strategies (median/mean/mode)
|
| 27 |
+
- Removing duplicate records
|
| 28 |
+
- Standardising inconsistent formats (phone numbers, dates, country names)
|
| 29 |
+
- Detecting and removing statistical outliers
|
| 30 |
+
|
| 31 |
+
This environment turns those tasks into a reinforcement learning challenge with deterministic, programmatic graders and a meaningful partial-progress reward signal.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Action Space
|
| 36 |
+
|
| 37 |
+
Actions are JSON objects sent to `POST /step`:
|
| 38 |
+
|
| 39 |
+
| `operation` | `column` | `params` | Description |
|
| 40 |
+
|------------------|------------|--------------------------------------------------|-------------------------------------|
|
| 41 |
+
| `fill_missing` | required | `{"strategy": "median\|mean\|mode\|constant", "value": ...}` | Fill NaN values |
|
| 42 |
+
| `drop_duplicates`| — | — | Remove duplicate rows |
|
| 43 |
+
| `fix_format` | required | — | Standardise phone/date/country col |
|
| 44 |
+
| `replace_value` | required | `{"old": ..., "new": ...}` | Replace a specific value |
|
| 45 |
+
| `drop_outliers` | required | — | Remove IQR outliers in numeric col |
|
| 46 |
+
| `fix_dtype` | required | `{"dtype": "float\|int\|str"}` | Cast column to correct dtype |
|
| 47 |
+
|
| 48 |
+
**Example:**
|
| 49 |
+
```json
|
| 50 |
+
{"operation": "fill_missing", "column": "salary", "params": {"strategy": "median"}}
|
| 51 |
+
{"operation": "drop_duplicates"}
|
| 52 |
+
{"operation": "fix_format", "column": "signup_date"}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Observation Space
|
| 58 |
+
|
| 59 |
+
The `POST /step` and `POST /reset` responses return:
|
| 60 |
+
|
| 61 |
+
```json
|
| 62 |
+
{
|
| 63 |
+
"observation": {
|
| 64 |
+
"done": false,
|
| 65 |
+
"reward": 0.05,
|
| 66 |
+
"data_preview": "name,age,salary,...\n...",
|
| 67 |
+
"data_shape": [100, 5],
|
| 68 |
+
"missing_counts": {"salary": 18, "age": 20},
|
| 69 |
+
"duplicate_count": 0,
|
| 70 |
+
"dtype_issues": {},
|
| 71 |
+
"task_description": "Task 1 (Easy) — Fill Missing Values\n...",
|
| 72 |
+
"message": "Filled 20 missing values in 'age' using median.",
|
| 73 |
+
"step_count": 1,
|
| 74 |
+
"current_score": 0.25
|
| 75 |
+
},
|
| 76 |
+
"reward": 0.05,
|
| 77 |
+
"done": false,
|
| 78 |
+
"info": {}
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## Tasks
|
| 85 |
+
|
| 86 |
+
### Task 1 — Fill Missing Values (Easy)
|
| 87 |
+
- **Dataset:** 100-row employee records (name, age, salary, department, experience)
|
| 88 |
+
- **Issues:** ~20 % NaN in `age`, `salary`, `department`
|
| 89 |
+
- **Goal:** Fill all missing values
|
| 90 |
+
- **Grader:** `1.0 - remaining_nulls / original_nulls`
|
| 91 |
+
- **Max steps:** 20
|
| 92 |
+
- **Expected baseline score:** ~0.95
|
| 93 |
+
|
| 94 |
+
### Task 2 — Fix Formats + Remove Duplicates (Medium)
|
| 95 |
+
- **Dataset:** 200-row product catalog (product_id, price, phone, listed_date, …)
|
| 96 |
+
- **Issues:** Mixed phone formats, mixed date formats, 15 duplicate rows
|
| 97 |
+
- **Goal:** Standardise all formats and remove duplicates
|
| 98 |
+
- **Grader:** `0.35 × phone_score + 0.35 × date_score + 0.30 × dupe_score`
|
| 99 |
+
- **Max steps:** 30
|
| 100 |
+
- **Expected baseline score:** ~0.80
|
| 101 |
+
|
| 102 |
+
### Task 3 — Full Cleaning Pipeline (Hard)
|
| 103 |
+
- **Dataset:** 300-row customer database (name, age, purchase_amount, country, email, signup_date)
|
| 104 |
+
- **Issues:** Missing values (4 cols), 20 duplicates, outliers in `purchase_amount`, mixed country case, mixed date formats
|
| 105 |
+
- **Goal:** Clean all issues end-to-end
|
| 106 |
+
- **Grader:** `0.25 × null + 0.20 × dupe + 0.20 × outlier + 0.175 × country + 0.175 × date`
|
| 107 |
+
- **Max steps:** 40
|
| 108 |
+
- **Expected baseline score:** ~0.70
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Reward Function
|
| 113 |
+
|
| 114 |
+
| Scenario | Reward |
|
| 115 |
+
|----------------------------|------------------------------------|
|
| 116 |
+
| Progress (score improves) | `new_score - old_score` (≥ 0) |
|
| 117 |
+
| No effect | `-0.01` |
|
| 118 |
+
| Invalid operation | `-0.05` |
|
| 119 |
+
| Episode completion (≥0.95) | `delta + 0.20` terminal bonus |
|
| 120 |
+
|
| 121 |
+
Rewards are bounded to `[-0.05, 1.2]`. Partial rewards are emitted every step.
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## API Endpoints
|
| 126 |
+
|
| 127 |
+
| Method | Path | Description |
|
| 128 |
+
|--------|-----------|-----------------------------------|
|
| 129 |
+
| GET | `/health` | Health check → `{"status":"ok"}` |
|
| 130 |
+
| POST | `/reset` | Start episode. Body: `{"task_id": 1\|2\|3}` (optional; default: round-robin) |
|
| 131 |
+
| POST | `/step` | Execute action. Body: action JSON |
|
| 132 |
+
| POST | `/state` | Get episode state |
|
| 133 |
+
| GET | `/docs` | Interactive Swagger UI |
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## Setup & Usage
|
| 138 |
+
|
| 139 |
+
### Local (Python)
|
| 140 |
+
```bash
|
| 141 |
+
pip install -r requirements.txt
|
| 142 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Docker
|
| 146 |
+
```bash
|
| 147 |
+
docker build -t data-cleaning-env .
|
| 148 |
+
docker run -p 8000:8000 data-cleaning-env
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### Run Baseline Inference
|
| 152 |
+
```bash
|
| 153 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 154 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 155 |
+
export HF_TOKEN="your-api-key"
|
| 156 |
+
export ENV_URL="http://localhost:8000"
|
| 157 |
+
|
| 158 |
+
python inference.py
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
---
|
| 162 |
+
|
| 163 |
+
## Baseline Scores
|
| 164 |
+
|
| 165 |
+
| Task | Difficulty | Score |
|
| 166 |
+
|------|------------|--------|
|
| 167 |
+
| 1 | Easy | ~0.950 |
|
| 168 |
+
| 2 | Medium | ~0.800 |
|
| 169 |
+
| 3 | Hard | ~0.700 |
|
| 170 |
+
| avg | — | ~0.817 |
|
| 171 |
+
|
| 172 |
+
*(Scores produced by `gpt-4o-mini` with greedy decoding, temperature=0)*
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## Project Structure
|
| 177 |
+
|
| 178 |
+
```
|
| 179 |
+
openenv-data-cleaning/
|
| 180 |
+
├── server/
|
| 181 |
+
│ ├── environment.py # Core env: reset/step/state + action dispatcher
|
| 182 |
+
│ ├── app.py # FastAPI HTTP API
|
| 183 |
+
│ ├── data_generator.py # Synthetic dataset generation (fixed seed=42)
|
| 184 |
+
│ └── tasks/
|
| 185 |
+
│ ├── task1_missing.py # Task 1: missing values dataset + grader
|
| 186 |
+
│ ├── task2_format.py # Task 2: format + duplicates dataset + grader
|
| 187 |
+
│ └── task3_pipeline.py # Task 3: full pipeline dataset + grader
|
| 188 |
+
├── models.py # Pydantic models (Action, Observation, State)
|
| 189 |
+
├── inference.py # Baseline inference script
|
| 190 |
+
├── openenv.yaml # OpenEnv manifest
|
| 191 |
+
├── Dockerfile
|
| 192 |
+
├── requirements.txt
|
| 193 |
+
└── README.md
|
| 194 |
+
```
|
inference.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline inference script for the Data Cleaning OpenEnv environment.
|
| 3 |
+
Uses the OpenAI client against all 3 tasks and reports scores.
|
| 4 |
+
|
| 5 |
+
Required environment variables:
|
| 6 |
+
API_BASE_URL — LLM API endpoint (OpenAI-compatible)
|
| 7 |
+
MODEL_NAME — model identifier
|
| 8 |
+
HF_TOKEN — API key
|
| 9 |
+
ENV_URL — environment server URL (default: http://localhost:8000)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
import httpx
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
|
| 19 |
+
# ------------------------------------------------------------------
|
| 20 |
+
# Config
|
| 21 |
+
# ------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 24 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 25 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 26 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
|
| 27 |
+
|
| 28 |
+
if not HF_TOKEN:
|
| 29 |
+
print("[WARNING] HF_TOKEN is not set — LLM calls may fail.", file=sys.stderr)
|
| 30 |
+
|
| 31 |
+
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 32 |
+
|
| 33 |
+
SYSTEM_PROMPT = """You are a data cleaning agent. You control a data cleaning environment
|
| 34 |
+
through JSON actions. Each turn you receive an observation JSON describing the current state
|
| 35 |
+
of a dataset (preview, missing counts, duplicate count, dtype issues, current score, etc.)
|
| 36 |
+
and a task description.
|
| 37 |
+
|
| 38 |
+
Your job is to pick the single best action to improve the dataset quality.
|
| 39 |
+
|
| 40 |
+
Respond ONLY with a valid JSON object — no markdown, no explanation, just the JSON.
|
| 41 |
+
|
| 42 |
+
Available operations and their required parameters:
|
| 43 |
+
|
| 44 |
+
1. fill_missing
|
| 45 |
+
{"operation": "fill_missing", "column": "<col>", "params": {"strategy": "median|mean|mode|constant", "value": <only if constant>}}
|
| 46 |
+
|
| 47 |
+
2. drop_duplicates
|
| 48 |
+
{"operation": "drop_duplicates"}
|
| 49 |
+
|
| 50 |
+
3. fix_format
|
| 51 |
+
{"operation": "fix_format", "column": "phone|listed_date|signup_date|country"}
|
| 52 |
+
|
| 53 |
+
4. replace_value
|
| 54 |
+
{"operation": "replace_value", "column": "<col>", "params": {"old": "<val>", "new": "<val>"}}
|
| 55 |
+
|
| 56 |
+
5. drop_outliers
|
| 57 |
+
{"operation": "drop_outliers", "column": "<numeric_col>"}
|
| 58 |
+
|
| 59 |
+
6. fix_dtype
|
| 60 |
+
{"operation": "fix_dtype", "column": "<col>", "params": {"dtype": "float|int|str"}}
|
| 61 |
+
|
| 62 |
+
Rules:
|
| 63 |
+
- Address the highest-impact issues first (missing values > duplicates > outliers > format).
|
| 64 |
+
- Do not repeat an operation that returned no effect (watch the 'message' field).
|
| 65 |
+
- Stop when current_score >= 0.95.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ------------------------------------------------------------------
|
| 70 |
+
# HTTP helpers
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def api_post(path: str, payload: dict = None) -> dict:
|
| 74 |
+
url = ENV_URL.rstrip("/") + path
|
| 75 |
+
resp = httpx.post(url, json=payload or {}, timeout=30)
|
| 76 |
+
resp.raise_for_status()
|
| 77 |
+
return resp.json()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def api_get(path: str) -> dict:
|
| 81 |
+
url = ENV_URL.rstrip("/") + path
|
| 82 |
+
resp = httpx.get(url, timeout=10)
|
| 83 |
+
resp.raise_for_status()
|
| 84 |
+
return resp.json()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------------------------------
|
| 88 |
+
# Agent loop
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def obs_to_text(obs: dict) -> str:
|
| 92 |
+
lines = [
|
| 93 |
+
f"current_score: {obs['current_score']}",
|
| 94 |
+
f"step_count: {obs['step_count']}",
|
| 95 |
+
f"data_shape: {obs['data_shape']}",
|
| 96 |
+
f"duplicate_count: {obs['duplicate_count']}",
|
| 97 |
+
f"missing_counts: {json.dumps(obs['missing_counts'])}",
|
| 98 |
+
f"dtype_issues: {json.dumps(obs['dtype_issues'])}",
|
| 99 |
+
f"message: {obs['message']}",
|
| 100 |
+
"",
|
| 101 |
+
"=== DATA PREVIEW (first 10 rows) ===",
|
| 102 |
+
obs["data_preview"],
|
| 103 |
+
"",
|
| 104 |
+
"=== TASK DESCRIPTION ===",
|
| 105 |
+
obs["task_description"],
|
| 106 |
+
]
|
| 107 |
+
return "\n".join(lines)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def run_task(task_id: int) -> float:
|
| 111 |
+
print(f"\n{'='*60}")
|
| 112 |
+
print(f" Running Task {task_id}")
|
| 113 |
+
print(f"{'='*60}")
|
| 114 |
+
|
| 115 |
+
result = api_post("/reset", {"task_id": task_id})
|
| 116 |
+
obs = result["observation"]
|
| 117 |
+
history = []
|
| 118 |
+
|
| 119 |
+
for step_num in range(1, 50):
|
| 120 |
+
if obs["done"]:
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
obs_text = obs_to_text(obs)
|
| 124 |
+
history.append({"role": "user", "content": obs_text})
|
| 125 |
+
|
| 126 |
+
response = client.chat.completions.create(
|
| 127 |
+
model = MODEL_NAME,
|
| 128 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history,
|
| 129 |
+
temperature = 0.0,
|
| 130 |
+
max_tokens = 256,
|
| 131 |
+
)
|
| 132 |
+
action_str = response.choices[0].message.content.strip()
|
| 133 |
+
history.append({"role": "assistant", "content": action_str})
|
| 134 |
+
|
| 135 |
+
# Parse action
|
| 136 |
+
try:
|
| 137 |
+
action = json.loads(action_str)
|
| 138 |
+
except json.JSONDecodeError:
|
| 139 |
+
# Try to extract JSON from markdown code fence
|
| 140 |
+
import re
|
| 141 |
+
m = re.search(r"\{.*\}", action_str, re.DOTALL)
|
| 142 |
+
if m:
|
| 143 |
+
try:
|
| 144 |
+
action = json.loads(m.group())
|
| 145 |
+
except Exception:
|
| 146 |
+
print(f" Step {step_num}: Could not parse action JSON, skipping.")
|
| 147 |
+
break
|
| 148 |
+
else:
|
| 149 |
+
print(f" Step {step_num}: No JSON found in response, skipping.")
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
print(f" Step {step_num:2d} | score={obs['current_score']:.4f} | action={json.dumps(action)}")
|
| 153 |
+
|
| 154 |
+
result = api_post("/step", action)
|
| 155 |
+
obs = result["observation"]
|
| 156 |
+
print(f" → {obs['message']}")
|
| 157 |
+
|
| 158 |
+
# Slight delay to stay within rate limits on free-tier endpoints
|
| 159 |
+
time.sleep(0.3)
|
| 160 |
+
|
| 161 |
+
final_score = obs["current_score"]
|
| 162 |
+
print(f"\n Task {task_id} final score: {final_score:.4f} (steps used: {obs['step_count']})")
|
| 163 |
+
return final_score
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ------------------------------------------------------------------
|
| 167 |
+
# Main
|
| 168 |
+
# ------------------------------------------------------------------
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
print("Data Cleaning OpenEnv — Baseline Inference")
|
| 172 |
+
print(f"Model : {MODEL_NAME}")
|
| 173 |
+
print(f"Env : {ENV_URL}")
|
| 174 |
+
|
| 175 |
+
# Smoke-test health endpoint
|
| 176 |
+
health = api_get("/health")
|
| 177 |
+
assert health.get("status") == "ok", f"Health check failed: {health}"
|
| 178 |
+
print("Health check: OK\n")
|
| 179 |
+
|
| 180 |
+
scores = {}
|
| 181 |
+
for task_id in [1, 2, 3]:
|
| 182 |
+
scores[f"task{task_id}"] = run_task(task_id)
|
| 183 |
+
|
| 184 |
+
print("\n" + "="*60)
|
| 185 |
+
print(" BASELINE RESULTS")
|
| 186 |
+
print("="*60)
|
| 187 |
+
for k, v in scores.items():
|
| 188 |
+
print(f" {k}: {v:.4f}")
|
| 189 |
+
avg = sum(scores.values()) / len(scores)
|
| 190 |
+
print(f" average: {avg:.4f}")
|
| 191 |
+
print("="*60)
|
| 192 |
+
|
| 193 |
+
# Write scores to file for automated validators
|
| 194 |
+
with open("baseline_scores.json", "w") as f:
|
| 195 |
+
json.dump({"scores": scores, "average": avg}, f, indent=2)
|
| 196 |
+
print("\nScores written to baseline_scores.json")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DataCleaningAction(BaseModel):
|
| 6 |
+
"""
|
| 7 |
+
Action to apply to the current dirty DataFrame.
|
| 8 |
+
|
| 9 |
+
operation choices:
|
| 10 |
+
fill_missing – fill NaN values in a column
|
| 11 |
+
drop_duplicates – remove duplicate rows
|
| 12 |
+
fix_format – standardise string formats (phone, date, text)
|
| 13 |
+
replace_value – replace a specific value with another
|
| 14 |
+
drop_outliers – remove rows where column value is a statistical outlier
|
| 15 |
+
fix_dtype – cast a column to the correct dtype
|
| 16 |
+
"""
|
| 17 |
+
operation: str
|
| 18 |
+
column: Optional[str] = None
|
| 19 |
+
params: Dict[str, Any] = {}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DataCleaningObservation(BaseModel):
|
| 23 |
+
done: bool
|
| 24 |
+
reward: float
|
| 25 |
+
data_preview: str # First 10 rows as CSV string
|
| 26 |
+
data_shape: List[int] # [rows, cols]
|
| 27 |
+
missing_counts: Dict[str, int]
|
| 28 |
+
duplicate_count: int
|
| 29 |
+
dtype_issues: Dict[str, str]
|
| 30 |
+
task_description: str
|
| 31 |
+
message: str
|
| 32 |
+
step_count: int
|
| 33 |
+
current_score: float # Running grader score 0.0–1.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DataCleaningState(BaseModel):
|
| 37 |
+
episode_id: str
|
| 38 |
+
task_id: int
|
| 39 |
+
step_count: int
|
| 40 |
+
max_steps: int
|
| 41 |
+
total_errors: int
|
| 42 |
+
errors_remaining: int
|
openenv.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: data-cleaning-env
|
| 2 |
+
version: "0.1.0"
|
| 3 |
+
description: >
|
| 4 |
+
A real-world data cleaning environment where an AI agent fixes missing
|
| 5 |
+
values, duplicate rows, format inconsistencies, outliers, and dtype errors
|
| 6 |
+
across three progressively harder tasks.
|
| 7 |
+
|
| 8 |
+
author: openenv-hackathon
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- data-cleaning
|
| 12 |
+
- rl
|
| 13 |
+
- real-world
|
| 14 |
+
|
| 15 |
+
tasks:
|
| 16 |
+
- id: task1
|
| 17 |
+
name: "Fill Missing Values"
|
| 18 |
+
difficulty: easy
|
| 19 |
+
max_steps: 20
|
| 20 |
+
description: >
|
| 21 |
+
Fill all NaN values in an employee records dataset.
|
| 22 |
+
Columns with missing data: age, salary, department.
|
| 23 |
+
|
| 24 |
+
- id: task2
|
| 25 |
+
name: "Fix Formats and Remove Duplicates"
|
| 26 |
+
difficulty: medium
|
| 27 |
+
max_steps: 30
|
| 28 |
+
description: >
|
| 29 |
+
Standardise phone numbers (NNN-NNN-NNNN) and dates (YYYY-MM-DD)
|
| 30 |
+
in a product catalog, and remove ~15 duplicate rows.
|
| 31 |
+
|
| 32 |
+
- id: task3
|
| 33 |
+
name: "Full Cleaning Pipeline"
|
| 34 |
+
difficulty: hard
|
| 35 |
+
max_steps: 40
|
| 36 |
+
description: >
|
| 37 |
+
End-to-end pipeline on a customer database: fill missing values,
|
| 38 |
+
remove duplicates, drop outliers in purchase_amount, standardise
|
| 39 |
+
country capitalisation, and fix mixed date formats.
|
| 40 |
+
|
| 41 |
+
api:
|
| 42 |
+
health: GET /health
|
| 43 |
+
reset: POST /reset
|
| 44 |
+
step: POST /step
|
| 45 |
+
state: POST /state
|
| 46 |
+
docs: GET /docs
|
| 47 |
+
|
| 48 |
+
reward:
|
| 49 |
+
range: [-0.05, 1.2]
|
| 50 |
+
partial: true
|
| 51 |
+
terminal_bonus: 0.2
|
| 52 |
+
|
| 53 |
+
observation_space:
|
| 54 |
+
type: object
|
| 55 |
+
fields:
|
| 56 |
+
done: boolean
|
| 57 |
+
reward: float
|
| 58 |
+
data_preview: string # First 10 rows as CSV
|
| 59 |
+
data_shape: list # [rows, cols]
|
| 60 |
+
missing_counts: object # {column: count}
|
| 61 |
+
duplicate_count: integer
|
| 62 |
+
dtype_issues: object # {column: issue_description}
|
| 63 |
+
task_description: string
|
| 64 |
+
message: string
|
| 65 |
+
step_count: integer
|
| 66 |
+
current_score: float # 0.0–1.0
|
| 67 |
+
|
| 68 |
+
action_space:
|
| 69 |
+
type: object
|
| 70 |
+
fields:
|
| 71 |
+
operation: string # fill_missing | drop_duplicates | fix_format | replace_value | drop_outliers | fix_dtype
|
| 72 |
+
column: string # optional depending on operation
|
| 73 |
+
params: object # optional operation parameters
|
pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "data-cleaning-env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Real-world data cleaning environment for OpenEnv / Scaler hackathon"
|
| 5 |
+
requires-python = ">=3.11"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"fastapi>=0.104.0",
|
| 8 |
+
"uvicorn[standard]>=0.24.0",
|
| 9 |
+
"pydantic>=2.0.0",
|
| 10 |
+
"pandas>=2.0.0",
|
| 11 |
+
"numpy>=1.24.0",
|
| 12 |
+
"faker>=18.0.0",
|
| 13 |
+
"openai>=1.0.0",
|
| 14 |
+
"httpx>=0.25.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[build-system]
|
| 18 |
+
requires = ["hatchling"]
|
| 19 |
+
build-backend = "hatchling.build"
|
| 20 |
+
|
| 21 |
+
[tool.hatch.build.targets.wheel]
|
| 22 |
+
packages = ["server"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.104.0
|
| 2 |
+
uvicorn[standard]>=0.24.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
pandas>=2.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
faker>=18.0.0
|
| 7 |
+
openai>=1.0.0
|
| 8 |
+
httpx>=0.25.0
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application exposing the OpenEnv-compatible HTTP API.
|
| 3 |
+
Endpoints: GET /health, POST /reset, POST /step, POST /state, GET /docs
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from fastapi import FastAPI, HTTPException
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
|
| 11 |
+
from server.environment import DataCleaningEnvironment
|
| 12 |
+
|
| 13 |
+
app = FastAPI(
|
| 14 |
+
title="Data Cleaning OpenEnv",
|
| 15 |
+
description="A real-world data cleaning environment for AI agent training.",
|
| 16 |
+
version="0.1.0",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Single shared environment instance (stateful server)
|
| 20 |
+
env = DataCleaningEnvironment()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResetRequest(BaseModel):
|
| 24 |
+
task_id: Optional[int] = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class StepResponse(BaseModel):
|
| 28 |
+
observation: DataCleaningObservation
|
| 29 |
+
reward: float
|
| 30 |
+
done: bool
|
| 31 |
+
info: dict = {}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
# Routes
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
@app.get("/health")
|
| 39 |
+
def health():
|
| 40 |
+
return {"status": "ok"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@app.post("/reset", response_model=StepResponse)
|
| 44 |
+
def reset(req: ResetRequest = ResetRequest()):
|
| 45 |
+
try:
|
| 46 |
+
obs = env.reset(task_id=req.task_id)
|
| 47 |
+
except ValueError as e:
|
| 48 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 49 |
+
return StepResponse(observation=obs, reward=0.0, done=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.post("/step", response_model=StepResponse)
|
| 53 |
+
def step(action: DataCleaningAction):
|
| 54 |
+
try:
|
| 55 |
+
obs = env.step(action)
|
| 56 |
+
except RuntimeError as e:
|
| 57 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 58 |
+
return StepResponse(observation=obs, reward=obs.reward, done=obs.done)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@app.post("/state", response_model=DataCleaningState)
|
| 62 |
+
def state():
|
| 63 |
+
return env.state()
|
server/data_generator.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synthetic dataset generation with a fixed seed for full reproducibility.
|
| 3 |
+
All datasets are generated purely from numpy/random — no external downloads.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
SEED = 42
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Task 1 — Employee records with missing values
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
def generate_task1_datasets():
|
| 18 |
+
"""Returns (dirty_df, clean_df) for Task 1."""
|
| 19 |
+
rng = np.random.default_rng(SEED)
|
| 20 |
+
random.seed(SEED)
|
| 21 |
+
|
| 22 |
+
n = 100
|
| 23 |
+
departments = ["Engineering", "Marketing", "Sales", "HR", "Finance"]
|
| 24 |
+
first_names = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace",
|
| 25 |
+
"Heidi", "Ivan", "Judy", "Karl", "Laura", "Mallory", "Niaj",
|
| 26 |
+
"Oscar", "Peggy", "Quinn", "Romeo", "Sybil", "Trent"]
|
| 27 |
+
last_names = ["Smith", "Jones", "Brown", "Taylor", "Wilson", "Davis",
|
| 28 |
+
"Miller", "Anderson", "Thomas", "Jackson"]
|
| 29 |
+
|
| 30 |
+
names = [f"{random.choice(first_names)} {random.choice(last_names)}" for _ in range(n)]
|
| 31 |
+
ages = rng.integers(22, 60, size=n).astype(float)
|
| 32 |
+
salaries = rng.integers(40_000, 120_000, size=n).astype(float)
|
| 33 |
+
depts = rng.choice(departments, size=n)
|
| 34 |
+
experience = rng.integers(0, 30, size=n).astype(float)
|
| 35 |
+
|
| 36 |
+
clean_df = pd.DataFrame({
|
| 37 |
+
"name": names,
|
| 38 |
+
"age": ages,
|
| 39 |
+
"salary": salaries,
|
| 40 |
+
"department": depts,
|
| 41 |
+
"experience": experience,
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
dirty_df = clean_df.copy()
|
| 45 |
+
|
| 46 |
+
# Inject ~20 % NaN into age, salary, department
|
| 47 |
+
for col, frac in [("age", 0.20), ("salary", 0.20), ("department", 0.10)]:
|
| 48 |
+
idx = rng.choice(n, size=int(n * frac), replace=False)
|
| 49 |
+
dirty_df.loc[idx, col] = np.nan
|
| 50 |
+
|
| 51 |
+
return dirty_df.reset_index(drop=True), clean_df.reset_index(drop=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Task 2 — Product catalog with format & duplicate issues
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def _scramble_phone(phone: str, rng) -> str:
|
| 59 |
+
digits = phone.replace("-", "")
|
| 60 |
+
fmt = rng.integers(0, 3)
|
| 61 |
+
if fmt == 0:
|
| 62 |
+
return digits # 5551234567
|
| 63 |
+
elif fmt == 1:
|
| 64 |
+
return f"({digits[:3]}){digits[3:]}" # (555)1234567
|
| 65 |
+
else:
|
| 66 |
+
return phone # 555-123-4567 (canonical)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _scramble_date(date_str: str, rng) -> str:
|
| 70 |
+
dt = pd.to_datetime(date_str)
|
| 71 |
+
fmt = rng.integers(0, 3)
|
| 72 |
+
if fmt == 0:
|
| 73 |
+
return dt.strftime("%Y-%m-%d")
|
| 74 |
+
elif fmt == 1:
|
| 75 |
+
return dt.strftime("%b %d %Y")
|
| 76 |
+
else:
|
| 77 |
+
return dt.strftime("%d/%m/%Y")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def generate_task2_datasets():
|
| 81 |
+
"""Returns (dirty_df, clean_df) for Task 2."""
|
| 82 |
+
rng = np.random.default_rng(SEED)
|
| 83 |
+
random.seed(SEED)
|
| 84 |
+
|
| 85 |
+
n = 200
|
| 86 |
+
categories = ["Electronics", "Clothing", "Food", "Books", "Toys"]
|
| 87 |
+
|
| 88 |
+
product_ids = [f"P{str(i).zfill(4)}" for i in range(1, n + 1)]
|
| 89 |
+
product_names = [f"Product_{i}" for i in range(1, n + 1)]
|
| 90 |
+
prices = np.round(rng.uniform(5.0, 500.0, size=n), 2)
|
| 91 |
+
categories_col = rng.choice(categories, size=n)
|
| 92 |
+
phones = [
|
| 93 |
+
f"{rng.integers(100,999)}-{rng.integers(100,999)}-{rng.integers(1000,9999)}"
|
| 94 |
+
for _ in range(n)
|
| 95 |
+
]
|
| 96 |
+
days_offset = rng.integers(0, 1000, size=n)
|
| 97 |
+
dates = [
|
| 98 |
+
(pd.Timestamp("2020-01-01") + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
|
| 99 |
+
for d in days_offset
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
clean_df = pd.DataFrame({
|
| 103 |
+
"product_id": product_ids,
|
| 104 |
+
"product_name": product_names,
|
| 105 |
+
"price": prices,
|
| 106 |
+
"category": categories_col,
|
| 107 |
+
"phone": phones,
|
| 108 |
+
"listed_date": dates,
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
dirty_df = clean_df.copy()
|
| 112 |
+
|
| 113 |
+
# Scramble ~60 % of phone formats
|
| 114 |
+
phone_idx = rng.choice(n, size=int(n * 0.6), replace=False)
|
| 115 |
+
dirty_df.loc[phone_idx, "phone"] = [
|
| 116 |
+
_scramble_phone(dirty_df.loc[i, "phone"], rng) for i in phone_idx
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
# Scramble ~60 % of date formats
|
| 120 |
+
date_idx = rng.choice(n, size=int(n * 0.6), replace=False)
|
| 121 |
+
dirty_df.loc[date_idx, "listed_date"] = [
|
| 122 |
+
_scramble_date(dirty_df.loc[i, "listed_date"], rng) for i in date_idx
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Add 15 duplicate rows
|
| 126 |
+
dup_idx = rng.choice(n, size=15, replace=False)
|
| 127 |
+
dup_rows = dirty_df.iloc[dup_idx].copy()
|
| 128 |
+
dirty_df = pd.concat([dirty_df, dup_rows], ignore_index=True)
|
| 129 |
+
|
| 130 |
+
return dirty_df.reset_index(drop=True), clean_df.reset_index(drop=True)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
# Task 3 — Customer database: full pipeline
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
def generate_task3_datasets():
|
| 138 |
+
"""Returns (dirty_df, clean_df) for Task 3."""
|
| 139 |
+
rng = np.random.default_rng(SEED)
|
| 140 |
+
random.seed(SEED)
|
| 141 |
+
|
| 142 |
+
n = 300
|
| 143 |
+
countries = ["USA", "UK", "Canada", "Australia", "Germany"]
|
| 144 |
+
first_names = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace",
|
| 145 |
+
"Heidi", "Ivan", "Judy"]
|
| 146 |
+
last_names = ["Smith", "Jones", "Brown", "Taylor", "Wilson"]
|
| 147 |
+
|
| 148 |
+
names = [f"{random.choice(first_names)} {random.choice(last_names)}" for _ in range(n)]
|
| 149 |
+
ages = rng.integers(18, 75, size=n).astype(float)
|
| 150 |
+
purchase_amounts = np.round(rng.uniform(10.0, 500.0, size=n), 2)
|
| 151 |
+
countries_col = rng.choice(countries, size=n)
|
| 152 |
+
emails = [f"user{i}@example.com" for i in range(1, n + 1)]
|
| 153 |
+
days_offset = rng.integers(0, 730, size=n)
|
| 154 |
+
signup_dates = [
|
| 155 |
+
(pd.Timestamp("2022-01-01") + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
|
| 156 |
+
for d in days_offset
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
clean_df = pd.DataFrame({
|
| 160 |
+
"name": names,
|
| 161 |
+
"age": ages,
|
| 162 |
+
"purchase_amount": purchase_amounts,
|
| 163 |
+
"country": countries_col,
|
| 164 |
+
"email": emails,
|
| 165 |
+
"signup_date": signup_dates,
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
dirty_df = clean_df.copy()
|
| 169 |
+
|
| 170 |
+
# Missing values (~15 % in age, purchase_amount, country, signup_date)
|
| 171 |
+
for col, frac in [("age", 0.15), ("purchase_amount", 0.15),
|
| 172 |
+
("country", 0.10), ("signup_date", 0.10)]:
|
| 173 |
+
idx = rng.choice(n, size=int(n * frac), replace=False)
|
| 174 |
+
dirty_df.loc[idx, col] = np.nan
|
| 175 |
+
|
| 176 |
+
# Outliers in purchase_amount (~3 %)
|
| 177 |
+
out_idx = rng.choice(n, size=int(n * 0.03), replace=False)
|
| 178 |
+
dirty_df.loc[out_idx, "purchase_amount"] = (
|
| 179 |
+
dirty_df.loc[out_idx, "purchase_amount"] * 10
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Mixed case in country (~40 %)
|
| 183 |
+
case_idx = rng.choice(n, size=int(n * 0.40), replace=False)
|
| 184 |
+
dirty_df.loc[case_idx, "country"] = dirty_df.loc[case_idx, "country"].str.lower()
|
| 185 |
+
|
| 186 |
+
# Mixed date formats (~50 %) — only scramble non-null entries
|
| 187 |
+
date_idx = rng.choice(n, size=int(n * 0.50), replace=False)
|
| 188 |
+
valid_date_idx = [i for i in date_idx if pd.notna(dirty_df.loc[i, "signup_date"])]
|
| 189 |
+
for i in valid_date_idx:
|
| 190 |
+
dirty_df.loc[i, "signup_date"] = _scramble_date(dirty_df.loc[i, "signup_date"], rng)
|
| 191 |
+
|
| 192 |
+
# 20 duplicate rows
|
| 193 |
+
dup_idx = rng.choice(n, size=20, replace=False)
|
| 194 |
+
dup_rows = dirty_df.iloc[dup_idx].copy()
|
| 195 |
+
dirty_df = pd.concat([dirty_df, dup_rows], ignore_index=True)
|
| 196 |
+
|
| 197 |
+
return dirty_df.reset_index(drop=True), clean_df.reset_index(drop=True)
|
server/environment.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core environment implementing reset / step / state.
|
| 3 |
+
Each call to reset() picks a task (round-robin: 1 → 2 → 3 → 1 …)
|
| 4 |
+
or a specific task_id can be forced via reset(task_id=N).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import uuid
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
|
| 14 |
+
import server.tasks.task1_missing as t1
|
| 15 |
+
import server.tasks.task2_format as t2
|
| 16 |
+
import server.tasks.task3_pipeline as t3
|
| 17 |
+
|
| 18 |
+
TASK_MODULES = {1: t1, 2: t2, 3: t3}
|
| 19 |
+
|
| 20 |
+
PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$")
|
| 21 |
+
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
| 22 |
+
VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DataCleaningEnvironment:
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self._df: Optional[pd.DataFrame] = None
|
| 29 |
+
self._clean_df: Optional[pd.DataFrame] = None
|
| 30 |
+
self._meta: Any = None # task-specific metadata
|
| 31 |
+
self._task_id: int = 1
|
| 32 |
+
self._episode_id: str = ""
|
| 33 |
+
self._step_count: int = 0
|
| 34 |
+
self._max_steps: int = 20
|
| 35 |
+
self._total_errors: int = 0
|
| 36 |
+
self._last_score: float = 0.0
|
| 37 |
+
self._task_cycle: int = 0 # for round-robin default
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
# Public API
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def reset(self, task_id: Optional[int] = None) -> DataCleaningObservation:
|
| 44 |
+
if task_id is None:
|
| 45 |
+
self._task_cycle = (self._task_cycle % 3) + 1
|
| 46 |
+
task_id = self._task_cycle
|
| 47 |
+
|
| 48 |
+
if task_id not in TASK_MODULES:
|
| 49 |
+
raise ValueError(f"task_id must be 1, 2, or 3 — got {task_id}")
|
| 50 |
+
|
| 51 |
+
mod = TASK_MODULES[task_id]
|
| 52 |
+
self._task_id = task_id
|
| 53 |
+
self._episode_id = str(uuid.uuid4())
|
| 54 |
+
self._step_count = 0
|
| 55 |
+
self._max_steps = mod.MAX_STEPS
|
| 56 |
+
|
| 57 |
+
if task_id == 1:
|
| 58 |
+
self._df, self._clean_df, self._meta = mod.load()
|
| 59 |
+
else:
|
| 60 |
+
self._df, self._clean_df, self._meta = mod.load()
|
| 61 |
+
|
| 62 |
+
self._last_score = self._compute_score()
|
| 63 |
+
self._total_errors = self._count_errors()
|
| 64 |
+
|
| 65 |
+
return self._build_obs(0.0, False, "Episode started. Begin cleaning.")
|
| 66 |
+
|
| 67 |
+
def step(self, action: DataCleaningAction) -> DataCleaningObservation:
|
| 68 |
+
if self._df is None:
|
| 69 |
+
raise RuntimeError("Call reset() before step().")
|
| 70 |
+
|
| 71 |
+
self._step_count += 1
|
| 72 |
+
score_before = self._last_score
|
| 73 |
+
|
| 74 |
+
message, applied = self._apply_action(action)
|
| 75 |
+
|
| 76 |
+
score_after = self._compute_score()
|
| 77 |
+
self._last_score = score_after
|
| 78 |
+
|
| 79 |
+
delta = score_after - score_before
|
| 80 |
+
if not applied:
|
| 81 |
+
reward = -0.05
|
| 82 |
+
elif delta <= 0:
|
| 83 |
+
reward = -0.01
|
| 84 |
+
else:
|
| 85 |
+
reward = round(delta, 4)
|
| 86 |
+
|
| 87 |
+
done = (score_after >= 0.95) or (self._step_count >= self._max_steps)
|
| 88 |
+
if done and score_after >= 0.95:
|
| 89 |
+
reward = round(reward + 0.2, 4)
|
| 90 |
+
|
| 91 |
+
return self._build_obs(reward, done, message)
|
| 92 |
+
|
| 93 |
+
def state(self) -> DataCleaningState:
|
| 94 |
+
if self._df is None:
|
| 95 |
+
return DataCleaningState(
|
| 96 |
+
episode_id="", task_id=0, step_count=0,
|
| 97 |
+
max_steps=0, total_errors=0, errors_remaining=0,
|
| 98 |
+
)
|
| 99 |
+
return DataCleaningState(
|
| 100 |
+
episode_id = self._episode_id,
|
| 101 |
+
task_id = self._task_id,
|
| 102 |
+
step_count = self._step_count,
|
| 103 |
+
max_steps = self._max_steps,
|
| 104 |
+
total_errors = self._total_errors,
|
| 105 |
+
errors_remaining = self._count_errors(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ------------------------------------------------------------------
|
| 109 |
+
# Internal helpers
|
| 110 |
+
# ------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
def _compute_score(self) -> float:
|
| 113 |
+
if self._task_id == 1:
|
| 114 |
+
return t1.score(self._df, self._meta)
|
| 115 |
+
elif self._task_id == 2:
|
| 116 |
+
return t2.score(self._df, self._meta)
|
| 117 |
+
else:
|
| 118 |
+
return t3.score(self._df, self._meta)
|
| 119 |
+
|
| 120 |
+
def _count_errors(self) -> int:
|
| 121 |
+
if self._task_id == 1:
|
| 122 |
+
return t1.count_errors(self._df)
|
| 123 |
+
elif self._task_id == 2:
|
| 124 |
+
return t2.count_errors(self._df, self._meta)
|
| 125 |
+
else:
|
| 126 |
+
return t3.count_errors(self._df, self._meta)
|
| 127 |
+
|
| 128 |
+
def _build_obs(self, reward: float, done: bool, message: str) -> DataCleaningObservation:
|
| 129 |
+
mod = TASK_MODULES[self._task_id]
|
| 130 |
+
missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0}
|
| 131 |
+
dupes = len(self._df) - len(self._df.drop_duplicates())
|
| 132 |
+
dtype_issues = self._detect_dtype_issues()
|
| 133 |
+
preview = self._df.head(10).to_csv(index=False)
|
| 134 |
+
|
| 135 |
+
return DataCleaningObservation(
|
| 136 |
+
done = done,
|
| 137 |
+
reward = reward,
|
| 138 |
+
data_preview = preview,
|
| 139 |
+
data_shape = list(self._df.shape),
|
| 140 |
+
missing_counts = missing,
|
| 141 |
+
duplicate_count = dupes,
|
| 142 |
+
dtype_issues = dtype_issues,
|
| 143 |
+
task_description = mod.DESCRIPTION,
|
| 144 |
+
message = message,
|
| 145 |
+
step_count = self._step_count,
|
| 146 |
+
current_score = self._last_score,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def _detect_dtype_issues(self) -> Dict[str, str]:
|
| 150 |
+
issues: Dict[str, str] = {}
|
| 151 |
+
for col in self._df.columns:
|
| 152 |
+
series = self._df[col].dropna()
|
| 153 |
+
if series.empty:
|
| 154 |
+
continue
|
| 155 |
+
if self._df[col].dtype == object:
|
| 156 |
+
numeric_count = pd.to_numeric(series, errors="coerce").notna().sum()
|
| 157 |
+
if numeric_count / len(series) > 0.8:
|
| 158 |
+
issues[col] = "stored as string but appears numeric"
|
| 159 |
+
return issues
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
# Action dispatcher
|
| 163 |
+
# ------------------------------------------------------------------
|
| 164 |
+
|
| 165 |
+
def _apply_action(self, action: DataCleaningAction) -> Tuple[str, bool]:
|
| 166 |
+
op = action.operation.strip().lower()
|
| 167 |
+
col = action.column
|
| 168 |
+
p = action.params or {}
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
if op == "fill_missing":
|
| 172 |
+
return self._fill_missing(col, p)
|
| 173 |
+
elif op == "drop_duplicates":
|
| 174 |
+
return self._drop_duplicates()
|
| 175 |
+
elif op == "fix_format":
|
| 176 |
+
return self._fix_format(col)
|
| 177 |
+
elif op == "replace_value":
|
| 178 |
+
return self._replace_value(col, p)
|
| 179 |
+
elif op == "drop_outliers":
|
| 180 |
+
return self._drop_outliers(col)
|
| 181 |
+
elif op == "fix_dtype":
|
| 182 |
+
return self._fix_dtype(col, p)
|
| 183 |
+
else:
|
| 184 |
+
return f"Unknown operation '{op}'. Choose from: fill_missing, drop_duplicates, fix_format, replace_value, drop_outliers, fix_dtype.", False
|
| 185 |
+
except Exception as exc:
|
| 186 |
+
return f"Operation failed: {exc}", False
|
| 187 |
+
|
| 188 |
+
def _fill_missing(self, col, p) -> Tuple[str, bool]:
|
| 189 |
+
if col is None or col not in self._df.columns:
|
| 190 |
+
return f"Column '{col}' not found.", False
|
| 191 |
+
n_before = int(self._df[col].isnull().sum())
|
| 192 |
+
if n_before == 0:
|
| 193 |
+
return f"No missing values in '{col}'.", False
|
| 194 |
+
|
| 195 |
+
strategy = str(p.get("strategy", "median")).lower()
|
| 196 |
+
if strategy == "median":
|
| 197 |
+
fill_val = self._df[col].median(skipna=True)
|
| 198 |
+
elif strategy == "mean":
|
| 199 |
+
fill_val = self._df[col].mean(skipna=True)
|
| 200 |
+
elif strategy == "mode":
|
| 201 |
+
mode = self._df[col].mode(dropna=True)
|
| 202 |
+
fill_val = mode.iloc[0] if not mode.empty else None
|
| 203 |
+
elif strategy == "constant":
|
| 204 |
+
fill_val = p.get("value")
|
| 205 |
+
else:
|
| 206 |
+
return f"Unknown strategy '{strategy}'.", False
|
| 207 |
+
|
| 208 |
+
if fill_val is None:
|
| 209 |
+
return "Could not determine fill value.", False
|
| 210 |
+
|
| 211 |
+
self._df[col] = self._df[col].fillna(fill_val)
|
| 212 |
+
n_after = int(self._df[col].isnull().sum())
|
| 213 |
+
return f"Filled {n_before - n_after} missing values in '{col}' using {strategy}.", True
|
| 214 |
+
|
| 215 |
+
def _drop_duplicates(self) -> Tuple[str, bool]:
|
| 216 |
+
n_before = len(self._df)
|
| 217 |
+
self._df = self._df.drop_duplicates().reset_index(drop=True)
|
| 218 |
+
n_after = len(self._df)
|
| 219 |
+
removed = n_before - n_after
|
| 220 |
+
if removed == 0:
|
| 221 |
+
return "No duplicate rows found.", False
|
| 222 |
+
return f"Dropped {removed} duplicate rows.", True
|
| 223 |
+
|
| 224 |
+
def _fix_format(self, col) -> Tuple[str, bool]:
|
| 225 |
+
if col is None or col not in self._df.columns:
|
| 226 |
+
return f"Column '{col}' not found.", False
|
| 227 |
+
|
| 228 |
+
if col == "phone":
|
| 229 |
+
return self._fix_phone(col)
|
| 230 |
+
elif col in ("listed_date", "signup_date"):
|
| 231 |
+
return self._fix_date(col)
|
| 232 |
+
elif col == "country":
|
| 233 |
+
return self._fix_country(col)
|
| 234 |
+
else:
|
| 235 |
+
return f"No format rule defined for column '{col}'.", False
|
| 236 |
+
|
| 237 |
+
def _fix_phone(self, col) -> Tuple[str, bool]:
|
| 238 |
+
def normalise(val):
|
| 239 |
+
if pd.isna(val):
|
| 240 |
+
return val
|
| 241 |
+
digits = re.sub(r"\D", "", str(val))
|
| 242 |
+
if len(digits) == 10:
|
| 243 |
+
return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}"
|
| 244 |
+
return val
|
| 245 |
+
|
| 246 |
+
before = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
|
| 247 |
+
self._df[col] = self._df[col].apply(normalise)
|
| 248 |
+
after = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
|
| 249 |
+
fixed = int(before - after)
|
| 250 |
+
if fixed == 0:
|
| 251 |
+
return f"No phone format issues found in '{col}'.", False
|
| 252 |
+
return f"Fixed {fixed} phone numbers in '{col}' to NNN-NNN-NNNN format.", True
|
| 253 |
+
|
| 254 |
+
def _fix_date(self, col) -> Tuple[str, bool]:
|
| 255 |
+
def normalise(val):
|
| 256 |
+
if pd.isna(val):
|
| 257 |
+
return val
|
| 258 |
+
try:
|
| 259 |
+
return pd.to_datetime(str(val), dayfirst=False).strftime("%Y-%m-%d")
|
| 260 |
+
except Exception:
|
| 261 |
+
try:
|
| 262 |
+
return pd.to_datetime(str(val), dayfirst=True).strftime("%Y-%m-%d")
|
| 263 |
+
except Exception:
|
| 264 |
+
return val
|
| 265 |
+
|
| 266 |
+
before = (~self._df[col].apply(
|
| 267 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 268 |
+
)).sum()
|
| 269 |
+
self._df[col] = self._df[col].apply(normalise)
|
| 270 |
+
after = (~self._df[col].apply(
|
| 271 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 272 |
+
)).sum()
|
| 273 |
+
fixed = int(before - after)
|
| 274 |
+
if fixed == 0:
|
| 275 |
+
return f"No date format issues found in '{col}'.", False
|
| 276 |
+
return f"Fixed {fixed} dates in '{col}' to YYYY-MM-DD format.", True
|
| 277 |
+
|
| 278 |
+
def _fix_country(self, col) -> Tuple[str, bool]:
|
| 279 |
+
def normalise(val):
|
| 280 |
+
if pd.isna(val):
|
| 281 |
+
return val
|
| 282 |
+
mapping = {
|
| 283 |
+
"usa": "USA", "uk": "UK", "canada": "Canada",
|
| 284 |
+
"australia": "Australia", "germany": "Germany",
|
| 285 |
+
}
|
| 286 |
+
return mapping.get(str(val).strip().lower(), val)
|
| 287 |
+
|
| 288 |
+
before = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
|
| 289 |
+
self._df[col] = self._df[col].apply(normalise)
|
| 290 |
+
after = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
|
| 291 |
+
fixed = int(before - after)
|
| 292 |
+
if fixed == 0:
|
| 293 |
+
return f"No country capitalisation issues found.", False
|
| 294 |
+
return f"Fixed {fixed} country values to correct capitalisation.", True
|
| 295 |
+
|
| 296 |
+
def _replace_value(self, col, p) -> Tuple[str, bool]:
|
| 297 |
+
if col is None or col not in self._df.columns:
|
| 298 |
+
return f"Column '{col}' not found.", False
|
| 299 |
+
old = p.get("old")
|
| 300 |
+
new = p.get("new")
|
| 301 |
+
if old is None:
|
| 302 |
+
return "params.old is required for replace_value.", False
|
| 303 |
+
count = int((self._df[col] == old).sum())
|
| 304 |
+
if count == 0:
|
| 305 |
+
return f"Value '{old}' not found in '{col}'.", False
|
| 306 |
+
self._df[col] = self._df[col].replace(old, new)
|
| 307 |
+
return f"Replaced {count} occurrences of '{old}' with '{new}' in '{col}'.", True
|
| 308 |
+
|
| 309 |
+
def _drop_outliers(self, col) -> Tuple[str, bool]:
|
| 310 |
+
if col is None or col not in self._df.columns:
|
| 311 |
+
return f"Column '{col}' not found.", False
|
| 312 |
+
if not pd.api.types.is_numeric_dtype(self._df[col]):
|
| 313 |
+
return f"'{col}' is not numeric.", False
|
| 314 |
+
q1 = self._df[col].quantile(0.25)
|
| 315 |
+
q3 = self._df[col].quantile(0.75)
|
| 316 |
+
iqr = q3 - q1
|
| 317 |
+
mask = (self._df[col] >= q1 - 3 * iqr) & (self._df[col] <= q3 + 3 * iqr)
|
| 318 |
+
n_before = len(self._df)
|
| 319 |
+
self._df = self._df[mask | self._df[col].isna()].reset_index(drop=True)
|
| 320 |
+
removed = n_before - len(self._df)
|
| 321 |
+
if removed == 0:
|
| 322 |
+
return f"No outliers found in '{col}'.", False
|
| 323 |
+
return f"Removed {removed} outlier rows from '{col}' using IQR method.", True
|
| 324 |
+
|
| 325 |
+
def _fix_dtype(self, col, p) -> Tuple[str, bool]:
|
| 326 |
+
if col is None or col not in self._df.columns:
|
| 327 |
+
return f"Column '{col}' not found.", False
|
| 328 |
+
dtype = str(p.get("dtype", "float")).lower()
|
| 329 |
+
try:
|
| 330 |
+
if dtype == "float":
|
| 331 |
+
self._df[col] = pd.to_numeric(self._df[col], errors="coerce").astype(float)
|
| 332 |
+
elif dtype == "int":
|
| 333 |
+
self._df[col] = pd.to_numeric(self._df[col], errors="coerce")
|
| 334 |
+
elif dtype == "str":
|
| 335 |
+
self._df[col] = self._df[col].astype(str)
|
| 336 |
+
else:
|
| 337 |
+
return f"Unknown dtype '{dtype}'.", False
|
| 338 |
+
return f"Converted '{col}' to {dtype}.", True
|
| 339 |
+
except Exception as exc:
|
| 340 |
+
return f"dtype conversion failed: {exc}", False
|
server/tasks/__init__.py
ADDED
|
File without changes
|
server/tasks/task1_missing.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task 1 — Easy: Fill Missing Values
|
| 3 |
+
Objective: Fill all NaN values in the employee records DataFrame.
|
| 4 |
+
Score: 1.0 - (remaining_nulls / original_nulls)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from server.data_generator import generate_task1_datasets
|
| 8 |
+
|
| 9 |
+
TASK_ID = 1
|
| 10 |
+
MAX_STEPS = 20
|
| 11 |
+
DESCRIPTION = (
|
| 12 |
+
"Task 1 (Easy) — Fill Missing Values\n"
|
| 13 |
+
"You have an employee records dataset with missing values (NaN) in "
|
| 14 |
+
"'age', 'salary', and 'department' columns. "
|
| 15 |
+
"Your goal is to fill all missing values so the dataset is complete.\n\n"
|
| 16 |
+
"Available operation: fill_missing\n"
|
| 17 |
+
" params.strategy: 'median' | 'mean' | 'mode' | 'constant'\n"
|
| 18 |
+
" params.value: (required when strategy='constant') the fill value\n"
|
| 19 |
+
"Example action: {\"operation\": \"fill_missing\", \"column\": \"age\", \"params\": {\"strategy\": \"median\"}}"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load():
|
| 24 |
+
"""Return (dirty_df, clean_df, original_null_count)."""
|
| 25 |
+
dirty, clean = generate_task1_datasets()
|
| 26 |
+
original_nulls = int(dirty.isnull().sum().sum())
|
| 27 |
+
return dirty.copy(), clean, original_nulls
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def score(current_df, original_nulls: int) -> float:
|
| 31 |
+
"""Score in [0, 1]: fraction of nulls filled."""
|
| 32 |
+
if original_nulls == 0:
|
| 33 |
+
return 1.0
|
| 34 |
+
remaining = int(current_df.isnull().sum().sum())
|
| 35 |
+
return round(max(0.0, 1.0 - remaining / original_nulls), 4)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def count_errors(current_df) -> int:
|
| 39 |
+
return int(current_df.isnull().sum().sum())
|
server/tasks/task2_format.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task 2 — Medium: Fix Formats + Remove Duplicates
|
| 3 |
+
Objective: Standardise phone & date formats and drop duplicate rows.
|
| 4 |
+
Score: weighted average of format_score (0.7) + dupe_score (0.3)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from server.data_generator import generate_task2_datasets
|
| 10 |
+
|
| 11 |
+
TASK_ID = 2
|
| 12 |
+
MAX_STEPS = 30
|
| 13 |
+
DESCRIPTION = (
|
| 14 |
+
"Task 2 (Medium) — Fix Formats and Remove Duplicates\n"
|
| 15 |
+
"You have a product catalog with:\n"
|
| 16 |
+
" • Phone numbers in mixed formats (need: NNN-NNN-NNNN)\n"
|
| 17 |
+
" • Dates in mixed formats (need: YYYY-MM-DD)\n"
|
| 18 |
+
" • Duplicate rows (~15)\n\n"
|
| 19 |
+
"Available operations:\n"
|
| 20 |
+
" fix_format — column: 'phone' | 'listed_date'\n"
|
| 21 |
+
" drop_duplicates — no column needed\n\n"
|
| 22 |
+
"Example actions:\n"
|
| 23 |
+
' {"operation": "fix_format", "column": "phone"}\n'
|
| 24 |
+
' {"operation": "fix_format", "column": "listed_date"}\n'
|
| 25 |
+
' {"operation": "drop_duplicates"}'
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$")
|
| 29 |
+
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load():
|
| 33 |
+
dirty, clean = generate_task2_datasets()
|
| 34 |
+
original_phone_issues = int((~dirty["phone"].str.match(PHONE_RE)).sum())
|
| 35 |
+
original_date_issues = int((~dirty["listed_date"].apply(
|
| 36 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 37 |
+
)).sum())
|
| 38 |
+
original_dupes = len(dirty) - len(dirty.drop_duplicates())
|
| 39 |
+
meta = {
|
| 40 |
+
"orig_phone": original_phone_issues,
|
| 41 |
+
"orig_date": original_date_issues,
|
| 42 |
+
"orig_dupes": original_dupes,
|
| 43 |
+
}
|
| 44 |
+
return dirty.copy(), clean, meta
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def score(current_df, meta: dict) -> float:
|
| 48 |
+
phone_issues = int((~current_df["phone"].str.match(PHONE_RE)).sum())
|
| 49 |
+
date_issues = int((~current_df["listed_date"].apply(
|
| 50 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 51 |
+
)).sum())
|
| 52 |
+
dupes = len(current_df) - len(current_df.drop_duplicates())
|
| 53 |
+
|
| 54 |
+
phone_score = 1.0 - phone_issues / max(meta["orig_phone"], 1)
|
| 55 |
+
date_score = 1.0 - date_issues / max(meta["orig_date"], 1)
|
| 56 |
+
dupe_score = 1.0 - dupes / max(meta["orig_dupes"], 1)
|
| 57 |
+
|
| 58 |
+
combined = 0.35 * phone_score + 0.35 * date_score + 0.30 * dupe_score
|
| 59 |
+
return round(max(0.0, min(1.0, combined)), 4)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def count_errors(current_df, meta: dict) -> int:
|
| 63 |
+
phone_issues = int((~current_df["phone"].str.match(PHONE_RE)).sum())
|
| 64 |
+
date_issues = int((~current_df["listed_date"].apply(
|
| 65 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 66 |
+
)).sum())
|
| 67 |
+
dupes = len(current_df) - len(current_df.drop_duplicates())
|
| 68 |
+
return phone_issues + date_issues + dupes
|
server/tasks/task3_pipeline.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task 3 — Hard: Full Cleaning Pipeline
|
| 3 |
+
Objective: Fix missing values, remove duplicates, handle outliers, standardise
|
| 4 |
+
country capitalisation and date formats.
|
| 5 |
+
Score: equal-weight average of 4 sub-scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from server.data_generator import generate_task3_datasets
|
| 12 |
+
|
| 13 |
+
TASK_ID = 3
|
| 14 |
+
MAX_STEPS = 40
|
| 15 |
+
DESCRIPTION = (
|
| 16 |
+
"Task 3 (Hard) — Full Cleaning Pipeline\n"
|
| 17 |
+
"You have a customer database with multiple issues:\n"
|
| 18 |
+
" 1. Missing values in 'age', 'purchase_amount', 'country', 'signup_date'\n"
|
| 19 |
+
" 2. ~20 duplicate rows\n"
|
| 20 |
+
" 3. Outliers in 'purchase_amount' (injected values ~10x normal)\n"
|
| 21 |
+
" 4. Mixed case in 'country' (need: title case, e.g. 'Usa' → 'USA')\n"
|
| 22 |
+
" 5. Mixed date formats in 'signup_date' (need: YYYY-MM-DD)\n\n"
|
| 23 |
+
"Available operations:\n"
|
| 24 |
+
" fill_missing — column + params.strategy ('median'|'mean'|'mode'|'constant')\n"
|
| 25 |
+
" drop_duplicates — no column needed\n"
|
| 26 |
+
" drop_outliers — column (numeric); uses IQR method\n"
|
| 27 |
+
" fix_format — column: 'country' | 'signup_date'\n"
|
| 28 |
+
" fix_dtype — column + params.dtype ('float'|'int'|'str')\n\n"
|
| 29 |
+
"Example actions:\n"
|
| 30 |
+
' {"operation": "fill_missing", "column": "age", "params": {"strategy": "median"}}\n'
|
| 31 |
+
' {"operation": "drop_duplicates"}\n'
|
| 32 |
+
' {"operation": "drop_outliers", "column": "purchase_amount"}\n'
|
| 33 |
+
' {"operation": "fix_format", "column": "signup_date"}\n'
|
| 34 |
+
' {"operation": "fix_format", "column": "country"}'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
| 38 |
+
VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load():
|
| 42 |
+
dirty, clean = generate_task3_datasets()
|
| 43 |
+
orig_nulls = int(dirty.isnull().sum().sum())
|
| 44 |
+
orig_dupes = len(dirty) - len(dirty.drop_duplicates())
|
| 45 |
+
|
| 46 |
+
# Outlier baseline: count rows where purchase_amount > Q3 + 3*IQR
|
| 47 |
+
pa = dirty["purchase_amount"].dropna()
|
| 48 |
+
q1, q3 = pa.quantile(0.25), pa.quantile(0.75)
|
| 49 |
+
iqr = q3 - q1
|
| 50 |
+
orig_outliers = int((pa > q3 + 3 * iqr).sum())
|
| 51 |
+
|
| 52 |
+
orig_country_issues = int((~dirty["country"].isin(VALID_COUNTRIES) &
|
| 53 |
+
dirty["country"].notna()).sum())
|
| 54 |
+
orig_date_issues = int((~dirty["signup_date"].apply(
|
| 55 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 56 |
+
)).sum())
|
| 57 |
+
|
| 58 |
+
meta = {
|
| 59 |
+
"orig_nulls": orig_nulls,
|
| 60 |
+
"orig_dupes": orig_dupes,
|
| 61 |
+
"orig_outliers": max(orig_outliers, 1),
|
| 62 |
+
"orig_country_issues": max(orig_country_issues, 1),
|
| 63 |
+
"orig_date_issues": max(orig_date_issues, 1),
|
| 64 |
+
"q1": q1, "q3": q3, "iqr": iqr,
|
| 65 |
+
}
|
| 66 |
+
return dirty.copy(), clean, meta
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def score(current_df, meta: dict) -> float:
|
| 70 |
+
remaining_nulls = int(current_df.isnull().sum().sum())
|
| 71 |
+
remaining_dupes = len(current_df) - len(current_df.drop_duplicates())
|
| 72 |
+
|
| 73 |
+
pa = current_df["purchase_amount"].dropna()
|
| 74 |
+
remaining_outliers = int((pa > meta["q3"] + 3 * meta["iqr"]).sum())
|
| 75 |
+
|
| 76 |
+
remaining_country = int((~current_df["country"].isin(VALID_COUNTRIES) &
|
| 77 |
+
current_df["country"].notna()).sum())
|
| 78 |
+
remaining_dates = int((~current_df["signup_date"].apply(
|
| 79 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 80 |
+
)).sum())
|
| 81 |
+
|
| 82 |
+
null_score = 1.0 - remaining_nulls / max(meta["orig_nulls"], 1)
|
| 83 |
+
dupe_score = 1.0 - remaining_dupes / max(meta["orig_dupes"], 1)
|
| 84 |
+
outlier_score = 1.0 - remaining_outliers / meta["orig_outliers"]
|
| 85 |
+
country_score = 1.0 - remaining_country / meta["orig_country_issues"]
|
| 86 |
+
date_score = 1.0 - remaining_dates / meta["orig_date_issues"]
|
| 87 |
+
|
| 88 |
+
combined = 0.25 * null_score + 0.20 * dupe_score + 0.20 * outlier_score \
|
| 89 |
+
+ 0.175 * country_score + 0.175 * date_score
|
| 90 |
+
return round(max(0.0, min(1.0, combined)), 4)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def count_errors(current_df, meta: dict) -> int:
|
| 94 |
+
remaining_nulls = int(current_df.isnull().sum().sum())
|
| 95 |
+
remaining_dupes = len(current_df) - len(current_df.drop_duplicates())
|
| 96 |
+
pa = current_df["purchase_amount"].dropna()
|
| 97 |
+
remaining_outliers = int((pa > meta["q3"] + 3 * meta["iqr"]).sum())
|
| 98 |
+
remaining_country = int((~current_df["country"].isin(VALID_COUNTRIES) &
|
| 99 |
+
current_df["country"].notna()).sum())
|
| 100 |
+
remaining_dates = int((~current_df["signup_date"].apply(
|
| 101 |
+
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
|
| 102 |
+
)).sum())
|
| 103 |
+
return remaining_nulls + remaining_dupes + remaining_outliers + \
|
| 104 |
+
remaining_country + remaining_dates
|