CodeKnightDebjit commited on
Commit
d627dc7
·
verified ·
1 Parent(s): 3feb717

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=data_cleaning_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,255 @@
1
  ---
2
- title: Data Cleaning Env
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Data Cleaning Env Environment Server
3
+ emoji: 🎹
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # Data Cleaning Env Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the Data Cleaning Env environment is through the `DataCleaningEnv` class:
21
+
22
+ ```python
23
+ from data_cleaning_env import CleanAction, DataCleaningEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ data_cleaning_envenv = DataCleaningEnv.from_docker_image("data_cleaning_env-env:latest")
28
+
29
+ # Reset
30
+ result = data_cleaning_envenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = data_cleaning_envenv.step(CleanAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" → Echoed: '{result.observation.echoed_message}'")
40
+ print(f" → Length: {result.observation.message_length}")
41
+ print(f" → Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ data_cleaning_envenv.close()
46
+ ```
47
+
48
+ That's it! The `DataCleaningEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t data_cleaning_env-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
+
119
+ ## Environment Details
120
+
121
+ ### Action
122
+ **CleanAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
+
125
+ ### Observation
126
+ **CleanAction**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length × 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
+
133
+ ### Reward
134
+ The reward is calculated as: `message_length × 0.1`
135
+ - "Hi" → reward: 0.2
136
+ - "Hello, World!" → reward: 1.3
137
+ - Empty message → reward: 0.0
138
+
139
+ ## Advanced Usage
140
+
141
+ ### Connecting to an Existing Server
142
+
143
+ If you already have a Data Cleaning Env environment server running, you can connect directly:
144
+
145
+ ```python
146
+ from data_cleaning_env import DataCleaningEnv
147
+
148
+ # Connect to existing server
149
+ data_cleaning_envenv = DataCleaningEnv(base_url="<ENV_HTTP_URL_HERE>")
150
+
151
+ # Use as normal
152
+ result = data_cleaning_envenv.reset()
153
+ result = data_cleaning_envenv.step(CleanAction(message="Hello!"))
154
+ ```
155
+
156
+ Note: When connecting to an existing server, `data_cleaning_envenv.close()` will NOT stop the server.
157
+
158
+ ### Using the Context Manager
159
+
160
+ The client supports context manager usage for automatic connection management:
161
+
162
+ ```python
163
+ from data_cleaning_env import CleanAction, DataCleaningEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with DataCleaningEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(CleanAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
+ ```
174
+
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
+
180
+ ### Concurrent WebSocket Sessions
181
+
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
184
+
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ DataCleaningEnvironment, # Pass class, not instance
189
+ CleanAction,
190
+ CleanAction,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
193
+ ```
194
+
195
+ Then multiple clients can connect simultaneously:
196
+
197
+ ```python
198
+ from data_cleaning_env import CleanAction, DataCleaningEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with DataCleaningEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(CleanAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
+ ```
212
+
213
+ ## Development & Testing
214
+
215
+ ### Direct Environment Testing
216
+
217
+ Test the environment logic directly without starting the HTTP server:
218
+
219
+ ```bash
220
+ # From the server directory
221
+ python3 server/data_cleaning_env_environment.py
222
+ ```
223
+
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
+
230
+ ### Running Locally
231
+
232
+ Run the server locally for development:
233
+
234
+ ```bash
235
+ uvicorn server.app:app --reload
236
+ ```
237
+
238
+ ## Project Structure
239
+
240
+ ```
241
+ data_cleaning_env/
242
+ ├── .dockerignore # Docker build exclusions
243
+ ├── __init__.py # Module exports
244
+ ├── README.md # This file
245
+ ├── openenv.yaml # OpenEnv manifest
246
+ ├── pyproject.toml # Project metadata and dependencies
247
+ ├── uv.lock # Locked dependencies (generated)
248
+ ├── client.py # DataCleaningEnv client
249
+ ├── models.py # Action and Observation models
250
+ └── server/
251
+ ├── __init__.py # Server module exports
252
+ ├── data_cleaning_env_environment.py # Core environment logic
253
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ └── Dockerfile # Container image definition
255
+ ```
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Data Cleaning Env Environment."""
8
+
9
+ from .client import DataCleaningEnv
10
+ from .models import CleanAction, CleanObservation
11
+
12
+ __all__ = [
13
+ "CleanAction",
14
+ "CleanObservation",
15
+ "DataCleaningEnv",
16
+ ]
client.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ client.py
3
+ ---------
4
+ DataCleaningEnv — the typed WebSocket client for the data cleaning pipeline.
5
+
6
+ This module contains exactly one public class: ``DataCleaningEnv``.
7
+ It extends ``EnvClient`` from OpenEnv core and implements the three abstract
8
+ translation methods that bridge Python objects and the server's JSON wire format:
9
+
10
+ _step_payload(action) CleanAction → dict (outbound)
11
+ _parse_result(payload) dict → StepResult[CleanObservation] (inbound)
12
+ _parse_state(payload) dict → CleanState (inbound)
13
+
14
+ Everything else — WebSocket lifecycle, connect/disconnect, async context
15
+ manager, the `.sync()` wrapper — is handled by the base class.
16
+
17
+ Usage (async)
18
+ -------------
19
+ import asyncio
20
+ from data_cleaning_env.client import DataCleaningEnv
21
+ from data_cleaning_env.models import CleanAction
22
+
23
+ async def main():
24
+ async with DataCleaningEnv(base_url="http://localhost:8000") as env:
25
+ result = await env.reset(task_id="easy")
26
+ print(result.observation.schema_hint)
27
+
28
+ result = await env.set_value(row_index=3, column="price", value="29.99")
29
+ print(result.reward, result.observation.current_score)
30
+
31
+ result = await env.done()
32
+
33
+ asyncio.run(main())
34
+
35
+ Usage (sync wrapper)
36
+ --------------------
37
+ env = DataCleaningEnv(base_url="http://localhost:8000").sync()
38
+ with env:
39
+ result = env.reset(task_id="medium")
40
+ result = env.fill_missing(column="amount", fill_strategy="median")
41
+ result = env.done()
42
+ """
43
+
44
+ from __future__ import annotations
45
+
46
+ from typing import Any, Optional
47
+
48
+ # ── OpenEnv core imports ──────────────────────────────────────────────────────
49
+ try:
50
+ from openenv.core.client_types import StepResult
51
+ from openenv.core.env_client import EnvClient
52
+ except ImportError:
53
+ from openenv.core.client_types import StepResult # type: ignore[no-redef]
54
+ from openenv.core.env_client import EnvClient # type: ignore[no-redef]
55
+
56
+ # ── Local model imports (try relative then absolute) ──────────────────────────
57
+ try:
58
+ from .models import (
59
+ CleanAction,
60
+ CleanObservation,
61
+ CleanState,
62
+ MAX_STEPS,
63
+ DONE_THRESHOLD,
64
+ )
65
+ except ImportError:
66
+ from models import ( # type: ignore[no-redef]
67
+ CleanAction,
68
+ CleanObservation,
69
+ CleanState,
70
+ MAX_STEPS,
71
+ DONE_THRESHOLD,
72
+ )
73
+
74
+
75
+ class DataCleaningEnv(EnvClient[CleanAction, CleanObservation, CleanState]):
76
+ """
77
+ Async WebSocket client for the Data Cleaning Pipeline environment.
78
+
79
+ Connects to a running ``DataCleaningEnvironment`` server and exposes the
80
+ standard OpenEnv interface (``reset``, ``step``, ``state``) plus typed
81
+ convenience helpers for each command.
82
+
83
+ All methods are async. For synchronous use, call ``.sync()`` to get a
84
+ ``SyncEnvClient`` wrapper:
85
+
86
+ with DataCleaningEnv(base_url="http://localhost:8000").sync() as env:
87
+ result = env.reset(task_id="easy")
88
+ result = env.set_value(row_index=0, column="price", value="9.99")
89
+
90
+ Connecting to different backends
91
+ ---------------------------------
92
+ Local dev server (after ``openenv serve``):
93
+ env = DataCleaningEnv(base_url="http://localhost:8000")
94
+
95
+ Local Docker image (after ``openenv build``):
96
+ env = await DataCleaningEnv.from_docker_image("data-cleaning-env:latest")
97
+
98
+ Hugging Face Space (after ``openenv push``):
99
+ env = await DataCleaningEnv.from_env("your-org/data-cleaning-env")
100
+ """
101
+
102
+ # ─────────────────────────────────────────────────────────────────────────
103
+ # Abstract method implementations — the three translation methods
104
+ # ─────────────────────────────────────────────────────────────────────────
105
+
106
+ def _step_payload(self, action: CleanAction) -> dict[str, Any]:
107
+ """
108
+ Serialise a CleanAction to the JSON dict the server expects.
109
+
110
+ The server's ``step()`` endpoint receives this dict, validates it
111
+ against ``CleanAction``, and dispatches to the correct handler.
112
+
113
+ We use ``model_dump(exclude_none=True)`` to omit fields the agent
114
+ left as ``None`` — this keeps the wire message minimal and avoids
115
+ triggering Pydantic's ``extra="forbid"`` validator on the server side
116
+ for fields that weren't set.
117
+ """
118
+ return action.model_dump(exclude_none=True)
119
+
120
+ def _parse_result(self, payload: dict[str, Any]) -> StepResult[CleanObservation]:
121
+ """
122
+ Parse the server's step/reset response into a ``StepResult``.
123
+
124
+ Wire format (what the server sends back):
125
+ ::
126
+ {
127
+ "observation": {
128
+ "done": false,
129
+ "reward": -0.005,
130
+ "metadata": {},
131
+ "task_id": "easy",
132
+ "schema_hint": "Sales orders...",
133
+ "initial_dirty_cells": 29,
134
+ "dirty_csv": "row_index,order_id,...\\n0,1001,...",
135
+ "current_score": 0.9550,
136
+ "issues_remaining": 18,
137
+ "step_number": 1,
138
+ "max_steps": 40,
139
+ "last_action_success": true,
140
+ "last_action_error": null
141
+ },
142
+ "reward": -0.005,
143
+ "done": false
144
+ }
145
+
146
+ Note: ``reward`` and ``done`` appear both at the top level (for
147
+ convenience) and inside ``observation`` (because ``Observation`` base
148
+ carries them). We use the top-level copies for ``StepResult`` so the
149
+ caller doesn't have to dig into the observation.
150
+ """
151
+ obs_data = payload.get("observation", {})
152
+
153
+ observation = CleanObservation(
154
+ # ── inherited from Observation base ──────────────────────────────
155
+ done=payload.get("done", obs_data.get("done", False)),
156
+ reward=payload.get("reward", obs_data.get("reward")),
157
+ metadata=obs_data.get("metadata", {}),
158
+
159
+ # ── task context (constant for the episode) ───────────────────────
160
+ task_id=obs_data["task_id"],
161
+ schema_hint=obs_data["schema_hint"],
162
+ initial_dirty_cells=obs_data["initial_dirty_cells"],
163
+
164
+ # ── per-step state ────────────────────────────────────────────────
165
+ dirty_csv=obs_data["dirty_csv"],
166
+ current_score=obs_data.get("current_score", 0.0),
167
+ issues_remaining=obs_data.get("issues_remaining", 0),
168
+ step_number=obs_data.get("step_number", 0),
169
+ max_steps=obs_data["max_steps"],
170
+
171
+ # ── last-action feedback ──────────────────────────────────────────
172
+ last_action_success=obs_data.get("last_action_success", True),
173
+ last_action_error=obs_data.get("last_action_error"),
174
+ )
175
+
176
+ return StepResult(
177
+ observation=observation,
178
+ reward=payload.get("reward"),
179
+ done=payload.get("done", False),
180
+ )
181
+
182
+ def _parse_state(self, payload: dict[str, Any]) -> CleanState:
183
+ """
184
+ Parse the server's state response into a ``CleanState``.
185
+
186
+ The server serialises ``CleanState`` via Pydantic's ``model_dump()``,
187
+ so the wire keys match our field names exactly. We use ``.get()``
188
+ with sensible defaults everywhere so a partially-initialised state
189
+ (e.g. before the first reset) doesn't crash the client.
190
+ """
191
+ return CleanState(
192
+ # ── inherited from State base ─────────────────────────────────────
193
+ episode_id=payload.get("episode_id"),
194
+ step_count=payload.get("step_count", 0),
195
+
196
+ # ── task identity ─────────────────────────────────────────────────
197
+ task_id=payload.get("task_id", "easy"),
198
+
199
+ # ── DataFrame snapshots ───────────────────────────────────────────
200
+ dirty_csv_snapshot=payload.get("dirty_csv_snapshot", ""),
201
+ clean_csv_snapshot=payload.get("clean_csv_snapshot", ""),
202
+
203
+ # ── scoring ───────────────────────────────────────────────────────
204
+ initial_dirty_cells=payload.get("initial_dirty_cells", 0),
205
+ current_score=payload.get("current_score", 0.0),
206
+ previous_score=payload.get("previous_score", 0.0),
207
+
208
+ # ── grader metadata ───────────────────────────────────────────────
209
+ task_metadata=payload.get("task_metadata", {}),
210
+
211
+ # ── schema ────────────────────────────────────────────────────────
212
+ schema_hint=payload.get("schema_hint", ""),
213
+
214
+ # ── step budget ──────────────────────────────��────────────────────
215
+ max_steps=payload.get("max_steps", 40),
216
+ )
217
+
218
+ # ─────────────────────────────────────────────────────────────────────────
219
+ # Typed convenience helpers — one per CleanAction command
220
+ # ─────────────────────────────────────────────────────────────────────────
221
+ # These methods exist purely for ergonomics: they let callers write
222
+ #
223
+ # await env.set_value(row_index=3, column="price", value="29.99")
224
+ #
225
+ # instead of the more verbose:
226
+ #
227
+ # await env.step(CleanAction(
228
+ # command="SET_VALUE", row_index=3, column="price", value="29.99"
229
+ # ))
230
+ #
231
+ # The baseline inference script can use either form.
232
+
233
+ async def set_value(
234
+ self,
235
+ row_index: int,
236
+ column: str,
237
+ value: str,
238
+ ) -> StepResult[CleanObservation]:
239
+ """Fix a single cell. ``value`` is always passed as a string; the
240
+ server casts it to the column's target dtype automatically."""
241
+ return await self.step(
242
+ CleanAction(
243
+ command="SET_VALUE",
244
+ row_index=row_index,
245
+ column=column,
246
+ value=value,
247
+ )
248
+ )
249
+
250
+ async def drop_row(self, row_index: int) -> StepResult[CleanObservation]:
251
+ """Remove an entire row (e.g. a true outlier in the medium task)."""
252
+ return await self.step(
253
+ CleanAction(command="DROP_ROW", row_index=row_index)
254
+ )
255
+
256
+ async def standardize_col(self, column: str) -> StepResult[CleanObservation]:
257
+ """Normalise a whole column's format.
258
+
259
+ The server auto-detects what to do:
260
+ - Date columns → parse any format, reformat as ``YYYY-MM-DD``
261
+ - Numeric columns → coerce to float/int, drop unit strings
262
+ - String columns → strip leading/trailing whitespace
263
+ """
264
+ return await self.step(
265
+ CleanAction(command="STANDARDIZE_COL", column=column)
266
+ )
267
+
268
+ async def fill_missing(
269
+ self,
270
+ column: str,
271
+ fill_strategy: str,
272
+ ) -> StepResult[CleanObservation]:
273
+ """Fill ``NaN`` values in ``column``.
274
+
275
+ Args:
276
+ column: Column name to fill.
277
+ fill_strategy: One of ``"mean"``, ``"median"``, ``"mode"``, ``"drop"``.
278
+ ``"drop"`` removes rows where the column is ``NaN``.
279
+ """
280
+ return await self.step(
281
+ CleanAction(
282
+ command="FILL_MISSING",
283
+ column=column,
284
+ fill_strategy=fill_strategy,
285
+ )
286
+ )
287
+
288
+ async def done(self) -> StepResult[CleanObservation]:
289
+ """Signal that the agent believes the CSV is clean.
290
+
291
+ This ends the episode immediately. If the current score is below
292
+ ``EARLY_DONE_THRESHOLD`` (0.60) a penalty of -0.20 is applied.
293
+ """
294
+ return await self.step(CleanAction(command="DONE"))
295
+
296
+ # ─────────────────────────────────────────────────────────────────────────
297
+ # Introspection helpers
298
+ # ─────────────────────────────────────────────────────────────────────────
299
+
300
+ async def current_score(self) -> float:
301
+ """Return the grader score from the last step (0.0–1.0)."""
302
+ st = await self.state()
303
+ return st.current_score
304
+
305
+ async def task_id(self) -> str:
306
+ """Return the active task ID (``"easy"``, ``"medium"``, or ``"hard"``)."""
307
+ st = await self.state()
308
+ return st.task_id
309
+
310
+ async def steps_remaining(self) -> int:
311
+ """Return the number of steps left before forced termination."""
312
+ st = await self.state()
313
+ return max(0, st.max_steps - st.step_count)
314
+
315
+ async def is_solved(self) -> bool:
316
+ """Return ``True`` if the current score meets the task's done threshold."""
317
+ st = await self.state()
318
+ threshold = DONE_THRESHOLD.get(st.task_id, 0.95)
319
+ return st.current_score >= threshold
dataset_factory.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset_factory.py
3
+ ------------------
4
+ Generates (dirty_df, clean_df, metadata) triples for all 3 tasks.
5
+
6
+ Key design decisions:
7
+ - Fixed random seeds per task → reproducible grader scores
8
+ - clean_df is ALWAYS generated first, then dirt is injected
9
+ - metadata carries ground-truth info the grader needs (e.g. which
10
+ rows are real outliers vs valid extremes in Task 2)
11
+ - No external files needed — everything is generated in memory
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import random
18
+ import string
19
+ from dataclasses import dataclass, field
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+
25
+ # ── Reproducible seeds ────────────────────────────────────────────────────────
26
+
27
+ SEEDS = {
28
+ "easy": 42,
29
+ "medium": 137,
30
+ "hard": 999,
31
+ }
32
+
33
+ # ── Return type ───────────────────────────────────────────────────────────────
34
+
35
+ @dataclass
36
+ class TaskDataset:
37
+ """Everything the environment and grader need for one episode."""
38
+ task_id: str
39
+ dirty_df: pd.DataFrame
40
+ clean_df: pd.DataFrame
41
+ schema_hint: str # plain-English schema description
42
+ total_dirty_cells: int # how many cells differ at episode start
43
+ metadata: dict[str, Any] = field(default_factory=dict)
44
+ # metadata keys used by graders:
45
+ # "outlier_rows" (Task 2) — list of row indices that ARE true outliers
46
+ # "valid_extreme_rows" (Task 2) — valid rows that look extreme but must stay
47
+ # "canonical_columns" (Task 3) — {alias: canonical_name} mapping
48
+ # "duplicate_row_ids" (Task 3) — list of (original_idx, duplicate_idx) pairs
49
+
50
+
51
+ # ── Public API ────────────────────────────────────────────────────────────────
52
+
53
+ def make_dataset(task_id: str) -> TaskDataset:
54
+ """Entry point. Call this from the environment's reset()."""
55
+ if task_id == "easy":
56
+ return _make_easy()
57
+ elif task_id == "medium":
58
+ return _make_medium()
59
+ elif task_id == "hard":
60
+ return _make_hard()
61
+ else:
62
+ raise ValueError(f"Unknown task_id: {task_id!r}. Must be easy/medium/hard.")
63
+
64
+
65
+ def count_dirty_cells(dirty_df: pd.DataFrame, clean_df: pd.DataFrame) -> int:
66
+ """Number of cells that differ between dirty and clean DataFrames."""
67
+ # Align on same dtypes for comparison
68
+ d = dirty_df.astype(str).reset_index(drop=True)
69
+ c = clean_df.astype(str).reset_index(drop=True)
70
+ return int((d != c).sum().sum())
71
+
72
+
73
+ # ── Task 1: easy ─────────────────────────────────────────────────────────────
74
+ #
75
+ # 50-row sales CSV.
76
+ # Clean schema:
77
+ # order_id (int), customer (str), product (str), category (str),
78
+ # price (float, 2dp), quantity (int), order_date (YYYY-MM-DD),
79
+ # region (str)
80
+ #
81
+ # Injected issues (29 dirty cells total):
82
+ # • 10 wrong-type cells — numeric column contains a word
83
+ # • 8 missing values — NaN in various columns
84
+ # • 5 bad dates — future year (2099-xx-xx)
85
+ # • 6 whitespace cells — leading/trailing spaces in string columns
86
+
87
+ def _make_easy() -> TaskDataset:
88
+ rng = random.Random(SEEDS["easy"])
89
+ np_rng = np.random.default_rng(SEEDS["easy"])
90
+
91
+ n = 50
92
+ categories = ["Electronics", "Clothing", "Home", "Sports", "Books"]
93
+ regions = ["North", "South", "East", "West"]
94
+ products = ["Widget A", "Widget B", "Gadget X", "Gadget Y", "Item Z"]
95
+ customers = [f"Customer_{i:03d}" for i in range(1, 31)]
96
+
97
+ # ── Build clean DataFrame ────────────────────────────────────────────────
98
+ clean = pd.DataFrame({
99
+ "order_id": range(1001, 1001 + n),
100
+ "customer": [rng.choice(customers) for _ in range(n)],
101
+ "product": [rng.choice(products) for _ in range(n)],
102
+ "category": [rng.choice(categories) for _ in range(n)],
103
+ "price": np_rng.uniform(5.0, 500.0, n).round(2),
104
+ "quantity": np_rng.integers(1, 20, n),
105
+ "order_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
106
+ "region": [rng.choice(regions) for _ in range(n)],
107
+ })
108
+ clean["price"] = clean["price"].astype(float)
109
+ clean["quantity"] = clean["quantity"].astype(int)
110
+
111
+ # ── Inject dirt ───────────────────────────────────────────────────���──────
112
+ dirty = clean.copy(deep=True).astype(object)
113
+
114
+ injected: set[tuple[int, str]] = set()
115
+
116
+ def pick_fresh(col: str, exclude: set) -> int:
117
+ rows = [r for r in range(n) if (r, col) not in exclude]
118
+ return rng.choice(rows)
119
+
120
+ # 10 wrong-type cells in numeric columns
121
+ bad_words = ["N/A", "unknown", "missing", "null", "TBD", "??", "-", "n/a", "none", "—"]
122
+ for word, col in zip(bad_words, rng.choices(["price", "quantity"], k=10)):
123
+ row = pick_fresh(col, injected)
124
+ dirty.at[row, col] = word
125
+ injected.add((row, col))
126
+
127
+ # 8 missing values in various columns
128
+ missing_cols = rng.choices(["customer", "product", "price", "quantity", "region"], k=8)
129
+ for col in missing_cols:
130
+ row = pick_fresh(col, injected)
131
+ dirty.at[row, col] = np.nan
132
+ injected.add((row, col))
133
+
134
+ # 5 bad dates — far-future year
135
+ bad_date_templates = [
136
+ "2099-01-15", "2099-07-04", "2099-12-31", "2099-03-22", "2099-11-11"
137
+ ]
138
+ for bad_date in bad_date_templates:
139
+ row = pick_fresh("order_date", injected)
140
+ dirty.at[row, "order_date"] = bad_date
141
+ injected.add((row, "order_date"))
142
+
143
+ # 6 whitespace cells in string columns
144
+ ws_cols = rng.choices(["customer", "product", "category", "region"], k=6)
145
+ for col in ws_cols:
146
+ row = pick_fresh(col, injected)
147
+ orig = str(dirty.at[row, col])
148
+ dirty.at[row, col] = f" {orig} "
149
+ injected.add((row, col))
150
+
151
+ dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
152
+
153
+ schema_hint = (
154
+ "Sales orders dataset. Expected columns: "
155
+ "order_id (integer), customer (string, no leading/trailing spaces), "
156
+ "product (string, no spaces), category (one of: Electronics/Clothing/Home/Sports/Books), "
157
+ "price (float, 2 decimal places, no text), "
158
+ "quantity (integer, no text), "
159
+ "order_date (YYYY-MM-DD format, year must be 2023 or 2024), "
160
+ "region (one of: North/South/East/West, no spaces). "
161
+ "No missing values allowed."
162
+ )
163
+
164
+ return TaskDataset(
165
+ task_id="easy",
166
+ dirty_df=dirty,
167
+ clean_df=clean.astype(object),
168
+ schema_hint=schema_hint,
169
+ total_dirty_cells=dirty_cell_count,
170
+ metadata={"injected_cells": list(injected)},
171
+ )
172
+
173
+
174
+ # ── Task 2: medium ────────────────────────────────────────────────────────────
175
+ #
176
+ # 200-row customer transaction CSV.
177
+ # Clean schema:
178
+ # tx_id (int), customer_id (int), amount (float), tx_date (YYYY-MM-DD),
179
+ # category (str), country (str), status (str)
180
+ #
181
+ # Injected issues:
182
+ # • 15 statistical outliers — amount Z-score > 4.0 (should be removed/capped)
183
+ # • 5 valid extremes — genuinely large transactions, must NOT be removed
184
+ # • 12 category typos — slight misspellings
185
+
186
+ def _make_medium() -> TaskDataset:
187
+ rng = random.Random(SEEDS["medium"])
188
+ np_rng = np.random.default_rng(SEEDS["medium"])
189
+
190
+ n = 200
191
+ categories = ["Food", "Electronics", "Travel", "Healthcare", "Entertainment"]
192
+ countries = ["US", "UK", "CA", "AU", "DE"]
193
+ statuses = ["completed", "pending", "refunded"]
194
+
195
+ # ── Build clean base ────────────────────────────────────────────────────
196
+ # Normal transaction amounts: mean $150, sd $60, clipped to [5, 800]
197
+ amounts = np_rng.normal(150, 60, n).clip(5, 800).round(2)
198
+
199
+ clean = pd.DataFrame({
200
+ "tx_id": range(9001, 9001 + n),
201
+ "customer_id": np_rng.integers(1, 501, n),
202
+ "amount": amounts,
203
+ "tx_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
204
+ "category": [rng.choice(categories) for _ in range(n)],
205
+ "country": [rng.choice(countries) for _ in range(n)],
206
+ "status": [rng.choice(statuses) for _ in range(n)],
207
+ })
208
+
209
+ # ── Choose outlier rows (15) — will be injected with extreme amounts ─────
210
+ all_rows = list(range(n))
211
+ outlier_rows: list[int] = rng.sample(all_rows, 15)
212
+ remaining = [r for r in all_rows if r not in outlier_rows]
213
+
214
+ # ── Choose valid extreme rows (5) — large but legitimate ─────────────────
215
+ # These are NOT in outlier_rows; amounts are large (Z > 3) but real
216
+ valid_extreme_rows: list[int] = rng.sample(remaining, 5)
217
+
218
+ # ── Build dirty DataFrame ────────────────────────────────────────────────
219
+ dirty = clean.copy(deep=True).astype(object)
220
+
221
+ # Inject true outliers: very high or very low (Z > 4)
222
+ for row in outlier_rows:
223
+ if rng.random() > 0.3:
224
+ dirty.at[row, "amount"] = round(rng.uniform(5000, 15000), 2) # extreme high
225
+ else:
226
+ dirty.at[row, "amount"] = round(rng.uniform(-500, -10), 2) # negative (impossible)
227
+
228
+ # Inject valid extremes (in clean AND dirty — they stay)
229
+ for row in valid_extreme_rows:
230
+ valid_large = round(rng.uniform(900, 2000), 2)
231
+ clean.at[row, "amount"] = valid_large
232
+ dirty.at[row, "amount"] = valid_large
233
+
234
+ # Inject 12 category typos
235
+ typo_map: dict[str, str] = {
236
+ "Electronics": ["Electrnics", "Electronis", "Electonics"],
237
+ "Food": ["Foood", "Fod", "Fo0d"],
238
+ "Travel": ["Travle", "Trevel", "Travell"],
239
+ "Healthcare": ["Helthcare", "Healtcare", "Heathcare"],
240
+ "Entertainment": ["Entertainmnt", "Entertainmet", "Entertainmen"],
241
+ }
242
+ injected_typo_rows: set[int] = set()
243
+ typo_count = 0
244
+ typo_cells: list[tuple[int, str, str]] = [] # (row, dirty_val, clean_val)
245
+
246
+ for row in rng.sample(remaining, min(12, len(remaining))):
247
+ if typo_count >= 12:
248
+ break
249
+ if row in injected_typo_rows:
250
+ continue
251
+ orig_cat = str(clean.at[row, "category"])
252
+ misspellings = typo_map.get(orig_cat)
253
+ if misspellings:
254
+ bad = rng.choice(misspellings)
255
+ dirty.at[row, "category"] = bad
256
+ typo_cells.append((row, bad, orig_cat))
257
+ injected_typo_rows.add(row)
258
+ typo_count += 1
259
+
260
+ dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
261
+
262
+ schema_hint = (
263
+ "Customer transactions dataset. Expected columns: "
264
+ "tx_id (integer), customer_id (integer 1–500), "
265
+ "amount (float, must be positive; realistic range is $5–$2000; "
266
+ "amounts above $2000 or below $0 are data errors), "
267
+ "tx_date (YYYY-MM-DD), "
268
+ "category (one of: Food/Electronics/Travel/Healthcare/Entertainment — exact spelling), "
269
+ "country (two-letter code: US/UK/CA/AU/DE), "
270
+ "status (one of: completed/pending/refunded). "
271
+ "Note: some large transactions ($900–$2000) are legitimate — do not remove them. "
272
+ "Only remove rows where the amount is clearly erroneous (negative or > $2000)."
273
+ )
274
+
275
+ return TaskDataset(
276
+ task_id="medium",
277
+ dirty_df=dirty,
278
+ clean_df=clean.astype(object),
279
+ schema_hint=schema_hint,
280
+ total_dirty_cells=dirty_cell_count,
281
+ metadata={
282
+ "outlier_rows": outlier_rows,
283
+ "valid_extreme_rows": valid_extreme_rows,
284
+ "typo_cells": typo_cells, # [(row, dirty_val, clean_val)]
285
+ },
286
+ )
287
+
288
+
289
+ # ── Task 3: hard ──────────────────────────────────────────────────────────────
290
+ #
291
+ # 400-row CSV merged from 3 fictional data sources.
292
+ # Each source uses different column names for the same concepts.
293
+ # Issues:
294
+ # • Inconsistent column naming (3 aliases per concept)
295
+ # • Mixed date formats across sources (ISO, US, EU)
296
+ # • 30 duplicate rows (exact and near-duplicate)
297
+ # • No schema documentation — agent must infer canonical form
298
+ #
299
+ # Canonical schema (what the agent must produce):
300
+ # record_id, customer_id, full_name, email, amount,
301
+ # currency, purchase_date (YYYY-MM-DD), product_name, region
302
+
303
+ _CANONICAL_COLS = [
304
+ "record_id", "customer_id", "full_name", "email",
305
+ "amount", "currency", "purchase_date", "product_name", "region",
306
+ ]
307
+
308
+ # Column aliases per source
309
+ _SOURCE_ALIASES = {
310
+ "source_a": {
311
+ "record_id": "record_id",
312
+ "customer_id": "cust_id",
313
+ "full_name": "name",
314
+ "email": "email_address",
315
+ "amount": "sale_amount",
316
+ "currency": "ccy",
317
+ "purchase_date":"date",
318
+ "product_name": "item",
319
+ "region": "territory",
320
+ },
321
+ "source_b": {
322
+ "record_id": "id",
323
+ "customer_id": "customer_id",
324
+ "full_name": "full_name",
325
+ "email": "contact_email",
326
+ "amount": "value",
327
+ "currency": "currency",
328
+ "purchase_date":"purchase_date",
329
+ "product_name": "product",
330
+ "region": "area",
331
+ },
332
+ "source_c": {
333
+ "record_id": "RecordID",
334
+ "customer_id": "CustomerID",
335
+ "full_name": "CustomerName",
336
+ "email": "Email",
337
+ "amount": "Amount",
338
+ "currency": "Currency",
339
+ "purchase_date":"PurchaseDate",
340
+ "product_name": "ProductName",
341
+ "region": "Region",
342
+ },
343
+ }
344
+
345
+ # Date format used by each source
346
+ _SOURCE_DATE_FORMATS = {
347
+ "source_a": "%Y-%m-%d", # ISO: 2023-04-15
348
+ "source_b": "%m/%d/%Y", # US: 04/15/2023
349
+ "source_c": "%d.%m.%Y", # EU: 15.04.2023
350
+ }
351
+
352
+ def _make_hard() -> TaskDataset:
353
+ rng = random.Random(SEEDS["hard"])
354
+ np_rng = np.random.default_rng(SEEDS["hard"])
355
+
356
+ currencies = ["USD", "EUR", "GBP"]
357
+ regions = ["APAC", "EMEA", "AMER", "LATAM"]
358
+ products = [
359
+ "Pro Subscription", "Enterprise License", "Support Package",
360
+ "Training Course", "Hardware Bundle", "Consulting Day",
361
+ ]
362
+
363
+ # Helper: generate a block of rows for one source
364
+ def _source_block(source: str, n: int, id_start: int) -> pd.DataFrame:
365
+ aliases = _SOURCE_ALIASES[source]
366
+ date_fmt = _SOURCE_DATE_FORMATS[source]
367
+ cust_ids = np_rng.integers(2001, 3001, n)
368
+ amounts = np_rng.uniform(100, 5000, n).round(2)
369
+ iso_dates = _random_dates(np_rng, n, "2022-01-01", "2024-06-30")
370
+
371
+ # Format dates in source-specific format
372
+ formatted_dates = [
373
+ pd.to_datetime(d).strftime(date_fmt)
374
+ for d in iso_dates
375
+ ]
376
+
377
+ names = [_random_name(rng) for _ in range(n)]
378
+ emails = [_name_to_email(nm) for nm in names]
379
+
380
+ data = {
381
+ aliases["record_id"]: range(id_start, id_start + n),
382
+ aliases["customer_id"]: cust_ids.tolist(),
383
+ aliases["full_name"]: names,
384
+ aliases["email"]: emails,
385
+ aliases["amount"]: amounts.tolist(),
386
+ aliases["currency"]: [rng.choice(currencies) for _ in range(n)],
387
+ aliases["purchase_date"]: formatted_dates,
388
+ aliases["product_name"]: [rng.choice(products) for _ in range(n)],
389
+ aliases["region"]: [rng.choice(regions) for _ in range(n)],
390
+ }
391
+ return pd.DataFrame(data)
392
+
393
+ # Three sources, ~133 rows each (total ~400)
394
+ block_a = _source_block("source_a", 134, id_start=1)
395
+ block_b = _source_block("source_b", 133, id_start=135)
396
+ block_c = _source_block("source_c", 133, id_start=268)
397
+
398
+ # ── Canonical (clean) dataframe ─────────────────────────────────────────
399
+ def _to_canonical(df: pd.DataFrame, source: str) -> pd.DataFrame:
400
+ rev = {v: k for k, v in _SOURCE_ALIASES[source].items()}
401
+ renamed = df.rename(columns=rev)
402
+ # Normalise date to YYYY-MM-DD
403
+ renamed["purchase_date"] = pd.to_datetime(
404
+ renamed["purchase_date"],
405
+ format=_SOURCE_DATE_FORMATS[source],
406
+ ).dt.strftime("%Y-%m-%d")
407
+ return renamed[_CANONICAL_COLS]
408
+
409
+ clean_a = _to_canonical(block_a, "source_a")
410
+ clean_b = _to_canonical(block_b, "source_b")
411
+ clean_c = _to_canonical(block_c, "source_c")
412
+ clean = pd.concat([clean_a, clean_b, clean_c], ignore_index=True)
413
+ clean["record_id"] = range(1, len(clean) + 1)
414
+
415
+ # ── Dirty dataframe = concat of raw source blocks ────────────────────────
416
+ # (columns are still in aliased form, dates in source-specific format)
417
+ dirty = pd.concat([block_a, block_b, block_c], ignore_index=True)
418
+
419
+ # ── Inject 30 duplicate rows ─────────────────────────────────────────────
420
+ n_clean = len(dirty)
421
+ sampled_orig = rng.sample(range(n_clean), 30)
422
+ duplicate_rows_to_inject: list[pd.DataFrame] = []
423
+ duplicate_pairs: list[tuple[int, int]] = []
424
+
425
+ for orig_idx in sampled_orig:
426
+ dup = dirty.iloc[[orig_idx]].copy()
427
+ # Near-duplicate: 40% chance of a minor field change
428
+ if rng.random() < 0.4:
429
+ # Slightly alter the amount (±1%)
430
+ col_amount = list(_SOURCE_ALIASES["source_a"].values())[4] # 'sale_amount'
431
+ # Find which column name is 'amount-like' in this row's source
432
+ # Since we concat all sources, each row might have NaN in other sources' cols.
433
+ # Simpler: just modify the raw value in the only non-null amount column.
434
+ for amt_col in ["sale_amount", "value", "Amount"]:
435
+ if amt_col in dup.columns and pd.notna(dup.iloc[0].get(amt_col)):
436
+ old_val = dup.at[dup.index[0], amt_col]
437
+ dup.at[dup.index[0], amt_col] = round(float(old_val) * rng.uniform(0.99, 1.01), 2)
438
+ break
439
+ duplicate_rows_to_inject.append(dup)
440
+ duplicate_pairs.append((orig_idx, n_clean + len(duplicate_pairs)))
441
+
442
+ dirty = pd.concat([dirty] + duplicate_rows_to_inject, ignore_index=True)
443
+
444
+ # Shuffle so duplicates aren't obviously at the bottom
445
+ dirty = dirty.sample(frac=1, random_state=SEEDS["hard"]).reset_index(drop=True)
446
+
447
+ # Build canonical alias lookup for grader
448
+ canonical_lookup: dict[str, str] = {}
449
+ for source, aliases in _SOURCE_ALIASES.items():
450
+ for canonical, alias in aliases.items():
451
+ canonical_lookup[alias] = canonical
452
+
453
+ dirty_cell_count = len(dirty) * len(_CANONICAL_COLS) # hard task: whole-df scope
454
+
455
+ schema_hint = (
456
+ "Merged dataset from 3 sources with inconsistent schemas. "
457
+ "Your goal is to produce a single clean DataFrame with these canonical columns: "
458
+ "record_id (integer, unique), customer_id (integer), full_name (string), "
459
+ "email (string), amount (float), currency (one of: USD/EUR/GBP), "
460
+ "purchase_date (YYYY-MM-DD), product_name (string), region (one of: APAC/EMEA/AMER/LATAM). "
461
+ "Column names in the raw data vary by source (e.g. 'cust_id', 'customer_id', 'CustomerID' "
462
+ "all mean customer_id). Date formats also vary (ISO, US MM/DD/YYYY, EU DD.MM.YYYY). "
463
+ "There are also ~30 duplicate rows (some exact, some near-duplicate). "
464
+ "Remove duplicates, normalise all column names and date formats."
465
+ )
466
+
467
+ return TaskDataset(
468
+ task_id="hard",
469
+ dirty_df=dirty,
470
+ clean_df=clean.astype(object),
471
+ schema_hint=schema_hint,
472
+ total_dirty_cells=dirty_cell_count,
473
+ metadata={
474
+ "canonical_columns": _CANONICAL_COLS,
475
+ "canonical_lookup": canonical_lookup, # alias → canonical name
476
+ "source_aliases": _SOURCE_ALIASES,
477
+ "source_date_formats": _SOURCE_DATE_FORMATS,
478
+ "duplicate_pairs": duplicate_pairs, # (original_idx, dup_idx) in pre-shuffle dirty
479
+ "n_clean_rows": len(clean),
480
+ },
481
+ )
482
+
483
+
484
+ # ── Internal helpers ──────────────────────────────────────────────────────────
485
+
486
+ def _random_dates(
487
+ rng: np.random.Generator,
488
+ n: int,
489
+ start: str,
490
+ end: str,
491
+ ) -> list[str]:
492
+ """Generate n random ISO-format date strings between start and end."""
493
+ start_ts = pd.Timestamp(start)
494
+ end_ts = pd.Timestamp(end)
495
+ delta_days = (end_ts - start_ts).days
496
+ offsets = rng.integers(0, delta_days, n)
497
+ return [
498
+ (start_ts + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
499
+ for d in offsets
500
+ ]
501
+
502
+
503
+ _FIRST_NAMES = [
504
+ "Alice", "Bob", "Carol", "David", "Eva", "Frank", "Grace", "Henry",
505
+ "Iris", "Jack", "Karen", "Leo", "Mia", "Nathan", "Olivia", "Paul",
506
+ "Quinn", "Rosa", "Sam", "Tara", "Uma", "Victor", "Wendy", "Xavier",
507
+ "Yuki", "Zara",
508
+ ]
509
+
510
+ _LAST_NAMES = [
511
+ "Smith", "Jones", "Williams", "Brown", "Taylor", "Davies", "Evans",
512
+ "Wilson", "Thomas", "Roberts", "Johnson", "Lee", "Martin", "Garcia",
513
+ "Martinez", "Anderson", "Thompson", "White", "Harris", "Clark",
514
+ ]
515
+
516
+
517
+ def _random_name(rng: random.Random) -> str:
518
+ return f"{rng.choice(_FIRST_NAMES)} {rng.choice(_LAST_NAMES)}"
519
+
520
+
521
+ def _name_to_email(name: str) -> str:
522
+ first, last = name.lower().split()
523
+ domains = ["example.com", "mail.com", "inbox.net", "corp.io"]
524
+ return f"{first}.{last}@{domains[hash(name) % len(domains)]}"
525
+
526
+
527
+ # ── Smoke test ────────────────────────────────────────────────────────────────
528
+
529
+ if __name__ == "__main__":
530
+ for task_id in ("easy", "medium", "hard"):
531
+ ds = make_dataset(task_id)
532
+ print(f"\n{'─'*60}")
533
+ print(f"Task: {task_id.upper()}")
534
+ print(f" dirty shape : {ds.dirty_df.shape}")
535
+ print(f" clean shape : {ds.clean_df.shape}")
536
+ print(f" dirty cells : {ds.total_dirty_cells}")
537
+ print(f" schema hint : {ds.schema_hint[:80]}…")
538
+ print(f" metadata keys: {list(ds.metadata.keys())}")
539
+ if task_id == "easy":
540
+ print(f"\n Sample dirty rows (price/quantity col):")
541
+ mask = ds.dirty_df["price"].astype(str).str.contains(
542
+ r"[a-zA-Z]|nan", na=True
543
+ )
544
+ print(ds.dirty_df[mask][["order_id","price","quantity"]].head(3).to_string(index=False))
545
+ if task_id == "medium":
546
+ print(f"\n Outlier rows (first 5): {ds.metadata['outlier_rows'][:5]}")
547
+ print(f" Valid extreme rows: {ds.metadata['valid_extreme_rows']}")
548
+ if task_id == "hard":
549
+ print(f"\n Raw column names: {list(ds.dirty_df.columns)}")
550
+ print(f" Duplicate pairs (first 3): {ds.metadata['duplicate_pairs'][:3]}")
graders.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ graders.py
3
+ ----------
4
+ Deterministic graders for all three tasks.
5
+
6
+ Each grader receives the agent's current working DataFrame and the
7
+ TaskDataset produced by dataset_factory, and returns a GradeResult
8
+ with a scalar score in [0.0, 1.0] plus a human-readable breakdown.
9
+
10
+ Public API
11
+ ----------
12
+ grade(task_id, agent_df, dataset) -> GradeResult
13
+
14
+ Dispatches to the correct grader. Call this from step().
15
+
16
+ GradeResult
17
+ .score float 0.0–1.0 (the number that feeds the reward)
18
+ .breakdown dict (sub-scores, useful for logging/debugging)
19
+ .issues_remaining int (how many cells still need fixing)
20
+ .detail str (one-line human summary)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import re
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional, Tuple
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+
32
+
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+ # Return type
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+
37
+ @dataclass
38
+ class GradeResult:
39
+ score: float # 0.0 – 1.0, fed into reward
40
+ breakdown: Dict[str, float] = field(default_factory=dict)
41
+ issues_remaining: int = 0
42
+ detail: str = ""
43
+
44
+ def __post_init__(self) -> None:
45
+ self.score = round(float(np.clip(self.score, 0.0, 1.0)), 4)
46
+
47
+
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+ # Public dispatcher
50
+ # ─────────────────────────────────────────────────────────────────────────────
51
+
52
+ def grade(
53
+ task_id: str,
54
+ agent_df: pd.DataFrame,
55
+ clean_df: pd.DataFrame,
56
+ metadata: Dict[str, Any],
57
+ initial_dirty_cells: int,
58
+ ) -> GradeResult:
59
+ """
60
+ Route to the correct grader and return a GradeResult.
61
+
62
+ Parameters
63
+ ----------
64
+ task_id
65
+ One of "easy", "medium", "hard".
66
+ agent_df
67
+ The agent's current working DataFrame (may still be dirty).
68
+ clean_df
69
+ Ground-truth clean DataFrame from TaskDataset.
70
+ metadata
71
+ TaskDataset.metadata dict (grader-specific ground truth).
72
+ initial_dirty_cells
73
+ Dirty cell count at episode start; used to compute issues_remaining
74
+ for easy/medium tasks.
75
+ """
76
+ if agent_df is None or len(agent_df) == 0:
77
+ return GradeResult(score=0.0, detail="Empty DataFrame — no score.")
78
+
79
+ if task_id == "easy":
80
+ return _grade_easy(agent_df, clean_df, metadata, initial_dirty_cells)
81
+ elif task_id == "medium":
82
+ return _grade_medium(agent_df, clean_df, metadata, initial_dirty_cells)
83
+ elif task_id == "hard":
84
+ return _grade_hard(agent_df, clean_df, metadata)
85
+ else:
86
+ raise ValueError(f"Unknown task_id: {task_id!r}")
87
+
88
+
89
+ # ─────────────────────────────────────────────────────────────────────────────
90
+ # Task 1 — easy: cell-level match against ground truth
91
+ # ─────────────────────────────────────────────────────────────────────────────
92
+ #
93
+ # Score = (cells matching ground truth) / (total cells)
94
+ #
95
+ # "Matching" is defined after normalisation:
96
+ # - strip leading/trailing whitespace
97
+ # - numeric columns: round to 2dp, compare as float strings
98
+ # - date column: accept YYYY-MM-DD only
99
+ # - string columns: case-sensitive exact match after strip
100
+ # - NaN vs NaN → always mismatch (agent must fill or fix them)
101
+
102
+ def _grade_easy(
103
+ agent_df: pd.DataFrame,
104
+ clean_df: pd.DataFrame,
105
+ metadata: Dict[str, Any],
106
+ initial_dirty_cells: int,
107
+ ) -> GradeResult:
108
+
109
+ # Align shape — agent might have different row count if they accidentally
110
+ # dropped rows; penalise by treating missing rows as all-wrong.
111
+ agent_norm = _normalise_easy(agent_df, clean_df)
112
+ clean_norm = _normalise_easy(clean_df, clean_df)
113
+
114
+ total_cells = clean_norm.size
115
+
116
+ # Pad or truncate agent rows to match clean row count
117
+ if len(agent_norm) < len(clean_norm):
118
+ pad = pd.DataFrame(
119
+ [["__MISSING__"] * len(clean_norm.columns)] * (len(clean_norm) - len(agent_norm)),
120
+ columns=clean_norm.columns,
121
+ )
122
+ agent_norm = pd.concat([agent_norm, pad], ignore_index=True)
123
+ elif len(agent_norm) > len(clean_norm):
124
+ agent_norm = agent_norm.iloc[: len(clean_norm)].copy()
125
+
126
+ matches = (agent_norm == clean_norm).sum().sum()
127
+ score = matches / total_cells
128
+
129
+ # Issues remaining: number of cells that still differ
130
+ mismatches = int((agent_norm != clean_norm).sum().sum())
131
+
132
+ breakdown = {
133
+ "cell_match_ratio": round(score, 4),
134
+ "cells_matched": int(matches),
135
+ "total_cells": int(total_cells),
136
+ "cells_mismatched": mismatches,
137
+ }
138
+
139
+ detail = (
140
+ f"{int(matches)}/{total_cells} cells correct "
141
+ f"({100*score:.1f}%) — {mismatches} still need fixing."
142
+ )
143
+
144
+ return GradeResult(
145
+ score=score,
146
+ breakdown=breakdown,
147
+ issues_remaining=mismatches,
148
+ detail=detail,
149
+ )
150
+
151
+
152
+ def _normalise_easy(df: pd.DataFrame, clean_df: pd.DataFrame) -> pd.DataFrame:
153
+ """
154
+ Bring a DataFrame to a canonical string form for cell-level comparison.
155
+
156
+ Rules applied per column based on clean_df's dtype:
157
+ - Numeric (price, quantity): round to 2 decimal places → string
158
+ - Date (order_date): parse and reformat to YYYY-MM-DD
159
+ - String (all others): strip whitespace, leave case unchanged
160
+ - NaN / unparseable: normalise to the sentinel "__NAN__"
161
+ """
162
+ out = {}
163
+ NUMERIC_COLS = {"price", "quantity"}
164
+ DATE_COLS = {"order_date"}
165
+
166
+ for col in clean_df.columns:
167
+ if col not in df.columns:
168
+ # Agent removed or renamed the column — all cells wrong
169
+ out[col] = pd.Series(["__MISSING_COL__"] * len(df))
170
+ continue
171
+
172
+ series = df[col].copy()
173
+
174
+ if col in NUMERIC_COLS:
175
+ out[col] = series.apply(_to_numeric_str)
176
+ elif col in DATE_COLS:
177
+ out[col] = series.apply(_to_date_str)
178
+ else:
179
+ out[col] = series.apply(
180
+ lambda x: "__NAN__" if _is_missing(x) else str(x).strip()
181
+ )
182
+
183
+ return pd.DataFrame(out, dtype=str)
184
+
185
+
186
+ def _to_numeric_str(x: Any) -> str:
187
+ if _is_missing(x):
188
+ return "__NAN__"
189
+ try:
190
+ return f"{float(str(x).strip().replace(',', '')):.2f}"
191
+ except (ValueError, TypeError):
192
+ return "__INVALID__"
193
+
194
+
195
+ def _to_date_str(x: Any) -> str:
196
+ if _is_missing(x):
197
+ return "__NAN__"
198
+ s = str(x).strip()
199
+ # Reject obviously wrong dates (e.g. year 2099)
200
+ try:
201
+ parsed = pd.to_datetime(s, dayfirst=False)
202
+ if parsed.year > 2030 or parsed.year < 2000:
203
+ return "__BAD_DATE__"
204
+ return parsed.strftime("%Y-%m-%d")
205
+ except Exception:
206
+ return "__INVALID_DATE__"
207
+
208
+
209
+ def _is_missing(x: Any) -> bool:
210
+ if x is None:
211
+ return True
212
+ try:
213
+ return bool(pd.isna(x))
214
+ except (TypeError, ValueError):
215
+ return False
216
+
217
+
218
+ # ─────────────────────────────────────────────────────────────────────────────
219
+ # Task 2 — medium: F1 on outlier detection + typo correction
220
+ # ─────────────────────────────────────────────────────────────────────────────
221
+ #
222
+ # Two independent sub-scores, equally weighted:
223
+ #
224
+ # outlier_f1 — precision/recall on which rows were fixed or removed
225
+ # typo_score — fraction of category typo-cells correctly fixed
226
+ #
227
+ # Final score = 0.50 * outlier_f1 + 0.50 * typo_score
228
+ #
229
+ # Outlier logic:
230
+ # A true-outlier row is "correctly handled" if:
231
+ # (a) the row still exists AND amount is now in [5, 800], OR
232
+ # (b) the row was dropped entirely
233
+ # A valid-extreme row is a "false positive" if it was dropped OR
234
+ # its amount was changed to something outside [900, 2000].
235
+ #
236
+ # The thresholds match the schema_hint the agent was given.
237
+
238
+ _VALID_AMOUNT_MIN = 5.0
239
+ _VALID_AMOUNT_MAX = 800.0
240
+ _EXTREME_AMOUNT_MIN = 900.0
241
+ _EXTREME_AMOUNT_MAX = 2000.0
242
+
243
+
244
+ def _grade_medium(
245
+ agent_df: pd.DataFrame,
246
+ clean_df: pd.DataFrame,
247
+ metadata: Dict[str, Any],
248
+ initial_dirty_cells: int,
249
+ ) -> GradeResult:
250
+
251
+ outlier_rows: List[int] = metadata.get("outlier_rows", [])
252
+ valid_extreme_rows: List[int] = metadata.get("valid_extreme_rows", [])
253
+ typo_cells: List[Tuple[int, str, str]] = metadata.get("typo_cells", [])
254
+
255
+ # ── Outlier sub-score ────────────────────────────────────────────────────
256
+ # Detect which of the original row indices are still present in agent_df.
257
+ # We track by tx_id (which is stable and unique) rather than df index,
258
+ # since the agent may reset the index after dropping rows.
259
+ agent_tx_ids: set = set()
260
+ if "tx_id" in agent_df.columns:
261
+ agent_tx_ids = set(agent_df["tx_id"].dropna().astype(int).tolist())
262
+
263
+ tp = 0 # outlier rows that were correctly handled
264
+ fn = 0 # outlier rows still wrong (extreme amount still present)
265
+ fp = 0 # valid-extreme rows wrongly removed or damaged
266
+
267
+ # True-positive check
268
+ for orig_idx in outlier_rows:
269
+ tx_id_val = int(clean_df.iloc[orig_idx]["tx_id"]) if orig_idx < len(clean_df) else None
270
+ if tx_id_val is None:
271
+ continue
272
+ if tx_id_val not in agent_tx_ids:
273
+ # Row was dropped — counts as correctly handled (outlier removed)
274
+ tp += 1
275
+ else:
276
+ # Row still present — check if amount was fixed
277
+ agent_row = agent_df[agent_df["tx_id"].astype(int) == tx_id_val]
278
+ if len(agent_row) == 0:
279
+ tp += 1 # dropped after all
280
+ else:
281
+ amt = _safe_float(agent_row.iloc[0].get("amount"))
282
+ if amt is not None and _VALID_AMOUNT_MIN <= amt <= _VALID_AMOUNT_MAX:
283
+ tp += 1
284
+ else:
285
+ fn += 1
286
+
287
+ # False-positive check (valid extremes must survive untouched)
288
+ for orig_idx in valid_extreme_rows:
289
+ if orig_idx >= len(clean_df):
290
+ continue
291
+ tx_id_val = int(clean_df.iloc[orig_idx]["tx_id"])
292
+ clean_amt = float(clean_df.iloc[orig_idx]["amount"])
293
+
294
+ if tx_id_val not in agent_tx_ids:
295
+ fp += 1 # wrongly dropped a valid row
296
+ else:
297
+ agent_row = agent_df[agent_df["tx_id"].astype(int) == tx_id_val]
298
+ if len(agent_row) == 0:
299
+ fp += 1
300
+ else:
301
+ amt = _safe_float(agent_row.iloc[0].get("amount"))
302
+ # Accept if amount is within ±5% of original clean value
303
+ if amt is None or not (clean_amt * 0.95 <= amt <= clean_amt * 1.05):
304
+ fp += 1
305
+
306
+ n_outliers = len(outlier_rows)
307
+ precision = tp / (tp + fp + 1e-9)
308
+ recall = tp / (n_outliers + 1e-9)
309
+ outlier_f1 = (2 * precision * recall) / (precision + recall + 1e-9)
310
+
311
+ # ── Typo sub-score ───────────────────────────────────────────────────────
312
+ typo_correct = 0
313
+ for (row_idx, dirty_val, clean_val) in typo_cells:
314
+ if "tx_id" not in clean_df.columns or row_idx >= len(clean_df):
315
+ continue
316
+ tx_id_val = int(clean_df.iloc[row_idx]["tx_id"])
317
+ agent_rows = agent_df[agent_df["tx_id"].astype(int) == tx_id_val] \
318
+ if "tx_id" in agent_df.columns else pd.DataFrame()
319
+ if len(agent_rows) == 0:
320
+ continue # row dropped; neither credit nor penalty
321
+ agent_cat = str(agent_rows.iloc[0].get("category", "")).strip()
322
+ if agent_cat == clean_val:
323
+ typo_correct += 1
324
+
325
+ typo_score = typo_correct / max(len(typo_cells), 1)
326
+
327
+ # ── Combined score ───────────────────────────────────────────────────────
328
+ score = 0.50 * outlier_f1 + 0.50 * typo_score
329
+
330
+ # Approximate issues remaining: unsolved outliers + unsolved typos
331
+ issues_remaining = fn + (len(typo_cells) - typo_correct)
332
+
333
+ breakdown = {
334
+ "outlier_f1": round(outlier_f1, 4),
335
+ "outlier_tp": tp,
336
+ "outlier_fn": fn,
337
+ "outlier_fp": fp,
338
+ "precision": round(precision, 4),
339
+ "recall": round(recall, 4),
340
+ "typo_score": round(typo_score, 4),
341
+ "typos_fixed": typo_correct,
342
+ "typos_total": len(typo_cells),
343
+ "combined": round(score, 4),
344
+ }
345
+
346
+ detail = (
347
+ f"Outlier F1={outlier_f1:.3f} (TP={tp}, FP={fp}, FN={fn}) | "
348
+ f"Typos {typo_correct}/{len(typo_cells)} fixed → score={score:.3f}"
349
+ )
350
+
351
+ return GradeResult(
352
+ score=score,
353
+ breakdown=breakdown,
354
+ issues_remaining=issues_remaining,
355
+ detail=detail,
356
+ )
357
+
358
+
359
+ def _safe_float(x: Any) -> Optional[float]:
360
+ if _is_missing(x):
361
+ return None
362
+ try:
363
+ return float(str(x).strip().replace(",", ""))
364
+ except (ValueError, TypeError):
365
+ return None
366
+
367
+
368
+ # ─────────────────────────────────────────────────────────────────────────────
369
+ # Task 3 — hard: schema normalisation + deduplication + date formatting
370
+ # ──────────────────��──────────────────────────────────────────────────────────
371
+ #
372
+ # Three independent sub-scores:
373
+ #
374
+ # schema_score (weight 0.40)
375
+ # Fraction of canonical column names present in agent_df.
376
+ # Bonus: all 9 canonical columns present AND no extra columns → +0.1
377
+ #
378
+ # dedup_score (weight 0.35)
379
+ # How many of the 30 true duplicate tx records were removed.
380
+ # Penalises over-deletion (removing rows that were not duplicates).
381
+ # dedup_precision = removed_true_dups / (rows_removed + ε)
382
+ # dedup_recall = removed_true_dups / n_duplicate_pairs
383
+ # dedup_f1 = harmonic mean
384
+ #
385
+ # format_score (weight 0.25)
386
+ # Fraction of values in the purchase_date column (or canonical alias)
387
+ # that are valid YYYY-MM-DD strings.
388
+ #
389
+ # Final score = 0.40 * schema_score + 0.35 * dedup_score + 0.25 * format_score
390
+
391
+ _CANONICAL_COLS = [
392
+ "record_id", "customer_id", "full_name", "email",
393
+ "amount", "currency", "purchase_date", "product_name", "region",
394
+ ]
395
+
396
+ _ISO_DATE_PATTERN = re.compile(r"^\d{4}-\d{2}-\d{2}$")
397
+
398
+
399
+ def _grade_hard(
400
+ agent_df: pd.DataFrame,
401
+ clean_df: pd.DataFrame,
402
+ metadata: Dict[str, Any],
403
+ ) -> GradeResult:
404
+
405
+ canonical_lookup: Dict[str, str] = metadata.get("canonical_lookup", {})
406
+ n_clean_rows: int = metadata.get("n_clean_rows", len(clean_df))
407
+
408
+ # ── 1. Schema score ──────────────────────────────────────────────────────
409
+ schema_score, schema_detail = _grade_schema(agent_df, canonical_lookup)
410
+
411
+ # ── 2. Deduplication score ───────────────────────────────────────────────
412
+ dedup_score, dedup_detail = _grade_deduplication(
413
+ agent_df, clean_df, n_clean_rows, canonical_lookup
414
+ )
415
+
416
+ # ── 3. Date format score ─────────────────────────────────────────────────
417
+ format_score, format_detail = _grade_date_format(agent_df, canonical_lookup)
418
+
419
+ # ── Combined ─────────────────────────────────────────────────────────────
420
+ score = 0.40 * schema_score + 0.35 * dedup_score + 0.25 * format_score
421
+
422
+ # issues_remaining: rough proxy (unresolved column aliases + excess rows)
423
+ n_canonical_present = sum(
424
+ 1 for c in _CANONICAL_COLS if c in agent_df.columns
425
+ )
426
+ issues_remaining = (
427
+ (len(_CANONICAL_COLS) - n_canonical_present) # missing canonical cols
428
+ + max(0, len(agent_df) - n_clean_rows) # excess rows (dups not removed)
429
+ )
430
+
431
+ breakdown = {
432
+ "schema_score": round(schema_score, 4),
433
+ "dedup_score": round(dedup_score, 4),
434
+ "format_score": round(format_score, 4),
435
+ "combined": round(score, 4),
436
+ **{f"schema_{k}": v for k, v in schema_detail.items()},
437
+ **{f"dedup_{k}": v for k, v in dedup_detail.items()},
438
+ **{f"fmt_{k}": v for k, v in format_detail.items()},
439
+ }
440
+
441
+ detail = (
442
+ f"Schema={schema_score:.3f} | "
443
+ f"Dedup={dedup_score:.3f} | "
444
+ f"DateFmt={format_score:.3f} → score={score:.3f}"
445
+ )
446
+
447
+ return GradeResult(
448
+ score=score,
449
+ breakdown=breakdown,
450
+ issues_remaining=issues_remaining,
451
+ detail=detail,
452
+ )
453
+
454
+
455
+ def _grade_schema(
456
+ agent_df: pd.DataFrame,
457
+ canonical_lookup: Dict[str, str],
458
+ ) -> Tuple[float, Dict[str, Any]]:
459
+ """
460
+ Score how well the agent normalised column names.
461
+
462
+ Strategy:
463
+ - Build a set of "recognised" columns: canonical names + their aliases.
464
+ - For each canonical column, check if the agent has it (by canonical name).
465
+ - Partial credit per canonical column found.
466
+ - Small bonus if ALL 9 are present and no unrecognised extra columns remain.
467
+ """
468
+ agent_cols = set(agent_df.columns)
469
+ canonical_set = set(_CANONICAL_COLS)
470
+
471
+ # All known column names (canonical + every alias)
472
+ all_known = canonical_set | set(canonical_lookup.keys())
473
+
474
+ # Count canonical columns present
475
+ found = [c for c in _CANONICAL_COLS if c in agent_cols]
476
+ n_found = len(found)
477
+ base = n_found / len(_CANONICAL_COLS)
478
+
479
+ # Bonus: all canonical present AND no leftover alias columns
480
+ leftover_aliases = [c for c in agent_cols if c not in canonical_set]
481
+ all_present = n_found == len(_CANONICAL_COLS)
482
+ clean_rename = len(leftover_aliases) == 0
483
+
484
+ bonus = 0.10 if (all_present and clean_rename) else 0.0
485
+
486
+ score = min(1.0, base + bonus)
487
+
488
+ detail: Dict[str, Any] = {
489
+ "canonical_found": n_found,
490
+ "canonical_total": len(_CANONICAL_COLS),
491
+ "leftover_aliases": len(leftover_aliases),
492
+ "rename_bonus": bonus,
493
+ }
494
+ return score, detail
495
+
496
+
497
+ def _grade_deduplication(
498
+ agent_df: pd.DataFrame,
499
+ clean_df: pd.DataFrame,
500
+ n_clean_rows: int,
501
+ canonical_lookup: Dict[str, str],
502
+ ) -> Tuple[float, Dict[str, Any]]:
503
+ """
504
+ Score how well the agent removed duplicate rows.
505
+
506
+ We compare row counts and detect near-duplicate detection quality:
507
+ - n_injected_dups: 30 (hardcoded from dataset_factory)
508
+ - expected_final_rows: n_clean_rows (400)
509
+ - rows_removed: (raw dirty rows = 430) - len(agent_df)
510
+ - true_dups_removed: min(rows_removed, 30) if rows_removed ≤ 35
511
+ (we're lenient — removing 1–35 rows likely targets dups)
512
+ - over_deletion: max(0, rows_removed - 30) rows beyond the dup count
513
+ penalises removing valid data.
514
+
515
+ Precision = true_dups_removed / (rows_removed + ε)
516
+ Recall = true_dups_removed / 30
517
+ F1 = harmonic mean
518
+ """
519
+ N_INJECTED_DUPS = 30
520
+ N_DIRTY_ROWS = n_clean_rows + N_INJECTED_DUPS # 430
521
+
522
+ rows_removed = max(0, N_DIRTY_ROWS - len(agent_df))
523
+
524
+ # Heuristic: any removal ≤ 35 rows is probably targeting dups
525
+ true_dups_removed = min(rows_removed, N_INJECTED_DUPS)
526
+
527
+ # Penalise over-removal (agent deleted valid rows beyond dups)
528
+ over_deletion = max(0, rows_removed - N_INJECTED_DUPS)
529
+ # Each over-deleted row reduces precision
530
+ effective_true = max(0, true_dups_removed - over_deletion)
531
+
532
+ precision = effective_true / (rows_removed + 1e-9)
533
+ recall = true_dups_removed / (N_INJECTED_DUPS + 1e-9)
534
+ f1 = (2 * precision * recall) / (precision + recall + 1e-9)
535
+
536
+ detail: Dict[str, Any] = {
537
+ "rows_removed": rows_removed,
538
+ "true_dups_removed": true_dups_removed,
539
+ "over_deletion": over_deletion,
540
+ "precision": round(precision, 4),
541
+ "recall": round(recall, 4),
542
+ "f1": round(f1, 4),
543
+ }
544
+ return f1, detail
545
+
546
+
547
+ def _grade_date_format(
548
+ agent_df: pd.DataFrame,
549
+ canonical_lookup: Dict[str, str],
550
+ ) -> Tuple[float, Dict[str, Any]]:
551
+ """
552
+ Fraction of purchase_date values matching YYYY-MM-DD.
553
+
554
+ Looks for the canonical name "purchase_date" first; falls back to
555
+ known aliases ("date", "PurchaseDate") if the agent hasn't renamed yet.
556
+ """
557
+ DATE_ALIASES = {"purchase_date", "date", "PurchaseDate"}
558
+
559
+ date_col = None
560
+ # Prefer canonical name
561
+ if "purchase_date" in agent_df.columns:
562
+ date_col = "purchase_date"
563
+ else:
564
+ for alias in DATE_ALIASES:
565
+ if alias in agent_df.columns:
566
+ date_col = alias
567
+ break
568
+
569
+ if date_col is None:
570
+ return 0.0, {"date_col_found": False, "valid_ratio": 0.0}
571
+
572
+ # Guard: duplicate column names after rename produce a DataFrame, not Series.
573
+ # Take the first occurrence.
574
+ col_data = agent_df[date_col]
575
+ if isinstance(col_data, pd.DataFrame):
576
+ col_data = col_data.iloc[:, 0]
577
+
578
+ # Force object dtype so .sum() always returns a numeric 0, not '' (the
579
+ # StringDtype identity). Python 3.14 + pandas 2.2+ infer StringDtype
580
+ # from .astype(str), which makes .sum() on an empty Series return ''.
581
+ series = col_data.dropna().astype(object).apply(str).str.strip()
582
+ n_total = len(series)
583
+ if n_total == 0:
584
+ return 0.0, {"date_col_found": True, "valid_ratio": 0.0, "n_total": 0}
585
+
586
+ # Combined check: ISO pattern match AND year in plausible range
587
+ def _is_valid_iso(s: str) -> bool:
588
+ if not _ISO_DATE_PATTERN.match(s):
589
+ return False
590
+ try:
591
+ return 2000 <= int(s[:4]) <= 2030
592
+ except Exception:
593
+ return False
594
+
595
+ valid_flags = series.apply(_is_valid_iso)
596
+ n_valid = int(valid_flags.sum()) # int() guards against numpy/pandas scalar types
597
+ n_year_ok = n_valid # same condition — kept for breakdown detail
598
+ valid_ratio = n_year_ok / n_total
599
+
600
+ detail: Dict[str, Any] = {
601
+ "date_col_found": True,
602
+ "date_col_used": date_col,
603
+ "n_total": int(n_total),
604
+ "n_valid_iso": int(n_valid),
605
+ "n_year_ok": int(n_year_ok),
606
+ "valid_ratio": round(valid_ratio, 4),
607
+ }
608
+ return valid_ratio, detail
609
+
610
+
611
+ # ─────────────────────────────────────────────────────────────────────────────
612
+ # Smoke test
613
+ # ────────────────────────────────────────────────────────────────���────────────
614
+
615
+ if __name__ == "__main__":
616
+ import sys
617
+ sys.path.insert(0, ".")
618
+ from dataset_factory import make_dataset
619
+
620
+ SEP = "─" * 62
621
+
622
+ # ── Task 1: easy ─────────────────────────────────────────────────────────
623
+ print(f"\n{SEP}\nTASK: easy\n{SEP}")
624
+ ds = make_dataset("easy")
625
+
626
+ # Baseline: grade dirty df (should be low)
627
+ r_dirty = grade("easy", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
628
+ print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
629
+
630
+ # Perfect: grade clean df (should be 1.0)
631
+ r_clean = grade("easy", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
632
+ print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
633
+
634
+ # Partial: fix half the injected cells
635
+ partial = ds.dirty_df.copy()
636
+ injected = ds.metadata.get("injected_cells", [])
637
+ for (row, col) in injected[:len(injected)//2]:
638
+ partial.at[row, col] = ds.clean_df.at[row, col]
639
+ r_partial = grade("easy", partial, ds.clean_df, ds.metadata, ds.total_dirty_cells)
640
+ print(f"[half] score={r_partial.score:.4f} {r_partial.detail}")
641
+
642
+ print(f"Breakdown: {r_partial.breakdown}")
643
+
644
+ # ── Task 2: medium ────────────────────────────────────────────────────────
645
+ print(f"\n{SEP}\nTASK: medium\n{SEP}")
646
+ ds = make_dataset("medium")
647
+
648
+ r_dirty = grade("medium", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
649
+ print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
650
+
651
+ r_clean = grade("medium", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
652
+ print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
653
+
654
+ # Simulate agent fixing all outliers (set amount to 150.0) + all typos
655
+ fixed = ds.dirty_df.copy()
656
+ for row in ds.metadata["outlier_rows"]:
657
+ if "tx_id" in ds.clean_df.columns:
658
+ fixed.at[row, "amount"] = 150.0
659
+ for (row, dirty_val, clean_val) in ds.metadata["typo_cells"]:
660
+ fixed.at[row, "category"] = clean_val
661
+ r_fixed = grade("medium", fixed, ds.clean_df, ds.metadata, ds.total_dirty_cells)
662
+ print(f"[fixed] score={r_fixed.score:.4f} {r_fixed.detail}")
663
+
664
+ print(f"Breakdown: {r_fixed.breakdown}")
665
+
666
+ # ── Task 3: hard ──────────────────────────────────────────────────────────
667
+ print(f"\n{SEP}\nTASK: hard\n{SEP}")
668
+ ds = make_dataset("hard")
669
+
670
+ r_dirty = grade("hard", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
671
+ print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
672
+
673
+ r_clean = grade("hard", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
674
+ print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
675
+
676
+ # Simulate partial fix: rename columns only, don't dedup or fix dates
677
+ partial_hard = ds.dirty_df.copy()
678
+ rename_map = ds.metadata.get("canonical_lookup", {})
679
+ partial_hard = partial_hard.rename(columns=rename_map)
680
+ # Keep only canonical columns that exist
681
+ canonical_present = [c for c in _CANONICAL_COLS if c in partial_hard.columns]
682
+ partial_hard = partial_hard[canonical_present]
683
+ r_renamed = grade("hard", partial_hard, ds.clean_df, ds.metadata, ds.total_dirty_cells)
684
+ print(f"[rename] score={r_renamed.score:.4f} {r_renamed.detail}")
685
+
686
+ print(f"Breakdown: {r_renamed.breakdown}")
inference.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ------------
4
+ Official submission inference script for the Data Cleaning Pipeline environment.
5
+
6
+ Reads from environment variables (ALL FREE — no paid API needed):
7
+ API_BASE_URL LLM endpoint. Default: HuggingFace free router.
8
+ MODEL_NAME Model to use. Default: free open model.
9
+ HF_TOKEN Your free HuggingFace token (hf_...).
10
+ LOCAL_IMAGE_NAME Docker image name if using from_docker_image().
11
+ Leave unset to connect via ENV_BASE_URL instead.
12
+ ENV_BASE_URL Direct server URL. Default: http://localhost:8000
13
+
14
+ STDOUT FORMAT (evaluator parses these lines exactly — do not modify):
15
+ [START] task=<n> env=<benchmark> model=<model>
16
+ [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
17
+ [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
18
+ """
19
+
20
+ import asyncio
21
+ import json
22
+ import os
23
+ import re
24
+ import sys
25
+ from typing import List, Optional
26
+ from unittest import result
27
+ from client import DataCleaningEnv, CleanAction, CleanObservation
28
+ from openai import OpenAI
29
+
30
+ # ── Environment client imports ────────────────────────────────────────────────
31
+ try:
32
+ from client import DataCleaningEnv
33
+ from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
34
+ except ImportError:
35
+ sys.path.insert(0, os.path.dirname(__file__))
36
+ from client import DataCleaningEnv
37
+ from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
38
+
39
+
40
+ # ── Configuration — all defaults are FREE ────────────────────────────────────
41
+
42
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
43
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
44
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
45
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
46
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
47
+
48
+ BENCHMARK = "data_cleaning_env"
49
+ TASK_IDS = ["easy", "medium", "hard"]
50
+
51
+ # Conservative budgets — keeps total runtime under 20 min on vcpu=2 / 8 GB
52
+ STEP_LIMITS = {"easy": 25, "medium": 50, "hard": 80}
53
+
54
+
55
+ # ── Official log helpers ──────────────────────────────────────────────────────
56
+ # Field names, order, and spacing match the evaluator spec exactly.
57
+
58
+ def log_start(task: str, env: str, model: str) -> None:
59
+ print(f"[START] task={task} env={env} model={model}", flush=True)
60
+
61
+
62
+ def log_step(
63
+ step: int,
64
+ action: str,
65
+ reward: float,
66
+ done: bool,
67
+ error: Optional[str],
68
+ ) -> None:
69
+ error_val = error if error else "null"
70
+ done_val = str(done).lower()
71
+ action_str = action[:80].replace("\n", " ") # keep line single-line
72
+ print(
73
+ f"[STEP] step={step} action={action_str} "
74
+ f"reward={reward:.2f} done={done_val} error={error_val}",
75
+ flush=True,
76
+ )
77
+
78
+
79
+ def log_end(
80
+ success: bool,
81
+ steps: int,
82
+ score: float,
83
+ rewards: List[float],
84
+ ) -> None:
85
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
86
+ print(
87
+ f"[END] success={str(success).lower()} steps={steps} "
88
+ f"score={score:.2f} rewards={rewards_str}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ # ── LLM helpers ───────────────────────────────────────────────────────────────
94
+
95
+ SYSTEM_PROMPT = (
96
+ "You are a data cleaning agent. You receive a dirty CSV and must fix it "
97
+ "step by step using JSON action commands. Fix the most impactful issues "
98
+ "first. Be precise — wrong column names cause errors. "
99
+ "Output a single valid JSON object and nothing else — no explanation, no markdown."
100
+ )
101
+
102
+
103
+ def build_prompt(obs) -> str:
104
+ rows = obs.dirty_csv.strip().split("\n")
105
+ preview = "\n".join(rows[:30])
106
+ truncated = len(rows) > 30
107
+ last_err = f"\nLast error: {obs.last_action_error}" if obs.last_action_error else ""
108
+ return (
109
+ f"Task: {obs.task_id}\n"
110
+ f"Schema: {obs.schema_hint}\n"
111
+ f"Score: {obs.current_score:.4f} | Issues remaining: {obs.issues_remaining}\n"
112
+ f"Step {obs.step_number}/{obs.max_steps}{last_err}\n"
113
+ f"\nCSV{' (first 30 rows)' if truncated else ''}:\n{preview}\n\n"
114
+ "Reply with ONE JSON action:\n"
115
+ ' {"command":"SET_VALUE", "row_index":<int>, "column":"<name>", "value":"<str>"}\n'
116
+ ' {"command":"DROP_ROW", "row_index":<int>}\n'
117
+ ' {"command":"STANDARDIZE_COL", "column":"<name>"}\n'
118
+ ' {"command":"FILL_MISSING", "column":"<name>", "fill_strategy":"mean|median|mode|drop"}\n'
119
+ ' {"command":"DONE"}\n'
120
+ "row_index = integer in the leftmost column of the CSV. JSON only."
121
+ )
122
+
123
+
124
+ def parse_action(raw: str) -> CleanAction:
125
+ """Convert model output to CleanAction. Falls back to DONE on any error."""
126
+ text = raw.strip()
127
+ if text.startswith("```"):
128
+ lines = text.split("\n")
129
+ inner = lines[1:-1] if lines[-1].strip().startswith("```") else lines[1:]
130
+ text = "\n".join(inner).strip()
131
+ try:
132
+ return CleanAction(**json.loads(text))
133
+ except Exception:
134
+ m = re.search(r"\{[^{}]+\}", text, re.DOTALL)
135
+ if m:
136
+ try:
137
+ return CleanAction(**json.loads(m.group()))
138
+ except Exception:
139
+ pass
140
+ return CleanAction(command="DONE")
141
+
142
+
143
+ def call_llm(client: OpenAI, messages: list) -> str:
144
+ response = client.chat.completions.create(
145
+ model=MODEL_NAME,
146
+ messages=messages,
147
+ max_tokens=150, # actions are short; saves free-tier quota
148
+ temperature=0.1,
149
+ )
150
+ return (response.choices[0].message.content or "").strip()
151
+
152
+
153
+ # ── Episode loop ───────────────────────────────────────────────────────────────
154
+
155
+ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
156
+ """Run one episode. Emits [START] → N×[STEP] → [END]."""
157
+ max_steps = STEP_LIMITS[task_id]
158
+ threshold = DONE_THRESHOLD[task_id]
159
+ rewards: List[float] = []
160
+ steps_taken = 0
161
+ score = 0.0
162
+ success = False
163
+
164
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
165
+
166
+ try:
167
+ result = await env.reset(task_id=task_id)
168
+ obs = result.observation
169
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
170
+
171
+ for step in range(1, max_steps + 1):
172
+ if obs.done:
173
+ break
174
+
175
+ steps_taken = step
176
+ messages.append({"role": "user", "content": build_prompt(obs)})
177
+
178
+ try:
179
+ raw = call_llm(client, messages)
180
+ action = parse_action(raw)
181
+ messages.append({"role": "assistant", "content": raw})
182
+ except Exception as exc:
183
+ # API or parse failure — log and stop episode
184
+ log_step(step, "DONE", 0.00, True, str(exc)[:120])
185
+ rewards.append(0.0)
186
+ break
187
+
188
+ # Keep only system + last 8 exchanges to stay inside free-tier context limits
189
+ if len(messages) > 17:
190
+ messages = [messages[0]] + messages[-16:]
191
+
192
+ result = await env.step(action)
193
+ obs = result.observation
194
+ reward = result.reward or 0.0
195
+ rewards.append(reward)
196
+ score = obs.current_score
197
+
198
+ log_step(
199
+ step = step,
200
+ action = action.command,
201
+ reward = reward,
202
+ done = obs.done,
203
+ error = obs.last_action_error,
204
+ )
205
+
206
+ if obs.done or score >= threshold:
207
+ break
208
+
209
+ success = score >= threshold
210
+
211
+ finally:
212
+ # [END] is always emitted, even if the episode crashed
213
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
214
+
215
+ return {"task_id": task_id, "score": score,
216
+ "reward": sum(rewards), "steps": steps_taken, "success": success}
217
+
218
+
219
+ # ── Entry point ────────────────────────────────────────────────────────────────
220
+
221
+ async def main() -> None:
222
+ if not HF_TOKEN:
223
+ print(
224
+ "ERROR: HF_TOKEN is not set.\n"
225
+ "1. Go to https://huggingface.co/settings/tokens\n"
226
+ "2. Click 'New token' → choose 'Read' → copy it\n"
227
+ "3. In PowerShell: $env:HF_TOKEN='hf_xxxxxxxxxxxx'\n"
228
+ "4. Then run: python inference.py",
229
+ file=sys.stderr,
230
+ )
231
+ sys.exit(1)
232
+
233
+ print(f"API_BASE_URL : {API_BASE_URL}", flush=True)
234
+ print(f"MODEL_NAME : {MODEL_NAME}", flush=True)
235
+ print(f"LOCAL_IMAGE_NAME : {LOCAL_IMAGE_NAME or '(not set — using ENV_BASE_URL)'}", flush=True)
236
+ print(f"ENV_BASE_URL : {ENV_BASE_URL}", flush=True)
237
+ print("", flush=True)
238
+
239
+ # ✅ Create llm and env in the correct order
240
+ llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
241
+
242
+ if LOCAL_IMAGE_NAME:
243
+ env = await DataCleaningEnv.from_docker_image(LOCAL_IMAGE_NAME)
244
+ else:
245
+ env = DataCleaningEnv(base_url=ENV_BASE_URL)
246
+ await env.connect()
247
+
248
+ results = []
249
+ try:
250
+ for task_id in TASK_IDS:
251
+ summary = await run_episode(env, llm, task_id)
252
+ results.append(summary)
253
+ print("", flush=True)
254
+ finally:
255
+ await env.close()
256
+
257
+ # Human-readable summary (evaluator ignores lines that don't start with [START]/[STEP]/[END])
258
+ print("=" * 56, flush=True)
259
+ print(f"{'Task':<10} {'Score':>7} {'Reward':>9} {'Steps':>6} {'Pass':>5}")
260
+ print("-" * 56, flush=True)
261
+ for r in results:
262
+ print(
263
+ f"{r['task_id']:<10} {r['score']:>7.4f} {r['reward']:>9.4f} "
264
+ f"{r['steps']:>6} {'YES' if r['success'] else 'NO':>4}",
265
+ flush=True,
266
+ )
267
+ print("=" * 56, flush=True)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models.py
3
+ ---------
4
+ Pydantic models for the Data Cleaning Pipeline environment.
5
+
6
+ Three models define the full agent↔environment contract:
7
+
8
+ CleanAction — what the agent sends on each step
9
+ CleanObservation — what the agent receives back
10
+ CleanState — internal server state (not sent to agent directly)
11
+
12
+ Inheritance chain (confirmed from OpenEnv source):
13
+ Action → extra="forbid", has: metadata: Dict[str, Any]
14
+ Observation → extra="forbid", has: done: bool, reward: float|None, metadata: Dict[str, Any]
15
+ State → extra="allow", has: episode_id: Optional[str], step_count: int
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Any, Dict, List, Literal, Optional
21
+
22
+ from pydantic import Field, field_validator, model_validator
23
+
24
+ try:
25
+ from openenv.core.env_server.types import Action, Observation, State
26
+ except ImportError:
27
+ # Fallback for local development without the full OpenEnv install
28
+ from openenv.core.env_server import Action, Observation, State
29
+
30
+
31
+ # ── Valid values (used by validators + schema hints) ──────────────────────────
32
+
33
+ VALID_COMMANDS = Literal[
34
+ "SET_VALUE", # Fix a specific cell: (row_index, column, value)
35
+ "DROP_ROW", # Remove an entire row: (row_index,)
36
+ "STANDARDIZE_COL", # Normalize an entire column's format: (column,)
37
+ "FILL_MISSING", # Fill NaN values in a column: (column, fill_strategy)
38
+ "DONE", # Agent signals episode is complete: ()
39
+ ]
40
+
41
+ VALID_FILL_STRATEGIES = Literal["mean", "median", "mode", "drop"]
42
+
43
+ VALID_TASK_IDS = Literal["easy", "medium", "hard"]
44
+
45
+
46
+ # ─────────────────────────────────────────────────────────────────────────────
47
+ # CleanAction
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+
50
+ class CleanAction(Action):
51
+ """Action sent by the agent each step.
52
+
53
+ The ``command`` field selects the operation. Depending on command,
54
+ only a subset of the remaining fields are required:
55
+
56
+ +-----------------+------------+--------+-------+---------------+
57
+ | command | row_index | column | value | fill_strategy |
58
+ +=================+============+========+=======+===============+
59
+ | SET_VALUE | required | req | req | — |
60
+ | DROP_ROW | required | — | — | — |
61
+ | STANDARDIZE_COL | — | req | — | — |
62
+ | FILL_MISSING | — | req | — | required |
63
+ | DONE | — | — | — | — |
64
+ +-----------------+------------+--------+-------+---------------+
65
+
66
+ Example (fix a single cell)::
67
+
68
+ CleanAction(
69
+ command="SET_VALUE",
70
+ row_index=3,
71
+ column="price",
72
+ value="29.99",
73
+ )
74
+
75
+ Example (drop a whole row)::
76
+
77
+ CleanAction(command="DROP_ROW", row_index=17)
78
+
79
+ Example (fill all NaN in a column with the median)::
80
+
81
+ CleanAction(
82
+ command="FILL_MISSING",
83
+ column="quantity",
84
+ fill_strategy="median",
85
+ )
86
+ """
87
+
88
+ command: VALID_COMMANDS = Field(
89
+ ...,
90
+ description=(
91
+ "Operation to perform. One of: SET_VALUE, DROP_ROW, "
92
+ "STANDARDIZE_COL, FILL_MISSING, DONE."
93
+ ),
94
+ )
95
+
96
+ row_index: Optional[int] = Field(
97
+ default=None,
98
+ ge=0,
99
+ description=(
100
+ "Zero-based row index to target. "
101
+ "Required for SET_VALUE and DROP_ROW."
102
+ ),
103
+ )
104
+
105
+ column: Optional[str] = Field(
106
+ default=None,
107
+ min_length=1,
108
+ description=(
109
+ "Name of the column to target. "
110
+ "Required for SET_VALUE, STANDARDIZE_COL, and FILL_MISSING."
111
+ ),
112
+ )
113
+
114
+ value: Optional[str] = Field(
115
+ default=None,
116
+ description=(
117
+ "New cell value as a string. "
118
+ "Required for SET_VALUE. The environment casts this to the "
119
+ "column's expected dtype (e.g. '29.99' → float for a price column)."
120
+ ),
121
+ )
122
+
123
+ fill_strategy: Optional[VALID_FILL_STRATEGIES] = Field(
124
+ default=None,
125
+ description=(
126
+ "Strategy for FILL_MISSING. One of: mean, median, mode, drop. "
127
+ "'drop' removes rows where the column is NaN."
128
+ ),
129
+ )
130
+
131
+ @model_validator(mode="after")
132
+ def _check_required_fields(self) -> "CleanAction":
133
+ """Ensure each command has exactly the fields it needs."""
134
+ cmd = self.command
135
+
136
+ if cmd == "SET_VALUE":
137
+ missing = []
138
+ if self.row_index is None:
139
+ missing.append("row_index")
140
+ if self.column is None:
141
+ missing.append("column")
142
+ if self.value is None:
143
+ missing.append("value")
144
+ if missing:
145
+ raise ValueError(
146
+ f"SET_VALUE requires: {', '.join(missing)}"
147
+ )
148
+
149
+ elif cmd == "DROP_ROW":
150
+ if self.row_index is None:
151
+ raise ValueError("DROP_ROW requires row_index")
152
+
153
+ elif cmd == "STANDARDIZE_COL":
154
+ if self.column is None:
155
+ raise ValueError("STANDARDIZE_COL requires column")
156
+
157
+ elif cmd == "FILL_MISSING":
158
+ missing = []
159
+ if self.column is None:
160
+ missing.append("column")
161
+ if self.fill_strategy is None:
162
+ missing.append("fill_strategy")
163
+ if missing:
164
+ raise ValueError(
165
+ f"FILL_MISSING requires: {', '.join(missing)}"
166
+ )
167
+
168
+ # DONE requires nothing — always valid
169
+
170
+ return self
171
+
172
+ @field_validator("row_index")
173
+ @classmethod
174
+ def _non_negative_row(cls, v: Optional[int]) -> Optional[int]:
175
+ if v is not None and v < 0:
176
+ raise ValueError(f"row_index must be >= 0, got {v}")
177
+ return v
178
+
179
+
180
+ # ─────────────────────────────────────────────────────────────────────────────
181
+ # CleanObservation
182
+ # ─────────────────────────────────────────────────────────────────────────────
183
+
184
+ class CleanObservation(Observation):
185
+ """Observation returned to the agent after each step (and at reset).
186
+
187
+ The agent sees the full current state of the dirty CSV at every step
188
+ so it can decide what to fix next. This is intentionally verbose —
189
+ passing the whole CSV string keeps the environment stateless from the
190
+ agent's perspective (no hidden memory needed).
191
+
192
+ Inherited from Observation (do NOT redeclare these):
193
+ done: bool — True when the episode has ended
194
+ reward: float | None — per-step reward (None at reset)
195
+ metadata: Dict[str, Any] — extra info (unused by core loop)
196
+ """
197
+
198
+ # ── Task context (set at reset, constant for the episode) ────────────────
199
+
200
+ task_id: VALID_TASK_IDS = Field(
201
+ ...,
202
+ description="Which task is active: 'easy', 'medium', or 'hard'.",
203
+ )
204
+
205
+ schema_hint: str = Field(
206
+ ...,
207
+ description=(
208
+ "Plain-English description of the target schema. "
209
+ "Tells the agent what the clean data should look like."
210
+ ),
211
+ )
212
+
213
+ initial_dirty_cells: int = Field(
214
+ ...,
215
+ ge=0,
216
+ description=(
217
+ "Total number of cells that differed from ground truth at episode start. "
218
+ "Used to compute a normalised progress score."
219
+ ),
220
+ )
221
+
222
+ # ── Per-step state ───────────────────────────────────────────────────────
223
+
224
+ dirty_csv: str = Field(
225
+ ...,
226
+ description=(
227
+ "Full current state of the working DataFrame serialised as a CSV string. "
228
+ "This reflects all changes the agent has made so far this episode."
229
+ ),
230
+ )
231
+
232
+ current_score: float = Field(
233
+ default=0.0,
234
+ ge=0.0,
235
+ le=1.0,
236
+ description=(
237
+ "Grader score after the last action (0.0 = no cells correct, "
238
+ "1.0 = perfect match with ground truth)."
239
+ ),
240
+ )
241
+
242
+ issues_remaining: int = Field(
243
+ default=0,
244
+ ge=0,
245
+ description=(
246
+ "Approximate count of cells still differing from ground truth. "
247
+ "Convenience field — agents can also derive this from the CSV."
248
+ ),
249
+ )
250
+
251
+ step_number: int = Field(
252
+ default=0,
253
+ ge=0,
254
+ description="How many steps have been taken in this episode so far.",
255
+ )
256
+
257
+ max_steps: int = Field(
258
+ ...,
259
+ ge=1,
260
+ description="Maximum steps allowed for this task before forced termination.",
261
+ )
262
+
263
+ # ── Last-action feedback ────────────────────────────────────────────────
264
+
265
+ last_action_success: bool = Field(
266
+ default=True,
267
+ description=(
268
+ "Whether the last action was applied without errors. "
269
+ "False if the column/row didn't exist, value couldn't be cast, etc."
270
+ ),
271
+ )
272
+
273
+ last_action_error: Optional[str] = Field(
274
+ default=None,
275
+ description=(
276
+ "Error message if last_action_success is False, else None. "
277
+ "Helps the agent self-correct."
278
+ ),
279
+ )
280
+
281
+ @field_validator("current_score")
282
+ @classmethod
283
+ def _round_score(cls, v: float) -> float:
284
+ return round(v, 4)
285
+
286
+
287
+ # ─────────────────────────────────────────────────────────────────────────────
288
+ # CleanState
289
+ # ─────────────────────────────────────────────────────────────────────────────
290
+
291
+ class CleanState(State):
292
+ """Internal server-side state. Never sent to the agent directly.
293
+
294
+ Holds the live DataFrames, ground truth, and grader metadata.
295
+ Because State uses extra="allow", we can store arbitrary fields
296
+ without listing them in the JSON schema.
297
+
298
+ Inherited from State:
299
+ episode_id: Optional[str] — unique episode identifier
300
+ step_count: int — steps taken this episode (ge=0)
301
+ """
302
+
303
+ # ── Task identity ────────────────────────────────────────────────────────
304
+
305
+ task_id: str = Field(
306
+ default="easy",
307
+ description="Active task: 'easy', 'medium', or 'hard'.",
308
+ )
309
+
310
+ # ── DataFrame snapshots (stored as CSV strings for serialisation) ────────
311
+ # NOTE: The environment keeps live pd.DataFrame objects in instance vars.
312
+ # These string fields are the serialised snapshots used by state() calls
313
+ # and for WebSocket state responses.
314
+
315
+ dirty_csv_snapshot: str = Field(
316
+ default="",
317
+ description="Current working DataFrame serialised to CSV string.",
318
+ )
319
+
320
+ clean_csv_snapshot: str = Field(
321
+ default="",
322
+ description="Ground-truth clean DataFrame serialised to CSV string.",
323
+ )
324
+
325
+ # ── Scoring ──────────────────────────────────────────────────────────────
326
+
327
+ initial_dirty_cells: int = Field(
328
+ default=0,
329
+ ge=0,
330
+ description="Dirty cell count at episode start (denominator for progress).",
331
+ )
332
+
333
+ current_score: float = Field(
334
+ default=0.0,
335
+ ge=0.0,
336
+ le=1.0,
337
+ description="Grader score after the last step.",
338
+ )
339
+
340
+ previous_score: float = Field(
341
+ default=0.0,
342
+ ge=0.0,
343
+ le=1.0,
344
+ description="Grader score before the last step (for reward delta).",
345
+ )
346
+
347
+ # ── Task metadata (passed through from TaskDataset.metadata) ─────────────
348
+ # Contains grader-specific ground truth: outlier_rows, canonical_lookup, etc.
349
+
350
+ task_metadata: Dict[str, Any] = Field(
351
+ default_factory=dict,
352
+ description=(
353
+ "Task-specific metadata from dataset_factory.TaskDataset.metadata. "
354
+ "Contains grader ground truth (outlier_rows, duplicate_pairs, etc.)."
355
+ ),
356
+ )
357
+
358
+ # ── Schema hint (echoed in observations) ────────────────────────────────
359
+
360
+ schema_hint: str = Field(
361
+ default="",
362
+ description="Plain-English schema description for this task.",
363
+ )
364
+
365
+ # ── Per-task step budget ─────────────────────────────────────────────────
366
+
367
+ max_steps: int = Field(
368
+ default=40,
369
+ ge=1,
370
+ description="Maximum steps for this task (40 / 80 / 150 for easy/medium/hard).",
371
+ )
372
+
373
+ @field_validator("current_score", "previous_score")
374
+ @classmethod
375
+ def _clamp_score(cls, v: float) -> float:
376
+ return round(max(0.0, min(1.0, v)), 4)
377
+
378
+
379
+ # ── Step budget constants ─────────────────────────────────────────────────────
380
+
381
+ MAX_STEPS: Dict[str, int] = {
382
+ "easy": 40,
383
+ "medium": 80,
384
+ "hard": 150,
385
+ }
386
+
387
+ # Done threshold: score at which the agent is considered successful
388
+ DONE_THRESHOLD: Dict[str, float] = {
389
+ "easy": 0.95,
390
+ "medium": 0.85,
391
+ "hard": 0.80,
392
+ }
393
+
394
+
395
+ # ── Smoke test ────────────────────────────────────────────────────────────────
396
+
397
+ if __name__ == "__main__":
398
+ import json
399
+
400
+ print("── CleanAction examples ──────────────────────────────────────")
401
+
402
+ a1 = CleanAction(command="SET_VALUE", row_index=3, column="price", value="29.99")
403
+ print("SET_VALUE: ", a1.model_dump())
404
+
405
+ a2 = CleanAction(command="DROP_ROW", row_index=17)
406
+ print("DROP_ROW: ", a2.model_dump())
407
+
408
+ a3 = CleanAction(command="FILL_MISSING", column="quantity", fill_strategy="median")
409
+ print("FILL_MISSING: ", a3.model_dump())
410
+
411
+ a4 = CleanAction(command="STANDARDIZE_COL", column="order_date")
412
+ print("STANDARDIZE_COL:", a4.model_dump())
413
+
414
+ a5 = CleanAction(command="DONE")
415
+ print("DONE: ", a5.model_dump())
416
+
417
+ # Validation: SET_VALUE without row_index should fail
418
+ print("\n── Validation ────────────────────────────────────────────────")
419
+ try:
420
+ bad = CleanAction(command="SET_VALUE", column="price", value="10.0")
421
+ except Exception as e:
422
+ print(f"Expected error (missing row_index): {e}")
423
+
424
+ try:
425
+ bad = CleanAction(command="FILL_MISSING", column="price")
426
+ except Exception as e:
427
+ print(f"Expected error (missing fill_strategy): {e}")
428
+
429
+ print("\n── CleanObservation ──────────────────────────────────────────")
430
+ obs = CleanObservation(
431
+ task_id="easy",
432
+ schema_hint="Sales orders dataset. price must be float.",
433
+ initial_dirty_cells=29,
434
+ dirty_csv="order_id,price\n1001,N/A\n1002,19.99",
435
+ current_score=0.0,
436
+ issues_remaining=29,
437
+ step_number=0,
438
+ max_steps=40,
439
+ done=False,
440
+ reward=None,
441
+ )
442
+ print(json.dumps(obs.model_dump(), indent=2))
443
+
444
+ print("\n── CleanState ────────────────────────────────────────────────")
445
+ state = CleanState(
446
+ episode_id="ep-001",
447
+ step_count=0,
448
+ task_id="easy",
449
+ dirty_csv_snapshot="order_id,price\n1001,N/A",
450
+ clean_csv_snapshot="order_id,price\n1001,14.99",
451
+ initial_dirty_cells=29,
452
+ current_score=0.0,
453
+ previous_score=0.0,
454
+ task_metadata={"injected_cells": [(0, "price")]},
455
+ schema_hint="Sales orders dataset.",
456
+ max_steps=40,
457
+ )
458
+ print(json.dumps(state.model_dump(), indent=2))
459
+
460
+ print("\n── JSON schemas ──────────────────────────────────────────────")
461
+ print("Action schema keys: ", list(CleanAction.model_json_schema()["properties"].keys()))
462
+ print("Observation schema keys:", list(CleanObservation.model_json_schema()["properties"].keys()))
463
+ print("State schema keys: ", list(CleanState.model_json_schema()["properties"].keys()))
openenv.yaml ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openenv.yaml
2
+ # ─────────────────────────────────────────────────────────────────────────────
3
+ # Manifest for the Data Cleaning Pipeline OpenEnv environment.
4
+ #
5
+ # Field reference
6
+ # ───────────────
7
+ # Required by the CLI (serve / build / push / validate):
8
+ # spec_version — always 1 for this generation of the spec
9
+ # name — environment identifier used by the CLI and auto-discovery
10
+ # type — "space" means it can be deployed as a Hugging Face Space
11
+ # runtime — "fastapi" tells the server how to boot
12
+ # app — Python import path to the FastAPI app object
13
+ # port — port the server listens on inside the container
14
+ #
15
+ # Read by AutoEnv auto-discovery (openenv.auto._discovery):
16
+ # name — maps to env_key after stripping the "_env" suffix
17
+ # description — human-readable label shown in env listings
18
+ # spec_version — stored in EnvironmentInfo for introspection
19
+ # action — EXPLICIT override of the auto-inferred class name
20
+ # observation — EXPLICIT override of the auto-inferred class name
21
+ #
22
+ # NOTE on action / observation overrides:
23
+ # Auto-discovery infers class names from the env name using PascalCase:
24
+ # "data_cleaning_env" → base "data_cleaning" → "CleanAction"
25
+ # Our actual class is named "CleanAction" (not "CleanAction"),
26
+ # so these fields MUST be set to avoid ImportError on AutoEnv.from_env().
27
+ #
28
+ # All other fields (tasks, reward, tags) are informational. They are not
29
+ # parsed by the current OpenEnv tooling but are preserved in
30
+ # EnvironmentInfo.manifest and available to the web UI and external tools.
31
+ # ─────────────────────────────────────────────────────────────────────────────
32
+
33
+ # ── Core deployment fields ────────────────────────────────────────────────────
34
+
35
+ spec_version: 1
36
+ name: data_cleaning_env
37
+ type: space
38
+ runtime: fastapi
39
+ app: server.app:app
40
+ port: 8000
41
+
42
+ # ── Package metadata ──────────────────────────────────────────────────────────
43
+
44
+ version: "1.0.0"
45
+
46
+ description: >-
47
+ Data cleaning pipeline: the agent receives a dirty CSV and must detect
48
+ and fix type errors, missing values, outliers, and schema inconsistencies
49
+ to match a hidden ground-truth dataset. Three tasks (easy → medium → hard)
50
+ with a deterministic grader that returns a continuous score in [0.0, 1.0].
51
+
52
+ # ── Auto-discovery class overrides ───────────────────────────────────────────
53
+ # These override auto-inferred names (which would be CleanAction /
54
+ # CleanAction) to match the actual class names defined in models.py.
55
+
56
+ action: CleanAction
57
+ observation: CleanObservation
58
+
59
+ # The client class is correctly inferred as DataCleaningEnv (data_cleaning →
60
+ # DataCleaning + Env), which matches client.py, so no override is needed.
61
+
62
+ # ── Tags (informational) ──────────────────────────────────────────────────────
63
+
64
+ tags:
65
+ - data-cleaning
66
+ - tabular
67
+ - real-world
68
+ - hackathon
69
+
70
+ # ── Task manifest (informational) ─────────────────────────────────────────────
71
+ # One entry per task. These values mirror the constants in models.py
72
+ # (MAX_STEPS, DONE_THRESHOLD) and the descriptions in dataset_factory.py.
73
+
74
+ tasks:
75
+ - id: easy
76
+ name: Fix obvious errors
77
+ description: >-
78
+ 50-row sales CSV with 29 injected dirty cells: 10 type mismatches
79
+ (text in numeric columns), 8 missing values, 5 far-future dates
80
+ (year 2099), and 6 cells with leading/trailing whitespace.
81
+ Graded by exact cell-level match against the ground truth (0.0–1.0).
82
+ dataset_rows: 50
83
+ dirty_cells: 29
84
+ max_steps: 40
85
+ done_threshold: 0.95
86
+
87
+ - id: medium
88
+ name: Outlier detection without false positives
89
+ description: >-
90
+ 200-row customer transaction CSV with 15 true statistical outliers
91
+ (negative or > $2000 amounts) that must be fixed or removed, 5 valid
92
+ large transactions ($900–$2000) that must NOT be removed, and 12
93
+ category spelling typos. Graded by F1 score on outlier detection
94
+ (0.5 weight) and typo correction rate (0.5 weight).
95
+ dataset_rows: 200
96
+ dirty_cells: 27
97
+ max_steps: 80
98
+ done_threshold: 0.85
99
+
100
+ - id: hard
101
+ name: Multi-source schema normalisation and deduplication
102
+ description: >-
103
+ 430-row CSV (400 clean + 30 duplicates) merged from 3 fictional data
104
+ sources with inconsistent column naming (e.g. cust_id / customer_id /
105
+ CustomerID), mixed date formats (ISO, US, EU), and ~30 duplicate rows
106
+ (exact and near-duplicate). Agent must infer the canonical 9-column
107
+ schema without explicit documentation. Graded by schema match (40%),
108
+ deduplication F1 (35%), and date format compliance (25%).
109
+ dataset_rows: 430
110
+ canonical_rows: 400
111
+ canonical_columns: 9
112
+ duplicate_rows: 30
113
+ max_steps: 150
114
+ done_threshold: 0.80
115
+
116
+ # ── Reward function summary (informational) ───────────────────────────────────
117
+
118
+ reward:
119
+ type: dense
120
+ range: [-0.5, 1.0]
121
+ step_cost: -0.005
122
+ components:
123
+ - name: progress
124
+ weight: primary
125
+ description: >-
126
+ Grader score delta each step (curr_score − prev_score).
127
+ The main learning signal — any cell fixed produces a non-zero reward.
128
+
129
+ - name: efficiency_bonus
130
+ weight: "+0.10 × (1 − step_fraction)"
131
+ description: >-
132
+ Small bonus awarded the step the episode is solved (score crosses
133
+ done_threshold). Rewards finishing early relative to the step budget.
134
+
135
+ - name: false_positive_penalty
136
+ weight: -0.15
137
+ description: >-
138
+ Applied when DROP_ROW removes a valid-extreme row in the medium task.
139
+ Penalises aggressive deletion without checking schema_hint.
140
+
141
+ - name: early_done_penalty
142
+ weight: -0.20
143
+ description: >-
144
+ Applied when the agent sends DONE with current_score < 0.60.
145
+ Discourages giving up prematurely.
146
+
147
+ - name: step_cost
148
+ weight: -0.005
149
+ description: >-
150
+ Fixed cost every step regardless of outcome.
151
+ Prevents infinite loops and padding.
openenv_data_cleaning_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-data_cleaning_env
3
+ Version: 0.1.0
4
+ Summary: Data Cleaning Env environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core
7
+ Requires-Dist: pandas>=2.0
8
+ Requires-Dist: numpy>=1.24
9
+ Requires-Dist: fastapi
10
+ Requires-Dist: uvicorn
11
+ Provides-Extra: dev
12
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
13
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_data_cleaning_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ __init__.py
3
+ client.py
4
+ dataset_factory.py
5
+ graders.py
6
+ models.py
7
+ pyproject.toml
8
+ ./__init__.py
9
+ ./client.py
10
+ ./dataset_factory.py
11
+ ./graders.py
12
+ ./models.py
13
+ openenv_data_cleaning_env.egg-info/PKG-INFO
14
+ openenv_data_cleaning_env.egg-info/SOURCES.txt
15
+ openenv_data_cleaning_env.egg-info/dependency_links.txt
16
+ openenv_data_cleaning_env.egg-info/entry_points.txt
17
+ openenv_data_cleaning_env.egg-info/requires.txt
18
+ openenv_data_cleaning_env.egg-info/top_level.txt
19
+ server/__init__.py
20
+ server/app.py
21
+ server/data_cleaning_env.py
openenv_data_cleaning_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_data_cleaning_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = data_cleaning_env.server.app:main
openenv_data_cleaning_env.egg-info/requires.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core
2
+ pandas>=2.0
3
+ numpy>=1.24
4
+ fastapi
5
+ uvicorn
6
+
7
+ [dev]
8
+ pytest>=8.0.0
9
+ pytest-cov>=4.0.0
openenv_data_cleaning_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ data_cleaning_env
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-data_cleaning_env"
13
+ version = "0.1.0"
14
+ description = "Data Cleaning Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ # Environment-specific dependencies
21
+ # Add all dependencies needed for your environment here
22
+ # Examples:
23
+ # "torch>=2.0.0",
24
+ # "gymnasium>=0.29.0",
25
+ # "openspiel>=1.0.0",
26
+ # "smolagents>=1.22.0,<2",
27
+ "openenv-core",
28
+ "pandas>=2.0",
29
+ "numpy>=1.24",
30
+ "fastapi",
31
+ "uvicorn",
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "pytest>=8.0.0",
37
+ "pytest-cov>=4.0.0",
38
+ ]
39
+
40
+ [project.scripts]
41
+ # Server entry point - enables running via: uv run --project . server
42
+ # or: python -m data_cleaning_env.server.app
43
+ server = "data_cleaning_env.server.app:main"
44
+
45
+ [tool.setuptools]
46
+ include-package-data = true
47
+ packages = ["data_cleaning_env", "data_cleaning_env.server"]
48
+ package-dir = { "data_cleaning_env" = ".", "data_cleaning_env.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Data Cleaning Env environment server components."""
8
+
9
+ from .data_cleaning_env import DataCleaningEnvironment
10
+
11
+ __all__ = ["DataCleaningEnvironment"]
server/app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from openenv.core.env_server import create_app
3
+ from ..models import CleanAction, CleanObservation
4
+ from .data_cleaning_env import DataCleaningEnvironment
5
+ except ImportError:
6
+ from openenv.core.env_server import create_app
7
+ from models import CleanAction, CleanObservation
8
+ from server.data_cleaning_env import DataCleaningEnvironment
9
+
10
+ app = create_app(
11
+ DataCleaningEnvironment, # class, not instance
12
+ CleanAction,
13
+ CleanObservation,
14
+ env_name="data_cleaning_env",
15
+ )
16
+
17
+
18
+ def main() -> None:
19
+ """Entry point for openenv serve / uv run / python -m."""
20
+ import uvicorn
21
+ uvicorn.run(app, host="0.0.0.0", port=8000)
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()
server/data_cleaning_env.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/data_cleaning_env.py
3
+ ---------------------------
4
+ DataCleaningEnvironment — the heart of the environment.
5
+
6
+ Implements the three abstract methods from openenv.core.env_server.interfaces.Environment:
7
+ reset(seed, episode_id, **kwargs) -> CleanObservation
8
+ step(action, timeout_s, **kwargs) -> CleanObservation
9
+ state (property) -> CleanState
10
+
11
+ Architecture
12
+ ------------
13
+ Live DataFrames (_dirty_df, _clean_df) live as instance variables for speed.
14
+ CleanState holds lightweight CSV snapshots used only for WebSocket state()
15
+ responses — not for every step. This avoids serialising a 400-row DataFrame
16
+ on every call.
17
+
18
+ Action dispatch
19
+ ---------------
20
+ Each CleanAction.command routes to a private _apply_* method that mutates
21
+ _dirty_df in place. Errors in those methods (bad column name, out-of-bounds
22
+ row) are caught and returned as (success=False, error_msg=...) so the agent
23
+ gets corrective feedback instead of a 500.
24
+
25
+ Reward
26
+ ------
27
+ compute_reward() implements the dense reward formula designed in the plan:
28
+ progress term — grader score delta (main signal)
29
+ efficiency bonus — small reward for early completion
30
+ false-positive penalty — for dropping a valid-extreme row (medium task)
31
+ early-DONE penalty — for calling DONE with a low score
32
+ step cost — -0.005 every step to discourage padding
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import sys
38
+ import os
39
+ from typing import Any, Optional
40
+ from uuid import uuid4
41
+
42
+ import numpy as np
43
+ import pandas as pd
44
+
45
+ # ── OpenEnv imports (try relative → absolute) ─────────────────────────────────
46
+ try:
47
+ from openenv.core.env_server.interfaces import Environment
48
+ from openenv.core.env_server.types import EnvironmentMetadata
49
+ except ImportError:
50
+ from openenv.core.env_server.interfaces import Environment
51
+ from openenv.core.env_server.types import EnvironmentMetadata
52
+
53
+ # ── Local imports (try relative → absolute for both server and standalone) ───
54
+ try:
55
+ from ..models import (
56
+ CleanAction, CleanObservation, CleanState,
57
+ MAX_STEPS, DONE_THRESHOLD,
58
+ )
59
+ from ..dataset_factory import make_dataset, TaskDataset
60
+ from ..graders import grade, GradeResult
61
+ except ImportError:
62
+ try:
63
+ from models import (
64
+ CleanAction, CleanObservation, CleanState,
65
+ MAX_STEPS, DONE_THRESHOLD,
66
+ )
67
+ from dataset_factory import make_dataset, TaskDataset
68
+ from graders import grade, GradeResult
69
+ except ImportError:
70
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
71
+ from models import (
72
+ CleanAction, CleanObservation, CleanState,
73
+ MAX_STEPS, DONE_THRESHOLD,
74
+ )
75
+ from dataset_factory import make_dataset, TaskDataset
76
+ from graders import grade, GradeResult
77
+
78
+
79
+ # ── Constants ─────────────────────────────────────────────────────────────────
80
+
81
+ # Per-step cost that discourages infinite loops / padding
82
+ STEP_COST = -0.005
83
+
84
+ # Penalty for calling DONE before the score is reasonable
85
+ EARLY_DONE_PENALTY = -0.20
86
+ EARLY_DONE_THRESHOLD = 0.60 # DONE below this score triggers the penalty
87
+
88
+ # Penalty for removing a valid-extreme row in the medium task
89
+ FALSE_POSITIVE_PENALTY = -0.15
90
+
91
+ # Efficiency bonus multiplier (only awarded when episode is solved)
92
+ EFFICIENCY_BONUS_WEIGHT = 0.10
93
+
94
+ # Date formats the STANDARDIZE_COL handler will try, in priority order
95
+ _DATE_PARSE_FORMATS = [
96
+ "%Y-%m-%d", # ISO — most reliable, try first
97
+ "%m/%d/%Y", # US
98
+ "%d.%m.%Y", # EU
99
+ "%d/%m/%Y", # EU alt
100
+ "%Y/%m/%d", # Asian
101
+ ]
102
+
103
+
104
+ # ─────────────────────────────────────────────────────────────────────────────
105
+ # DataCleaningEnvironment
106
+ # ─────────────────────────────────────────────────────────────────────────────
107
+
108
+ class DataCleaningEnvironment(Environment):
109
+ """
110
+ Gym-style environment for the data cleaning pipeline task.
111
+
112
+ Each episode:
113
+ 1. reset(task_id="easy"|"medium"|"hard") loads a dirty/clean CSV pair.
114
+ 2. The agent calls step() repeatedly, each time sending a CleanAction.
115
+ 3. The episode ends when the agent sends DONE, the score crosses the
116
+ task threshold, or the step budget is exhausted.
117
+
118
+ The environment is fully stateless between sessions — all mutable state
119
+ lives in instance variables, so concurrent sessions each get their own
120
+ isolated copy (SUPPORTS_CONCURRENT_SESSIONS = True).
121
+ """
122
+
123
+ SUPPORTS_CONCURRENT_SESSIONS = True
124
+
125
+ def __init__(self) -> None:
126
+ super().__init__()
127
+
128
+ # Live DataFrames — mutated by each step()
129
+ self._dirty_df: Optional[pd.DataFrame] = None
130
+ self._clean_df: Optional[pd.DataFrame] = None
131
+
132
+ # Full task dataset from dataset_factory (holds metadata for grader)
133
+ self._dataset: Optional[TaskDataset] = None
134
+
135
+ # Pydantic state (lightweight; updated on demand)
136
+ self._state: Optional[CleanState] = None
137
+
138
+ # ─────────────────────────────────────────────────────────────────────────
139
+ # reset()
140
+ # ─────────────────────────────────────────────────────────────────────────
141
+
142
+ def reset(
143
+ self,
144
+ seed: Optional[int] = None,
145
+ episode_id: Optional[str] = None,
146
+ task_id: str = "easy",
147
+ **kwargs: Any,
148
+ ) -> CleanObservation:
149
+ """
150
+ Reset the environment for a new episode.
151
+
152
+ Parameters
153
+ ----------
154
+ seed
155
+ Ignored — datasets use fixed seeds per task for reproducibility.
156
+ episode_id
157
+ Optional; auto-generated if not provided.
158
+ task_id
159
+ Which task to load: "easy", "medium", or "hard".
160
+ """
161
+ if task_id not in MAX_STEPS:
162
+ raise ValueError(
163
+ f"Unknown task_id {task_id!r}. Must be one of: {list(MAX_STEPS)}"
164
+ )
165
+
166
+ # Load dataset (always deterministic via fixed seed in dataset_factory)
167
+ self._dataset = make_dataset(task_id)
168
+ self._dirty_df = self._dataset.dirty_df.copy(deep=True)
169
+ self._clean_df = self._dataset.clean_df.copy(deep=True)
170
+
171
+ max_steps = MAX_STEPS[task_id]
172
+
173
+ # Run grader on the initial dirty state so we have a starting score
174
+ initial_result = grade(
175
+ task_id=task_id,
176
+ agent_df=self._dirty_df,
177
+ clean_df=self._clean_df,
178
+ metadata=self._dataset.metadata,
179
+ initial_dirty_cells=self._dataset.total_dirty_cells,
180
+ )
181
+
182
+ self._state = CleanState(
183
+ episode_id=episode_id or str(uuid4()),
184
+ step_count=0,
185
+ task_id=task_id,
186
+ dirty_csv_snapshot=self._df_to_csv(self._dirty_df),
187
+ clean_csv_snapshot=self._df_to_csv(self._clean_df),
188
+ initial_dirty_cells=self._dataset.total_dirty_cells,
189
+ current_score=initial_result.score,
190
+ previous_score=0.0,
191
+ task_metadata=self._dataset.metadata,
192
+ schema_hint=self._dataset.schema_hint,
193
+ max_steps=max_steps,
194
+ )
195
+
196
+ return self._build_observation(
197
+ reward=None,
198
+ done=False,
199
+ last_action_success=True,
200
+ last_action_error=None,
201
+ grader_result=initial_result,
202
+ )
203
+
204
+ # ─────────────────────────────────────────────────────────────────────────
205
+ # step()
206
+ # ─────────────────────────────────────────────────────────────────────────
207
+
208
+ def step(
209
+ self,
210
+ action: CleanAction,
211
+ timeout_s: Optional[float] = None,
212
+ **kwargs: Any,
213
+ ) -> CleanObservation:
214
+ """
215
+ Apply one CleanAction and return the resulting observation.
216
+
217
+ Never raises for bad action inputs — instead returns
218
+ last_action_success=False with a descriptive error message so the
219
+ agent can self-correct on the next step.
220
+ """
221
+ if self._state is None or self._dirty_df is None:
222
+ raise RuntimeError("Environment not initialised. Call reset() first.")
223
+
224
+ self._state.step_count += 1
225
+
226
+ # ── Save previous score before mutating ──────────────────────────────
227
+ prev_score = self._state.current_score
228
+ self._state.previous_score = prev_score
229
+
230
+ # ── DONE shortcut ────────────────────────────────────────────────────
231
+ if action.command == "DONE":
232
+ reward = self._compute_reward(
233
+ action=action,
234
+ prev_score=prev_score,
235
+ curr_score=prev_score, # score doesn't change on DONE
236
+ action_success=True,
237
+ was_false_positive=False,
238
+ )
239
+ done = True
240
+ self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
241
+ return self._build_observation(
242
+ reward=reward,
243
+ done=done,
244
+ last_action_success=True,
245
+ last_action_error=None,
246
+ grader_result=GradeResult(
247
+ score=prev_score,
248
+ issues_remaining=self._state.initial_dirty_cells
249
+ - int(prev_score * self._state.initial_dirty_cells),
250
+ detail="Agent signalled DONE.",
251
+ ),
252
+ )
253
+
254
+ # ── Apply action to _dirty_df ────────────────────────────────────────
255
+ action_success, error_msg, was_false_positive = self._apply_action(action)
256
+
257
+ # ── Grade the result ──────────────────────────────────────────────────
258
+ grader_result = grade(
259
+ task_id=self._state.task_id,
260
+ agent_df=self._dirty_df,
261
+ clean_df=self._clean_df,
262
+ metadata=self._state.task_metadata,
263
+ initial_dirty_cells=self._state.initial_dirty_cells,
264
+ )
265
+ curr_score = grader_result.score
266
+ self._state.current_score = curr_score
267
+
268
+ # ── Compute reward ────────────────────────────────────────────────────
269
+ reward = self._compute_reward(
270
+ action=action,
271
+ prev_score=prev_score,
272
+ curr_score=curr_score,
273
+ action_success=action_success,
274
+ was_false_positive=was_false_positive,
275
+ )
276
+
277
+ # ── Check termination ────────────────────────────────────────────────
278
+ done = (
279
+ curr_score >= DONE_THRESHOLD[self._state.task_id]
280
+ or self._state.step_count >= self._state.max_steps
281
+ )
282
+
283
+ # ── Sync state snapshot ──────────────────────────────────────────────
284
+ self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
285
+
286
+ return self._build_observation(
287
+ reward=reward,
288
+ done=done,
289
+ last_action_success=action_success,
290
+ last_action_error=error_msg,
291
+ grader_result=grader_result,
292
+ )
293
+
294
+ # ─────────────────────────────────────────────────────────────────────────
295
+ # state (property)
296
+ # ─────────────────────────────────────────────────────────────────────────
297
+
298
+ @property
299
+ def state(self) -> CleanState:
300
+ """Return the current environment state (serialisable snapshot)."""
301
+ if self._state is None:
302
+ raise RuntimeError("Environment not initialised. Call reset() first.")
303
+ # Keep snapshot fresh in case step() was called without triggering a sync
304
+ if self._dirty_df is not None:
305
+ self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
306
+ return self._state
307
+
308
+ # ─────────────────────────────────────────────────────────────────────────
309
+ # Action dispatch
310
+ # ─────────────────────────────────────────────────────────────────────────
311
+
312
+ def _apply_action(
313
+ self, action: CleanAction
314
+ ) -> tuple[bool, Optional[str], bool]:
315
+ """
316
+ Mutate self._dirty_df according to the action.
317
+
318
+ Returns
319
+ -------
320
+ (success, error_msg, was_false_positive)
321
+ success — True if action applied without error
322
+ error_msg — human-readable description if success=False
323
+ was_false_positive — True if a DROP_ROW removed a valid-extreme row
324
+ """
325
+ cmd = action.command
326
+
327
+ if cmd == "SET_VALUE":
328
+ return self._apply_set_value(action)
329
+
330
+ elif cmd == "DROP_ROW":
331
+ return self._apply_drop_row(action)
332
+
333
+ elif cmd == "STANDARDIZE_COL":
334
+ return self._apply_standardize_col(action)
335
+
336
+ elif cmd == "FILL_MISSING":
337
+ return self._apply_fill_missing(action)
338
+
339
+ else:
340
+ return False, f"Unknown command: {cmd!r}", False
341
+
342
+ # ── SET_VALUE ─────────────────────────────────────────────────────────────
343
+
344
+ def _apply_set_value(
345
+ self, action: CleanAction
346
+ ) -> tuple[bool, Optional[str], bool]:
347
+ df = self._dirty_df
348
+ row_idx = action.row_index
349
+ col = action.column
350
+ val = action.value
351
+
352
+ # Validate column
353
+ if col not in df.columns:
354
+ return (
355
+ False,
356
+ f"Column {col!r} not found. Available: {list(df.columns)}",
357
+ False,
358
+ )
359
+
360
+ # Validate row index (positional)
361
+ if row_idx < 0 or row_idx >= len(df):
362
+ return (
363
+ False,
364
+ f"Row index {row_idx} out of range. DataFrame has {len(df)} rows (0–{len(df)-1}).",
365
+ False,
366
+ )
367
+
368
+ # Try to cast value to the column's expected type
369
+ cast_val, cast_err = self._cast_value(val, df, col)
370
+ if cast_err:
371
+ return False, cast_err, False
372
+
373
+ df.iloc[row_idx, df.columns.get_loc(col)] = cast_val
374
+ return True, None, False
375
+
376
+ # ── DROP_ROW ──────────────────────────────────────────────────────────────
377
+
378
+ def _apply_drop_row(
379
+ self, action: CleanAction
380
+ ) -> tuple[bool, Optional[str], bool]:
381
+ df = self._dirty_df
382
+ row_idx = action.row_index
383
+
384
+ if row_idx < 0 or row_idx >= len(df):
385
+ return (
386
+ False,
387
+ f"Row index {row_idx} out of range. DataFrame has {len(df)} rows.",
388
+ False,
389
+ )
390
+
391
+ # Detect false positive for medium task: is this a valid-extreme row?
392
+ was_false_positive = self._is_valid_extreme_row(row_idx)
393
+
394
+ # Drop the row and reset positional index so future iloc references stay valid
395
+ self._dirty_df = df.drop(df.index[row_idx]).reset_index(drop=True)
396
+ return True, None, was_false_positive
397
+
398
+ def _is_valid_extreme_row(self, iloc_idx: int) -> bool:
399
+ """
400
+ Return True if dropping this row would be a false positive.
401
+ Only applies to the medium task, which tracks valid_extreme_rows
402
+ by their original tx_id.
403
+ """
404
+ if self._state is None or self._state.task_id != "medium":
405
+ return False
406
+
407
+ valid_extreme_rows: list = self._state.task_metadata.get(
408
+ "valid_extreme_rows", []
409
+ )
410
+ if not valid_extreme_rows or self._clean_df is None:
411
+ return False
412
+
413
+ df = self._dirty_df
414
+ if "tx_id" not in df.columns:
415
+ return False
416
+
417
+ # Get the tx_id of the row being dropped
418
+ try:
419
+ tx_id_to_drop = int(df.iloc[iloc_idx]["tx_id"])
420
+ except (IndexError, ValueError, KeyError):
421
+ return False
422
+
423
+ # Check if any valid-extreme row in clean_df has this tx_id
424
+ for orig_idx in valid_extreme_rows:
425
+ if orig_idx >= len(self._clean_df):
426
+ continue
427
+ if int(self._clean_df.iloc[orig_idx]["tx_id"]) == tx_id_to_drop:
428
+ return True
429
+
430
+ return False
431
+
432
+ # ── STANDARDIZE_COL ───────────────────────────────────────────────────────
433
+
434
+ def _apply_standardize_col(
435
+ self, action: CleanAction
436
+ ) -> tuple[bool, Optional[str], bool]:
437
+ df = self._dirty_df
438
+ col = action.column
439
+
440
+ if col not in df.columns:
441
+ return (
442
+ False,
443
+ f"Column {col!r} not found. Available: {list(df.columns)}",
444
+ False,
445
+ )
446
+
447
+ series = df[col].copy()
448
+
449
+ # ── Try date normalisation first ──────────────────────────────────────
450
+ if self._looks_like_date_column(col, series):
451
+ normalised, err = self._normalise_dates(series)
452
+ if err:
453
+ return False, f"Date normalisation failed for column {col!r}: {err}", False
454
+ self._dirty_df[col] = normalised
455
+ return True, None, False
456
+
457
+ # ── Try numeric coercion ──────────────────────────────────────────────
458
+ if self._looks_like_numeric_column(col, series):
459
+ numeric = pd.to_numeric(series, errors="coerce")
460
+ # Only apply if we didn't lose more than 20% of non-null values
461
+ original_non_null = series.notna().sum()
462
+ coerced_non_null = numeric.notna().sum()
463
+ if original_non_null == 0 or coerced_non_null / original_non_null >= 0.8:
464
+ self._dirty_df[col] = numeric
465
+ return True, None, False
466
+
467
+ # ── String normalisation: strip whitespace ───────────────────────────
468
+ self._dirty_df[col] = series.apply(
469
+ lambda x: str(x).strip() if not _is_nan(x) else x
470
+ )
471
+ return True, None, False
472
+
473
+ def _looks_like_date_column(self, col: str, series: pd.Series) -> bool:
474
+ """Heuristic: column name contains 'date' or most non-null values parse as dates."""
475
+ if "date" in col.lower():
476
+ return True
477
+ sample = series.dropna().astype(str).head(5)
478
+ parsed = 0
479
+ for s in sample:
480
+ for fmt in _DATE_PARSE_FORMATS:
481
+ try:
482
+ pd.to_datetime(s, format=fmt)
483
+ parsed += 1
484
+ break
485
+ except Exception:
486
+ pass
487
+ return parsed >= max(1, len(sample) // 2)
488
+
489
+ def _looks_like_numeric_column(self, col: str, series: pd.Series) -> bool:
490
+ """Heuristic: column name or majority of values suggests numeric data."""
491
+ numeric_keywords = {"price", "amount", "value", "quantity", "qty", "count", "id", "num"}
492
+ if any(kw in col.lower() for kw in numeric_keywords):
493
+ return True
494
+ sample = series.dropna().head(10)
495
+ if len(sample) == 0:
496
+ return False
497
+ convertible = pd.to_numeric(sample, errors="coerce").notna().sum()
498
+ return convertible / len(sample) >= 0.7
499
+
500
+ def _normalise_dates(self, series: pd.Series) -> tuple[pd.Series, Optional[str]]:
501
+ """Parse dates in any supported format and reformat as YYYY-MM-DD."""
502
+ def _parse_one(x: Any) -> Any:
503
+ if _is_nan(x):
504
+ return x
505
+ s = str(x).strip()
506
+ for fmt in _DATE_PARSE_FORMATS:
507
+ try:
508
+ return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d")
509
+ except Exception:
510
+ pass
511
+ # Last resort: let pandas guess
512
+ try:
513
+ parsed = pd.to_datetime(s, dayfirst=False)
514
+ if 2000 <= parsed.year <= 2030:
515
+ return parsed.strftime("%Y-%m-%d")
516
+ except Exception:
517
+ pass
518
+ return x # leave unchanged if unparseable
519
+
520
+ return series.apply(_parse_one), None
521
+
522
+ # ── FILL_MISSING ──────────────────────────────────────────────────────────
523
+
524
+ def _apply_fill_missing(
525
+ self, action: CleanAction
526
+ ) -> tuple[bool, Optional[str], bool]:
527
+ df = self._dirty_df
528
+ col = action.column
529
+ strategy = action.fill_strategy
530
+
531
+ if col not in df.columns:
532
+ return (
533
+ False,
534
+ f"Column {col!r} not found. Available: {list(df.columns)}",
535
+ False,
536
+ )
537
+
538
+ series = df[col].copy()
539
+ numeric = pd.to_numeric(series, errors="coerce")
540
+ has_numeric = numeric.notna().sum() > 0
541
+
542
+ if strategy == "mean":
543
+ if not has_numeric:
544
+ return False, f"Cannot compute mean for non-numeric column {col!r}.", False
545
+ fill_val = numeric.mean()
546
+ self._dirty_df[col] = numeric.fillna(round(fill_val, 2))
547
+
548
+ elif strategy == "median":
549
+ if not has_numeric:
550
+ return False, f"Cannot compute median for non-numeric column {col!r}.", False
551
+ fill_val = numeric.median()
552
+ self._dirty_df[col] = numeric.fillna(round(fill_val, 2))
553
+
554
+ elif strategy == "mode":
555
+ mode_result = series.mode(dropna=True)
556
+ if mode_result.empty:
557
+ return False, f"No mode found for column {col!r} (all values missing?).", False
558
+ self._dirty_df[col] = series.fillna(mode_result.iloc[0])
559
+
560
+ elif strategy == "drop":
561
+ before = len(self._dirty_df)
562
+ self._dirty_df = self._dirty_df.dropna(subset=[col]).reset_index(drop=True)
563
+ after = len(self._dirty_df)
564
+ return True, None, False
565
+
566
+ else:
567
+ return False, f"Unknown fill_strategy: {strategy!r}", False
568
+
569
+ return True, None, False
570
+
571
+ # ─────────────────────────────────────────────────────────────────────────
572
+ # Reward computation
573
+ # ─────────────────────────────────────────────────────────────────────────
574
+
575
+ def _compute_reward(
576
+ self,
577
+ action: CleanAction,
578
+ prev_score: float,
579
+ curr_score: float,
580
+ action_success: bool,
581
+ was_false_positive: bool,
582
+ ) -> float:
583
+ """
584
+ Dense per-step reward in the range [-0.5, +1.0].
585
+
586
+ Components
587
+ ----------
588
+ progress score delta (main learning signal)
589
+ efficiency bonus small reward for solving with steps to spare
590
+ fp_penalty penalise removing a valid-extreme row (medium task)
591
+ early_done_penalty penalise calling DONE with a very low score
592
+ step_cost tiny constant cost to discourage padding
593
+ """
594
+ if self._state is None:
595
+ return 0.0
596
+
597
+ max_steps = self._state.max_steps
598
+ step_count = self._state.step_count
599
+
600
+ # 1. Progress term
601
+ progress = curr_score - prev_score
602
+
603
+ # 2. Efficiency bonus (only when task is solved this step)
604
+ threshold = DONE_THRESHOLD[self._state.task_id]
605
+ just_solved = prev_score < threshold <= curr_score
606
+ step_fraction = step_count / max_steps
607
+ efficiency = EFFICIENCY_BONUS_WEIGHT * (1.0 - step_fraction) if just_solved else 0.0
608
+
609
+ # 3. False-positive penalty
610
+ fp_penalty = FALSE_POSITIVE_PENALTY if was_false_positive else 0.0
611
+
612
+ # 4. Early-DONE penalty
613
+ early_done = (
614
+ EARLY_DONE_PENALTY
615
+ if action.command == "DONE" and curr_score < EARLY_DONE_THRESHOLD
616
+ else 0.0
617
+ )
618
+
619
+ # 5. Step cost
620
+ step_cost = STEP_COST
621
+
622
+ reward = progress + efficiency + fp_penalty + early_done + step_cost
623
+ return round(float(np.clip(reward, -0.5, 1.0)), 4)
624
+
625
+ # ─────────────────────────────────────────────────────────────────────────
626
+ # Observation builder
627
+ # ─────────────────────────────────────────────────────────────────────────
628
+
629
+ def _build_observation(
630
+ self,
631
+ reward: Optional[float],
632
+ done: bool,
633
+ last_action_success: bool,
634
+ last_action_error: Optional[str],
635
+ grader_result: GradeResult,
636
+ ) -> CleanObservation:
637
+ if self._state is None:
638
+ raise RuntimeError("State not initialised.")
639
+
640
+ return CleanObservation(
641
+ # Inherited from Observation base
642
+ done=done,
643
+ reward=reward,
644
+ # Task context
645
+ task_id=self._state.task_id,
646
+ schema_hint=self._state.schema_hint,
647
+ initial_dirty_cells=self._state.initial_dirty_cells,
648
+ # Per-step state
649
+ dirty_csv=self._df_to_csv(self._dirty_df),
650
+ current_score=grader_result.score,
651
+ issues_remaining=grader_result.issues_remaining,
652
+ step_number=self._state.step_count,
653
+ max_steps=self._state.max_steps,
654
+ # Last-action feedback
655
+ last_action_success=last_action_success,
656
+ last_action_error=last_action_error,
657
+ )
658
+
659
+ # ─────────────────────────────────────────────────────────────────────────
660
+ # Utilities
661
+ # ─────────────────────────────────────────────────────────────────────────
662
+
663
+ @staticmethod
664
+ def _df_to_csv(df: Optional[pd.DataFrame]) -> str:
665
+ """Serialise DataFrame to CSV string with the integer position index."""
666
+ if df is None:
667
+ return ""
668
+ return df.to_csv(index=True, index_label="row_index")
669
+
670
+ @staticmethod
671
+ def _cast_value(
672
+ val: str, df: pd.DataFrame, col: str
673
+ ) -> tuple[Any, Optional[str]]:
674
+ """
675
+ Try to cast a string value to the appropriate type for `col`.
676
+
677
+ Returns (cast_value, error_message). error_message is None on success.
678
+ """
679
+ # Determine target type from the clean (non-null, non-text) column values
680
+ sample = pd.to_numeric(
681
+ df[col].dropna().astype(str).str.strip(), errors="coerce"
682
+ )
683
+ majority_numeric = sample.notna().sum() / max(len(df[col].dropna()), 1) >= 0.5
684
+
685
+ if majority_numeric:
686
+ try:
687
+ float_val = float(val.strip().replace(",", ""))
688
+ # If all sample values are whole numbers, keep as int
689
+ if (sample.dropna() % 1 == 0).all() and float_val % 1 == 0:
690
+ return int(float_val), None
691
+ return round(float_val, 2), None
692
+ except (ValueError, AttributeError):
693
+ return (
694
+ None,
695
+ f"Cannot cast {val!r} to numeric for column {col!r}. "
696
+ f"Provide a plain number (e.g. '29.99').",
697
+ )
698
+
699
+ # String column — accept as-is (strip whitespace)
700
+ return val.strip(), None
701
+
702
+ # ──────────────────────────────────────────────────────────��──────────────
703
+ # Lifecycle
704
+ # ─────────────────────────────────────────────────────────────────────────
705
+
706
+ def close(self) -> None:
707
+ self._dirty_df = None
708
+ self._clean_df = None
709
+ self._dataset = None
710
+ self._state = None
711
+
712
+ def get_metadata(self) -> EnvironmentMetadata:
713
+ return EnvironmentMetadata(
714
+ name="data_cleaning_env",
715
+ description=(
716
+ "Data cleaning pipeline: the agent receives a dirty CSV "
717
+ "and must fix type errors, outliers, missing values, and "
718
+ "schema inconsistencies to match a hidden ground truth."
719
+ ),
720
+ version="1.0.0",
721
+ author="hackathon",
722
+ )
723
+
724
+
725
+ # ─────────────────────────────────────────────────────────────────────────────
726
+ # Helpers
727
+ # ─────────────────────────────────────────────────────────────────────────────
728
+
729
+ def _is_nan(x: Any) -> bool:
730
+ """Return True if x is any flavour of missing value."""
731
+ if x is None:
732
+ return True
733
+ try:
734
+ return bool(pd.isna(x))
735
+ except (TypeError, ValueError):
736
+ return False
737
+
738
+
739
+ # ─────────────────────────────────────────────────────────────────────────────
740
+ # Smoke test
741
+ # ─────────────────────────────────────────────────────────────────────────────
742
+
743
+ if __name__ == "__main__":
744
+ SEP = "─" * 64
745
+
746
+ for task_id in ("easy", "medium", "hard"):
747
+ print(f"\n{SEP}\nTASK: {task_id.upper()}\n{SEP}")
748
+
749
+ env = DataCleaningEnvironment()
750
+
751
+ # ── reset ────────────────────────────────────────────────────────────
752
+ obs = env.reset(task_id=task_id)
753
+ print(f"reset() → score={obs.current_score:.4f} "
754
+ f"issues={obs.issues_remaining} done={obs.done}")
755
+ assert obs.reward is None, "reward must be None after reset"
756
+ assert obs.done is False, "done must be False after reset"
757
+
758
+ lines = obs.dirty_csv.strip().split("\n")
759
+ print(f" CSV: {len(lines)} rows, {len(lines[0].split(','))} cols")
760
+ print(f" Hint: {obs.schema_hint[:70]}…")
761
+
762
+ # ── state() ──────────────────────────────────────────────────────────
763
+ st = env.state
764
+ print(f"state() → episode_id={st.episode_id[:8]}… step_count={st.step_count}")
765
+
766
+ # ── step: bad column (should give feedback, not crash) ───────────────
767
+ bad_action = CleanAction(
768
+ command="SET_VALUE", row_index=0, column="DOES_NOT_EXIST", value="0"
769
+ )
770
+ obs2 = env.step(bad_action)
771
+ assert obs2.last_action_success is False
772
+ print(f"step (bad col) → success={obs2.last_action_success} "
773
+ f"error='{obs2.last_action_error[:50]}…'")
774
+
775
+ # ── step: out-of-bounds row ──────────────────────────────────────────
776
+ bad_row = CleanAction(
777
+ command="SET_VALUE", row_index=9999, column="price", value="10.0"
778
+ )
779
+ obs3 = env.step(bad_row)
780
+ assert obs3.last_action_success is False
781
+ print(f"step (bad row) → success={obs3.last_action_success} "
782
+ f"error='{obs3.last_action_error[:50]}…'")
783
+
784
+ # ── step: valid fix ──────────────────────────────────────────────────
785
+ if task_id == "easy":
786
+ # Find the first injected dirty cell and fix it
787
+ injected = env._dataset.metadata.get("injected_cells", [])
788
+ if injected:
789
+ row, col = injected[0]
790
+ clean_val = str(env._clean_df.iloc[row][col])
791
+ fix_action = CleanAction(
792
+ command="SET_VALUE", row_index=row, column=col, value=clean_val
793
+ )
794
+ obs4 = env.step(fix_action)
795
+ print(f"step (fix row={row} col={col!r}) → "
796
+ f"success={obs4.last_action_success} "
797
+ f"score={obs4.current_score:.4f} "
798
+ f"reward={obs4.reward:.4f}")
799
+ assert obs4.last_action_success is True
800
+ assert obs4.reward is not None
801
+
802
+ elif task_id == "medium":
803
+ # Fix one outlier row via FILL_MISSING on amount
804
+ obs4 = env.step(CleanAction(
805
+ command="FILL_MISSING", column="amount", fill_strategy="median"
806
+ ))
807
+ print(f"step (FILL_MISSING amount/median) → "
808
+ f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}")
809
+
810
+ elif task_id == "hard":
811
+ # Standardize the date column
812
+ obs4 = env.step(CleanAction(
813
+ command="STANDARDIZE_COL", column="date"
814
+ ))
815
+ print(f"step (STANDARDIZE_COL date) → "
816
+ f"success={obs4.last_action_success} "
817
+ f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}")
818
+
819
+ # ── DONE action ───────────────────────────────────────────────────────
820
+ done_obs = env.step(CleanAction(command="DONE"))
821
+ assert done_obs.done is True
822
+ print(f"step (DONE) → done={done_obs.done} "
823
+ f"reward={done_obs.reward:.4f} score={done_obs.current_score:.4f}")
824
+
825
+ env.close()
826
+
827
+ print(f"\n{SEP}\nAll smoke tests passed.\n{SEP}")
server/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ pandas>=2.0.0
5
+ numpy>=2.0.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
validate-submission.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: https://docs.docker.com/get-docker/
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh https://my-team.hf.space
25
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d '{}' \
112
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
+
114
+ if [ "$HTTP_CODE" = "200" ]; then
115
+ pass "HF Space is live and responds to /reset"
116
+ elif [ "$HTTP_CODE" = "000" ]; then
117
+ fail "HF Space not reachable (connection failed or timed out)"
118
+ hint "Check your network connection and that the Space is running."
119
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
+ stop_at "Step 1"
121
+ else
122
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
+ hint "Make sure your Space is running and the URL is correct."
124
+ hint "Try opening $PING_URL in your browser first."
125
+ stop_at "Step 1"
126
+ fi
127
+
128
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
+
130
+ if ! command -v docker &>/dev/null; then
131
+ fail "docker command not found"
132
+ hint "Install Docker: https://docs.docker.com/get-docker/"
133
+ stop_at "Step 2"
134
+ fi
135
+
136
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
137
+ DOCKER_CONTEXT="$REPO_DIR"
138
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
140
+ else
141
+ fail "No Dockerfile found in repo root or server/ directory"
142
+ stop_at "Step 2"
143
+ fi
144
+
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
+
147
+ BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
+
150
+ if [ "$BUILD_OK" = true ]; then
151
+ pass "Docker build succeeded"
152
+ else
153
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
+ stop_at "Step 2"
156
+ fi
157
+
158
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
+
160
+ if ! command -v openenv &>/dev/null; then
161
+ fail "openenv command not found"
162
+ hint "Install it: pip install openenv-core"
163
+ stop_at "Step 3"
164
+ fi
165
+
166
+ VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
+
169
+ if [ "$VALIDATE_OK" = true ]; then
170
+ pass "openenv validate passed"
171
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
+ else
173
+ fail "openenv validate failed"
174
+ printf "%s\n" "$VALIDATE_OUTPUT"
175
+ stop_at "Step 3"
176
+ fi
177
+
178
+ printf "\n"
179
+ printf "${BOLD}========================================${NC}\n"
180
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
+ printf "${BOLD}========================================${NC}\n"
183
+ printf "\n"
184
+
185
+ exit 0