Dishaaa25's picture
Upload folder using huggingface_hub
c22bf49 verified
from __future__ import annotations
import argparse
from typing import Any, Literal
import uvicorn
from fastapi import Body, FastAPI
from fastapi.responses import RedirectResponse, Response
from pydantic import BaseModel
from models import Action, Observation
from .environment import DataCleaningEnv
TASKS = ["basic_cleaning", "moderate_cleaning", "full_pipeline"]
ENV_NAME = "data_cleaning_env"
ENV_DESCRIPTION = (
"RL environment for interactive tabular data cleaning and preparation. "
"Agents must fix missing values, duplicates, dtype issues, category inconsistencies, "
"and derived-feature requirements."
)
app = FastAPI(title="Data Cleaning OpenEnv", version="1.0.0")
ENV = DataCleaningEnv()
class ResetRequest(BaseModel):
task_name: Literal["basic_cleaning", "moderate_cleaning", "full_pipeline"] = "basic_cleaning"
def _metadata() -> dict[str, Any]:
return {
"name": ENV_NAME,
"description": ENV_DESCRIPTION,
"version": "1.0.0",
"tasks": TASKS,
"mode": "simulation",
}
@app.get("/")
def root() -> dict[str, Any]:
payload = _metadata()
payload["status"] = "ok"
return payload
@app.get("/web", include_in_schema=False)
def web_root() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/web/", include_in_schema=False)
def web_root_slash() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
return Response(status_code=204)
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/metadata")
def metadata() -> dict[str, Any]:
return _metadata()
@app.get("/tasks")
def list_tasks() -> dict[str, list[str]]:
return {"tasks": TASKS}
@app.get("/schema")
def schema() -> dict[str, Any]:
observation_schema = Observation.model_json_schema()
return {
"action": Action.model_json_schema(),
"observation": observation_schema,
"state": observation_schema,
}
@app.post("/mcp")
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
return {
"jsonrpc": "2.0",
"id": payload.get("id"),
"error": {
"code": -32601,
"message": "MCP methods are not implemented for this benchmark.",
},
}
@app.post("/reset")
def reset(request: ResetRequest | None = None) -> dict[str, Any]:
effective_request = request or ResetRequest()
ENV.task_name = effective_request.task_name
observation = ENV.reset()
return observation.model_dump()
@app.post("/step")
def step(action: Action) -> dict[str, Any]:
observation, reward, done, info = ENV.step(action)
return {
"observation": observation.model_dump(),
"reward": reward,
"done": done,
"info": info,
}
@app.get("/state")
def state() -> dict[str, Any]:
if not ENV.dataset:
ENV.reset()
return ENV.state().model_dump()
def main(host: str | None = None, port: int | None = None) -> None:
if host is None or port is None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
host = args.host if host is None else host
port = args.port if port is None else port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()