Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +250 -5
- __init__.py +16 -0
- client.py +319 -0
- dataset_factory.py +550 -0
- graders.py +686 -0
- inference.py +271 -0
- models.py +463 -0
- openenv.yaml +151 -0
- openenv_data_cleaning_env.egg-info/PKG-INFO +13 -0
- openenv_data_cleaning_env.egg-info/SOURCES.txt +21 -0
- openenv_data_cleaning_env.egg-info/dependency_links.txt +1 -0
- openenv_data_cleaning_env.egg-info/entry_points.txt +2 -0
- openenv_data_cleaning_env.egg-info/requires.txt +9 -0
- openenv_data_cleaning_env.egg-info/top_level.txt +1 -0
- pyproject.toml +48 -0
- server/__init__.py +11 -0
- server/app.py +25 -0
- server/data_cleaning_env.py +827 -0
- server/requirements.txt +5 -0
- uv.lock +0 -0
- validate-submission.sh +185 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=data_cleaning_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,255 @@
|
|
| 1 |
---
|
| 2 |
-
title: Data Cleaning Env
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Data Cleaning Env Environment Server
|
| 3 |
+
emoji: 🎹
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Data Cleaning Env Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the Data Cleaning Env environment is through the `DataCleaningEnv` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from data_cleaning_env import CleanAction, DataCleaningEnv
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
data_cleaning_envenv = DataCleaningEnv.from_docker_image("data_cleaning_env-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = data_cleaning_envenv.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = data_cleaning_envenv.step(CleanAction(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
data_cleaning_envenv.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `DataCleaningEnv.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t data_cleaning_env-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**CleanAction**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**CleanAction**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a Data Cleaning Env environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from data_cleaning_env import DataCleaningEnv
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
data_cleaning_envenv = DataCleaningEnv(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = data_cleaning_envenv.reset()
|
| 153 |
+
result = data_cleaning_envenv.step(CleanAction(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `data_cleaning_envenv.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from data_cleaning_env import CleanAction, DataCleaningEnv
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with DataCleaningEnv(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(CleanAction(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
DataCleaningEnvironment, # Pass class, not instance
|
| 189 |
+
CleanAction,
|
| 190 |
+
CleanAction,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from data_cleaning_env import CleanAction, DataCleaningEnv
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with DataCleaningEnv(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(CleanAction(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/data_cleaning_env_environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
data_cleaning_env/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # DataCleaningEnv client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── data_cleaning_env_environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Data Cleaning Env Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import DataCleaningEnv
|
| 10 |
+
from .models import CleanAction, CleanObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"CleanAction",
|
| 14 |
+
"CleanObservation",
|
| 15 |
+
"DataCleaningEnv",
|
| 16 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
client.py
|
| 3 |
+
---------
|
| 4 |
+
DataCleaningEnv — the typed WebSocket client for the data cleaning pipeline.
|
| 5 |
+
|
| 6 |
+
This module contains exactly one public class: ``DataCleaningEnv``.
|
| 7 |
+
It extends ``EnvClient`` from OpenEnv core and implements the three abstract
|
| 8 |
+
translation methods that bridge Python objects and the server's JSON wire format:
|
| 9 |
+
|
| 10 |
+
_step_payload(action) CleanAction → dict (outbound)
|
| 11 |
+
_parse_result(payload) dict → StepResult[CleanObservation] (inbound)
|
| 12 |
+
_parse_state(payload) dict → CleanState (inbound)
|
| 13 |
+
|
| 14 |
+
Everything else — WebSocket lifecycle, connect/disconnect, async context
|
| 15 |
+
manager, the `.sync()` wrapper — is handled by the base class.
|
| 16 |
+
|
| 17 |
+
Usage (async)
|
| 18 |
+
-------------
|
| 19 |
+
import asyncio
|
| 20 |
+
from data_cleaning_env.client import DataCleaningEnv
|
| 21 |
+
from data_cleaning_env.models import CleanAction
|
| 22 |
+
|
| 23 |
+
async def main():
|
| 24 |
+
async with DataCleaningEnv(base_url="http://localhost:8000") as env:
|
| 25 |
+
result = await env.reset(task_id="easy")
|
| 26 |
+
print(result.observation.schema_hint)
|
| 27 |
+
|
| 28 |
+
result = await env.set_value(row_index=3, column="price", value="29.99")
|
| 29 |
+
print(result.reward, result.observation.current_score)
|
| 30 |
+
|
| 31 |
+
result = await env.done()
|
| 32 |
+
|
| 33 |
+
asyncio.run(main())
|
| 34 |
+
|
| 35 |
+
Usage (sync wrapper)
|
| 36 |
+
--------------------
|
| 37 |
+
env = DataCleaningEnv(base_url="http://localhost:8000").sync()
|
| 38 |
+
with env:
|
| 39 |
+
result = env.reset(task_id="medium")
|
| 40 |
+
result = env.fill_missing(column="amount", fill_strategy="median")
|
| 41 |
+
result = env.done()
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
|
| 46 |
+
from typing import Any, Optional
|
| 47 |
+
|
| 48 |
+
# ── OpenEnv core imports ──────────────────────────────────────────────────────
|
| 49 |
+
try:
|
| 50 |
+
from openenv.core.client_types import StepResult
|
| 51 |
+
from openenv.core.env_client import EnvClient
|
| 52 |
+
except ImportError:
|
| 53 |
+
from openenv.core.client_types import StepResult # type: ignore[no-redef]
|
| 54 |
+
from openenv.core.env_client import EnvClient # type: ignore[no-redef]
|
| 55 |
+
|
| 56 |
+
# ── Local model imports (try relative then absolute) ──────────────────────────
|
| 57 |
+
try:
|
| 58 |
+
from .models import (
|
| 59 |
+
CleanAction,
|
| 60 |
+
CleanObservation,
|
| 61 |
+
CleanState,
|
| 62 |
+
MAX_STEPS,
|
| 63 |
+
DONE_THRESHOLD,
|
| 64 |
+
)
|
| 65 |
+
except ImportError:
|
| 66 |
+
from models import ( # type: ignore[no-redef]
|
| 67 |
+
CleanAction,
|
| 68 |
+
CleanObservation,
|
| 69 |
+
CleanState,
|
| 70 |
+
MAX_STEPS,
|
| 71 |
+
DONE_THRESHOLD,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class DataCleaningEnv(EnvClient[CleanAction, CleanObservation, CleanState]):
|
| 76 |
+
"""
|
| 77 |
+
Async WebSocket client for the Data Cleaning Pipeline environment.
|
| 78 |
+
|
| 79 |
+
Connects to a running ``DataCleaningEnvironment`` server and exposes the
|
| 80 |
+
standard OpenEnv interface (``reset``, ``step``, ``state``) plus typed
|
| 81 |
+
convenience helpers for each command.
|
| 82 |
+
|
| 83 |
+
All methods are async. For synchronous use, call ``.sync()`` to get a
|
| 84 |
+
``SyncEnvClient`` wrapper:
|
| 85 |
+
|
| 86 |
+
with DataCleaningEnv(base_url="http://localhost:8000").sync() as env:
|
| 87 |
+
result = env.reset(task_id="easy")
|
| 88 |
+
result = env.set_value(row_index=0, column="price", value="9.99")
|
| 89 |
+
|
| 90 |
+
Connecting to different backends
|
| 91 |
+
---------------------------------
|
| 92 |
+
Local dev server (after ``openenv serve``):
|
| 93 |
+
env = DataCleaningEnv(base_url="http://localhost:8000")
|
| 94 |
+
|
| 95 |
+
Local Docker image (after ``openenv build``):
|
| 96 |
+
env = await DataCleaningEnv.from_docker_image("data-cleaning-env:latest")
|
| 97 |
+
|
| 98 |
+
Hugging Face Space (after ``openenv push``):
|
| 99 |
+
env = await DataCleaningEnv.from_env("your-org/data-cleaning-env")
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 103 |
+
# Abstract method implementations — the three translation methods
|
| 104 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 105 |
+
|
| 106 |
+
def _step_payload(self, action: CleanAction) -> dict[str, Any]:
|
| 107 |
+
"""
|
| 108 |
+
Serialise a CleanAction to the JSON dict the server expects.
|
| 109 |
+
|
| 110 |
+
The server's ``step()`` endpoint receives this dict, validates it
|
| 111 |
+
against ``CleanAction``, and dispatches to the correct handler.
|
| 112 |
+
|
| 113 |
+
We use ``model_dump(exclude_none=True)`` to omit fields the agent
|
| 114 |
+
left as ``None`` — this keeps the wire message minimal and avoids
|
| 115 |
+
triggering Pydantic's ``extra="forbid"`` validator on the server side
|
| 116 |
+
for fields that weren't set.
|
| 117 |
+
"""
|
| 118 |
+
return action.model_dump(exclude_none=True)
|
| 119 |
+
|
| 120 |
+
def _parse_result(self, payload: dict[str, Any]) -> StepResult[CleanObservation]:
|
| 121 |
+
"""
|
| 122 |
+
Parse the server's step/reset response into a ``StepResult``.
|
| 123 |
+
|
| 124 |
+
Wire format (what the server sends back):
|
| 125 |
+
::
|
| 126 |
+
{
|
| 127 |
+
"observation": {
|
| 128 |
+
"done": false,
|
| 129 |
+
"reward": -0.005,
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"task_id": "easy",
|
| 132 |
+
"schema_hint": "Sales orders...",
|
| 133 |
+
"initial_dirty_cells": 29,
|
| 134 |
+
"dirty_csv": "row_index,order_id,...\\n0,1001,...",
|
| 135 |
+
"current_score": 0.9550,
|
| 136 |
+
"issues_remaining": 18,
|
| 137 |
+
"step_number": 1,
|
| 138 |
+
"max_steps": 40,
|
| 139 |
+
"last_action_success": true,
|
| 140 |
+
"last_action_error": null
|
| 141 |
+
},
|
| 142 |
+
"reward": -0.005,
|
| 143 |
+
"done": false
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
Note: ``reward`` and ``done`` appear both at the top level (for
|
| 147 |
+
convenience) and inside ``observation`` (because ``Observation`` base
|
| 148 |
+
carries them). We use the top-level copies for ``StepResult`` so the
|
| 149 |
+
caller doesn't have to dig into the observation.
|
| 150 |
+
"""
|
| 151 |
+
obs_data = payload.get("observation", {})
|
| 152 |
+
|
| 153 |
+
observation = CleanObservation(
|
| 154 |
+
# ── inherited from Observation base ──────────────────────────────
|
| 155 |
+
done=payload.get("done", obs_data.get("done", False)),
|
| 156 |
+
reward=payload.get("reward", obs_data.get("reward")),
|
| 157 |
+
metadata=obs_data.get("metadata", {}),
|
| 158 |
+
|
| 159 |
+
# ── task context (constant for the episode) ───────────────────────
|
| 160 |
+
task_id=obs_data["task_id"],
|
| 161 |
+
schema_hint=obs_data["schema_hint"],
|
| 162 |
+
initial_dirty_cells=obs_data["initial_dirty_cells"],
|
| 163 |
+
|
| 164 |
+
# ── per-step state ────────────────────────────────────────────────
|
| 165 |
+
dirty_csv=obs_data["dirty_csv"],
|
| 166 |
+
current_score=obs_data.get("current_score", 0.0),
|
| 167 |
+
issues_remaining=obs_data.get("issues_remaining", 0),
|
| 168 |
+
step_number=obs_data.get("step_number", 0),
|
| 169 |
+
max_steps=obs_data["max_steps"],
|
| 170 |
+
|
| 171 |
+
# ── last-action feedback ──────────────────────────────────────────
|
| 172 |
+
last_action_success=obs_data.get("last_action_success", True),
|
| 173 |
+
last_action_error=obs_data.get("last_action_error"),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return StepResult(
|
| 177 |
+
observation=observation,
|
| 178 |
+
reward=payload.get("reward"),
|
| 179 |
+
done=payload.get("done", False),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _parse_state(self, payload: dict[str, Any]) -> CleanState:
|
| 183 |
+
"""
|
| 184 |
+
Parse the server's state response into a ``CleanState``.
|
| 185 |
+
|
| 186 |
+
The server serialises ``CleanState`` via Pydantic's ``model_dump()``,
|
| 187 |
+
so the wire keys match our field names exactly. We use ``.get()``
|
| 188 |
+
with sensible defaults everywhere so a partially-initialised state
|
| 189 |
+
(e.g. before the first reset) doesn't crash the client.
|
| 190 |
+
"""
|
| 191 |
+
return CleanState(
|
| 192 |
+
# ── inherited from State base ─────────────────────────────────────
|
| 193 |
+
episode_id=payload.get("episode_id"),
|
| 194 |
+
step_count=payload.get("step_count", 0),
|
| 195 |
+
|
| 196 |
+
# ── task identity ─────────────────────────────────────────────────
|
| 197 |
+
task_id=payload.get("task_id", "easy"),
|
| 198 |
+
|
| 199 |
+
# ── DataFrame snapshots ───────────────────────────────────────────
|
| 200 |
+
dirty_csv_snapshot=payload.get("dirty_csv_snapshot", ""),
|
| 201 |
+
clean_csv_snapshot=payload.get("clean_csv_snapshot", ""),
|
| 202 |
+
|
| 203 |
+
# ── scoring ───────────────────────────────────────────────────────
|
| 204 |
+
initial_dirty_cells=payload.get("initial_dirty_cells", 0),
|
| 205 |
+
current_score=payload.get("current_score", 0.0),
|
| 206 |
+
previous_score=payload.get("previous_score", 0.0),
|
| 207 |
+
|
| 208 |
+
# ── grader metadata ───────────────────────────────────────────────
|
| 209 |
+
task_metadata=payload.get("task_metadata", {}),
|
| 210 |
+
|
| 211 |
+
# ── schema ────────────────────────────────────────────────────────
|
| 212 |
+
schema_hint=payload.get("schema_hint", ""),
|
| 213 |
+
|
| 214 |
+
# ── step budget ──────────────────────────────��────────────────────
|
| 215 |
+
max_steps=payload.get("max_steps", 40),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 219 |
+
# Typed convenience helpers — one per CleanAction command
|
| 220 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 221 |
+
# These methods exist purely for ergonomics: they let callers write
|
| 222 |
+
#
|
| 223 |
+
# await env.set_value(row_index=3, column="price", value="29.99")
|
| 224 |
+
#
|
| 225 |
+
# instead of the more verbose:
|
| 226 |
+
#
|
| 227 |
+
# await env.step(CleanAction(
|
| 228 |
+
# command="SET_VALUE", row_index=3, column="price", value="29.99"
|
| 229 |
+
# ))
|
| 230 |
+
#
|
| 231 |
+
# The baseline inference script can use either form.
|
| 232 |
+
|
| 233 |
+
async def set_value(
|
| 234 |
+
self,
|
| 235 |
+
row_index: int,
|
| 236 |
+
column: str,
|
| 237 |
+
value: str,
|
| 238 |
+
) -> StepResult[CleanObservation]:
|
| 239 |
+
"""Fix a single cell. ``value`` is always passed as a string; the
|
| 240 |
+
server casts it to the column's target dtype automatically."""
|
| 241 |
+
return await self.step(
|
| 242 |
+
CleanAction(
|
| 243 |
+
command="SET_VALUE",
|
| 244 |
+
row_index=row_index,
|
| 245 |
+
column=column,
|
| 246 |
+
value=value,
|
| 247 |
+
)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
async def drop_row(self, row_index: int) -> StepResult[CleanObservation]:
|
| 251 |
+
"""Remove an entire row (e.g. a true outlier in the medium task)."""
|
| 252 |
+
return await self.step(
|
| 253 |
+
CleanAction(command="DROP_ROW", row_index=row_index)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
async def standardize_col(self, column: str) -> StepResult[CleanObservation]:
|
| 257 |
+
"""Normalise a whole column's format.
|
| 258 |
+
|
| 259 |
+
The server auto-detects what to do:
|
| 260 |
+
- Date columns → parse any format, reformat as ``YYYY-MM-DD``
|
| 261 |
+
- Numeric columns → coerce to float/int, drop unit strings
|
| 262 |
+
- String columns → strip leading/trailing whitespace
|
| 263 |
+
"""
|
| 264 |
+
return await self.step(
|
| 265 |
+
CleanAction(command="STANDARDIZE_COL", column=column)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
async def fill_missing(
|
| 269 |
+
self,
|
| 270 |
+
column: str,
|
| 271 |
+
fill_strategy: str,
|
| 272 |
+
) -> StepResult[CleanObservation]:
|
| 273 |
+
"""Fill ``NaN`` values in ``column``.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
column: Column name to fill.
|
| 277 |
+
fill_strategy: One of ``"mean"``, ``"median"``, ``"mode"``, ``"drop"``.
|
| 278 |
+
``"drop"`` removes rows where the column is ``NaN``.
|
| 279 |
+
"""
|
| 280 |
+
return await self.step(
|
| 281 |
+
CleanAction(
|
| 282 |
+
command="FILL_MISSING",
|
| 283 |
+
column=column,
|
| 284 |
+
fill_strategy=fill_strategy,
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
async def done(self) -> StepResult[CleanObservation]:
|
| 289 |
+
"""Signal that the agent believes the CSV is clean.
|
| 290 |
+
|
| 291 |
+
This ends the episode immediately. If the current score is below
|
| 292 |
+
``EARLY_DONE_THRESHOLD`` (0.60) a penalty of -0.20 is applied.
|
| 293 |
+
"""
|
| 294 |
+
return await self.step(CleanAction(command="DONE"))
|
| 295 |
+
|
| 296 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 297 |
+
# Introspection helpers
|
| 298 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 299 |
+
|
| 300 |
+
async def current_score(self) -> float:
|
| 301 |
+
"""Return the grader score from the last step (0.0–1.0)."""
|
| 302 |
+
st = await self.state()
|
| 303 |
+
return st.current_score
|
| 304 |
+
|
| 305 |
+
async def task_id(self) -> str:
|
| 306 |
+
"""Return the active task ID (``"easy"``, ``"medium"``, or ``"hard"``)."""
|
| 307 |
+
st = await self.state()
|
| 308 |
+
return st.task_id
|
| 309 |
+
|
| 310 |
+
async def steps_remaining(self) -> int:
|
| 311 |
+
"""Return the number of steps left before forced termination."""
|
| 312 |
+
st = await self.state()
|
| 313 |
+
return max(0, st.max_steps - st.step_count)
|
| 314 |
+
|
| 315 |
+
async def is_solved(self) -> bool:
|
| 316 |
+
"""Return ``True`` if the current score meets the task's done threshold."""
|
| 317 |
+
st = await self.state()
|
| 318 |
+
threshold = DONE_THRESHOLD.get(st.task_id, 0.95)
|
| 319 |
+
return st.current_score >= threshold
|
dataset_factory.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset_factory.py
|
| 3 |
+
------------------
|
| 4 |
+
Generates (dirty_df, clean_df, metadata) triples for all 3 tasks.
|
| 5 |
+
|
| 6 |
+
Key design decisions:
|
| 7 |
+
- Fixed random seeds per task → reproducible grader scores
|
| 8 |
+
- clean_df is ALWAYS generated first, then dirt is injected
|
| 9 |
+
- metadata carries ground-truth info the grader needs (e.g. which
|
| 10 |
+
rows are real outliers vs valid extremes in Task 2)
|
| 11 |
+
- No external files needed — everything is generated in memory
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
import string
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
|
| 25 |
+
# ── Reproducible seeds ────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
SEEDS = {
|
| 28 |
+
"easy": 42,
|
| 29 |
+
"medium": 137,
|
| 30 |
+
"hard": 999,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# ── Return type ───────────────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class TaskDataset:
|
| 37 |
+
"""Everything the environment and grader need for one episode."""
|
| 38 |
+
task_id: str
|
| 39 |
+
dirty_df: pd.DataFrame
|
| 40 |
+
clean_df: pd.DataFrame
|
| 41 |
+
schema_hint: str # plain-English schema description
|
| 42 |
+
total_dirty_cells: int # how many cells differ at episode start
|
| 43 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 44 |
+
# metadata keys used by graders:
|
| 45 |
+
# "outlier_rows" (Task 2) — list of row indices that ARE true outliers
|
| 46 |
+
# "valid_extreme_rows" (Task 2) — valid rows that look extreme but must stay
|
| 47 |
+
# "canonical_columns" (Task 3) — {alias: canonical_name} mapping
|
| 48 |
+
# "duplicate_row_ids" (Task 3) — list of (original_idx, duplicate_idx) pairs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ── Public API ────────────────────────────────────────────────────────────────
|
| 52 |
+
|
| 53 |
+
def make_dataset(task_id: str) -> TaskDataset:
|
| 54 |
+
"""Entry point. Call this from the environment's reset()."""
|
| 55 |
+
if task_id == "easy":
|
| 56 |
+
return _make_easy()
|
| 57 |
+
elif task_id == "medium":
|
| 58 |
+
return _make_medium()
|
| 59 |
+
elif task_id == "hard":
|
| 60 |
+
return _make_hard()
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown task_id: {task_id!r}. Must be easy/medium/hard.")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def count_dirty_cells(dirty_df: pd.DataFrame, clean_df: pd.DataFrame) -> int:
|
| 66 |
+
"""Number of cells that differ between dirty and clean DataFrames."""
|
| 67 |
+
# Align on same dtypes for comparison
|
| 68 |
+
d = dirty_df.astype(str).reset_index(drop=True)
|
| 69 |
+
c = clean_df.astype(str).reset_index(drop=True)
|
| 70 |
+
return int((d != c).sum().sum())
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ── Task 1: easy ─────────────────────────────────────────────────────────────
|
| 74 |
+
#
|
| 75 |
+
# 50-row sales CSV.
|
| 76 |
+
# Clean schema:
|
| 77 |
+
# order_id (int), customer (str), product (str), category (str),
|
| 78 |
+
# price (float, 2dp), quantity (int), order_date (YYYY-MM-DD),
|
| 79 |
+
# region (str)
|
| 80 |
+
#
|
| 81 |
+
# Injected issues (29 dirty cells total):
|
| 82 |
+
# • 10 wrong-type cells — numeric column contains a word
|
| 83 |
+
# • 8 missing values — NaN in various columns
|
| 84 |
+
# • 5 bad dates — future year (2099-xx-xx)
|
| 85 |
+
# • 6 whitespace cells — leading/trailing spaces in string columns
|
| 86 |
+
|
| 87 |
+
def _make_easy() -> TaskDataset:
|
| 88 |
+
rng = random.Random(SEEDS["easy"])
|
| 89 |
+
np_rng = np.random.default_rng(SEEDS["easy"])
|
| 90 |
+
|
| 91 |
+
n = 50
|
| 92 |
+
categories = ["Electronics", "Clothing", "Home", "Sports", "Books"]
|
| 93 |
+
regions = ["North", "South", "East", "West"]
|
| 94 |
+
products = ["Widget A", "Widget B", "Gadget X", "Gadget Y", "Item Z"]
|
| 95 |
+
customers = [f"Customer_{i:03d}" for i in range(1, 31)]
|
| 96 |
+
|
| 97 |
+
# ── Build clean DataFrame ────────────────────────────────────────────────
|
| 98 |
+
clean = pd.DataFrame({
|
| 99 |
+
"order_id": range(1001, 1001 + n),
|
| 100 |
+
"customer": [rng.choice(customers) for _ in range(n)],
|
| 101 |
+
"product": [rng.choice(products) for _ in range(n)],
|
| 102 |
+
"category": [rng.choice(categories) for _ in range(n)],
|
| 103 |
+
"price": np_rng.uniform(5.0, 500.0, n).round(2),
|
| 104 |
+
"quantity": np_rng.integers(1, 20, n),
|
| 105 |
+
"order_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
|
| 106 |
+
"region": [rng.choice(regions) for _ in range(n)],
|
| 107 |
+
})
|
| 108 |
+
clean["price"] = clean["price"].astype(float)
|
| 109 |
+
clean["quantity"] = clean["quantity"].astype(int)
|
| 110 |
+
|
| 111 |
+
# ── Inject dirt ───────────────────────────────────────────────────���──────
|
| 112 |
+
dirty = clean.copy(deep=True).astype(object)
|
| 113 |
+
|
| 114 |
+
injected: set[tuple[int, str]] = set()
|
| 115 |
+
|
| 116 |
+
def pick_fresh(col: str, exclude: set) -> int:
|
| 117 |
+
rows = [r for r in range(n) if (r, col) not in exclude]
|
| 118 |
+
return rng.choice(rows)
|
| 119 |
+
|
| 120 |
+
# 10 wrong-type cells in numeric columns
|
| 121 |
+
bad_words = ["N/A", "unknown", "missing", "null", "TBD", "??", "-", "n/a", "none", "—"]
|
| 122 |
+
for word, col in zip(bad_words, rng.choices(["price", "quantity"], k=10)):
|
| 123 |
+
row = pick_fresh(col, injected)
|
| 124 |
+
dirty.at[row, col] = word
|
| 125 |
+
injected.add((row, col))
|
| 126 |
+
|
| 127 |
+
# 8 missing values in various columns
|
| 128 |
+
missing_cols = rng.choices(["customer", "product", "price", "quantity", "region"], k=8)
|
| 129 |
+
for col in missing_cols:
|
| 130 |
+
row = pick_fresh(col, injected)
|
| 131 |
+
dirty.at[row, col] = np.nan
|
| 132 |
+
injected.add((row, col))
|
| 133 |
+
|
| 134 |
+
# 5 bad dates — far-future year
|
| 135 |
+
bad_date_templates = [
|
| 136 |
+
"2099-01-15", "2099-07-04", "2099-12-31", "2099-03-22", "2099-11-11"
|
| 137 |
+
]
|
| 138 |
+
for bad_date in bad_date_templates:
|
| 139 |
+
row = pick_fresh("order_date", injected)
|
| 140 |
+
dirty.at[row, "order_date"] = bad_date
|
| 141 |
+
injected.add((row, "order_date"))
|
| 142 |
+
|
| 143 |
+
# 6 whitespace cells in string columns
|
| 144 |
+
ws_cols = rng.choices(["customer", "product", "category", "region"], k=6)
|
| 145 |
+
for col in ws_cols:
|
| 146 |
+
row = pick_fresh(col, injected)
|
| 147 |
+
orig = str(dirty.at[row, col])
|
| 148 |
+
dirty.at[row, col] = f" {orig} "
|
| 149 |
+
injected.add((row, col))
|
| 150 |
+
|
| 151 |
+
dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
|
| 152 |
+
|
| 153 |
+
schema_hint = (
|
| 154 |
+
"Sales orders dataset. Expected columns: "
|
| 155 |
+
"order_id (integer), customer (string, no leading/trailing spaces), "
|
| 156 |
+
"product (string, no spaces), category (one of: Electronics/Clothing/Home/Sports/Books), "
|
| 157 |
+
"price (float, 2 decimal places, no text), "
|
| 158 |
+
"quantity (integer, no text), "
|
| 159 |
+
"order_date (YYYY-MM-DD format, year must be 2023 or 2024), "
|
| 160 |
+
"region (one of: North/South/East/West, no spaces). "
|
| 161 |
+
"No missing values allowed."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return TaskDataset(
|
| 165 |
+
task_id="easy",
|
| 166 |
+
dirty_df=dirty,
|
| 167 |
+
clean_df=clean.astype(object),
|
| 168 |
+
schema_hint=schema_hint,
|
| 169 |
+
total_dirty_cells=dirty_cell_count,
|
| 170 |
+
metadata={"injected_cells": list(injected)},
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ── Task 2: medium ────────────────────────────────────────────────────────────
|
| 175 |
+
#
|
| 176 |
+
# 200-row customer transaction CSV.
|
| 177 |
+
# Clean schema:
|
| 178 |
+
# tx_id (int), customer_id (int), amount (float), tx_date (YYYY-MM-DD),
|
| 179 |
+
# category (str), country (str), status (str)
|
| 180 |
+
#
|
| 181 |
+
# Injected issues:
|
| 182 |
+
# • 15 statistical outliers — amount Z-score > 4.0 (should be removed/capped)
|
| 183 |
+
# • 5 valid extremes — genuinely large transactions, must NOT be removed
|
| 184 |
+
# • 12 category typos — slight misspellings
|
| 185 |
+
|
| 186 |
+
def _make_medium() -> TaskDataset:
|
| 187 |
+
rng = random.Random(SEEDS["medium"])
|
| 188 |
+
np_rng = np.random.default_rng(SEEDS["medium"])
|
| 189 |
+
|
| 190 |
+
n = 200
|
| 191 |
+
categories = ["Food", "Electronics", "Travel", "Healthcare", "Entertainment"]
|
| 192 |
+
countries = ["US", "UK", "CA", "AU", "DE"]
|
| 193 |
+
statuses = ["completed", "pending", "refunded"]
|
| 194 |
+
|
| 195 |
+
# ── Build clean base ────────────────────────────────────────────────────
|
| 196 |
+
# Normal transaction amounts: mean $150, sd $60, clipped to [5, 800]
|
| 197 |
+
amounts = np_rng.normal(150, 60, n).clip(5, 800).round(2)
|
| 198 |
+
|
| 199 |
+
clean = pd.DataFrame({
|
| 200 |
+
"tx_id": range(9001, 9001 + n),
|
| 201 |
+
"customer_id": np_rng.integers(1, 501, n),
|
| 202 |
+
"amount": amounts,
|
| 203 |
+
"tx_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
|
| 204 |
+
"category": [rng.choice(categories) for _ in range(n)],
|
| 205 |
+
"country": [rng.choice(countries) for _ in range(n)],
|
| 206 |
+
"status": [rng.choice(statuses) for _ in range(n)],
|
| 207 |
+
})
|
| 208 |
+
|
| 209 |
+
# ── Choose outlier rows (15) — will be injected with extreme amounts ─────
|
| 210 |
+
all_rows = list(range(n))
|
| 211 |
+
outlier_rows: list[int] = rng.sample(all_rows, 15)
|
| 212 |
+
remaining = [r for r in all_rows if r not in outlier_rows]
|
| 213 |
+
|
| 214 |
+
# ── Choose valid extreme rows (5) — large but legitimate ─────────────────
|
| 215 |
+
# These are NOT in outlier_rows; amounts are large (Z > 3) but real
|
| 216 |
+
valid_extreme_rows: list[int] = rng.sample(remaining, 5)
|
| 217 |
+
|
| 218 |
+
# ── Build dirty DataFrame ────────────────────────────────────────────────
|
| 219 |
+
dirty = clean.copy(deep=True).astype(object)
|
| 220 |
+
|
| 221 |
+
# Inject true outliers: very high or very low (Z > 4)
|
| 222 |
+
for row in outlier_rows:
|
| 223 |
+
if rng.random() > 0.3:
|
| 224 |
+
dirty.at[row, "amount"] = round(rng.uniform(5000, 15000), 2) # extreme high
|
| 225 |
+
else:
|
| 226 |
+
dirty.at[row, "amount"] = round(rng.uniform(-500, -10), 2) # negative (impossible)
|
| 227 |
+
|
| 228 |
+
# Inject valid extremes (in clean AND dirty — they stay)
|
| 229 |
+
for row in valid_extreme_rows:
|
| 230 |
+
valid_large = round(rng.uniform(900, 2000), 2)
|
| 231 |
+
clean.at[row, "amount"] = valid_large
|
| 232 |
+
dirty.at[row, "amount"] = valid_large
|
| 233 |
+
|
| 234 |
+
# Inject 12 category typos
|
| 235 |
+
typo_map: dict[str, str] = {
|
| 236 |
+
"Electronics": ["Electrnics", "Electronis", "Electonics"],
|
| 237 |
+
"Food": ["Foood", "Fod", "Fo0d"],
|
| 238 |
+
"Travel": ["Travle", "Trevel", "Travell"],
|
| 239 |
+
"Healthcare": ["Helthcare", "Healtcare", "Heathcare"],
|
| 240 |
+
"Entertainment": ["Entertainmnt", "Entertainmet", "Entertainmen"],
|
| 241 |
+
}
|
| 242 |
+
injected_typo_rows: set[int] = set()
|
| 243 |
+
typo_count = 0
|
| 244 |
+
typo_cells: list[tuple[int, str, str]] = [] # (row, dirty_val, clean_val)
|
| 245 |
+
|
| 246 |
+
for row in rng.sample(remaining, min(12, len(remaining))):
|
| 247 |
+
if typo_count >= 12:
|
| 248 |
+
break
|
| 249 |
+
if row in injected_typo_rows:
|
| 250 |
+
continue
|
| 251 |
+
orig_cat = str(clean.at[row, "category"])
|
| 252 |
+
misspellings = typo_map.get(orig_cat)
|
| 253 |
+
if misspellings:
|
| 254 |
+
bad = rng.choice(misspellings)
|
| 255 |
+
dirty.at[row, "category"] = bad
|
| 256 |
+
typo_cells.append((row, bad, orig_cat))
|
| 257 |
+
injected_typo_rows.add(row)
|
| 258 |
+
typo_count += 1
|
| 259 |
+
|
| 260 |
+
dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
|
| 261 |
+
|
| 262 |
+
schema_hint = (
|
| 263 |
+
"Customer transactions dataset. Expected columns: "
|
| 264 |
+
"tx_id (integer), customer_id (integer 1–500), "
|
| 265 |
+
"amount (float, must be positive; realistic range is $5–$2000; "
|
| 266 |
+
"amounts above $2000 or below $0 are data errors), "
|
| 267 |
+
"tx_date (YYYY-MM-DD), "
|
| 268 |
+
"category (one of: Food/Electronics/Travel/Healthcare/Entertainment — exact spelling), "
|
| 269 |
+
"country (two-letter code: US/UK/CA/AU/DE), "
|
| 270 |
+
"status (one of: completed/pending/refunded). "
|
| 271 |
+
"Note: some large transactions ($900–$2000) are legitimate — do not remove them. "
|
| 272 |
+
"Only remove rows where the amount is clearly erroneous (negative or > $2000)."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return TaskDataset(
|
| 276 |
+
task_id="medium",
|
| 277 |
+
dirty_df=dirty,
|
| 278 |
+
clean_df=clean.astype(object),
|
| 279 |
+
schema_hint=schema_hint,
|
| 280 |
+
total_dirty_cells=dirty_cell_count,
|
| 281 |
+
metadata={
|
| 282 |
+
"outlier_rows": outlier_rows,
|
| 283 |
+
"valid_extreme_rows": valid_extreme_rows,
|
| 284 |
+
"typo_cells": typo_cells, # [(row, dirty_val, clean_val)]
|
| 285 |
+
},
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ── Task 3: hard ──────────────────────────────────────────────────────────────
|
| 290 |
+
#
|
| 291 |
+
# 400-row CSV merged from 3 fictional data sources.
|
| 292 |
+
# Each source uses different column names for the same concepts.
|
| 293 |
+
# Issues:
|
| 294 |
+
# • Inconsistent column naming (3 aliases per concept)
|
| 295 |
+
# • Mixed date formats across sources (ISO, US, EU)
|
| 296 |
+
# • 30 duplicate rows (exact and near-duplicate)
|
| 297 |
+
# • No schema documentation — agent must infer canonical form
|
| 298 |
+
#
|
| 299 |
+
# Canonical schema (what the agent must produce):
|
| 300 |
+
# record_id, customer_id, full_name, email, amount,
|
| 301 |
+
# currency, purchase_date (YYYY-MM-DD), product_name, region
|
| 302 |
+
|
| 303 |
+
_CANONICAL_COLS = [
|
| 304 |
+
"record_id", "customer_id", "full_name", "email",
|
| 305 |
+
"amount", "currency", "purchase_date", "product_name", "region",
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
# Column aliases per source
|
| 309 |
+
_SOURCE_ALIASES = {
|
| 310 |
+
"source_a": {
|
| 311 |
+
"record_id": "record_id",
|
| 312 |
+
"customer_id": "cust_id",
|
| 313 |
+
"full_name": "name",
|
| 314 |
+
"email": "email_address",
|
| 315 |
+
"amount": "sale_amount",
|
| 316 |
+
"currency": "ccy",
|
| 317 |
+
"purchase_date":"date",
|
| 318 |
+
"product_name": "item",
|
| 319 |
+
"region": "territory",
|
| 320 |
+
},
|
| 321 |
+
"source_b": {
|
| 322 |
+
"record_id": "id",
|
| 323 |
+
"customer_id": "customer_id",
|
| 324 |
+
"full_name": "full_name",
|
| 325 |
+
"email": "contact_email",
|
| 326 |
+
"amount": "value",
|
| 327 |
+
"currency": "currency",
|
| 328 |
+
"purchase_date":"purchase_date",
|
| 329 |
+
"product_name": "product",
|
| 330 |
+
"region": "area",
|
| 331 |
+
},
|
| 332 |
+
"source_c": {
|
| 333 |
+
"record_id": "RecordID",
|
| 334 |
+
"customer_id": "CustomerID",
|
| 335 |
+
"full_name": "CustomerName",
|
| 336 |
+
"email": "Email",
|
| 337 |
+
"amount": "Amount",
|
| 338 |
+
"currency": "Currency",
|
| 339 |
+
"purchase_date":"PurchaseDate",
|
| 340 |
+
"product_name": "ProductName",
|
| 341 |
+
"region": "Region",
|
| 342 |
+
},
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
# Date format used by each source
|
| 346 |
+
_SOURCE_DATE_FORMATS = {
|
| 347 |
+
"source_a": "%Y-%m-%d", # ISO: 2023-04-15
|
| 348 |
+
"source_b": "%m/%d/%Y", # US: 04/15/2023
|
| 349 |
+
"source_c": "%d.%m.%Y", # EU: 15.04.2023
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
def _make_hard() -> TaskDataset:
|
| 353 |
+
rng = random.Random(SEEDS["hard"])
|
| 354 |
+
np_rng = np.random.default_rng(SEEDS["hard"])
|
| 355 |
+
|
| 356 |
+
currencies = ["USD", "EUR", "GBP"]
|
| 357 |
+
regions = ["APAC", "EMEA", "AMER", "LATAM"]
|
| 358 |
+
products = [
|
| 359 |
+
"Pro Subscription", "Enterprise License", "Support Package",
|
| 360 |
+
"Training Course", "Hardware Bundle", "Consulting Day",
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# Helper: generate a block of rows for one source
|
| 364 |
+
def _source_block(source: str, n: int, id_start: int) -> pd.DataFrame:
|
| 365 |
+
aliases = _SOURCE_ALIASES[source]
|
| 366 |
+
date_fmt = _SOURCE_DATE_FORMATS[source]
|
| 367 |
+
cust_ids = np_rng.integers(2001, 3001, n)
|
| 368 |
+
amounts = np_rng.uniform(100, 5000, n).round(2)
|
| 369 |
+
iso_dates = _random_dates(np_rng, n, "2022-01-01", "2024-06-30")
|
| 370 |
+
|
| 371 |
+
# Format dates in source-specific format
|
| 372 |
+
formatted_dates = [
|
| 373 |
+
pd.to_datetime(d).strftime(date_fmt)
|
| 374 |
+
for d in iso_dates
|
| 375 |
+
]
|
| 376 |
+
|
| 377 |
+
names = [_random_name(rng) for _ in range(n)]
|
| 378 |
+
emails = [_name_to_email(nm) for nm in names]
|
| 379 |
+
|
| 380 |
+
data = {
|
| 381 |
+
aliases["record_id"]: range(id_start, id_start + n),
|
| 382 |
+
aliases["customer_id"]: cust_ids.tolist(),
|
| 383 |
+
aliases["full_name"]: names,
|
| 384 |
+
aliases["email"]: emails,
|
| 385 |
+
aliases["amount"]: amounts.tolist(),
|
| 386 |
+
aliases["currency"]: [rng.choice(currencies) for _ in range(n)],
|
| 387 |
+
aliases["purchase_date"]: formatted_dates,
|
| 388 |
+
aliases["product_name"]: [rng.choice(products) for _ in range(n)],
|
| 389 |
+
aliases["region"]: [rng.choice(regions) for _ in range(n)],
|
| 390 |
+
}
|
| 391 |
+
return pd.DataFrame(data)
|
| 392 |
+
|
| 393 |
+
# Three sources, ~133 rows each (total ~400)
|
| 394 |
+
block_a = _source_block("source_a", 134, id_start=1)
|
| 395 |
+
block_b = _source_block("source_b", 133, id_start=135)
|
| 396 |
+
block_c = _source_block("source_c", 133, id_start=268)
|
| 397 |
+
|
| 398 |
+
# ── Canonical (clean) dataframe ─────────────────────────────────────────
|
| 399 |
+
def _to_canonical(df: pd.DataFrame, source: str) -> pd.DataFrame:
|
| 400 |
+
rev = {v: k for k, v in _SOURCE_ALIASES[source].items()}
|
| 401 |
+
renamed = df.rename(columns=rev)
|
| 402 |
+
# Normalise date to YYYY-MM-DD
|
| 403 |
+
renamed["purchase_date"] = pd.to_datetime(
|
| 404 |
+
renamed["purchase_date"],
|
| 405 |
+
format=_SOURCE_DATE_FORMATS[source],
|
| 406 |
+
).dt.strftime("%Y-%m-%d")
|
| 407 |
+
return renamed[_CANONICAL_COLS]
|
| 408 |
+
|
| 409 |
+
clean_a = _to_canonical(block_a, "source_a")
|
| 410 |
+
clean_b = _to_canonical(block_b, "source_b")
|
| 411 |
+
clean_c = _to_canonical(block_c, "source_c")
|
| 412 |
+
clean = pd.concat([clean_a, clean_b, clean_c], ignore_index=True)
|
| 413 |
+
clean["record_id"] = range(1, len(clean) + 1)
|
| 414 |
+
|
| 415 |
+
# ── Dirty dataframe = concat of raw source blocks ────────────────────────
|
| 416 |
+
# (columns are still in aliased form, dates in source-specific format)
|
| 417 |
+
dirty = pd.concat([block_a, block_b, block_c], ignore_index=True)
|
| 418 |
+
|
| 419 |
+
# ── Inject 30 duplicate rows ─────────────────────────────────────────────
|
| 420 |
+
n_clean = len(dirty)
|
| 421 |
+
sampled_orig = rng.sample(range(n_clean), 30)
|
| 422 |
+
duplicate_rows_to_inject: list[pd.DataFrame] = []
|
| 423 |
+
duplicate_pairs: list[tuple[int, int]] = []
|
| 424 |
+
|
| 425 |
+
for orig_idx in sampled_orig:
|
| 426 |
+
dup = dirty.iloc[[orig_idx]].copy()
|
| 427 |
+
# Near-duplicate: 40% chance of a minor field change
|
| 428 |
+
if rng.random() < 0.4:
|
| 429 |
+
# Slightly alter the amount (±1%)
|
| 430 |
+
col_amount = list(_SOURCE_ALIASES["source_a"].values())[4] # 'sale_amount'
|
| 431 |
+
# Find which column name is 'amount-like' in this row's source
|
| 432 |
+
# Since we concat all sources, each row might have NaN in other sources' cols.
|
| 433 |
+
# Simpler: just modify the raw value in the only non-null amount column.
|
| 434 |
+
for amt_col in ["sale_amount", "value", "Amount"]:
|
| 435 |
+
if amt_col in dup.columns and pd.notna(dup.iloc[0].get(amt_col)):
|
| 436 |
+
old_val = dup.at[dup.index[0], amt_col]
|
| 437 |
+
dup.at[dup.index[0], amt_col] = round(float(old_val) * rng.uniform(0.99, 1.01), 2)
|
| 438 |
+
break
|
| 439 |
+
duplicate_rows_to_inject.append(dup)
|
| 440 |
+
duplicate_pairs.append((orig_idx, n_clean + len(duplicate_pairs)))
|
| 441 |
+
|
| 442 |
+
dirty = pd.concat([dirty] + duplicate_rows_to_inject, ignore_index=True)
|
| 443 |
+
|
| 444 |
+
# Shuffle so duplicates aren't obviously at the bottom
|
| 445 |
+
dirty = dirty.sample(frac=1, random_state=SEEDS["hard"]).reset_index(drop=True)
|
| 446 |
+
|
| 447 |
+
# Build canonical alias lookup for grader
|
| 448 |
+
canonical_lookup: dict[str, str] = {}
|
| 449 |
+
for source, aliases in _SOURCE_ALIASES.items():
|
| 450 |
+
for canonical, alias in aliases.items():
|
| 451 |
+
canonical_lookup[alias] = canonical
|
| 452 |
+
|
| 453 |
+
dirty_cell_count = len(dirty) * len(_CANONICAL_COLS) # hard task: whole-df scope
|
| 454 |
+
|
| 455 |
+
schema_hint = (
|
| 456 |
+
"Merged dataset from 3 sources with inconsistent schemas. "
|
| 457 |
+
"Your goal is to produce a single clean DataFrame with these canonical columns: "
|
| 458 |
+
"record_id (integer, unique), customer_id (integer), full_name (string), "
|
| 459 |
+
"email (string), amount (float), currency (one of: USD/EUR/GBP), "
|
| 460 |
+
"purchase_date (YYYY-MM-DD), product_name (string), region (one of: APAC/EMEA/AMER/LATAM). "
|
| 461 |
+
"Column names in the raw data vary by source (e.g. 'cust_id', 'customer_id', 'CustomerID' "
|
| 462 |
+
"all mean customer_id). Date formats also vary (ISO, US MM/DD/YYYY, EU DD.MM.YYYY). "
|
| 463 |
+
"There are also ~30 duplicate rows (some exact, some near-duplicate). "
|
| 464 |
+
"Remove duplicates, normalise all column names and date formats."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
return TaskDataset(
|
| 468 |
+
task_id="hard",
|
| 469 |
+
dirty_df=dirty,
|
| 470 |
+
clean_df=clean.astype(object),
|
| 471 |
+
schema_hint=schema_hint,
|
| 472 |
+
total_dirty_cells=dirty_cell_count,
|
| 473 |
+
metadata={
|
| 474 |
+
"canonical_columns": _CANONICAL_COLS,
|
| 475 |
+
"canonical_lookup": canonical_lookup, # alias → canonical name
|
| 476 |
+
"source_aliases": _SOURCE_ALIASES,
|
| 477 |
+
"source_date_formats": _SOURCE_DATE_FORMATS,
|
| 478 |
+
"duplicate_pairs": duplicate_pairs, # (original_idx, dup_idx) in pre-shuffle dirty
|
| 479 |
+
"n_clean_rows": len(clean),
|
| 480 |
+
},
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
# ── Internal helpers ──────────────────────────────────────────────────────────
|
| 485 |
+
|
| 486 |
+
def _random_dates(
|
| 487 |
+
rng: np.random.Generator,
|
| 488 |
+
n: int,
|
| 489 |
+
start: str,
|
| 490 |
+
end: str,
|
| 491 |
+
) -> list[str]:
|
| 492 |
+
"""Generate n random ISO-format date strings between start and end."""
|
| 493 |
+
start_ts = pd.Timestamp(start)
|
| 494 |
+
end_ts = pd.Timestamp(end)
|
| 495 |
+
delta_days = (end_ts - start_ts).days
|
| 496 |
+
offsets = rng.integers(0, delta_days, n)
|
| 497 |
+
return [
|
| 498 |
+
(start_ts + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
|
| 499 |
+
for d in offsets
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
_FIRST_NAMES = [
|
| 504 |
+
"Alice", "Bob", "Carol", "David", "Eva", "Frank", "Grace", "Henry",
|
| 505 |
+
"Iris", "Jack", "Karen", "Leo", "Mia", "Nathan", "Olivia", "Paul",
|
| 506 |
+
"Quinn", "Rosa", "Sam", "Tara", "Uma", "Victor", "Wendy", "Xavier",
|
| 507 |
+
"Yuki", "Zara",
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
_LAST_NAMES = [
|
| 511 |
+
"Smith", "Jones", "Williams", "Brown", "Taylor", "Davies", "Evans",
|
| 512 |
+
"Wilson", "Thomas", "Roberts", "Johnson", "Lee", "Martin", "Garcia",
|
| 513 |
+
"Martinez", "Anderson", "Thompson", "White", "Harris", "Clark",
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _random_name(rng: random.Random) -> str:
|
| 518 |
+
return f"{rng.choice(_FIRST_NAMES)} {rng.choice(_LAST_NAMES)}"
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _name_to_email(name: str) -> str:
|
| 522 |
+
first, last = name.lower().split()
|
| 523 |
+
domains = ["example.com", "mail.com", "inbox.net", "corp.io"]
|
| 524 |
+
return f"{first}.{last}@{domains[hash(name) % len(domains)]}"
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# ── Smoke test ────────────────────────────────────────────────────────────────
|
| 528 |
+
|
| 529 |
+
if __name__ == "__main__":
|
| 530 |
+
for task_id in ("easy", "medium", "hard"):
|
| 531 |
+
ds = make_dataset(task_id)
|
| 532 |
+
print(f"\n{'─'*60}")
|
| 533 |
+
print(f"Task: {task_id.upper()}")
|
| 534 |
+
print(f" dirty shape : {ds.dirty_df.shape}")
|
| 535 |
+
print(f" clean shape : {ds.clean_df.shape}")
|
| 536 |
+
print(f" dirty cells : {ds.total_dirty_cells}")
|
| 537 |
+
print(f" schema hint : {ds.schema_hint[:80]}…")
|
| 538 |
+
print(f" metadata keys: {list(ds.metadata.keys())}")
|
| 539 |
+
if task_id == "easy":
|
| 540 |
+
print(f"\n Sample dirty rows (price/quantity col):")
|
| 541 |
+
mask = ds.dirty_df["price"].astype(str).str.contains(
|
| 542 |
+
r"[a-zA-Z]|nan", na=True
|
| 543 |
+
)
|
| 544 |
+
print(ds.dirty_df[mask][["order_id","price","quantity"]].head(3).to_string(index=False))
|
| 545 |
+
if task_id == "medium":
|
| 546 |
+
print(f"\n Outlier rows (first 5): {ds.metadata['outlier_rows'][:5]}")
|
| 547 |
+
print(f" Valid extreme rows: {ds.metadata['valid_extreme_rows']}")
|
| 548 |
+
if task_id == "hard":
|
| 549 |
+
print(f"\n Raw column names: {list(ds.dirty_df.columns)}")
|
| 550 |
+
print(f" Duplicate pairs (first 3): {ds.metadata['duplicate_pairs'][:3]}")
|
graders.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
graders.py
|
| 3 |
+
----------
|
| 4 |
+
Deterministic graders for all three tasks.
|
| 5 |
+
|
| 6 |
+
Each grader receives the agent's current working DataFrame and the
|
| 7 |
+
TaskDataset produced by dataset_factory, and returns a GradeResult
|
| 8 |
+
with a scalar score in [0.0, 1.0] plus a human-readable breakdown.
|
| 9 |
+
|
| 10 |
+
Public API
|
| 11 |
+
----------
|
| 12 |
+
grade(task_id, agent_df, dataset) -> GradeResult
|
| 13 |
+
|
| 14 |
+
Dispatches to the correct grader. Call this from step().
|
| 15 |
+
|
| 16 |
+
GradeResult
|
| 17 |
+
.score float 0.0–1.0 (the number that feeds the reward)
|
| 18 |
+
.breakdown dict (sub-scores, useful for logging/debugging)
|
| 19 |
+
.issues_remaining int (how many cells still need fixing)
|
| 20 |
+
.detail str (one-line human summary)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import re
|
| 26 |
+
from dataclasses import dataclass, field
|
| 27 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import pandas as pd
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 34 |
+
# Return type
|
| 35 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class GradeResult:
|
| 39 |
+
score: float # 0.0 – 1.0, fed into reward
|
| 40 |
+
breakdown: Dict[str, float] = field(default_factory=dict)
|
| 41 |
+
issues_remaining: int = 0
|
| 42 |
+
detail: str = ""
|
| 43 |
+
|
| 44 |
+
def __post_init__(self) -> None:
|
| 45 |
+
self.score = round(float(np.clip(self.score, 0.0, 1.0)), 4)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 49 |
+
# Public dispatcher
|
| 50 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 51 |
+
|
| 52 |
+
def grade(
|
| 53 |
+
task_id: str,
|
| 54 |
+
agent_df: pd.DataFrame,
|
| 55 |
+
clean_df: pd.DataFrame,
|
| 56 |
+
metadata: Dict[str, Any],
|
| 57 |
+
initial_dirty_cells: int,
|
| 58 |
+
) -> GradeResult:
|
| 59 |
+
"""
|
| 60 |
+
Route to the correct grader and return a GradeResult.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
task_id
|
| 65 |
+
One of "easy", "medium", "hard".
|
| 66 |
+
agent_df
|
| 67 |
+
The agent's current working DataFrame (may still be dirty).
|
| 68 |
+
clean_df
|
| 69 |
+
Ground-truth clean DataFrame from TaskDataset.
|
| 70 |
+
metadata
|
| 71 |
+
TaskDataset.metadata dict (grader-specific ground truth).
|
| 72 |
+
initial_dirty_cells
|
| 73 |
+
Dirty cell count at episode start; used to compute issues_remaining
|
| 74 |
+
for easy/medium tasks.
|
| 75 |
+
"""
|
| 76 |
+
if agent_df is None or len(agent_df) == 0:
|
| 77 |
+
return GradeResult(score=0.0, detail="Empty DataFrame — no score.")
|
| 78 |
+
|
| 79 |
+
if task_id == "easy":
|
| 80 |
+
return _grade_easy(agent_df, clean_df, metadata, initial_dirty_cells)
|
| 81 |
+
elif task_id == "medium":
|
| 82 |
+
return _grade_medium(agent_df, clean_df, metadata, initial_dirty_cells)
|
| 83 |
+
elif task_id == "hard":
|
| 84 |
+
return _grade_hard(agent_df, clean_df, metadata)
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unknown task_id: {task_id!r}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 90 |
+
# Task 1 — easy: cell-level match against ground truth
|
| 91 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 92 |
+
#
|
| 93 |
+
# Score = (cells matching ground truth) / (total cells)
|
| 94 |
+
#
|
| 95 |
+
# "Matching" is defined after normalisation:
|
| 96 |
+
# - strip leading/trailing whitespace
|
| 97 |
+
# - numeric columns: round to 2dp, compare as float strings
|
| 98 |
+
# - date column: accept YYYY-MM-DD only
|
| 99 |
+
# - string columns: case-sensitive exact match after strip
|
| 100 |
+
# - NaN vs NaN → always mismatch (agent must fill or fix them)
|
| 101 |
+
|
| 102 |
+
def _grade_easy(
|
| 103 |
+
agent_df: pd.DataFrame,
|
| 104 |
+
clean_df: pd.DataFrame,
|
| 105 |
+
metadata: Dict[str, Any],
|
| 106 |
+
initial_dirty_cells: int,
|
| 107 |
+
) -> GradeResult:
|
| 108 |
+
|
| 109 |
+
# Align shape — agent might have different row count if they accidentally
|
| 110 |
+
# dropped rows; penalise by treating missing rows as all-wrong.
|
| 111 |
+
agent_norm = _normalise_easy(agent_df, clean_df)
|
| 112 |
+
clean_norm = _normalise_easy(clean_df, clean_df)
|
| 113 |
+
|
| 114 |
+
total_cells = clean_norm.size
|
| 115 |
+
|
| 116 |
+
# Pad or truncate agent rows to match clean row count
|
| 117 |
+
if len(agent_norm) < len(clean_norm):
|
| 118 |
+
pad = pd.DataFrame(
|
| 119 |
+
[["__MISSING__"] * len(clean_norm.columns)] * (len(clean_norm) - len(agent_norm)),
|
| 120 |
+
columns=clean_norm.columns,
|
| 121 |
+
)
|
| 122 |
+
agent_norm = pd.concat([agent_norm, pad], ignore_index=True)
|
| 123 |
+
elif len(agent_norm) > len(clean_norm):
|
| 124 |
+
agent_norm = agent_norm.iloc[: len(clean_norm)].copy()
|
| 125 |
+
|
| 126 |
+
matches = (agent_norm == clean_norm).sum().sum()
|
| 127 |
+
score = matches / total_cells
|
| 128 |
+
|
| 129 |
+
# Issues remaining: number of cells that still differ
|
| 130 |
+
mismatches = int((agent_norm != clean_norm).sum().sum())
|
| 131 |
+
|
| 132 |
+
breakdown = {
|
| 133 |
+
"cell_match_ratio": round(score, 4),
|
| 134 |
+
"cells_matched": int(matches),
|
| 135 |
+
"total_cells": int(total_cells),
|
| 136 |
+
"cells_mismatched": mismatches,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
detail = (
|
| 140 |
+
f"{int(matches)}/{total_cells} cells correct "
|
| 141 |
+
f"({100*score:.1f}%) — {mismatches} still need fixing."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return GradeResult(
|
| 145 |
+
score=score,
|
| 146 |
+
breakdown=breakdown,
|
| 147 |
+
issues_remaining=mismatches,
|
| 148 |
+
detail=detail,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _normalise_easy(df: pd.DataFrame, clean_df: pd.DataFrame) -> pd.DataFrame:
|
| 153 |
+
"""
|
| 154 |
+
Bring a DataFrame to a canonical string form for cell-level comparison.
|
| 155 |
+
|
| 156 |
+
Rules applied per column based on clean_df's dtype:
|
| 157 |
+
- Numeric (price, quantity): round to 2 decimal places → string
|
| 158 |
+
- Date (order_date): parse and reformat to YYYY-MM-DD
|
| 159 |
+
- String (all others): strip whitespace, leave case unchanged
|
| 160 |
+
- NaN / unparseable: normalise to the sentinel "__NAN__"
|
| 161 |
+
"""
|
| 162 |
+
out = {}
|
| 163 |
+
NUMERIC_COLS = {"price", "quantity"}
|
| 164 |
+
DATE_COLS = {"order_date"}
|
| 165 |
+
|
| 166 |
+
for col in clean_df.columns:
|
| 167 |
+
if col not in df.columns:
|
| 168 |
+
# Agent removed or renamed the column — all cells wrong
|
| 169 |
+
out[col] = pd.Series(["__MISSING_COL__"] * len(df))
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
series = df[col].copy()
|
| 173 |
+
|
| 174 |
+
if col in NUMERIC_COLS:
|
| 175 |
+
out[col] = series.apply(_to_numeric_str)
|
| 176 |
+
elif col in DATE_COLS:
|
| 177 |
+
out[col] = series.apply(_to_date_str)
|
| 178 |
+
else:
|
| 179 |
+
out[col] = series.apply(
|
| 180 |
+
lambda x: "__NAN__" if _is_missing(x) else str(x).strip()
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return pd.DataFrame(out, dtype=str)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _to_numeric_str(x: Any) -> str:
|
| 187 |
+
if _is_missing(x):
|
| 188 |
+
return "__NAN__"
|
| 189 |
+
try:
|
| 190 |
+
return f"{float(str(x).strip().replace(',', '')):.2f}"
|
| 191 |
+
except (ValueError, TypeError):
|
| 192 |
+
return "__INVALID__"
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _to_date_str(x: Any) -> str:
|
| 196 |
+
if _is_missing(x):
|
| 197 |
+
return "__NAN__"
|
| 198 |
+
s = str(x).strip()
|
| 199 |
+
# Reject obviously wrong dates (e.g. year 2099)
|
| 200 |
+
try:
|
| 201 |
+
parsed = pd.to_datetime(s, dayfirst=False)
|
| 202 |
+
if parsed.year > 2030 or parsed.year < 2000:
|
| 203 |
+
return "__BAD_DATE__"
|
| 204 |
+
return parsed.strftime("%Y-%m-%d")
|
| 205 |
+
except Exception:
|
| 206 |
+
return "__INVALID_DATE__"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _is_missing(x: Any) -> bool:
|
| 210 |
+
if x is None:
|
| 211 |
+
return True
|
| 212 |
+
try:
|
| 213 |
+
return bool(pd.isna(x))
|
| 214 |
+
except (TypeError, ValueError):
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 219 |
+
# Task 2 — medium: F1 on outlier detection + typo correction
|
| 220 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 221 |
+
#
|
| 222 |
+
# Two independent sub-scores, equally weighted:
|
| 223 |
+
#
|
| 224 |
+
# outlier_f1 — precision/recall on which rows were fixed or removed
|
| 225 |
+
# typo_score — fraction of category typo-cells correctly fixed
|
| 226 |
+
#
|
| 227 |
+
# Final score = 0.50 * outlier_f1 + 0.50 * typo_score
|
| 228 |
+
#
|
| 229 |
+
# Outlier logic:
|
| 230 |
+
# A true-outlier row is "correctly handled" if:
|
| 231 |
+
# (a) the row still exists AND amount is now in [5, 800], OR
|
| 232 |
+
# (b) the row was dropped entirely
|
| 233 |
+
# A valid-extreme row is a "false positive" if it was dropped OR
|
| 234 |
+
# its amount was changed to something outside [900, 2000].
|
| 235 |
+
#
|
| 236 |
+
# The thresholds match the schema_hint the agent was given.
|
| 237 |
+
|
| 238 |
+
_VALID_AMOUNT_MIN = 5.0
|
| 239 |
+
_VALID_AMOUNT_MAX = 800.0
|
| 240 |
+
_EXTREME_AMOUNT_MIN = 900.0
|
| 241 |
+
_EXTREME_AMOUNT_MAX = 2000.0
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _grade_medium(
|
| 245 |
+
agent_df: pd.DataFrame,
|
| 246 |
+
clean_df: pd.DataFrame,
|
| 247 |
+
metadata: Dict[str, Any],
|
| 248 |
+
initial_dirty_cells: int,
|
| 249 |
+
) -> GradeResult:
|
| 250 |
+
|
| 251 |
+
outlier_rows: List[int] = metadata.get("outlier_rows", [])
|
| 252 |
+
valid_extreme_rows: List[int] = metadata.get("valid_extreme_rows", [])
|
| 253 |
+
typo_cells: List[Tuple[int, str, str]] = metadata.get("typo_cells", [])
|
| 254 |
+
|
| 255 |
+
# ── Outlier sub-score ────────────────────────────────────────────────────
|
| 256 |
+
# Detect which of the original row indices are still present in agent_df.
|
| 257 |
+
# We track by tx_id (which is stable and unique) rather than df index,
|
| 258 |
+
# since the agent may reset the index after dropping rows.
|
| 259 |
+
agent_tx_ids: set = set()
|
| 260 |
+
if "tx_id" in agent_df.columns:
|
| 261 |
+
agent_tx_ids = set(agent_df["tx_id"].dropna().astype(int).tolist())
|
| 262 |
+
|
| 263 |
+
tp = 0 # outlier rows that were correctly handled
|
| 264 |
+
fn = 0 # outlier rows still wrong (extreme amount still present)
|
| 265 |
+
fp = 0 # valid-extreme rows wrongly removed or damaged
|
| 266 |
+
|
| 267 |
+
# True-positive check
|
| 268 |
+
for orig_idx in outlier_rows:
|
| 269 |
+
tx_id_val = int(clean_df.iloc[orig_idx]["tx_id"]) if orig_idx < len(clean_df) else None
|
| 270 |
+
if tx_id_val is None:
|
| 271 |
+
continue
|
| 272 |
+
if tx_id_val not in agent_tx_ids:
|
| 273 |
+
# Row was dropped — counts as correctly handled (outlier removed)
|
| 274 |
+
tp += 1
|
| 275 |
+
else:
|
| 276 |
+
# Row still present — check if amount was fixed
|
| 277 |
+
agent_row = agent_df[agent_df["tx_id"].astype(int) == tx_id_val]
|
| 278 |
+
if len(agent_row) == 0:
|
| 279 |
+
tp += 1 # dropped after all
|
| 280 |
+
else:
|
| 281 |
+
amt = _safe_float(agent_row.iloc[0].get("amount"))
|
| 282 |
+
if amt is not None and _VALID_AMOUNT_MIN <= amt <= _VALID_AMOUNT_MAX:
|
| 283 |
+
tp += 1
|
| 284 |
+
else:
|
| 285 |
+
fn += 1
|
| 286 |
+
|
| 287 |
+
# False-positive check (valid extremes must survive untouched)
|
| 288 |
+
for orig_idx in valid_extreme_rows:
|
| 289 |
+
if orig_idx >= len(clean_df):
|
| 290 |
+
continue
|
| 291 |
+
tx_id_val = int(clean_df.iloc[orig_idx]["tx_id"])
|
| 292 |
+
clean_amt = float(clean_df.iloc[orig_idx]["amount"])
|
| 293 |
+
|
| 294 |
+
if tx_id_val not in agent_tx_ids:
|
| 295 |
+
fp += 1 # wrongly dropped a valid row
|
| 296 |
+
else:
|
| 297 |
+
agent_row = agent_df[agent_df["tx_id"].astype(int) == tx_id_val]
|
| 298 |
+
if len(agent_row) == 0:
|
| 299 |
+
fp += 1
|
| 300 |
+
else:
|
| 301 |
+
amt = _safe_float(agent_row.iloc[0].get("amount"))
|
| 302 |
+
# Accept if amount is within ±5% of original clean value
|
| 303 |
+
if amt is None or not (clean_amt * 0.95 <= amt <= clean_amt * 1.05):
|
| 304 |
+
fp += 1
|
| 305 |
+
|
| 306 |
+
n_outliers = len(outlier_rows)
|
| 307 |
+
precision = tp / (tp + fp + 1e-9)
|
| 308 |
+
recall = tp / (n_outliers + 1e-9)
|
| 309 |
+
outlier_f1 = (2 * precision * recall) / (precision + recall + 1e-9)
|
| 310 |
+
|
| 311 |
+
# ── Typo sub-score ───────────────────────────────────────────────────────
|
| 312 |
+
typo_correct = 0
|
| 313 |
+
for (row_idx, dirty_val, clean_val) in typo_cells:
|
| 314 |
+
if "tx_id" not in clean_df.columns or row_idx >= len(clean_df):
|
| 315 |
+
continue
|
| 316 |
+
tx_id_val = int(clean_df.iloc[row_idx]["tx_id"])
|
| 317 |
+
agent_rows = agent_df[agent_df["tx_id"].astype(int) == tx_id_val] \
|
| 318 |
+
if "tx_id" in agent_df.columns else pd.DataFrame()
|
| 319 |
+
if len(agent_rows) == 0:
|
| 320 |
+
continue # row dropped; neither credit nor penalty
|
| 321 |
+
agent_cat = str(agent_rows.iloc[0].get("category", "")).strip()
|
| 322 |
+
if agent_cat == clean_val:
|
| 323 |
+
typo_correct += 1
|
| 324 |
+
|
| 325 |
+
typo_score = typo_correct / max(len(typo_cells), 1)
|
| 326 |
+
|
| 327 |
+
# ── Combined score ───────────────────────────────────────────────────────
|
| 328 |
+
score = 0.50 * outlier_f1 + 0.50 * typo_score
|
| 329 |
+
|
| 330 |
+
# Approximate issues remaining: unsolved outliers + unsolved typos
|
| 331 |
+
issues_remaining = fn + (len(typo_cells) - typo_correct)
|
| 332 |
+
|
| 333 |
+
breakdown = {
|
| 334 |
+
"outlier_f1": round(outlier_f1, 4),
|
| 335 |
+
"outlier_tp": tp,
|
| 336 |
+
"outlier_fn": fn,
|
| 337 |
+
"outlier_fp": fp,
|
| 338 |
+
"precision": round(precision, 4),
|
| 339 |
+
"recall": round(recall, 4),
|
| 340 |
+
"typo_score": round(typo_score, 4),
|
| 341 |
+
"typos_fixed": typo_correct,
|
| 342 |
+
"typos_total": len(typo_cells),
|
| 343 |
+
"combined": round(score, 4),
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
detail = (
|
| 347 |
+
f"Outlier F1={outlier_f1:.3f} (TP={tp}, FP={fp}, FN={fn}) | "
|
| 348 |
+
f"Typos {typo_correct}/{len(typo_cells)} fixed → score={score:.3f}"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return GradeResult(
|
| 352 |
+
score=score,
|
| 353 |
+
breakdown=breakdown,
|
| 354 |
+
issues_remaining=issues_remaining,
|
| 355 |
+
detail=detail,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _safe_float(x: Any) -> Optional[float]:
|
| 360 |
+
if _is_missing(x):
|
| 361 |
+
return None
|
| 362 |
+
try:
|
| 363 |
+
return float(str(x).strip().replace(",", ""))
|
| 364 |
+
except (ValueError, TypeError):
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 369 |
+
# Task 3 — hard: schema normalisation + deduplication + date formatting
|
| 370 |
+
# ──────────────────��──────────────────────────────────────────────────────────
|
| 371 |
+
#
|
| 372 |
+
# Three independent sub-scores:
|
| 373 |
+
#
|
| 374 |
+
# schema_score (weight 0.40)
|
| 375 |
+
# Fraction of canonical column names present in agent_df.
|
| 376 |
+
# Bonus: all 9 canonical columns present AND no extra columns → +0.1
|
| 377 |
+
#
|
| 378 |
+
# dedup_score (weight 0.35)
|
| 379 |
+
# How many of the 30 true duplicate tx records were removed.
|
| 380 |
+
# Penalises over-deletion (removing rows that were not duplicates).
|
| 381 |
+
# dedup_precision = removed_true_dups / (rows_removed + ε)
|
| 382 |
+
# dedup_recall = removed_true_dups / n_duplicate_pairs
|
| 383 |
+
# dedup_f1 = harmonic mean
|
| 384 |
+
#
|
| 385 |
+
# format_score (weight 0.25)
|
| 386 |
+
# Fraction of values in the purchase_date column (or canonical alias)
|
| 387 |
+
# that are valid YYYY-MM-DD strings.
|
| 388 |
+
#
|
| 389 |
+
# Final score = 0.40 * schema_score + 0.35 * dedup_score + 0.25 * format_score
|
| 390 |
+
|
| 391 |
+
_CANONICAL_COLS = [
|
| 392 |
+
"record_id", "customer_id", "full_name", "email",
|
| 393 |
+
"amount", "currency", "purchase_date", "product_name", "region",
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
_ISO_DATE_PATTERN = re.compile(r"^\d{4}-\d{2}-\d{2}$")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _grade_hard(
|
| 400 |
+
agent_df: pd.DataFrame,
|
| 401 |
+
clean_df: pd.DataFrame,
|
| 402 |
+
metadata: Dict[str, Any],
|
| 403 |
+
) -> GradeResult:
|
| 404 |
+
|
| 405 |
+
canonical_lookup: Dict[str, str] = metadata.get("canonical_lookup", {})
|
| 406 |
+
n_clean_rows: int = metadata.get("n_clean_rows", len(clean_df))
|
| 407 |
+
|
| 408 |
+
# ── 1. Schema score ──────────────────────────────────────────────────────
|
| 409 |
+
schema_score, schema_detail = _grade_schema(agent_df, canonical_lookup)
|
| 410 |
+
|
| 411 |
+
# ── 2. Deduplication score ───────────────────────────────────────────────
|
| 412 |
+
dedup_score, dedup_detail = _grade_deduplication(
|
| 413 |
+
agent_df, clean_df, n_clean_rows, canonical_lookup
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# ── 3. Date format score ─────────────────────────────────────────────────
|
| 417 |
+
format_score, format_detail = _grade_date_format(agent_df, canonical_lookup)
|
| 418 |
+
|
| 419 |
+
# ── Combined ─────────────────────────────────────────────────────────────
|
| 420 |
+
score = 0.40 * schema_score + 0.35 * dedup_score + 0.25 * format_score
|
| 421 |
+
|
| 422 |
+
# issues_remaining: rough proxy (unresolved column aliases + excess rows)
|
| 423 |
+
n_canonical_present = sum(
|
| 424 |
+
1 for c in _CANONICAL_COLS if c in agent_df.columns
|
| 425 |
+
)
|
| 426 |
+
issues_remaining = (
|
| 427 |
+
(len(_CANONICAL_COLS) - n_canonical_present) # missing canonical cols
|
| 428 |
+
+ max(0, len(agent_df) - n_clean_rows) # excess rows (dups not removed)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
breakdown = {
|
| 432 |
+
"schema_score": round(schema_score, 4),
|
| 433 |
+
"dedup_score": round(dedup_score, 4),
|
| 434 |
+
"format_score": round(format_score, 4),
|
| 435 |
+
"combined": round(score, 4),
|
| 436 |
+
**{f"schema_{k}": v for k, v in schema_detail.items()},
|
| 437 |
+
**{f"dedup_{k}": v for k, v in dedup_detail.items()},
|
| 438 |
+
**{f"fmt_{k}": v for k, v in format_detail.items()},
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
detail = (
|
| 442 |
+
f"Schema={schema_score:.3f} | "
|
| 443 |
+
f"Dedup={dedup_score:.3f} | "
|
| 444 |
+
f"DateFmt={format_score:.3f} → score={score:.3f}"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
return GradeResult(
|
| 448 |
+
score=score,
|
| 449 |
+
breakdown=breakdown,
|
| 450 |
+
issues_remaining=issues_remaining,
|
| 451 |
+
detail=detail,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def _grade_schema(
|
| 456 |
+
agent_df: pd.DataFrame,
|
| 457 |
+
canonical_lookup: Dict[str, str],
|
| 458 |
+
) -> Tuple[float, Dict[str, Any]]:
|
| 459 |
+
"""
|
| 460 |
+
Score how well the agent normalised column names.
|
| 461 |
+
|
| 462 |
+
Strategy:
|
| 463 |
+
- Build a set of "recognised" columns: canonical names + their aliases.
|
| 464 |
+
- For each canonical column, check if the agent has it (by canonical name).
|
| 465 |
+
- Partial credit per canonical column found.
|
| 466 |
+
- Small bonus if ALL 9 are present and no unrecognised extra columns remain.
|
| 467 |
+
"""
|
| 468 |
+
agent_cols = set(agent_df.columns)
|
| 469 |
+
canonical_set = set(_CANONICAL_COLS)
|
| 470 |
+
|
| 471 |
+
# All known column names (canonical + every alias)
|
| 472 |
+
all_known = canonical_set | set(canonical_lookup.keys())
|
| 473 |
+
|
| 474 |
+
# Count canonical columns present
|
| 475 |
+
found = [c for c in _CANONICAL_COLS if c in agent_cols]
|
| 476 |
+
n_found = len(found)
|
| 477 |
+
base = n_found / len(_CANONICAL_COLS)
|
| 478 |
+
|
| 479 |
+
# Bonus: all canonical present AND no leftover alias columns
|
| 480 |
+
leftover_aliases = [c for c in agent_cols if c not in canonical_set]
|
| 481 |
+
all_present = n_found == len(_CANONICAL_COLS)
|
| 482 |
+
clean_rename = len(leftover_aliases) == 0
|
| 483 |
+
|
| 484 |
+
bonus = 0.10 if (all_present and clean_rename) else 0.0
|
| 485 |
+
|
| 486 |
+
score = min(1.0, base + bonus)
|
| 487 |
+
|
| 488 |
+
detail: Dict[str, Any] = {
|
| 489 |
+
"canonical_found": n_found,
|
| 490 |
+
"canonical_total": len(_CANONICAL_COLS),
|
| 491 |
+
"leftover_aliases": len(leftover_aliases),
|
| 492 |
+
"rename_bonus": bonus,
|
| 493 |
+
}
|
| 494 |
+
return score, detail
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def _grade_deduplication(
|
| 498 |
+
agent_df: pd.DataFrame,
|
| 499 |
+
clean_df: pd.DataFrame,
|
| 500 |
+
n_clean_rows: int,
|
| 501 |
+
canonical_lookup: Dict[str, str],
|
| 502 |
+
) -> Tuple[float, Dict[str, Any]]:
|
| 503 |
+
"""
|
| 504 |
+
Score how well the agent removed duplicate rows.
|
| 505 |
+
|
| 506 |
+
We compare row counts and detect near-duplicate detection quality:
|
| 507 |
+
- n_injected_dups: 30 (hardcoded from dataset_factory)
|
| 508 |
+
- expected_final_rows: n_clean_rows (400)
|
| 509 |
+
- rows_removed: (raw dirty rows = 430) - len(agent_df)
|
| 510 |
+
- true_dups_removed: min(rows_removed, 30) if rows_removed ≤ 35
|
| 511 |
+
(we're lenient — removing 1–35 rows likely targets dups)
|
| 512 |
+
- over_deletion: max(0, rows_removed - 30) rows beyond the dup count
|
| 513 |
+
penalises removing valid data.
|
| 514 |
+
|
| 515 |
+
Precision = true_dups_removed / (rows_removed + ε)
|
| 516 |
+
Recall = true_dups_removed / 30
|
| 517 |
+
F1 = harmonic mean
|
| 518 |
+
"""
|
| 519 |
+
N_INJECTED_DUPS = 30
|
| 520 |
+
N_DIRTY_ROWS = n_clean_rows + N_INJECTED_DUPS # 430
|
| 521 |
+
|
| 522 |
+
rows_removed = max(0, N_DIRTY_ROWS - len(agent_df))
|
| 523 |
+
|
| 524 |
+
# Heuristic: any removal ≤ 35 rows is probably targeting dups
|
| 525 |
+
true_dups_removed = min(rows_removed, N_INJECTED_DUPS)
|
| 526 |
+
|
| 527 |
+
# Penalise over-removal (agent deleted valid rows beyond dups)
|
| 528 |
+
over_deletion = max(0, rows_removed - N_INJECTED_DUPS)
|
| 529 |
+
# Each over-deleted row reduces precision
|
| 530 |
+
effective_true = max(0, true_dups_removed - over_deletion)
|
| 531 |
+
|
| 532 |
+
precision = effective_true / (rows_removed + 1e-9)
|
| 533 |
+
recall = true_dups_removed / (N_INJECTED_DUPS + 1e-9)
|
| 534 |
+
f1 = (2 * precision * recall) / (precision + recall + 1e-9)
|
| 535 |
+
|
| 536 |
+
detail: Dict[str, Any] = {
|
| 537 |
+
"rows_removed": rows_removed,
|
| 538 |
+
"true_dups_removed": true_dups_removed,
|
| 539 |
+
"over_deletion": over_deletion,
|
| 540 |
+
"precision": round(precision, 4),
|
| 541 |
+
"recall": round(recall, 4),
|
| 542 |
+
"f1": round(f1, 4),
|
| 543 |
+
}
|
| 544 |
+
return f1, detail
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _grade_date_format(
|
| 548 |
+
agent_df: pd.DataFrame,
|
| 549 |
+
canonical_lookup: Dict[str, str],
|
| 550 |
+
) -> Tuple[float, Dict[str, Any]]:
|
| 551 |
+
"""
|
| 552 |
+
Fraction of purchase_date values matching YYYY-MM-DD.
|
| 553 |
+
|
| 554 |
+
Looks for the canonical name "purchase_date" first; falls back to
|
| 555 |
+
known aliases ("date", "PurchaseDate") if the agent hasn't renamed yet.
|
| 556 |
+
"""
|
| 557 |
+
DATE_ALIASES = {"purchase_date", "date", "PurchaseDate"}
|
| 558 |
+
|
| 559 |
+
date_col = None
|
| 560 |
+
# Prefer canonical name
|
| 561 |
+
if "purchase_date" in agent_df.columns:
|
| 562 |
+
date_col = "purchase_date"
|
| 563 |
+
else:
|
| 564 |
+
for alias in DATE_ALIASES:
|
| 565 |
+
if alias in agent_df.columns:
|
| 566 |
+
date_col = alias
|
| 567 |
+
break
|
| 568 |
+
|
| 569 |
+
if date_col is None:
|
| 570 |
+
return 0.0, {"date_col_found": False, "valid_ratio": 0.0}
|
| 571 |
+
|
| 572 |
+
# Guard: duplicate column names after rename produce a DataFrame, not Series.
|
| 573 |
+
# Take the first occurrence.
|
| 574 |
+
col_data = agent_df[date_col]
|
| 575 |
+
if isinstance(col_data, pd.DataFrame):
|
| 576 |
+
col_data = col_data.iloc[:, 0]
|
| 577 |
+
|
| 578 |
+
# Force object dtype so .sum() always returns a numeric 0, not '' (the
|
| 579 |
+
# StringDtype identity). Python 3.14 + pandas 2.2+ infer StringDtype
|
| 580 |
+
# from .astype(str), which makes .sum() on an empty Series return ''.
|
| 581 |
+
series = col_data.dropna().astype(object).apply(str).str.strip()
|
| 582 |
+
n_total = len(series)
|
| 583 |
+
if n_total == 0:
|
| 584 |
+
return 0.0, {"date_col_found": True, "valid_ratio": 0.0, "n_total": 0}
|
| 585 |
+
|
| 586 |
+
# Combined check: ISO pattern match AND year in plausible range
|
| 587 |
+
def _is_valid_iso(s: str) -> bool:
|
| 588 |
+
if not _ISO_DATE_PATTERN.match(s):
|
| 589 |
+
return False
|
| 590 |
+
try:
|
| 591 |
+
return 2000 <= int(s[:4]) <= 2030
|
| 592 |
+
except Exception:
|
| 593 |
+
return False
|
| 594 |
+
|
| 595 |
+
valid_flags = series.apply(_is_valid_iso)
|
| 596 |
+
n_valid = int(valid_flags.sum()) # int() guards against numpy/pandas scalar types
|
| 597 |
+
n_year_ok = n_valid # same condition — kept for breakdown detail
|
| 598 |
+
valid_ratio = n_year_ok / n_total
|
| 599 |
+
|
| 600 |
+
detail: Dict[str, Any] = {
|
| 601 |
+
"date_col_found": True,
|
| 602 |
+
"date_col_used": date_col,
|
| 603 |
+
"n_total": int(n_total),
|
| 604 |
+
"n_valid_iso": int(n_valid),
|
| 605 |
+
"n_year_ok": int(n_year_ok),
|
| 606 |
+
"valid_ratio": round(valid_ratio, 4),
|
| 607 |
+
}
|
| 608 |
+
return valid_ratio, detail
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 612 |
+
# Smoke test
|
| 613 |
+
# ────────────────────────────────────────────────────────────────���────────────
|
| 614 |
+
|
| 615 |
+
if __name__ == "__main__":
|
| 616 |
+
import sys
|
| 617 |
+
sys.path.insert(0, ".")
|
| 618 |
+
from dataset_factory import make_dataset
|
| 619 |
+
|
| 620 |
+
SEP = "─" * 62
|
| 621 |
+
|
| 622 |
+
# ── Task 1: easy ─────────────────────────────────────────────────────────
|
| 623 |
+
print(f"\n{SEP}\nTASK: easy\n{SEP}")
|
| 624 |
+
ds = make_dataset("easy")
|
| 625 |
+
|
| 626 |
+
# Baseline: grade dirty df (should be low)
|
| 627 |
+
r_dirty = grade("easy", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 628 |
+
print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
|
| 629 |
+
|
| 630 |
+
# Perfect: grade clean df (should be 1.0)
|
| 631 |
+
r_clean = grade("easy", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 632 |
+
print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
|
| 633 |
+
|
| 634 |
+
# Partial: fix half the injected cells
|
| 635 |
+
partial = ds.dirty_df.copy()
|
| 636 |
+
injected = ds.metadata.get("injected_cells", [])
|
| 637 |
+
for (row, col) in injected[:len(injected)//2]:
|
| 638 |
+
partial.at[row, col] = ds.clean_df.at[row, col]
|
| 639 |
+
r_partial = grade("easy", partial, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 640 |
+
print(f"[half] score={r_partial.score:.4f} {r_partial.detail}")
|
| 641 |
+
|
| 642 |
+
print(f"Breakdown: {r_partial.breakdown}")
|
| 643 |
+
|
| 644 |
+
# ── Task 2: medium ────────────────────────────────────────────────────────
|
| 645 |
+
print(f"\n{SEP}\nTASK: medium\n{SEP}")
|
| 646 |
+
ds = make_dataset("medium")
|
| 647 |
+
|
| 648 |
+
r_dirty = grade("medium", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 649 |
+
print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
|
| 650 |
+
|
| 651 |
+
r_clean = grade("medium", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 652 |
+
print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
|
| 653 |
+
|
| 654 |
+
# Simulate agent fixing all outliers (set amount to 150.0) + all typos
|
| 655 |
+
fixed = ds.dirty_df.copy()
|
| 656 |
+
for row in ds.metadata["outlier_rows"]:
|
| 657 |
+
if "tx_id" in ds.clean_df.columns:
|
| 658 |
+
fixed.at[row, "amount"] = 150.0
|
| 659 |
+
for (row, dirty_val, clean_val) in ds.metadata["typo_cells"]:
|
| 660 |
+
fixed.at[row, "category"] = clean_val
|
| 661 |
+
r_fixed = grade("medium", fixed, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 662 |
+
print(f"[fixed] score={r_fixed.score:.4f} {r_fixed.detail}")
|
| 663 |
+
|
| 664 |
+
print(f"Breakdown: {r_fixed.breakdown}")
|
| 665 |
+
|
| 666 |
+
# ── Task 3: hard ──────────────────────────────────────────────────────────
|
| 667 |
+
print(f"\n{SEP}\nTASK: hard\n{SEP}")
|
| 668 |
+
ds = make_dataset("hard")
|
| 669 |
+
|
| 670 |
+
r_dirty = grade("hard", ds.dirty_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 671 |
+
print(f"[dirty] score={r_dirty.score:.4f} {r_dirty.detail}")
|
| 672 |
+
|
| 673 |
+
r_clean = grade("hard", ds.clean_df, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 674 |
+
print(f"[clean] score={r_clean.score:.4f} {r_clean.detail}")
|
| 675 |
+
|
| 676 |
+
# Simulate partial fix: rename columns only, don't dedup or fix dates
|
| 677 |
+
partial_hard = ds.dirty_df.copy()
|
| 678 |
+
rename_map = ds.metadata.get("canonical_lookup", {})
|
| 679 |
+
partial_hard = partial_hard.rename(columns=rename_map)
|
| 680 |
+
# Keep only canonical columns that exist
|
| 681 |
+
canonical_present = [c for c in _CANONICAL_COLS if c in partial_hard.columns]
|
| 682 |
+
partial_hard = partial_hard[canonical_present]
|
| 683 |
+
r_renamed = grade("hard", partial_hard, ds.clean_df, ds.metadata, ds.total_dirty_cells)
|
| 684 |
+
print(f"[rename] score={r_renamed.score:.4f} {r_renamed.detail}")
|
| 685 |
+
|
| 686 |
+
print(f"Breakdown: {r_renamed.breakdown}")
|
inference.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
------------
|
| 4 |
+
Official submission inference script for the Data Cleaning Pipeline environment.
|
| 5 |
+
|
| 6 |
+
Reads from environment variables (ALL FREE — no paid API needed):
|
| 7 |
+
API_BASE_URL LLM endpoint. Default: HuggingFace free router.
|
| 8 |
+
MODEL_NAME Model to use. Default: free open model.
|
| 9 |
+
HF_TOKEN Your free HuggingFace token (hf_...).
|
| 10 |
+
LOCAL_IMAGE_NAME Docker image name if using from_docker_image().
|
| 11 |
+
Leave unset to connect via ENV_BASE_URL instead.
|
| 12 |
+
ENV_BASE_URL Direct server URL. Default: http://localhost:8000
|
| 13 |
+
|
| 14 |
+
STDOUT FORMAT (evaluator parses these lines exactly — do not modify):
|
| 15 |
+
[START] task=<n> env=<benchmark> model=<model>
|
| 16 |
+
[STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 17 |
+
[END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import asyncio
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import sys
|
| 25 |
+
from typing import List, Optional
|
| 26 |
+
from unittest import result
|
| 27 |
+
from client import DataCleaningEnv, CleanAction, CleanObservation
|
| 28 |
+
from openai import OpenAI
|
| 29 |
+
|
| 30 |
+
# ── Environment client imports ────────────────────────────────────────────────
|
| 31 |
+
try:
|
| 32 |
+
from client import DataCleaningEnv
|
| 33 |
+
from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
|
| 34 |
+
except ImportError:
|
| 35 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 36 |
+
from client import DataCleaningEnv
|
| 37 |
+
from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ── Configuration — all defaults are FREE ────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 43 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 44 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 45 |
+
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
|
| 46 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 47 |
+
|
| 48 |
+
BENCHMARK = "data_cleaning_env"
|
| 49 |
+
TASK_IDS = ["easy", "medium", "hard"]
|
| 50 |
+
|
| 51 |
+
# Conservative budgets — keeps total runtime under 20 min on vcpu=2 / 8 GB
|
| 52 |
+
STEP_LIMITS = {"easy": 25, "medium": 50, "hard": 80}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ── Official log helpers ──────────────────────────────────────────────────────
|
| 56 |
+
# Field names, order, and spacing match the evaluator spec exactly.
|
| 57 |
+
|
| 58 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 59 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def log_step(
|
| 63 |
+
step: int,
|
| 64 |
+
action: str,
|
| 65 |
+
reward: float,
|
| 66 |
+
done: bool,
|
| 67 |
+
error: Optional[str],
|
| 68 |
+
) -> None:
|
| 69 |
+
error_val = error if error else "null"
|
| 70 |
+
done_val = str(done).lower()
|
| 71 |
+
action_str = action[:80].replace("\n", " ") # keep line single-line
|
| 72 |
+
print(
|
| 73 |
+
f"[STEP] step={step} action={action_str} "
|
| 74 |
+
f"reward={reward:.2f} done={done_val} error={error_val}",
|
| 75 |
+
flush=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def log_end(
|
| 80 |
+
success: bool,
|
| 81 |
+
steps: int,
|
| 82 |
+
score: float,
|
| 83 |
+
rewards: List[float],
|
| 84 |
+
) -> None:
|
| 85 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 86 |
+
print(
|
| 87 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 88 |
+
f"score={score:.2f} rewards={rewards_str}",
|
| 89 |
+
flush=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ── LLM helpers ───────────────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
SYSTEM_PROMPT = (
|
| 96 |
+
"You are a data cleaning agent. You receive a dirty CSV and must fix it "
|
| 97 |
+
"step by step using JSON action commands. Fix the most impactful issues "
|
| 98 |
+
"first. Be precise — wrong column names cause errors. "
|
| 99 |
+
"Output a single valid JSON object and nothing else — no explanation, no markdown."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_prompt(obs) -> str:
|
| 104 |
+
rows = obs.dirty_csv.strip().split("\n")
|
| 105 |
+
preview = "\n".join(rows[:30])
|
| 106 |
+
truncated = len(rows) > 30
|
| 107 |
+
last_err = f"\nLast error: {obs.last_action_error}" if obs.last_action_error else ""
|
| 108 |
+
return (
|
| 109 |
+
f"Task: {obs.task_id}\n"
|
| 110 |
+
f"Schema: {obs.schema_hint}\n"
|
| 111 |
+
f"Score: {obs.current_score:.4f} | Issues remaining: {obs.issues_remaining}\n"
|
| 112 |
+
f"Step {obs.step_number}/{obs.max_steps}{last_err}\n"
|
| 113 |
+
f"\nCSV{' (first 30 rows)' if truncated else ''}:\n{preview}\n\n"
|
| 114 |
+
"Reply with ONE JSON action:\n"
|
| 115 |
+
' {"command":"SET_VALUE", "row_index":<int>, "column":"<name>", "value":"<str>"}\n'
|
| 116 |
+
' {"command":"DROP_ROW", "row_index":<int>}\n'
|
| 117 |
+
' {"command":"STANDARDIZE_COL", "column":"<name>"}\n'
|
| 118 |
+
' {"command":"FILL_MISSING", "column":"<name>", "fill_strategy":"mean|median|mode|drop"}\n'
|
| 119 |
+
' {"command":"DONE"}\n'
|
| 120 |
+
"row_index = integer in the leftmost column of the CSV. JSON only."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def parse_action(raw: str) -> CleanAction:
|
| 125 |
+
"""Convert model output to CleanAction. Falls back to DONE on any error."""
|
| 126 |
+
text = raw.strip()
|
| 127 |
+
if text.startswith("```"):
|
| 128 |
+
lines = text.split("\n")
|
| 129 |
+
inner = lines[1:-1] if lines[-1].strip().startswith("```") else lines[1:]
|
| 130 |
+
text = "\n".join(inner).strip()
|
| 131 |
+
try:
|
| 132 |
+
return CleanAction(**json.loads(text))
|
| 133 |
+
except Exception:
|
| 134 |
+
m = re.search(r"\{[^{}]+\}", text, re.DOTALL)
|
| 135 |
+
if m:
|
| 136 |
+
try:
|
| 137 |
+
return CleanAction(**json.loads(m.group()))
|
| 138 |
+
except Exception:
|
| 139 |
+
pass
|
| 140 |
+
return CleanAction(command="DONE")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def call_llm(client: OpenAI, messages: list) -> str:
|
| 144 |
+
response = client.chat.completions.create(
|
| 145 |
+
model=MODEL_NAME,
|
| 146 |
+
messages=messages,
|
| 147 |
+
max_tokens=150, # actions are short; saves free-tier quota
|
| 148 |
+
temperature=0.1,
|
| 149 |
+
)
|
| 150 |
+
return (response.choices[0].message.content or "").strip()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ── Episode loop ───────────────────────────────────────────────────────────────
|
| 154 |
+
|
| 155 |
+
async def run_episode(env, client: OpenAI, task_id: str) -> dict:
|
| 156 |
+
"""Run one episode. Emits [START] → N×[STEP] → [END]."""
|
| 157 |
+
max_steps = STEP_LIMITS[task_id]
|
| 158 |
+
threshold = DONE_THRESHOLD[task_id]
|
| 159 |
+
rewards: List[float] = []
|
| 160 |
+
steps_taken = 0
|
| 161 |
+
score = 0.0
|
| 162 |
+
success = False
|
| 163 |
+
|
| 164 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
result = await env.reset(task_id=task_id)
|
| 168 |
+
obs = result.observation
|
| 169 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 170 |
+
|
| 171 |
+
for step in range(1, max_steps + 1):
|
| 172 |
+
if obs.done:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
steps_taken = step
|
| 176 |
+
messages.append({"role": "user", "content": build_prompt(obs)})
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
raw = call_llm(client, messages)
|
| 180 |
+
action = parse_action(raw)
|
| 181 |
+
messages.append({"role": "assistant", "content": raw})
|
| 182 |
+
except Exception as exc:
|
| 183 |
+
# API or parse failure — log and stop episode
|
| 184 |
+
log_step(step, "DONE", 0.00, True, str(exc)[:120])
|
| 185 |
+
rewards.append(0.0)
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
# Keep only system + last 8 exchanges to stay inside free-tier context limits
|
| 189 |
+
if len(messages) > 17:
|
| 190 |
+
messages = [messages[0]] + messages[-16:]
|
| 191 |
+
|
| 192 |
+
result = await env.step(action)
|
| 193 |
+
obs = result.observation
|
| 194 |
+
reward = result.reward or 0.0
|
| 195 |
+
rewards.append(reward)
|
| 196 |
+
score = obs.current_score
|
| 197 |
+
|
| 198 |
+
log_step(
|
| 199 |
+
step = step,
|
| 200 |
+
action = action.command,
|
| 201 |
+
reward = reward,
|
| 202 |
+
done = obs.done,
|
| 203 |
+
error = obs.last_action_error,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if obs.done or score >= threshold:
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
success = score >= threshold
|
| 210 |
+
|
| 211 |
+
finally:
|
| 212 |
+
# [END] is always emitted, even if the episode crashed
|
| 213 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 214 |
+
|
| 215 |
+
return {"task_id": task_id, "score": score,
|
| 216 |
+
"reward": sum(rewards), "steps": steps_taken, "success": success}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ── Entry point ────────────────────────────────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
async def main() -> None:
|
| 222 |
+
if not HF_TOKEN:
|
| 223 |
+
print(
|
| 224 |
+
"ERROR: HF_TOKEN is not set.\n"
|
| 225 |
+
"1. Go to https://huggingface.co/settings/tokens\n"
|
| 226 |
+
"2. Click 'New token' → choose 'Read' → copy it\n"
|
| 227 |
+
"3. In PowerShell: $env:HF_TOKEN='hf_xxxxxxxxxxxx'\n"
|
| 228 |
+
"4. Then run: python inference.py",
|
| 229 |
+
file=sys.stderr,
|
| 230 |
+
)
|
| 231 |
+
sys.exit(1)
|
| 232 |
+
|
| 233 |
+
print(f"API_BASE_URL : {API_BASE_URL}", flush=True)
|
| 234 |
+
print(f"MODEL_NAME : {MODEL_NAME}", flush=True)
|
| 235 |
+
print(f"LOCAL_IMAGE_NAME : {LOCAL_IMAGE_NAME or '(not set — using ENV_BASE_URL)'}", flush=True)
|
| 236 |
+
print(f"ENV_BASE_URL : {ENV_BASE_URL}", flush=True)
|
| 237 |
+
print("", flush=True)
|
| 238 |
+
|
| 239 |
+
# ✅ Create llm and env in the correct order
|
| 240 |
+
llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 241 |
+
|
| 242 |
+
if LOCAL_IMAGE_NAME:
|
| 243 |
+
env = await DataCleaningEnv.from_docker_image(LOCAL_IMAGE_NAME)
|
| 244 |
+
else:
|
| 245 |
+
env = DataCleaningEnv(base_url=ENV_BASE_URL)
|
| 246 |
+
await env.connect()
|
| 247 |
+
|
| 248 |
+
results = []
|
| 249 |
+
try:
|
| 250 |
+
for task_id in TASK_IDS:
|
| 251 |
+
summary = await run_episode(env, llm, task_id)
|
| 252 |
+
results.append(summary)
|
| 253 |
+
print("", flush=True)
|
| 254 |
+
finally:
|
| 255 |
+
await env.close()
|
| 256 |
+
|
| 257 |
+
# Human-readable summary (evaluator ignores lines that don't start with [START]/[STEP]/[END])
|
| 258 |
+
print("=" * 56, flush=True)
|
| 259 |
+
print(f"{'Task':<10} {'Score':>7} {'Reward':>9} {'Steps':>6} {'Pass':>5}")
|
| 260 |
+
print("-" * 56, flush=True)
|
| 261 |
+
for r in results:
|
| 262 |
+
print(
|
| 263 |
+
f"{r['task_id']:<10} {r['score']:>7.4f} {r['reward']:>9.4f} "
|
| 264 |
+
f"{r['steps']:>6} {'YES' if r['success'] else 'NO':>4}",
|
| 265 |
+
flush=True,
|
| 266 |
+
)
|
| 267 |
+
print("=" * 56, flush=True)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
asyncio.run(main())
|
models.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models.py
|
| 3 |
+
---------
|
| 4 |
+
Pydantic models for the Data Cleaning Pipeline environment.
|
| 5 |
+
|
| 6 |
+
Three models define the full agent↔environment contract:
|
| 7 |
+
|
| 8 |
+
CleanAction — what the agent sends on each step
|
| 9 |
+
CleanObservation — what the agent receives back
|
| 10 |
+
CleanState — internal server state (not sent to agent directly)
|
| 11 |
+
|
| 12 |
+
Inheritance chain (confirmed from OpenEnv source):
|
| 13 |
+
Action → extra="forbid", has: metadata: Dict[str, Any]
|
| 14 |
+
Observation → extra="forbid", has: done: bool, reward: float|None, metadata: Dict[str, Any]
|
| 15 |
+
State → extra="allow", has: episode_id: Optional[str], step_count: int
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 21 |
+
|
| 22 |
+
from pydantic import Field, field_validator, model_validator
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 26 |
+
except ImportError:
|
| 27 |
+
# Fallback for local development without the full OpenEnv install
|
| 28 |
+
from openenv.core.env_server import Action, Observation, State
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Valid values (used by validators + schema hints) ──────────────────────────
|
| 32 |
+
|
| 33 |
+
VALID_COMMANDS = Literal[
|
| 34 |
+
"SET_VALUE", # Fix a specific cell: (row_index, column, value)
|
| 35 |
+
"DROP_ROW", # Remove an entire row: (row_index,)
|
| 36 |
+
"STANDARDIZE_COL", # Normalize an entire column's format: (column,)
|
| 37 |
+
"FILL_MISSING", # Fill NaN values in a column: (column, fill_strategy)
|
| 38 |
+
"DONE", # Agent signals episode is complete: ()
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
VALID_FILL_STRATEGIES = Literal["mean", "median", "mode", "drop"]
|
| 42 |
+
|
| 43 |
+
VALID_TASK_IDS = Literal["easy", "medium", "hard"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 47 |
+
# CleanAction
|
| 48 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 49 |
+
|
| 50 |
+
class CleanAction(Action):
|
| 51 |
+
"""Action sent by the agent each step.
|
| 52 |
+
|
| 53 |
+
The ``command`` field selects the operation. Depending on command,
|
| 54 |
+
only a subset of the remaining fields are required:
|
| 55 |
+
|
| 56 |
+
+-----------------+------------+--------+-------+---------------+
|
| 57 |
+
| command | row_index | column | value | fill_strategy |
|
| 58 |
+
+=================+============+========+=======+===============+
|
| 59 |
+
| SET_VALUE | required | req | req | — |
|
| 60 |
+
| DROP_ROW | required | — | — | — |
|
| 61 |
+
| STANDARDIZE_COL | — | req | — | — |
|
| 62 |
+
| FILL_MISSING | — | req | — | required |
|
| 63 |
+
| DONE | — | — | — | — |
|
| 64 |
+
+-----------------+------------+--------+-------+---------------+
|
| 65 |
+
|
| 66 |
+
Example (fix a single cell)::
|
| 67 |
+
|
| 68 |
+
CleanAction(
|
| 69 |
+
command="SET_VALUE",
|
| 70 |
+
row_index=3,
|
| 71 |
+
column="price",
|
| 72 |
+
value="29.99",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
Example (drop a whole row)::
|
| 76 |
+
|
| 77 |
+
CleanAction(command="DROP_ROW", row_index=17)
|
| 78 |
+
|
| 79 |
+
Example (fill all NaN in a column with the median)::
|
| 80 |
+
|
| 81 |
+
CleanAction(
|
| 82 |
+
command="FILL_MISSING",
|
| 83 |
+
column="quantity",
|
| 84 |
+
fill_strategy="median",
|
| 85 |
+
)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
command: VALID_COMMANDS = Field(
|
| 89 |
+
...,
|
| 90 |
+
description=(
|
| 91 |
+
"Operation to perform. One of: SET_VALUE, DROP_ROW, "
|
| 92 |
+
"STANDARDIZE_COL, FILL_MISSING, DONE."
|
| 93 |
+
),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
row_index: Optional[int] = Field(
|
| 97 |
+
default=None,
|
| 98 |
+
ge=0,
|
| 99 |
+
description=(
|
| 100 |
+
"Zero-based row index to target. "
|
| 101 |
+
"Required for SET_VALUE and DROP_ROW."
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
column: Optional[str] = Field(
|
| 106 |
+
default=None,
|
| 107 |
+
min_length=1,
|
| 108 |
+
description=(
|
| 109 |
+
"Name of the column to target. "
|
| 110 |
+
"Required for SET_VALUE, STANDARDIZE_COL, and FILL_MISSING."
|
| 111 |
+
),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
value: Optional[str] = Field(
|
| 115 |
+
default=None,
|
| 116 |
+
description=(
|
| 117 |
+
"New cell value as a string. "
|
| 118 |
+
"Required for SET_VALUE. The environment casts this to the "
|
| 119 |
+
"column's expected dtype (e.g. '29.99' → float for a price column)."
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
fill_strategy: Optional[VALID_FILL_STRATEGIES] = Field(
|
| 124 |
+
default=None,
|
| 125 |
+
description=(
|
| 126 |
+
"Strategy for FILL_MISSING. One of: mean, median, mode, drop. "
|
| 127 |
+
"'drop' removes rows where the column is NaN."
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@model_validator(mode="after")
|
| 132 |
+
def _check_required_fields(self) -> "CleanAction":
|
| 133 |
+
"""Ensure each command has exactly the fields it needs."""
|
| 134 |
+
cmd = self.command
|
| 135 |
+
|
| 136 |
+
if cmd == "SET_VALUE":
|
| 137 |
+
missing = []
|
| 138 |
+
if self.row_index is None:
|
| 139 |
+
missing.append("row_index")
|
| 140 |
+
if self.column is None:
|
| 141 |
+
missing.append("column")
|
| 142 |
+
if self.value is None:
|
| 143 |
+
missing.append("value")
|
| 144 |
+
if missing:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"SET_VALUE requires: {', '.join(missing)}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
elif cmd == "DROP_ROW":
|
| 150 |
+
if self.row_index is None:
|
| 151 |
+
raise ValueError("DROP_ROW requires row_index")
|
| 152 |
+
|
| 153 |
+
elif cmd == "STANDARDIZE_COL":
|
| 154 |
+
if self.column is None:
|
| 155 |
+
raise ValueError("STANDARDIZE_COL requires column")
|
| 156 |
+
|
| 157 |
+
elif cmd == "FILL_MISSING":
|
| 158 |
+
missing = []
|
| 159 |
+
if self.column is None:
|
| 160 |
+
missing.append("column")
|
| 161 |
+
if self.fill_strategy is None:
|
| 162 |
+
missing.append("fill_strategy")
|
| 163 |
+
if missing:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f"FILL_MISSING requires: {', '.join(missing)}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# DONE requires nothing — always valid
|
| 169 |
+
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
@field_validator("row_index")
|
| 173 |
+
@classmethod
|
| 174 |
+
def _non_negative_row(cls, v: Optional[int]) -> Optional[int]:
|
| 175 |
+
if v is not None and v < 0:
|
| 176 |
+
raise ValueError(f"row_index must be >= 0, got {v}")
|
| 177 |
+
return v
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 181 |
+
# CleanObservation
|
| 182 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 183 |
+
|
| 184 |
+
class CleanObservation(Observation):
|
| 185 |
+
"""Observation returned to the agent after each step (and at reset).
|
| 186 |
+
|
| 187 |
+
The agent sees the full current state of the dirty CSV at every step
|
| 188 |
+
so it can decide what to fix next. This is intentionally verbose —
|
| 189 |
+
passing the whole CSV string keeps the environment stateless from the
|
| 190 |
+
agent's perspective (no hidden memory needed).
|
| 191 |
+
|
| 192 |
+
Inherited from Observation (do NOT redeclare these):
|
| 193 |
+
done: bool — True when the episode has ended
|
| 194 |
+
reward: float | None — per-step reward (None at reset)
|
| 195 |
+
metadata: Dict[str, Any] — extra info (unused by core loop)
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
# ── Task context (set at reset, constant for the episode) ────────────────
|
| 199 |
+
|
| 200 |
+
task_id: VALID_TASK_IDS = Field(
|
| 201 |
+
...,
|
| 202 |
+
description="Which task is active: 'easy', 'medium', or 'hard'.",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
schema_hint: str = Field(
|
| 206 |
+
...,
|
| 207 |
+
description=(
|
| 208 |
+
"Plain-English description of the target schema. "
|
| 209 |
+
"Tells the agent what the clean data should look like."
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
initial_dirty_cells: int = Field(
|
| 214 |
+
...,
|
| 215 |
+
ge=0,
|
| 216 |
+
description=(
|
| 217 |
+
"Total number of cells that differed from ground truth at episode start. "
|
| 218 |
+
"Used to compute a normalised progress score."
|
| 219 |
+
),
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# ── Per-step state ───────────────────────────────────────────────────────
|
| 223 |
+
|
| 224 |
+
dirty_csv: str = Field(
|
| 225 |
+
...,
|
| 226 |
+
description=(
|
| 227 |
+
"Full current state of the working DataFrame serialised as a CSV string. "
|
| 228 |
+
"This reflects all changes the agent has made so far this episode."
|
| 229 |
+
),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
current_score: float = Field(
|
| 233 |
+
default=0.0,
|
| 234 |
+
ge=0.0,
|
| 235 |
+
le=1.0,
|
| 236 |
+
description=(
|
| 237 |
+
"Grader score after the last action (0.0 = no cells correct, "
|
| 238 |
+
"1.0 = perfect match with ground truth)."
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
issues_remaining: int = Field(
|
| 243 |
+
default=0,
|
| 244 |
+
ge=0,
|
| 245 |
+
description=(
|
| 246 |
+
"Approximate count of cells still differing from ground truth. "
|
| 247 |
+
"Convenience field — agents can also derive this from the CSV."
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
step_number: int = Field(
|
| 252 |
+
default=0,
|
| 253 |
+
ge=0,
|
| 254 |
+
description="How many steps have been taken in this episode so far.",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
max_steps: int = Field(
|
| 258 |
+
...,
|
| 259 |
+
ge=1,
|
| 260 |
+
description="Maximum steps allowed for this task before forced termination.",
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# ── Last-action feedback ────────────────────────────────────────────────
|
| 264 |
+
|
| 265 |
+
last_action_success: bool = Field(
|
| 266 |
+
default=True,
|
| 267 |
+
description=(
|
| 268 |
+
"Whether the last action was applied without errors. "
|
| 269 |
+
"False if the column/row didn't exist, value couldn't be cast, etc."
|
| 270 |
+
),
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
last_action_error: Optional[str] = Field(
|
| 274 |
+
default=None,
|
| 275 |
+
description=(
|
| 276 |
+
"Error message if last_action_success is False, else None. "
|
| 277 |
+
"Helps the agent self-correct."
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
@field_validator("current_score")
|
| 282 |
+
@classmethod
|
| 283 |
+
def _round_score(cls, v: float) -> float:
|
| 284 |
+
return round(v, 4)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 288 |
+
# CleanState
|
| 289 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 290 |
+
|
| 291 |
+
class CleanState(State):
|
| 292 |
+
"""Internal server-side state. Never sent to the agent directly.
|
| 293 |
+
|
| 294 |
+
Holds the live DataFrames, ground truth, and grader metadata.
|
| 295 |
+
Because State uses extra="allow", we can store arbitrary fields
|
| 296 |
+
without listing them in the JSON schema.
|
| 297 |
+
|
| 298 |
+
Inherited from State:
|
| 299 |
+
episode_id: Optional[str] — unique episode identifier
|
| 300 |
+
step_count: int — steps taken this episode (ge=0)
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
# ── Task identity ────────────────────────────────────────────────────────
|
| 304 |
+
|
| 305 |
+
task_id: str = Field(
|
| 306 |
+
default="easy",
|
| 307 |
+
description="Active task: 'easy', 'medium', or 'hard'.",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# ── DataFrame snapshots (stored as CSV strings for serialisation) ────────
|
| 311 |
+
# NOTE: The environment keeps live pd.DataFrame objects in instance vars.
|
| 312 |
+
# These string fields are the serialised snapshots used by state() calls
|
| 313 |
+
# and for WebSocket state responses.
|
| 314 |
+
|
| 315 |
+
dirty_csv_snapshot: str = Field(
|
| 316 |
+
default="",
|
| 317 |
+
description="Current working DataFrame serialised to CSV string.",
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
clean_csv_snapshot: str = Field(
|
| 321 |
+
default="",
|
| 322 |
+
description="Ground-truth clean DataFrame serialised to CSV string.",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# ── Scoring ──────────────────────────────────────────────────────────────
|
| 326 |
+
|
| 327 |
+
initial_dirty_cells: int = Field(
|
| 328 |
+
default=0,
|
| 329 |
+
ge=0,
|
| 330 |
+
description="Dirty cell count at episode start (denominator for progress).",
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
current_score: float = Field(
|
| 334 |
+
default=0.0,
|
| 335 |
+
ge=0.0,
|
| 336 |
+
le=1.0,
|
| 337 |
+
description="Grader score after the last step.",
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
previous_score: float = Field(
|
| 341 |
+
default=0.0,
|
| 342 |
+
ge=0.0,
|
| 343 |
+
le=1.0,
|
| 344 |
+
description="Grader score before the last step (for reward delta).",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# ── Task metadata (passed through from TaskDataset.metadata) ─────────────
|
| 348 |
+
# Contains grader-specific ground truth: outlier_rows, canonical_lookup, etc.
|
| 349 |
+
|
| 350 |
+
task_metadata: Dict[str, Any] = Field(
|
| 351 |
+
default_factory=dict,
|
| 352 |
+
description=(
|
| 353 |
+
"Task-specific metadata from dataset_factory.TaskDataset.metadata. "
|
| 354 |
+
"Contains grader ground truth (outlier_rows, duplicate_pairs, etc.)."
|
| 355 |
+
),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# ── Schema hint (echoed in observations) ────────────────────────────────
|
| 359 |
+
|
| 360 |
+
schema_hint: str = Field(
|
| 361 |
+
default="",
|
| 362 |
+
description="Plain-English schema description for this task.",
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# ── Per-task step budget ─────────────────────────────────────────────────
|
| 366 |
+
|
| 367 |
+
max_steps: int = Field(
|
| 368 |
+
default=40,
|
| 369 |
+
ge=1,
|
| 370 |
+
description="Maximum steps for this task (40 / 80 / 150 for easy/medium/hard).",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
@field_validator("current_score", "previous_score")
|
| 374 |
+
@classmethod
|
| 375 |
+
def _clamp_score(cls, v: float) -> float:
|
| 376 |
+
return round(max(0.0, min(1.0, v)), 4)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ── Step budget constants ─────────────────────────────────────────────────────
|
| 380 |
+
|
| 381 |
+
MAX_STEPS: Dict[str, int] = {
|
| 382 |
+
"easy": 40,
|
| 383 |
+
"medium": 80,
|
| 384 |
+
"hard": 150,
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
# Done threshold: score at which the agent is considered successful
|
| 388 |
+
DONE_THRESHOLD: Dict[str, float] = {
|
| 389 |
+
"easy": 0.95,
|
| 390 |
+
"medium": 0.85,
|
| 391 |
+
"hard": 0.80,
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ── Smoke test ────────────────────────────────────────────────────────────────
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
import json
|
| 399 |
+
|
| 400 |
+
print("── CleanAction examples ──────────────────────────────────────")
|
| 401 |
+
|
| 402 |
+
a1 = CleanAction(command="SET_VALUE", row_index=3, column="price", value="29.99")
|
| 403 |
+
print("SET_VALUE: ", a1.model_dump())
|
| 404 |
+
|
| 405 |
+
a2 = CleanAction(command="DROP_ROW", row_index=17)
|
| 406 |
+
print("DROP_ROW: ", a2.model_dump())
|
| 407 |
+
|
| 408 |
+
a3 = CleanAction(command="FILL_MISSING", column="quantity", fill_strategy="median")
|
| 409 |
+
print("FILL_MISSING: ", a3.model_dump())
|
| 410 |
+
|
| 411 |
+
a4 = CleanAction(command="STANDARDIZE_COL", column="order_date")
|
| 412 |
+
print("STANDARDIZE_COL:", a4.model_dump())
|
| 413 |
+
|
| 414 |
+
a5 = CleanAction(command="DONE")
|
| 415 |
+
print("DONE: ", a5.model_dump())
|
| 416 |
+
|
| 417 |
+
# Validation: SET_VALUE without row_index should fail
|
| 418 |
+
print("\n── Validation ────────────────────────────────────────────────")
|
| 419 |
+
try:
|
| 420 |
+
bad = CleanAction(command="SET_VALUE", column="price", value="10.0")
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"Expected error (missing row_index): {e}")
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
bad = CleanAction(command="FILL_MISSING", column="price")
|
| 426 |
+
except Exception as e:
|
| 427 |
+
print(f"Expected error (missing fill_strategy): {e}")
|
| 428 |
+
|
| 429 |
+
print("\n── CleanObservation ──────────────────────────────────────────")
|
| 430 |
+
obs = CleanObservation(
|
| 431 |
+
task_id="easy",
|
| 432 |
+
schema_hint="Sales orders dataset. price must be float.",
|
| 433 |
+
initial_dirty_cells=29,
|
| 434 |
+
dirty_csv="order_id,price\n1001,N/A\n1002,19.99",
|
| 435 |
+
current_score=0.0,
|
| 436 |
+
issues_remaining=29,
|
| 437 |
+
step_number=0,
|
| 438 |
+
max_steps=40,
|
| 439 |
+
done=False,
|
| 440 |
+
reward=None,
|
| 441 |
+
)
|
| 442 |
+
print(json.dumps(obs.model_dump(), indent=2))
|
| 443 |
+
|
| 444 |
+
print("\n── CleanState ────────────────────────────────────────────────")
|
| 445 |
+
state = CleanState(
|
| 446 |
+
episode_id="ep-001",
|
| 447 |
+
step_count=0,
|
| 448 |
+
task_id="easy",
|
| 449 |
+
dirty_csv_snapshot="order_id,price\n1001,N/A",
|
| 450 |
+
clean_csv_snapshot="order_id,price\n1001,14.99",
|
| 451 |
+
initial_dirty_cells=29,
|
| 452 |
+
current_score=0.0,
|
| 453 |
+
previous_score=0.0,
|
| 454 |
+
task_metadata={"injected_cells": [(0, "price")]},
|
| 455 |
+
schema_hint="Sales orders dataset.",
|
| 456 |
+
max_steps=40,
|
| 457 |
+
)
|
| 458 |
+
print(json.dumps(state.model_dump(), indent=2))
|
| 459 |
+
|
| 460 |
+
print("\n── JSON schemas ──────────────────────────────────────────────")
|
| 461 |
+
print("Action schema keys: ", list(CleanAction.model_json_schema()["properties"].keys()))
|
| 462 |
+
print("Observation schema keys:", list(CleanObservation.model_json_schema()["properties"].keys()))
|
| 463 |
+
print("State schema keys: ", list(CleanState.model_json_schema()["properties"].keys()))
|
openenv.yaml
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# openenv.yaml
|
| 2 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 3 |
+
# Manifest for the Data Cleaning Pipeline OpenEnv environment.
|
| 4 |
+
#
|
| 5 |
+
# Field reference
|
| 6 |
+
# ───────────────
|
| 7 |
+
# Required by the CLI (serve / build / push / validate):
|
| 8 |
+
# spec_version — always 1 for this generation of the spec
|
| 9 |
+
# name — environment identifier used by the CLI and auto-discovery
|
| 10 |
+
# type — "space" means it can be deployed as a Hugging Face Space
|
| 11 |
+
# runtime — "fastapi" tells the server how to boot
|
| 12 |
+
# app — Python import path to the FastAPI app object
|
| 13 |
+
# port — port the server listens on inside the container
|
| 14 |
+
#
|
| 15 |
+
# Read by AutoEnv auto-discovery (openenv.auto._discovery):
|
| 16 |
+
# name — maps to env_key after stripping the "_env" suffix
|
| 17 |
+
# description — human-readable label shown in env listings
|
| 18 |
+
# spec_version — stored in EnvironmentInfo for introspection
|
| 19 |
+
# action — EXPLICIT override of the auto-inferred class name
|
| 20 |
+
# observation — EXPLICIT override of the auto-inferred class name
|
| 21 |
+
#
|
| 22 |
+
# NOTE on action / observation overrides:
|
| 23 |
+
# Auto-discovery infers class names from the env name using PascalCase:
|
| 24 |
+
# "data_cleaning_env" → base "data_cleaning" → "CleanAction"
|
| 25 |
+
# Our actual class is named "CleanAction" (not "CleanAction"),
|
| 26 |
+
# so these fields MUST be set to avoid ImportError on AutoEnv.from_env().
|
| 27 |
+
#
|
| 28 |
+
# All other fields (tasks, reward, tags) are informational. They are not
|
| 29 |
+
# parsed by the current OpenEnv tooling but are preserved in
|
| 30 |
+
# EnvironmentInfo.manifest and available to the web UI and external tools.
|
| 31 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 32 |
+
|
| 33 |
+
# ── Core deployment fields ────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
spec_version: 1
|
| 36 |
+
name: data_cleaning_env
|
| 37 |
+
type: space
|
| 38 |
+
runtime: fastapi
|
| 39 |
+
app: server.app:app
|
| 40 |
+
port: 8000
|
| 41 |
+
|
| 42 |
+
# ── Package metadata ──────────────────────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
version: "1.0.0"
|
| 45 |
+
|
| 46 |
+
description: >-
|
| 47 |
+
Data cleaning pipeline: the agent receives a dirty CSV and must detect
|
| 48 |
+
and fix type errors, missing values, outliers, and schema inconsistencies
|
| 49 |
+
to match a hidden ground-truth dataset. Three tasks (easy → medium → hard)
|
| 50 |
+
with a deterministic grader that returns a continuous score in [0.0, 1.0].
|
| 51 |
+
|
| 52 |
+
# ── Auto-discovery class overrides ───────────────────────────────────────────
|
| 53 |
+
# These override auto-inferred names (which would be CleanAction /
|
| 54 |
+
# CleanAction) to match the actual class names defined in models.py.
|
| 55 |
+
|
| 56 |
+
action: CleanAction
|
| 57 |
+
observation: CleanObservation
|
| 58 |
+
|
| 59 |
+
# The client class is correctly inferred as DataCleaningEnv (data_cleaning →
|
| 60 |
+
# DataCleaning + Env), which matches client.py, so no override is needed.
|
| 61 |
+
|
| 62 |
+
# ── Tags (informational) ──────────────────────────────────────────────────────
|
| 63 |
+
|
| 64 |
+
tags:
|
| 65 |
+
- data-cleaning
|
| 66 |
+
- tabular
|
| 67 |
+
- real-world
|
| 68 |
+
- hackathon
|
| 69 |
+
|
| 70 |
+
# ── Task manifest (informational) ─────────────────────────────────────────────
|
| 71 |
+
# One entry per task. These values mirror the constants in models.py
|
| 72 |
+
# (MAX_STEPS, DONE_THRESHOLD) and the descriptions in dataset_factory.py.
|
| 73 |
+
|
| 74 |
+
tasks:
|
| 75 |
+
- id: easy
|
| 76 |
+
name: Fix obvious errors
|
| 77 |
+
description: >-
|
| 78 |
+
50-row sales CSV with 29 injected dirty cells: 10 type mismatches
|
| 79 |
+
(text in numeric columns), 8 missing values, 5 far-future dates
|
| 80 |
+
(year 2099), and 6 cells with leading/trailing whitespace.
|
| 81 |
+
Graded by exact cell-level match against the ground truth (0.0–1.0).
|
| 82 |
+
dataset_rows: 50
|
| 83 |
+
dirty_cells: 29
|
| 84 |
+
max_steps: 40
|
| 85 |
+
done_threshold: 0.95
|
| 86 |
+
|
| 87 |
+
- id: medium
|
| 88 |
+
name: Outlier detection without false positives
|
| 89 |
+
description: >-
|
| 90 |
+
200-row customer transaction CSV with 15 true statistical outliers
|
| 91 |
+
(negative or > $2000 amounts) that must be fixed or removed, 5 valid
|
| 92 |
+
large transactions ($900–$2000) that must NOT be removed, and 12
|
| 93 |
+
category spelling typos. Graded by F1 score on outlier detection
|
| 94 |
+
(0.5 weight) and typo correction rate (0.5 weight).
|
| 95 |
+
dataset_rows: 200
|
| 96 |
+
dirty_cells: 27
|
| 97 |
+
max_steps: 80
|
| 98 |
+
done_threshold: 0.85
|
| 99 |
+
|
| 100 |
+
- id: hard
|
| 101 |
+
name: Multi-source schema normalisation and deduplication
|
| 102 |
+
description: >-
|
| 103 |
+
430-row CSV (400 clean + 30 duplicates) merged from 3 fictional data
|
| 104 |
+
sources with inconsistent column naming (e.g. cust_id / customer_id /
|
| 105 |
+
CustomerID), mixed date formats (ISO, US, EU), and ~30 duplicate rows
|
| 106 |
+
(exact and near-duplicate). Agent must infer the canonical 9-column
|
| 107 |
+
schema without explicit documentation. Graded by schema match (40%),
|
| 108 |
+
deduplication F1 (35%), and date format compliance (25%).
|
| 109 |
+
dataset_rows: 430
|
| 110 |
+
canonical_rows: 400
|
| 111 |
+
canonical_columns: 9
|
| 112 |
+
duplicate_rows: 30
|
| 113 |
+
max_steps: 150
|
| 114 |
+
done_threshold: 0.80
|
| 115 |
+
|
| 116 |
+
# ── Reward function summary (informational) ───────────────────────────────────
|
| 117 |
+
|
| 118 |
+
reward:
|
| 119 |
+
type: dense
|
| 120 |
+
range: [-0.5, 1.0]
|
| 121 |
+
step_cost: -0.005
|
| 122 |
+
components:
|
| 123 |
+
- name: progress
|
| 124 |
+
weight: primary
|
| 125 |
+
description: >-
|
| 126 |
+
Grader score delta each step (curr_score − prev_score).
|
| 127 |
+
The main learning signal — any cell fixed produces a non-zero reward.
|
| 128 |
+
|
| 129 |
+
- name: efficiency_bonus
|
| 130 |
+
weight: "+0.10 × (1 − step_fraction)"
|
| 131 |
+
description: >-
|
| 132 |
+
Small bonus awarded the step the episode is solved (score crosses
|
| 133 |
+
done_threshold). Rewards finishing early relative to the step budget.
|
| 134 |
+
|
| 135 |
+
- name: false_positive_penalty
|
| 136 |
+
weight: -0.15
|
| 137 |
+
description: >-
|
| 138 |
+
Applied when DROP_ROW removes a valid-extreme row in the medium task.
|
| 139 |
+
Penalises aggressive deletion without checking schema_hint.
|
| 140 |
+
|
| 141 |
+
- name: early_done_penalty
|
| 142 |
+
weight: -0.20
|
| 143 |
+
description: >-
|
| 144 |
+
Applied when the agent sends DONE with current_score < 0.60.
|
| 145 |
+
Discourages giving up prematurely.
|
| 146 |
+
|
| 147 |
+
- name: step_cost
|
| 148 |
+
weight: -0.005
|
| 149 |
+
description: >-
|
| 150 |
+
Fixed cost every step regardless of outcome.
|
| 151 |
+
Prevents infinite loops and padding.
|
openenv_data_cleaning_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-data_cleaning_env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Data Cleaning Env environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core
|
| 7 |
+
Requires-Dist: pandas>=2.0
|
| 8 |
+
Requires-Dist: numpy>=1.24
|
| 9 |
+
Requires-Dist: fastapi
|
| 10 |
+
Requires-Dist: uvicorn
|
| 11 |
+
Provides-Extra: dev
|
| 12 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 13 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_data_cleaning_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
__init__.py
|
| 3 |
+
client.py
|
| 4 |
+
dataset_factory.py
|
| 5 |
+
graders.py
|
| 6 |
+
models.py
|
| 7 |
+
pyproject.toml
|
| 8 |
+
./__init__.py
|
| 9 |
+
./client.py
|
| 10 |
+
./dataset_factory.py
|
| 11 |
+
./graders.py
|
| 12 |
+
./models.py
|
| 13 |
+
openenv_data_cleaning_env.egg-info/PKG-INFO
|
| 14 |
+
openenv_data_cleaning_env.egg-info/SOURCES.txt
|
| 15 |
+
openenv_data_cleaning_env.egg-info/dependency_links.txt
|
| 16 |
+
openenv_data_cleaning_env.egg-info/entry_points.txt
|
| 17 |
+
openenv_data_cleaning_env.egg-info/requires.txt
|
| 18 |
+
openenv_data_cleaning_env.egg-info/top_level.txt
|
| 19 |
+
server/__init__.py
|
| 20 |
+
server/app.py
|
| 21 |
+
server/data_cleaning_env.py
|
openenv_data_cleaning_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_data_cleaning_env.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = data_cleaning_env.server.app:main
|
openenv_data_cleaning_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core
|
| 2 |
+
pandas>=2.0
|
| 3 |
+
numpy>=1.24
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
|
| 7 |
+
[dev]
|
| 8 |
+
pytest>=8.0.0
|
| 9 |
+
pytest-cov>=4.0.0
|
openenv_data_cleaning_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
data_cleaning_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-data_cleaning_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Data Cleaning Env environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
# Environment-specific dependencies
|
| 21 |
+
# Add all dependencies needed for your environment here
|
| 22 |
+
# Examples:
|
| 23 |
+
# "torch>=2.0.0",
|
| 24 |
+
# "gymnasium>=0.29.0",
|
| 25 |
+
# "openspiel>=1.0.0",
|
| 26 |
+
# "smolagents>=1.22.0,<2",
|
| 27 |
+
"openenv-core",
|
| 28 |
+
"pandas>=2.0",
|
| 29 |
+
"numpy>=1.24",
|
| 30 |
+
"fastapi",
|
| 31 |
+
"uvicorn",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[project.optional-dependencies]
|
| 35 |
+
dev = [
|
| 36 |
+
"pytest>=8.0.0",
|
| 37 |
+
"pytest-cov>=4.0.0",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[project.scripts]
|
| 41 |
+
# Server entry point - enables running via: uv run --project . server
|
| 42 |
+
# or: python -m data_cleaning_env.server.app
|
| 43 |
+
server = "data_cleaning_env.server.app:main"
|
| 44 |
+
|
| 45 |
+
[tool.setuptools]
|
| 46 |
+
include-package-data = true
|
| 47 |
+
packages = ["data_cleaning_env", "data_cleaning_env.server"]
|
| 48 |
+
package-dir = { "data_cleaning_env" = ".", "data_cleaning_env.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Data Cleaning Env environment server components."""
|
| 8 |
+
|
| 9 |
+
from .data_cleaning_env import DataCleaningEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["DataCleaningEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from openenv.core.env_server import create_app
|
| 3 |
+
from ..models import CleanAction, CleanObservation
|
| 4 |
+
from .data_cleaning_env import DataCleaningEnvironment
|
| 5 |
+
except ImportError:
|
| 6 |
+
from openenv.core.env_server import create_app
|
| 7 |
+
from models import CleanAction, CleanObservation
|
| 8 |
+
from server.data_cleaning_env import DataCleaningEnvironment
|
| 9 |
+
|
| 10 |
+
app = create_app(
|
| 11 |
+
DataCleaningEnvironment, # class, not instance
|
| 12 |
+
CleanAction,
|
| 13 |
+
CleanObservation,
|
| 14 |
+
env_name="data_cleaning_env",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main() -> None:
|
| 19 |
+
"""Entry point for openenv serve / uv run / python -m."""
|
| 20 |
+
import uvicorn
|
| 21 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
main()
|
server/data_cleaning_env.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/data_cleaning_env.py
|
| 3 |
+
---------------------------
|
| 4 |
+
DataCleaningEnvironment — the heart of the environment.
|
| 5 |
+
|
| 6 |
+
Implements the three abstract methods from openenv.core.env_server.interfaces.Environment:
|
| 7 |
+
reset(seed, episode_id, **kwargs) -> CleanObservation
|
| 8 |
+
step(action, timeout_s, **kwargs) -> CleanObservation
|
| 9 |
+
state (property) -> CleanState
|
| 10 |
+
|
| 11 |
+
Architecture
|
| 12 |
+
------------
|
| 13 |
+
Live DataFrames (_dirty_df, _clean_df) live as instance variables for speed.
|
| 14 |
+
CleanState holds lightweight CSV snapshots used only for WebSocket state()
|
| 15 |
+
responses — not for every step. This avoids serialising a 400-row DataFrame
|
| 16 |
+
on every call.
|
| 17 |
+
|
| 18 |
+
Action dispatch
|
| 19 |
+
---------------
|
| 20 |
+
Each CleanAction.command routes to a private _apply_* method that mutates
|
| 21 |
+
_dirty_df in place. Errors in those methods (bad column name, out-of-bounds
|
| 22 |
+
row) are caught and returned as (success=False, error_msg=...) so the agent
|
| 23 |
+
gets corrective feedback instead of a 500.
|
| 24 |
+
|
| 25 |
+
Reward
|
| 26 |
+
------
|
| 27 |
+
compute_reward() implements the dense reward formula designed in the plan:
|
| 28 |
+
progress term — grader score delta (main signal)
|
| 29 |
+
efficiency bonus — small reward for early completion
|
| 30 |
+
false-positive penalty — for dropping a valid-extreme row (medium task)
|
| 31 |
+
early-DONE penalty — for calling DONE with a low score
|
| 32 |
+
step cost — -0.005 every step to discourage padding
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
import sys
|
| 38 |
+
import os
|
| 39 |
+
from typing import Any, Optional
|
| 40 |
+
from uuid import uuid4
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
import pandas as pd
|
| 44 |
+
|
| 45 |
+
# ── OpenEnv imports (try relative → absolute) ─────────────────────────────────
|
| 46 |
+
try:
|
| 47 |
+
from openenv.core.env_server.interfaces import Environment
|
| 48 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 49 |
+
except ImportError:
|
| 50 |
+
from openenv.core.env_server.interfaces import Environment
|
| 51 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 52 |
+
|
| 53 |
+
# ── Local imports (try relative → absolute for both server and standalone) ───
|
| 54 |
+
try:
|
| 55 |
+
from ..models import (
|
| 56 |
+
CleanAction, CleanObservation, CleanState,
|
| 57 |
+
MAX_STEPS, DONE_THRESHOLD,
|
| 58 |
+
)
|
| 59 |
+
from ..dataset_factory import make_dataset, TaskDataset
|
| 60 |
+
from ..graders import grade, GradeResult
|
| 61 |
+
except ImportError:
|
| 62 |
+
try:
|
| 63 |
+
from models import (
|
| 64 |
+
CleanAction, CleanObservation, CleanState,
|
| 65 |
+
MAX_STEPS, DONE_THRESHOLD,
|
| 66 |
+
)
|
| 67 |
+
from dataset_factory import make_dataset, TaskDataset
|
| 68 |
+
from graders import grade, GradeResult
|
| 69 |
+
except ImportError:
|
| 70 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 71 |
+
from models import (
|
| 72 |
+
CleanAction, CleanObservation, CleanState,
|
| 73 |
+
MAX_STEPS, DONE_THRESHOLD,
|
| 74 |
+
)
|
| 75 |
+
from dataset_factory import make_dataset, TaskDataset
|
| 76 |
+
from graders import grade, GradeResult
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
| 80 |
+
|
| 81 |
+
# Per-step cost that discourages infinite loops / padding
|
| 82 |
+
STEP_COST = -0.005
|
| 83 |
+
|
| 84 |
+
# Penalty for calling DONE before the score is reasonable
|
| 85 |
+
EARLY_DONE_PENALTY = -0.20
|
| 86 |
+
EARLY_DONE_THRESHOLD = 0.60 # DONE below this score triggers the penalty
|
| 87 |
+
|
| 88 |
+
# Penalty for removing a valid-extreme row in the medium task
|
| 89 |
+
FALSE_POSITIVE_PENALTY = -0.15
|
| 90 |
+
|
| 91 |
+
# Efficiency bonus multiplier (only awarded when episode is solved)
|
| 92 |
+
EFFICIENCY_BONUS_WEIGHT = 0.10
|
| 93 |
+
|
| 94 |
+
# Date formats the STANDARDIZE_COL handler will try, in priority order
|
| 95 |
+
_DATE_PARSE_FORMATS = [
|
| 96 |
+
"%Y-%m-%d", # ISO — most reliable, try first
|
| 97 |
+
"%m/%d/%Y", # US
|
| 98 |
+
"%d.%m.%Y", # EU
|
| 99 |
+
"%d/%m/%Y", # EU alt
|
| 100 |
+
"%Y/%m/%d", # Asian
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 105 |
+
# DataCleaningEnvironment
|
| 106 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
class DataCleaningEnvironment(Environment):
|
| 109 |
+
"""
|
| 110 |
+
Gym-style environment for the data cleaning pipeline task.
|
| 111 |
+
|
| 112 |
+
Each episode:
|
| 113 |
+
1. reset(task_id="easy"|"medium"|"hard") loads a dirty/clean CSV pair.
|
| 114 |
+
2. The agent calls step() repeatedly, each time sending a CleanAction.
|
| 115 |
+
3. The episode ends when the agent sends DONE, the score crosses the
|
| 116 |
+
task threshold, or the step budget is exhausted.
|
| 117 |
+
|
| 118 |
+
The environment is fully stateless between sessions — all mutable state
|
| 119 |
+
lives in instance variables, so concurrent sessions each get their own
|
| 120 |
+
isolated copy (SUPPORTS_CONCURRENT_SESSIONS = True).
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 124 |
+
|
| 125 |
+
def __init__(self) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
# Live DataFrames — mutated by each step()
|
| 129 |
+
self._dirty_df: Optional[pd.DataFrame] = None
|
| 130 |
+
self._clean_df: Optional[pd.DataFrame] = None
|
| 131 |
+
|
| 132 |
+
# Full task dataset from dataset_factory (holds metadata for grader)
|
| 133 |
+
self._dataset: Optional[TaskDataset] = None
|
| 134 |
+
|
| 135 |
+
# Pydantic state (lightweight; updated on demand)
|
| 136 |
+
self._state: Optional[CleanState] = None
|
| 137 |
+
|
| 138 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 139 |
+
# reset()
|
| 140 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 141 |
+
|
| 142 |
+
def reset(
|
| 143 |
+
self,
|
| 144 |
+
seed: Optional[int] = None,
|
| 145 |
+
episode_id: Optional[str] = None,
|
| 146 |
+
task_id: str = "easy",
|
| 147 |
+
**kwargs: Any,
|
| 148 |
+
) -> CleanObservation:
|
| 149 |
+
"""
|
| 150 |
+
Reset the environment for a new episode.
|
| 151 |
+
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
seed
|
| 155 |
+
Ignored — datasets use fixed seeds per task for reproducibility.
|
| 156 |
+
episode_id
|
| 157 |
+
Optional; auto-generated if not provided.
|
| 158 |
+
task_id
|
| 159 |
+
Which task to load: "easy", "medium", or "hard".
|
| 160 |
+
"""
|
| 161 |
+
if task_id not in MAX_STEPS:
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"Unknown task_id {task_id!r}. Must be one of: {list(MAX_STEPS)}"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Load dataset (always deterministic via fixed seed in dataset_factory)
|
| 167 |
+
self._dataset = make_dataset(task_id)
|
| 168 |
+
self._dirty_df = self._dataset.dirty_df.copy(deep=True)
|
| 169 |
+
self._clean_df = self._dataset.clean_df.copy(deep=True)
|
| 170 |
+
|
| 171 |
+
max_steps = MAX_STEPS[task_id]
|
| 172 |
+
|
| 173 |
+
# Run grader on the initial dirty state so we have a starting score
|
| 174 |
+
initial_result = grade(
|
| 175 |
+
task_id=task_id,
|
| 176 |
+
agent_df=self._dirty_df,
|
| 177 |
+
clean_df=self._clean_df,
|
| 178 |
+
metadata=self._dataset.metadata,
|
| 179 |
+
initial_dirty_cells=self._dataset.total_dirty_cells,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
self._state = CleanState(
|
| 183 |
+
episode_id=episode_id or str(uuid4()),
|
| 184 |
+
step_count=0,
|
| 185 |
+
task_id=task_id,
|
| 186 |
+
dirty_csv_snapshot=self._df_to_csv(self._dirty_df),
|
| 187 |
+
clean_csv_snapshot=self._df_to_csv(self._clean_df),
|
| 188 |
+
initial_dirty_cells=self._dataset.total_dirty_cells,
|
| 189 |
+
current_score=initial_result.score,
|
| 190 |
+
previous_score=0.0,
|
| 191 |
+
task_metadata=self._dataset.metadata,
|
| 192 |
+
schema_hint=self._dataset.schema_hint,
|
| 193 |
+
max_steps=max_steps,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return self._build_observation(
|
| 197 |
+
reward=None,
|
| 198 |
+
done=False,
|
| 199 |
+
last_action_success=True,
|
| 200 |
+
last_action_error=None,
|
| 201 |
+
grader_result=initial_result,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 205 |
+
# step()
|
| 206 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 207 |
+
|
| 208 |
+
def step(
|
| 209 |
+
self,
|
| 210 |
+
action: CleanAction,
|
| 211 |
+
timeout_s: Optional[float] = None,
|
| 212 |
+
**kwargs: Any,
|
| 213 |
+
) -> CleanObservation:
|
| 214 |
+
"""
|
| 215 |
+
Apply one CleanAction and return the resulting observation.
|
| 216 |
+
|
| 217 |
+
Never raises for bad action inputs — instead returns
|
| 218 |
+
last_action_success=False with a descriptive error message so the
|
| 219 |
+
agent can self-correct on the next step.
|
| 220 |
+
"""
|
| 221 |
+
if self._state is None or self._dirty_df is None:
|
| 222 |
+
raise RuntimeError("Environment not initialised. Call reset() first.")
|
| 223 |
+
|
| 224 |
+
self._state.step_count += 1
|
| 225 |
+
|
| 226 |
+
# ── Save previous score before mutating ──────────────────────────────
|
| 227 |
+
prev_score = self._state.current_score
|
| 228 |
+
self._state.previous_score = prev_score
|
| 229 |
+
|
| 230 |
+
# ── DONE shortcut ────────────────────────────────────────────────────
|
| 231 |
+
if action.command == "DONE":
|
| 232 |
+
reward = self._compute_reward(
|
| 233 |
+
action=action,
|
| 234 |
+
prev_score=prev_score,
|
| 235 |
+
curr_score=prev_score, # score doesn't change on DONE
|
| 236 |
+
action_success=True,
|
| 237 |
+
was_false_positive=False,
|
| 238 |
+
)
|
| 239 |
+
done = True
|
| 240 |
+
self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
|
| 241 |
+
return self._build_observation(
|
| 242 |
+
reward=reward,
|
| 243 |
+
done=done,
|
| 244 |
+
last_action_success=True,
|
| 245 |
+
last_action_error=None,
|
| 246 |
+
grader_result=GradeResult(
|
| 247 |
+
score=prev_score,
|
| 248 |
+
issues_remaining=self._state.initial_dirty_cells
|
| 249 |
+
- int(prev_score * self._state.initial_dirty_cells),
|
| 250 |
+
detail="Agent signalled DONE.",
|
| 251 |
+
),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# ── Apply action to _dirty_df ────────────────────────────────────────
|
| 255 |
+
action_success, error_msg, was_false_positive = self._apply_action(action)
|
| 256 |
+
|
| 257 |
+
# ── Grade the result ──────────────────────────────────────────────────
|
| 258 |
+
grader_result = grade(
|
| 259 |
+
task_id=self._state.task_id,
|
| 260 |
+
agent_df=self._dirty_df,
|
| 261 |
+
clean_df=self._clean_df,
|
| 262 |
+
metadata=self._state.task_metadata,
|
| 263 |
+
initial_dirty_cells=self._state.initial_dirty_cells,
|
| 264 |
+
)
|
| 265 |
+
curr_score = grader_result.score
|
| 266 |
+
self._state.current_score = curr_score
|
| 267 |
+
|
| 268 |
+
# ── Compute reward ────────────────────────────────────────────────────
|
| 269 |
+
reward = self._compute_reward(
|
| 270 |
+
action=action,
|
| 271 |
+
prev_score=prev_score,
|
| 272 |
+
curr_score=curr_score,
|
| 273 |
+
action_success=action_success,
|
| 274 |
+
was_false_positive=was_false_positive,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# ── Check termination ────────────────────────────────────────────────
|
| 278 |
+
done = (
|
| 279 |
+
curr_score >= DONE_THRESHOLD[self._state.task_id]
|
| 280 |
+
or self._state.step_count >= self._state.max_steps
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# ── Sync state snapshot ──────────────────────────────────────────────
|
| 284 |
+
self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
|
| 285 |
+
|
| 286 |
+
return self._build_observation(
|
| 287 |
+
reward=reward,
|
| 288 |
+
done=done,
|
| 289 |
+
last_action_success=action_success,
|
| 290 |
+
last_action_error=error_msg,
|
| 291 |
+
grader_result=grader_result,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 295 |
+
# state (property)
|
| 296 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def state(self) -> CleanState:
|
| 300 |
+
"""Return the current environment state (serialisable snapshot)."""
|
| 301 |
+
if self._state is None:
|
| 302 |
+
raise RuntimeError("Environment not initialised. Call reset() first.")
|
| 303 |
+
# Keep snapshot fresh in case step() was called without triggering a sync
|
| 304 |
+
if self._dirty_df is not None:
|
| 305 |
+
self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df)
|
| 306 |
+
return self._state
|
| 307 |
+
|
| 308 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 309 |
+
# Action dispatch
|
| 310 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 311 |
+
|
| 312 |
+
def _apply_action(
|
| 313 |
+
self, action: CleanAction
|
| 314 |
+
) -> tuple[bool, Optional[str], bool]:
|
| 315 |
+
"""
|
| 316 |
+
Mutate self._dirty_df according to the action.
|
| 317 |
+
|
| 318 |
+
Returns
|
| 319 |
+
-------
|
| 320 |
+
(success, error_msg, was_false_positive)
|
| 321 |
+
success — True if action applied without error
|
| 322 |
+
error_msg — human-readable description if success=False
|
| 323 |
+
was_false_positive — True if a DROP_ROW removed a valid-extreme row
|
| 324 |
+
"""
|
| 325 |
+
cmd = action.command
|
| 326 |
+
|
| 327 |
+
if cmd == "SET_VALUE":
|
| 328 |
+
return self._apply_set_value(action)
|
| 329 |
+
|
| 330 |
+
elif cmd == "DROP_ROW":
|
| 331 |
+
return self._apply_drop_row(action)
|
| 332 |
+
|
| 333 |
+
elif cmd == "STANDARDIZE_COL":
|
| 334 |
+
return self._apply_standardize_col(action)
|
| 335 |
+
|
| 336 |
+
elif cmd == "FILL_MISSING":
|
| 337 |
+
return self._apply_fill_missing(action)
|
| 338 |
+
|
| 339 |
+
else:
|
| 340 |
+
return False, f"Unknown command: {cmd!r}", False
|
| 341 |
+
|
| 342 |
+
# ── SET_VALUE ─────────────────────────────────────────────────────────────
|
| 343 |
+
|
| 344 |
+
def _apply_set_value(
|
| 345 |
+
self, action: CleanAction
|
| 346 |
+
) -> tuple[bool, Optional[str], bool]:
|
| 347 |
+
df = self._dirty_df
|
| 348 |
+
row_idx = action.row_index
|
| 349 |
+
col = action.column
|
| 350 |
+
val = action.value
|
| 351 |
+
|
| 352 |
+
# Validate column
|
| 353 |
+
if col not in df.columns:
|
| 354 |
+
return (
|
| 355 |
+
False,
|
| 356 |
+
f"Column {col!r} not found. Available: {list(df.columns)}",
|
| 357 |
+
False,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Validate row index (positional)
|
| 361 |
+
if row_idx < 0 or row_idx >= len(df):
|
| 362 |
+
return (
|
| 363 |
+
False,
|
| 364 |
+
f"Row index {row_idx} out of range. DataFrame has {len(df)} rows (0–{len(df)-1}).",
|
| 365 |
+
False,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Try to cast value to the column's expected type
|
| 369 |
+
cast_val, cast_err = self._cast_value(val, df, col)
|
| 370 |
+
if cast_err:
|
| 371 |
+
return False, cast_err, False
|
| 372 |
+
|
| 373 |
+
df.iloc[row_idx, df.columns.get_loc(col)] = cast_val
|
| 374 |
+
return True, None, False
|
| 375 |
+
|
| 376 |
+
# ── DROP_ROW ──────────────────────────────────────────────────────────────
|
| 377 |
+
|
| 378 |
+
def _apply_drop_row(
|
| 379 |
+
self, action: CleanAction
|
| 380 |
+
) -> tuple[bool, Optional[str], bool]:
|
| 381 |
+
df = self._dirty_df
|
| 382 |
+
row_idx = action.row_index
|
| 383 |
+
|
| 384 |
+
if row_idx < 0 or row_idx >= len(df):
|
| 385 |
+
return (
|
| 386 |
+
False,
|
| 387 |
+
f"Row index {row_idx} out of range. DataFrame has {len(df)} rows.",
|
| 388 |
+
False,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Detect false positive for medium task: is this a valid-extreme row?
|
| 392 |
+
was_false_positive = self._is_valid_extreme_row(row_idx)
|
| 393 |
+
|
| 394 |
+
# Drop the row and reset positional index so future iloc references stay valid
|
| 395 |
+
self._dirty_df = df.drop(df.index[row_idx]).reset_index(drop=True)
|
| 396 |
+
return True, None, was_false_positive
|
| 397 |
+
|
| 398 |
+
def _is_valid_extreme_row(self, iloc_idx: int) -> bool:
|
| 399 |
+
"""
|
| 400 |
+
Return True if dropping this row would be a false positive.
|
| 401 |
+
Only applies to the medium task, which tracks valid_extreme_rows
|
| 402 |
+
by their original tx_id.
|
| 403 |
+
"""
|
| 404 |
+
if self._state is None or self._state.task_id != "medium":
|
| 405 |
+
return False
|
| 406 |
+
|
| 407 |
+
valid_extreme_rows: list = self._state.task_metadata.get(
|
| 408 |
+
"valid_extreme_rows", []
|
| 409 |
+
)
|
| 410 |
+
if not valid_extreme_rows or self._clean_df is None:
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
df = self._dirty_df
|
| 414 |
+
if "tx_id" not in df.columns:
|
| 415 |
+
return False
|
| 416 |
+
|
| 417 |
+
# Get the tx_id of the row being dropped
|
| 418 |
+
try:
|
| 419 |
+
tx_id_to_drop = int(df.iloc[iloc_idx]["tx_id"])
|
| 420 |
+
except (IndexError, ValueError, KeyError):
|
| 421 |
+
return False
|
| 422 |
+
|
| 423 |
+
# Check if any valid-extreme row in clean_df has this tx_id
|
| 424 |
+
for orig_idx in valid_extreme_rows:
|
| 425 |
+
if orig_idx >= len(self._clean_df):
|
| 426 |
+
continue
|
| 427 |
+
if int(self._clean_df.iloc[orig_idx]["tx_id"]) == tx_id_to_drop:
|
| 428 |
+
return True
|
| 429 |
+
|
| 430 |
+
return False
|
| 431 |
+
|
| 432 |
+
# ── STANDARDIZE_COL ───────────────────────────────────────────────────────
|
| 433 |
+
|
| 434 |
+
def _apply_standardize_col(
|
| 435 |
+
self, action: CleanAction
|
| 436 |
+
) -> tuple[bool, Optional[str], bool]:
|
| 437 |
+
df = self._dirty_df
|
| 438 |
+
col = action.column
|
| 439 |
+
|
| 440 |
+
if col not in df.columns:
|
| 441 |
+
return (
|
| 442 |
+
False,
|
| 443 |
+
f"Column {col!r} not found. Available: {list(df.columns)}",
|
| 444 |
+
False,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
series = df[col].copy()
|
| 448 |
+
|
| 449 |
+
# ── Try date normalisation first ──────────────────────────────────────
|
| 450 |
+
if self._looks_like_date_column(col, series):
|
| 451 |
+
normalised, err = self._normalise_dates(series)
|
| 452 |
+
if err:
|
| 453 |
+
return False, f"Date normalisation failed for column {col!r}: {err}", False
|
| 454 |
+
self._dirty_df[col] = normalised
|
| 455 |
+
return True, None, False
|
| 456 |
+
|
| 457 |
+
# ── Try numeric coercion ──────────────────────────────────────────────
|
| 458 |
+
if self._looks_like_numeric_column(col, series):
|
| 459 |
+
numeric = pd.to_numeric(series, errors="coerce")
|
| 460 |
+
# Only apply if we didn't lose more than 20% of non-null values
|
| 461 |
+
original_non_null = series.notna().sum()
|
| 462 |
+
coerced_non_null = numeric.notna().sum()
|
| 463 |
+
if original_non_null == 0 or coerced_non_null / original_non_null >= 0.8:
|
| 464 |
+
self._dirty_df[col] = numeric
|
| 465 |
+
return True, None, False
|
| 466 |
+
|
| 467 |
+
# ── String normalisation: strip whitespace ───────────────────────────
|
| 468 |
+
self._dirty_df[col] = series.apply(
|
| 469 |
+
lambda x: str(x).strip() if not _is_nan(x) else x
|
| 470 |
+
)
|
| 471 |
+
return True, None, False
|
| 472 |
+
|
| 473 |
+
def _looks_like_date_column(self, col: str, series: pd.Series) -> bool:
|
| 474 |
+
"""Heuristic: column name contains 'date' or most non-null values parse as dates."""
|
| 475 |
+
if "date" in col.lower():
|
| 476 |
+
return True
|
| 477 |
+
sample = series.dropna().astype(str).head(5)
|
| 478 |
+
parsed = 0
|
| 479 |
+
for s in sample:
|
| 480 |
+
for fmt in _DATE_PARSE_FORMATS:
|
| 481 |
+
try:
|
| 482 |
+
pd.to_datetime(s, format=fmt)
|
| 483 |
+
parsed += 1
|
| 484 |
+
break
|
| 485 |
+
except Exception:
|
| 486 |
+
pass
|
| 487 |
+
return parsed >= max(1, len(sample) // 2)
|
| 488 |
+
|
| 489 |
+
def _looks_like_numeric_column(self, col: str, series: pd.Series) -> bool:
|
| 490 |
+
"""Heuristic: column name or majority of values suggests numeric data."""
|
| 491 |
+
numeric_keywords = {"price", "amount", "value", "quantity", "qty", "count", "id", "num"}
|
| 492 |
+
if any(kw in col.lower() for kw in numeric_keywords):
|
| 493 |
+
return True
|
| 494 |
+
sample = series.dropna().head(10)
|
| 495 |
+
if len(sample) == 0:
|
| 496 |
+
return False
|
| 497 |
+
convertible = pd.to_numeric(sample, errors="coerce").notna().sum()
|
| 498 |
+
return convertible / len(sample) >= 0.7
|
| 499 |
+
|
| 500 |
+
def _normalise_dates(self, series: pd.Series) -> tuple[pd.Series, Optional[str]]:
|
| 501 |
+
"""Parse dates in any supported format and reformat as YYYY-MM-DD."""
|
| 502 |
+
def _parse_one(x: Any) -> Any:
|
| 503 |
+
if _is_nan(x):
|
| 504 |
+
return x
|
| 505 |
+
s = str(x).strip()
|
| 506 |
+
for fmt in _DATE_PARSE_FORMATS:
|
| 507 |
+
try:
|
| 508 |
+
return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d")
|
| 509 |
+
except Exception:
|
| 510 |
+
pass
|
| 511 |
+
# Last resort: let pandas guess
|
| 512 |
+
try:
|
| 513 |
+
parsed = pd.to_datetime(s, dayfirst=False)
|
| 514 |
+
if 2000 <= parsed.year <= 2030:
|
| 515 |
+
return parsed.strftime("%Y-%m-%d")
|
| 516 |
+
except Exception:
|
| 517 |
+
pass
|
| 518 |
+
return x # leave unchanged if unparseable
|
| 519 |
+
|
| 520 |
+
return series.apply(_parse_one), None
|
| 521 |
+
|
| 522 |
+
# ── FILL_MISSING ──────────────────────────────────────────────────────────
|
| 523 |
+
|
| 524 |
+
def _apply_fill_missing(
|
| 525 |
+
self, action: CleanAction
|
| 526 |
+
) -> tuple[bool, Optional[str], bool]:
|
| 527 |
+
df = self._dirty_df
|
| 528 |
+
col = action.column
|
| 529 |
+
strategy = action.fill_strategy
|
| 530 |
+
|
| 531 |
+
if col not in df.columns:
|
| 532 |
+
return (
|
| 533 |
+
False,
|
| 534 |
+
f"Column {col!r} not found. Available: {list(df.columns)}",
|
| 535 |
+
False,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
series = df[col].copy()
|
| 539 |
+
numeric = pd.to_numeric(series, errors="coerce")
|
| 540 |
+
has_numeric = numeric.notna().sum() > 0
|
| 541 |
+
|
| 542 |
+
if strategy == "mean":
|
| 543 |
+
if not has_numeric:
|
| 544 |
+
return False, f"Cannot compute mean for non-numeric column {col!r}.", False
|
| 545 |
+
fill_val = numeric.mean()
|
| 546 |
+
self._dirty_df[col] = numeric.fillna(round(fill_val, 2))
|
| 547 |
+
|
| 548 |
+
elif strategy == "median":
|
| 549 |
+
if not has_numeric:
|
| 550 |
+
return False, f"Cannot compute median for non-numeric column {col!r}.", False
|
| 551 |
+
fill_val = numeric.median()
|
| 552 |
+
self._dirty_df[col] = numeric.fillna(round(fill_val, 2))
|
| 553 |
+
|
| 554 |
+
elif strategy == "mode":
|
| 555 |
+
mode_result = series.mode(dropna=True)
|
| 556 |
+
if mode_result.empty:
|
| 557 |
+
return False, f"No mode found for column {col!r} (all values missing?).", False
|
| 558 |
+
self._dirty_df[col] = series.fillna(mode_result.iloc[0])
|
| 559 |
+
|
| 560 |
+
elif strategy == "drop":
|
| 561 |
+
before = len(self._dirty_df)
|
| 562 |
+
self._dirty_df = self._dirty_df.dropna(subset=[col]).reset_index(drop=True)
|
| 563 |
+
after = len(self._dirty_df)
|
| 564 |
+
return True, None, False
|
| 565 |
+
|
| 566 |
+
else:
|
| 567 |
+
return False, f"Unknown fill_strategy: {strategy!r}", False
|
| 568 |
+
|
| 569 |
+
return True, None, False
|
| 570 |
+
|
| 571 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 572 |
+
# Reward computation
|
| 573 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 574 |
+
|
| 575 |
+
def _compute_reward(
|
| 576 |
+
self,
|
| 577 |
+
action: CleanAction,
|
| 578 |
+
prev_score: float,
|
| 579 |
+
curr_score: float,
|
| 580 |
+
action_success: bool,
|
| 581 |
+
was_false_positive: bool,
|
| 582 |
+
) -> float:
|
| 583 |
+
"""
|
| 584 |
+
Dense per-step reward in the range [-0.5, +1.0].
|
| 585 |
+
|
| 586 |
+
Components
|
| 587 |
+
----------
|
| 588 |
+
progress score delta (main learning signal)
|
| 589 |
+
efficiency bonus small reward for solving with steps to spare
|
| 590 |
+
fp_penalty penalise removing a valid-extreme row (medium task)
|
| 591 |
+
early_done_penalty penalise calling DONE with a very low score
|
| 592 |
+
step_cost tiny constant cost to discourage padding
|
| 593 |
+
"""
|
| 594 |
+
if self._state is None:
|
| 595 |
+
return 0.0
|
| 596 |
+
|
| 597 |
+
max_steps = self._state.max_steps
|
| 598 |
+
step_count = self._state.step_count
|
| 599 |
+
|
| 600 |
+
# 1. Progress term
|
| 601 |
+
progress = curr_score - prev_score
|
| 602 |
+
|
| 603 |
+
# 2. Efficiency bonus (only when task is solved this step)
|
| 604 |
+
threshold = DONE_THRESHOLD[self._state.task_id]
|
| 605 |
+
just_solved = prev_score < threshold <= curr_score
|
| 606 |
+
step_fraction = step_count / max_steps
|
| 607 |
+
efficiency = EFFICIENCY_BONUS_WEIGHT * (1.0 - step_fraction) if just_solved else 0.0
|
| 608 |
+
|
| 609 |
+
# 3. False-positive penalty
|
| 610 |
+
fp_penalty = FALSE_POSITIVE_PENALTY if was_false_positive else 0.0
|
| 611 |
+
|
| 612 |
+
# 4. Early-DONE penalty
|
| 613 |
+
early_done = (
|
| 614 |
+
EARLY_DONE_PENALTY
|
| 615 |
+
if action.command == "DONE" and curr_score < EARLY_DONE_THRESHOLD
|
| 616 |
+
else 0.0
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# 5. Step cost
|
| 620 |
+
step_cost = STEP_COST
|
| 621 |
+
|
| 622 |
+
reward = progress + efficiency + fp_penalty + early_done + step_cost
|
| 623 |
+
return round(float(np.clip(reward, -0.5, 1.0)), 4)
|
| 624 |
+
|
| 625 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 626 |
+
# Observation builder
|
| 627 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 628 |
+
|
| 629 |
+
def _build_observation(
|
| 630 |
+
self,
|
| 631 |
+
reward: Optional[float],
|
| 632 |
+
done: bool,
|
| 633 |
+
last_action_success: bool,
|
| 634 |
+
last_action_error: Optional[str],
|
| 635 |
+
grader_result: GradeResult,
|
| 636 |
+
) -> CleanObservation:
|
| 637 |
+
if self._state is None:
|
| 638 |
+
raise RuntimeError("State not initialised.")
|
| 639 |
+
|
| 640 |
+
return CleanObservation(
|
| 641 |
+
# Inherited from Observation base
|
| 642 |
+
done=done,
|
| 643 |
+
reward=reward,
|
| 644 |
+
# Task context
|
| 645 |
+
task_id=self._state.task_id,
|
| 646 |
+
schema_hint=self._state.schema_hint,
|
| 647 |
+
initial_dirty_cells=self._state.initial_dirty_cells,
|
| 648 |
+
# Per-step state
|
| 649 |
+
dirty_csv=self._df_to_csv(self._dirty_df),
|
| 650 |
+
current_score=grader_result.score,
|
| 651 |
+
issues_remaining=grader_result.issues_remaining,
|
| 652 |
+
step_number=self._state.step_count,
|
| 653 |
+
max_steps=self._state.max_steps,
|
| 654 |
+
# Last-action feedback
|
| 655 |
+
last_action_success=last_action_success,
|
| 656 |
+
last_action_error=last_action_error,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 660 |
+
# Utilities
|
| 661 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 662 |
+
|
| 663 |
+
@staticmethod
|
| 664 |
+
def _df_to_csv(df: Optional[pd.DataFrame]) -> str:
|
| 665 |
+
"""Serialise DataFrame to CSV string with the integer position index."""
|
| 666 |
+
if df is None:
|
| 667 |
+
return ""
|
| 668 |
+
return df.to_csv(index=True, index_label="row_index")
|
| 669 |
+
|
| 670 |
+
@staticmethod
|
| 671 |
+
def _cast_value(
|
| 672 |
+
val: str, df: pd.DataFrame, col: str
|
| 673 |
+
) -> tuple[Any, Optional[str]]:
|
| 674 |
+
"""
|
| 675 |
+
Try to cast a string value to the appropriate type for `col`.
|
| 676 |
+
|
| 677 |
+
Returns (cast_value, error_message). error_message is None on success.
|
| 678 |
+
"""
|
| 679 |
+
# Determine target type from the clean (non-null, non-text) column values
|
| 680 |
+
sample = pd.to_numeric(
|
| 681 |
+
df[col].dropna().astype(str).str.strip(), errors="coerce"
|
| 682 |
+
)
|
| 683 |
+
majority_numeric = sample.notna().sum() / max(len(df[col].dropna()), 1) >= 0.5
|
| 684 |
+
|
| 685 |
+
if majority_numeric:
|
| 686 |
+
try:
|
| 687 |
+
float_val = float(val.strip().replace(",", ""))
|
| 688 |
+
# If all sample values are whole numbers, keep as int
|
| 689 |
+
if (sample.dropna() % 1 == 0).all() and float_val % 1 == 0:
|
| 690 |
+
return int(float_val), None
|
| 691 |
+
return round(float_val, 2), None
|
| 692 |
+
except (ValueError, AttributeError):
|
| 693 |
+
return (
|
| 694 |
+
None,
|
| 695 |
+
f"Cannot cast {val!r} to numeric for column {col!r}. "
|
| 696 |
+
f"Provide a plain number (e.g. '29.99').",
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# String column — accept as-is (strip whitespace)
|
| 700 |
+
return val.strip(), None
|
| 701 |
+
|
| 702 |
+
# ──────────────────────────────────────────────────────────��──────────────
|
| 703 |
+
# Lifecycle
|
| 704 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 705 |
+
|
| 706 |
+
def close(self) -> None:
|
| 707 |
+
self._dirty_df = None
|
| 708 |
+
self._clean_df = None
|
| 709 |
+
self._dataset = None
|
| 710 |
+
self._state = None
|
| 711 |
+
|
| 712 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 713 |
+
return EnvironmentMetadata(
|
| 714 |
+
name="data_cleaning_env",
|
| 715 |
+
description=(
|
| 716 |
+
"Data cleaning pipeline: the agent receives a dirty CSV "
|
| 717 |
+
"and must fix type errors, outliers, missing values, and "
|
| 718 |
+
"schema inconsistencies to match a hidden ground truth."
|
| 719 |
+
),
|
| 720 |
+
version="1.0.0",
|
| 721 |
+
author="hackathon",
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 726 |
+
# Helpers
|
| 727 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 728 |
+
|
| 729 |
+
def _is_nan(x: Any) -> bool:
|
| 730 |
+
"""Return True if x is any flavour of missing value."""
|
| 731 |
+
if x is None:
|
| 732 |
+
return True
|
| 733 |
+
try:
|
| 734 |
+
return bool(pd.isna(x))
|
| 735 |
+
except (TypeError, ValueError):
|
| 736 |
+
return False
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 740 |
+
# Smoke test
|
| 741 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 742 |
+
|
| 743 |
+
if __name__ == "__main__":
|
| 744 |
+
SEP = "─" * 64
|
| 745 |
+
|
| 746 |
+
for task_id in ("easy", "medium", "hard"):
|
| 747 |
+
print(f"\n{SEP}\nTASK: {task_id.upper()}\n{SEP}")
|
| 748 |
+
|
| 749 |
+
env = DataCleaningEnvironment()
|
| 750 |
+
|
| 751 |
+
# ── reset ────────────────────────────────────────────────────────────
|
| 752 |
+
obs = env.reset(task_id=task_id)
|
| 753 |
+
print(f"reset() → score={obs.current_score:.4f} "
|
| 754 |
+
f"issues={obs.issues_remaining} done={obs.done}")
|
| 755 |
+
assert obs.reward is None, "reward must be None after reset"
|
| 756 |
+
assert obs.done is False, "done must be False after reset"
|
| 757 |
+
|
| 758 |
+
lines = obs.dirty_csv.strip().split("\n")
|
| 759 |
+
print(f" CSV: {len(lines)} rows, {len(lines[0].split(','))} cols")
|
| 760 |
+
print(f" Hint: {obs.schema_hint[:70]}…")
|
| 761 |
+
|
| 762 |
+
# ── state() ──────────────────────────────────────────────────────────
|
| 763 |
+
st = env.state
|
| 764 |
+
print(f"state() → episode_id={st.episode_id[:8]}… step_count={st.step_count}")
|
| 765 |
+
|
| 766 |
+
# ── step: bad column (should give feedback, not crash) ───────────────
|
| 767 |
+
bad_action = CleanAction(
|
| 768 |
+
command="SET_VALUE", row_index=0, column="DOES_NOT_EXIST", value="0"
|
| 769 |
+
)
|
| 770 |
+
obs2 = env.step(bad_action)
|
| 771 |
+
assert obs2.last_action_success is False
|
| 772 |
+
print(f"step (bad col) → success={obs2.last_action_success} "
|
| 773 |
+
f"error='{obs2.last_action_error[:50]}…'")
|
| 774 |
+
|
| 775 |
+
# ── step: out-of-bounds row ──────────────────────────────────────────
|
| 776 |
+
bad_row = CleanAction(
|
| 777 |
+
command="SET_VALUE", row_index=9999, column="price", value="10.0"
|
| 778 |
+
)
|
| 779 |
+
obs3 = env.step(bad_row)
|
| 780 |
+
assert obs3.last_action_success is False
|
| 781 |
+
print(f"step (bad row) → success={obs3.last_action_success} "
|
| 782 |
+
f"error='{obs3.last_action_error[:50]}…'")
|
| 783 |
+
|
| 784 |
+
# ── step: valid fix ──────────────────────────────────────────────────
|
| 785 |
+
if task_id == "easy":
|
| 786 |
+
# Find the first injected dirty cell and fix it
|
| 787 |
+
injected = env._dataset.metadata.get("injected_cells", [])
|
| 788 |
+
if injected:
|
| 789 |
+
row, col = injected[0]
|
| 790 |
+
clean_val = str(env._clean_df.iloc[row][col])
|
| 791 |
+
fix_action = CleanAction(
|
| 792 |
+
command="SET_VALUE", row_index=row, column=col, value=clean_val
|
| 793 |
+
)
|
| 794 |
+
obs4 = env.step(fix_action)
|
| 795 |
+
print(f"step (fix row={row} col={col!r}) → "
|
| 796 |
+
f"success={obs4.last_action_success} "
|
| 797 |
+
f"score={obs4.current_score:.4f} "
|
| 798 |
+
f"reward={obs4.reward:.4f}")
|
| 799 |
+
assert obs4.last_action_success is True
|
| 800 |
+
assert obs4.reward is not None
|
| 801 |
+
|
| 802 |
+
elif task_id == "medium":
|
| 803 |
+
# Fix one outlier row via FILL_MISSING on amount
|
| 804 |
+
obs4 = env.step(CleanAction(
|
| 805 |
+
command="FILL_MISSING", column="amount", fill_strategy="median"
|
| 806 |
+
))
|
| 807 |
+
print(f"step (FILL_MISSING amount/median) → "
|
| 808 |
+
f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}")
|
| 809 |
+
|
| 810 |
+
elif task_id == "hard":
|
| 811 |
+
# Standardize the date column
|
| 812 |
+
obs4 = env.step(CleanAction(
|
| 813 |
+
command="STANDARDIZE_COL", column="date"
|
| 814 |
+
))
|
| 815 |
+
print(f"step (STANDARDIZE_COL date) → "
|
| 816 |
+
f"success={obs4.last_action_success} "
|
| 817 |
+
f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}")
|
| 818 |
+
|
| 819 |
+
# ── DONE action ───────────────────────────────────────────────────────
|
| 820 |
+
done_obs = env.step(CleanAction(command="DONE"))
|
| 821 |
+
assert done_obs.done is True
|
| 822 |
+
print(f"step (DONE) → done={done_obs.done} "
|
| 823 |
+
f"reward={done_obs.reward:.4f} score={done_obs.current_score:.4f}")
|
| 824 |
+
|
| 825 |
+
env.close()
|
| 826 |
+
|
| 827 |
+
print(f"\n{SEP}\nAll smoke tests passed.\n{SEP}")
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pandas>=2.0.0
|
| 5 |
+
numpy>=2.0.0
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
validate-submission.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 25 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 113 |
+
|
| 114 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 115 |
+
pass "HF Space is live and responds to /reset"
|
| 116 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 117 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 118 |
+
hint "Check your network connection and that the Space is running."
|
| 119 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 120 |
+
stop_at "Step 1"
|
| 121 |
+
else
|
| 122 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 123 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 124 |
+
hint "Try opening $PING_URL in your browser first."
|
| 125 |
+
stop_at "Step 1"
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 129 |
+
|
| 130 |
+
if ! command -v docker &>/dev/null; then
|
| 131 |
+
fail "docker command not found"
|
| 132 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 133 |
+
stop_at "Step 2"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 137 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 138 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 139 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 140 |
+
else
|
| 141 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 142 |
+
stop_at "Step 2"
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 146 |
+
|
| 147 |
+
BUILD_OK=false
|
| 148 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 149 |
+
|
| 150 |
+
if [ "$BUILD_OK" = true ]; then
|
| 151 |
+
pass "Docker build succeeded"
|
| 152 |
+
else
|
| 153 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 154 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 155 |
+
stop_at "Step 2"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 159 |
+
|
| 160 |
+
if ! command -v openenv &>/dev/null; then
|
| 161 |
+
fail "openenv command not found"
|
| 162 |
+
hint "Install it: pip install openenv-core"
|
| 163 |
+
stop_at "Step 3"
|
| 164 |
+
fi
|
| 165 |
+
|
| 166 |
+
VALIDATE_OK=false
|
| 167 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 168 |
+
|
| 169 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 170 |
+
pass "openenv validate passed"
|
| 171 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 172 |
+
else
|
| 173 |
+
fail "openenv validate failed"
|
| 174 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 175 |
+
stop_at "Step 3"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
printf "\n"
|
| 179 |
+
printf "${BOLD}========================================${NC}\n"
|
| 180 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 182 |
+
printf "${BOLD}========================================${NC}\n"
|
| 183 |
+
printf "\n"
|
| 184 |
+
|
| 185 |
+
exit 0
|