CodeKnightDebjit commited on
Commit
ef85ae8
Β·
verified Β·
1 Parent(s): 9e08f1d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +352 -251
inference.py CHANGED
@@ -1,336 +1,437 @@
1
  """
2
- Inference Script β€” Data Cleaning Environment
3
- =============================================
4
- MANDATORY environment variables:
5
- API_BASE_URL The API endpoint for the LLM.
6
- MODEL_NAME The model identifier to use for inference.
7
- HF_TOKEN Your Hugging Face / API key.
8
- LOCAL_IMAGE_NAME Docker image name (when using from_docker_image()).
9
-
10
- Defaults are set only for API_BASE_URL and MODEL_NAME (not HF_TOKEN).
11
-
12
- STDOUT FORMAT
13
- - [START] task=<task_name> env=<benchmark> model=<model_name>
14
- - [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
15
- - [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
 
 
 
 
 
 
16
  """
17
 
18
  import asyncio
19
  import json
20
  import os
21
  import re
22
- import textwrap
23
- from typing import List, Optional
24
 
25
  from openai import OpenAI
26
 
27
- from client import DataCleaningEnv
28
- from models import CleanAction
29
-
30
- # ── Environment variables ────────────────────────────────────────────────────
31
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "openenv-data_cleaning:latest")
32
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
33
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
34
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
35
- BENCHMARK = "data_cleaning_env"
36
-
37
- # ── Per-task config (mirrors server constants) ────────────────────────────────
38
- TASK_CONFIG = {
39
- "easy": {"max_steps": 40, "threshold": 0.95},
40
- "medium": {"max_steps": 80, "threshold": 0.85},
41
- "hard": {"max_steps": 150, "threshold": 0.80},
42
- }
43
-
44
- TEMPERATURE = 0.2 # low temp β†’ more deterministic action parsing
45
- MAX_TOKENS = 256
46
-
47
- # ── Logging helpers (strict stdout format) ───────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def log_start(task: str, env: str, model: str) -> None:
50
  print(f"[START] task={task} env={env} model={model}", flush=True)
51
 
52
 
53
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
 
54
  error_val = error if error else "null"
55
- done_val = str(done).lower()
56
  print(
57
- f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
 
58
  flush=True,
59
  )
60
 
61
 
62
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
 
63
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
64
  print(
65
- f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
 
66
  flush=True,
67
  )
68
 
69
- # ── Prompt builders ───────────────────────────────────────────────────────────
70
 
71
- SYSTEM_PROMPT = textwrap.dedent("""
72
- You are an expert data cleaning agent. You receive a dirty CSV dataset and must
73
- fix it step by step to match a hidden clean ground truth.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- Available commands (respond with EXACTLY one JSON object, no extra text):
 
 
76
 
77
- {"command": "SET_VALUE", "row_index": <int>, "column": "<col>", "value": "<val>"}
78
- {"command": "DROP_ROW", "row_index": <int>}
79
- {"command": "STANDARDIZE_COL", "column": "<col>"}
80
- {"command": "FILL_MISSING", "column": "<col>", "fill_strategy": "mean|median|mode|drop"}
81
- {"command": "DONE"}
 
 
 
82
 
83
- Rules:
84
- - Output ONLY the JSON object β€” no explanation, no markdown, no backticks.
85
- - Use DONE only when you are confident the score meets the task threshold.
86
- - SET_VALUE fixes a single bad cell.
87
- - STANDARDIZE_COL normalises an entire column's format.
88
- - FILL_MISSING fills NaN values in a column.
89
- - DROP_ROW removes a row; use carefully β€” false positives are penalised.
90
- - Row indices are 0-based positional indices (they shift after each DROP_ROW).
91
- """).strip()
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- def build_user_prompt(obs, history: List[str]) -> str:
95
- history_block = "\n".join(history[-15:]) if history else "None yet."
96
- return textwrap.dedent(f"""
97
- Task: {obs.task_id}
98
- Schema hint: {obs.schema_hint}
99
- Step: {obs.step_number} / {obs.max_steps}
100
- Current score: {obs.current_score:.4f}
101
- Issues remaining: {obs.issues_remaining}
102
- Initial dirty cells: {obs.initial_dirty_cells}
103
- Last action success: {obs.last_action_success}
104
- Last action error: {obs.last_action_error or 'none'}
105
-
106
- === ACTION HISTORY (most recent 15) ===
107
- {history_block}
108
-
109
- IMPORTANT RULES:
110
- - Do NOT repeat any action that already appears in the history with score_delta=0.0000.
111
- - Do NOT repeat STANDARDIZE_COL or FILL_MISSING on the same column twice.
112
- - If score is not improving after 2 steps, switch strategy entirely.
113
- - Use SET_VALUE to fix specific bad cells (wrong types, "N/A" strings, outliers, future dates).
114
- - Inspect the CSV carefully before choosing your action.
115
-
116
- Current CSV (first 80 rows shown if large):
117
- {_truncate_csv(obs.dirty_csv, max_rows=80)}
118
-
119
- Output your next CleanAction as a single JSON object.
120
- """).strip()
121
-
122
-
123
- def _truncate_csv(csv_text: str, max_rows: int = 80) -> str:
124
- lines = csv_text.splitlines()
125
- if len(lines) <= max_rows + 1: # +1 for header
126
- return csv_text
127
- header = lines[0]
128
- body = lines[1: max_rows + 1]
129
- omitted = len(lines) - 1 - max_rows
130
- return "\n".join([header] + body + [f"... ({omitted} more rows omitted)"])
131
-
132
- # ── Action parsing ────────────────────────────────────────────────────────────
133
-
134
- VALID_COMMANDS = {"SET_VALUE", "DROP_ROW", "STANDARDIZE_COL", "FILL_MISSING", "DONE"}
135
  VALID_STRATEGIES = {"mean", "median", "mode", "drop"}
136
 
137
 
138
- def parse_action(llm_output: str) -> CleanAction:
139
- """
140
- Parse the LLM's JSON output into a CleanAction.
141
- Falls back to STANDARDIZE_COL on the first column if parsing fails.
142
- """
143
- text = llm_output.strip()
144
 
145
- # Strip accidental markdown fences
146
- text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
147
- text = re.sub(r"```$", "", text).strip()
148
 
149
- # Extract first JSON object
150
- match = re.search(r"\{.*?\}", text, re.DOTALL)
151
- if not match:
152
- raise ValueError(f"No JSON object found in LLM output: {text!r}")
153
 
154
- data = json.loads(match.group())
155
- command = data.get("command", "").upper()
156
 
157
- if command not in VALID_COMMANDS:
158
- raise ValueError(f"Unknown command: {command!r}")
159
 
160
- if command == "SET_VALUE":
161
- return CleanAction(
162
- command="SET_VALUE",
163
- row_index=int(data["row_index"]),
164
- column=str(data["column"]),
165
- value=str(data["value"]),
166
- )
167
- elif command == "DROP_ROW":
168
- return CleanAction(command="DROP_ROW", row_index=int(data["row_index"]))
169
- elif command == "STANDARDIZE_COL":
170
- return CleanAction(command="STANDARDIZE_COL", column=str(data["column"]))
171
- elif command == "FILL_MISSING":
172
- strategy = str(data.get("fill_strategy", "median")).lower()
173
- if strategy not in VALID_STRATEGIES:
174
- strategy = "median"
175
- return CleanAction(
176
- command="FILL_MISSING",
177
- column=str(data["column"]),
178
- fill_strategy=strategy,
179
- )
180
- else: # DONE
181
  return CleanAction(command="DONE")
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- def _action_to_str(action: CleanAction) -> str:
185
- """Compact single-line string for [STEP] log."""
186
- parts = [action.command]
187
- if action.row_index is not None:
188
- parts.append(f"row={action.row_index}")
189
- if action.column:
190
- parts.append(f"col={action.column}")
191
- if action.value is not None:
192
- val_repr = str(action.value)[:30]
193
- parts.append(f"val={val_repr!r}")
194
- if action.fill_strategy:
195
- parts.append(f"strategy={action.fill_strategy}")
196
- return "(" + ",".join(parts) + ")"
197
-
198
- # ── LLM call ──────────────────────────────────────────────────────────────────
199
-
200
- def get_model_action(client: OpenAI, obs, history: List[str]) -> CleanAction:
201
- user_prompt = build_user_prompt(obs, history)
202
  try:
203
- completion = client.chat.completions.create(
204
- model=MODEL_NAME,
205
- messages=[
206
- {"role": "system", "content": SYSTEM_PROMPT},
207
- {"role": "user", "content": user_prompt},
208
- ],
209
- temperature=TEMPERATURE,
210
- max_tokens=MAX_TOKENS,
211
- stream=False,
212
  )
213
- text = (completion.choices[0].message.content or "").strip()
214
- return parse_action(text)
215
- except Exception as exc:
216
- print(f"[DEBUG] Model/parse error: {exc}", flush=True)
217
- return CleanAction(command="FILL_MISSING", column="quantity", fill_strategy="median")
218
-
219
- # ── Episode runner ────────────────────────────────────────────────────────────
220
-
221
- async def run_episode(env: DataCleaningEnv, client: OpenAI, task_id: str) -> dict:
222
- """
223
- Run a single episode for task_id. Returns a summary dict.
224
- """
225
- cfg = TASK_CONFIG[task_id]
226
- max_steps = cfg["max_steps"]
227
- threshold = cfg["threshold"]
228
-
229
- rewards: List[float] = []
230
- history: List[str] = [] # action history fed back to LLM each step
231
- steps_taken: int = 0
232
- score: float = 0.0
233
- prev_score: float = 0.0
234
- success: bool = False
 
 
235
 
236
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
237
 
238
  try:
239
- result = await env.reset(task_id=task_id)
240
- obs = result.observation
241
- prev_score = obs.current_score
 
 
242
 
243
  for step in range(1, max_steps + 1):
244
- if result.done:
245
  break
246
 
247
- action = get_model_action(client, obs, history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  result = await env.step(action)
250
  obs = result.observation
251
 
252
- reward = result.reward or 0.0
253
- done = result.done
254
- error = obs.last_action_error if not obs.last_action_success else None
255
- score_delta = obs.current_score - prev_score
256
- prev_score = obs.current_score
257
 
 
258
  rewards.append(reward)
259
- steps_taken = step
260
-
261
- # Build a rich history entry the LLM can learn from
262
- action_desc = _action_to_str(action)
263
- status = "βœ“" if obs.last_action_success else "βœ—"
264
- delta_str = f"+{score_delta:.4f}" if score_delta > 0 else f"{score_delta:.4f}"
265
- history.append(
266
- f"step={step} {status} {action_desc} reward={reward:+.2f} "
267
- f"score_delta={delta_str} score={obs.current_score:.4f}"
268
- + (f" ERROR={error}" if error else "")
269
- )
270
 
271
  log_step(
272
- step=step,
273
- action=action_desc,
274
- reward=reward,
275
- done=done,
276
- error=error,
277
  )
278
 
279
- if done:
 
 
 
 
 
 
 
 
280
  break
281
 
282
- score = obs.current_score if obs else 0.0
283
  success = score >= threshold
284
 
285
- finally:
286
- score = score if score else 0.0
287
- success = success if success else False
288
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
 
 
 
 
 
 
289
 
290
- return {
291
- "task": task_id,
292
- "score": score,
293
- "reward": sum(rewards),
294
- "steps": steps_taken,
295
- "success": success,
296
- }
297
 
298
- # ── Main ──────────────────────────────────────────────────────────────────────
299
 
300
  async def main() -> None:
301
- print(f"API_BASE_URL : {API_BASE_URL}")
302
- print(f"MODEL_NAME : {MODEL_NAME}")
303
- print(f"LOCAL_IMAGE_NAME : {LOCAL_IMAGE_NAME}")
304
- print()
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
307
 
308
- if os.getenv("SPACE_ID"): # Running inside HF Space
309
- env = DataCleaningEnv(base_url="http://localhost:7860")
310
- await env.connect()
311
- else:
312
- env = await DataCleaningEnv.from_docker_image(LOCAL_IMAGE_NAME)
313
 
314
  results = []
315
  try:
316
- for task_id in ("easy", "medium", "hard"):
317
- summary = await run_episode(env, client, task_id)
318
  results.append(summary)
319
- print() # blank line between tasks
320
  finally:
321
  try:
322
  await env.close()
323
- except Exception as e:
324
- print(f"[DEBUG] env.close() error: {e}", flush=True)
325
 
326
- # ── Summary table ────────────────────────────────────────────────────────
327
- print("═" * 56)
328
- print(f"{'Task':<12} {'Score':>7} {'Reward':>7} {'Steps':>5} {'Pass'}")
329
- print("─" * 56)
330
  for r in results:
331
- flag = "YES" if r["success"] else " NO"
332
- print(f"{r['task']:<12} {r['score']:>7.4f} {r['reward']:>7.4f} {r['steps']:>5} {flag}")
333
- print("═" * 56)
 
 
 
334
 
335
 
336
  if __name__ == "__main__":
 
1
  """
2
+ inference.py
3
+ ------------
4
+ Official submission inference script for the Data Cleaning Pipeline environment.
5
+
6
+ Environment variables:
7
+ API_BASE_URL LLM endpoint. Default: HuggingFace free router.
8
+ MODEL_NAME Model to use. Default: Qwen/Qwen2.5-72B-Instruct (free).
9
+ HF_TOKEN Your HuggingFace token (hf_...).
10
+ ENV_BASE_URL The running environment URL.
11
+ Set this to your HuggingFace Space URL, e.g.:
12
+ https://CodeKnightDebjit-data-cleaning-env.hf.space
13
+
14
+ NOTE: Do NOT use LOCAL_IMAGE_NAME / from_docker_image() in submitted scripts.
15
+ The evaluator machine does not have your local Docker image β€” it connects to
16
+ your live HF Space via ENV_BASE_URL.
17
+
18
+ STDOUT FORMAT (evaluator parses these exactly):
19
+ [START] task=<n> env=<benchmark> model=<model>
20
+ [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
21
+ [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
22
  """
23
 
24
  import asyncio
25
  import json
26
  import os
27
  import re
28
+ import sys
29
+ from typing import Any, Dict, List, Optional
30
 
31
  from openai import OpenAI
32
 
33
+ try:
34
+ from client import DataCleaningEnv
35
+ from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
36
+ except ImportError:
37
+ sys.path.insert(0, os.path.dirname(__file__))
38
+ from client import DataCleaningEnv
39
+ from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
40
+
41
+
42
+ # ── Configuration ──────────────────────────────────────────────────────────────
43
+ # ENV_BASE_URL must point to your live HuggingFace Space.
44
+ # The evaluator sets this automatically when it runs your script.
45
+
46
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
47
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
48
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
49
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://CodeKnightDebjit-data-cleaning-env.hf.space")
50
+
51
+ BENCHMARK = "data_cleaning_env"
52
+ TASK_IDS = ["easy", "medium", "hard"]
53
+ STEP_LIMITS = {"easy": 25, "medium": 50, "hard": 80}
54
+
55
+
56
+ # ── System prompt ──────────────────────────────────────────────────────────────
57
+
58
+ SYSTEM_PROMPT = """You are a deterministic data cleaning agent.
59
+ Your task is to clean a dataset step-by-step using valid actions.
60
+ You are operating inside an environment with strict rules.
61
+ --------------------------------------------------
62
+ ## INPUT PROVIDED EACH STEP
63
+ You will receive:
64
+ 1. Column schema (LIST OF VALID COLUMN NAMES β€” CASE SENSITIVE)
65
+ 2. Column status:
66
+ - missing values count
67
+ - whether standardized (true/false)
68
+ 3. Remaining issues (global state)
69
+ 4. Previous actions taken
70
+ --------------------------------------------------
71
+ ## OBJECTIVE
72
+ Fully clean the dataset with MINIMUM steps.
73
+ A dataset is CLEAN only if:
74
+ - No missing values remain
75
+ - All columns are standardized
76
+ - No invalid formats exist
77
+ --------------------------------------------------
78
+ ## STRICT RULES (MUST FOLLOW)
79
+ ### 1. NEVER TERMINATE EARLY
80
+ You MUST NOT output DONE unless:
81
+ - ALL columns have missing = 0
82
+ - ALL columns have standardized = true
83
+ - remaining_issues is empty
84
+ If ANY issue remains β†’ DO NOT output DONE.
85
+ --------------------------------------------------
86
+ ### 2. USE ONLY VALID COLUMNS
87
+ - You MUST use EXACT column names from schema
88
+ - Column names are CASE SENSITIVE
89
+ - NEVER invent new column names
90
+ --------------------------------------------------
91
+ ### 3. PRIORITIZE COLUMN-LEVEL ACTIONS
92
+ Preferred actions:
93
+ - FILL_MISSING (fixes entire column)
94
+ - STANDARDIZE_COL (fixes formatting)
95
+ Avoid:
96
+ - SET_VALUE (only for single isolated errors)
97
+ NEVER fix a full column using repeated SET_VALUE.
98
+ --------------------------------------------------
99
+ ### 4. DO NOT REPEAT ACTIONS
100
+ - Do NOT apply the same action repeatedly on the same column
101
+ - Do NOT standardize an already standardized column
102
+ - Do NOT fill missing if missing = 0
103
+ --------------------------------------------------
104
+ ### 5. AVOID DESTRUCTIVE ACTIONS
105
+ - DROP_ROW should be used ONLY when absolutely necessary
106
+ --------------------------------------------------
107
+ ## OUTPUT FORMAT (STRICT JSON ONLY)
108
+ Return ONLY one of these β€” no explanation, no markdown:
109
+ {"action": "FILL_MISSING", "column": "<col>", "strategy": "<mean|median|mode>"}
110
+ {"action": "STANDARDIZE_COL", "column": "<col>"}
111
+ {"action": "SET_VALUE", "column": "<col>", "row": <int>, "value": "<str>"}
112
+ {"action": "DROP_ROW", "row": <int>}
113
+ {"action": "DONE"}
114
+ --------------------------------------------------
115
+ ## FAILURE CONDITIONS (AVOID THESE)
116
+ - DONE prematurely β†’ penalty -1.0
117
+ - Invalid column names β†’ action fails
118
+ - Repeated same action β†’ wasted step
119
+ --------------------------------------------------
120
+ Every step must move the dataset closer to a fully clean state."""
121
+
122
+
123
+ # ── Official log format ────────────────────────────────────────────────────────
124
 
125
  def log_start(task: str, env: str, model: str) -> None:
126
  print(f"[START] task={task} env={env} model={model}", flush=True)
127
 
128
 
129
+ def log_step(step: int, action: str, reward: float, done: bool,
130
+ error: Optional[str]) -> None:
131
  error_val = error if error else "null"
 
132
  print(
133
+ f"[STEP] step={step} action={action[:80].replace(chr(10), ' ')} "
134
+ f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
135
  flush=True,
136
  )
137
 
138
 
139
+ def log_end(success: bool, steps: int, score: float,
140
+ rewards: List[float]) -> None:
141
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
142
  print(
143
+ f"[END] success={str(success).lower()} steps={steps} "
144
+ f"score={score:.2f} rewards={rewards_str}",
145
  flush=True,
146
  )
147
 
 
148
 
149
+ # ── Prompt builder ─────────────────────────────────────────────────────────────
150
+
151
+ def _col_status_block(column_status: Dict[str, Any]) -> str:
152
+ if not column_status:
153
+ return " (not available)"
154
+ lines = []
155
+ for col, s in column_status.items():
156
+ missing = s.get("missing", 0)
157
+ standardized = s.get("standardized", True)
158
+ issues = s.get("issues", [])
159
+ flag = "OK" if (missing == 0 and standardized) else "NEEDS_FIX"
160
+ issue_str = ", ".join(issues) if issues else ""
161
+ lines.append(
162
+ f" {col:<26} missing={missing:<3} standardized={str(standardized).lower():<5}"
163
+ + (f" issues=[{issue_str}]" if issue_str else "")
164
+ + f" β†’ {flag}"
165
+ )
166
+ return "\n".join(lines)
167
+
168
+
169
+ def build_user_prompt(obs, history: List[str]) -> str:
170
+ col_status: Dict[str, Any] = getattr(obs, "column_status", {})
171
+ valid_columns = list(col_status.keys())
172
+ broken = [c for c, s in col_status.items()
173
+ if s.get("missing", 0) > 0 or not s.get("standardized", True)]
174
+
175
+ rows = obs.dirty_csv.strip().split("\n")
176
+ preview = "\n".join(rows[:21])
177
+
178
+ all_clean = len(broken) == 0
179
+ done_hint = (
180
+ "ALL columns clean β†’ you MAY output DONE"
181
+ if all_clean else
182
+ f"{len(broken)} column(s) still broken β†’ DO NOT output DONE"
183
+ )
184
+
185
+ history_block = "\n".join(f" {h}" for h in history[-6:]) if history else " none"
186
+
187
+ return f"""--------------------------------------------------
188
+ ## COLUMN SCHEMA (EXACT CASE-SENSITIVE NAMES β€” USE THESE EXACTLY)
189
+ {chr(10).join(f' - {c}' for c in valid_columns)}
190
 
191
+ --------------------------------------------------
192
+ ## COLUMN STATUS
193
+ {_col_status_block(col_status)}
194
 
195
+ --------------------------------------------------
196
+ ## GLOBAL STATE
197
+ Task: {obs.task_id}
198
+ Step: {obs.step_number} / {obs.max_steps}
199
+ Score: {obs.current_score:.4f} (need >= {DONE_THRESHOLD[obs.task_id]:.2f})
200
+ Remaining issues: {obs.issues_remaining}
201
+ Broken columns: {broken}
202
+ DONE status: {done_hint}
203
 
204
+ --------------------------------------------------
205
+ ## SCHEMA HINT
206
+ {obs.schema_hint}
 
 
 
 
 
 
207
 
208
+ --------------------------------------------------
209
+ ## CSV PREVIEW (first 20 rows)
210
+ {preview}
211
+
212
+ --------------------------------------------------
213
+ ## PREVIOUS ACTIONS
214
+ {history_block}
215
+
216
+ --------------------------------------------------
217
+ Return ONLY valid JSON β€” no explanation, no markdown."""
218
+
219
+
220
+ # ── Action parsing ─────────────────────────────────────────────────────────────
221
+
222
+ COMMAND_MAP = {
223
+ "FILL_MISSING": "FILL_MISSING",
224
+ "STANDARDIZE_COL": "STANDARDIZE_COL",
225
+ "STANDARDIZE": "STANDARDIZE_COL",
226
+ "SET_VALUE": "SET_VALUE",
227
+ "DROP_ROW": "DROP_ROW",
228
+ "DROP": "DROP_ROW",
229
+ }
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  VALID_STRATEGIES = {"mean", "median", "mode", "drop"}
232
 
233
 
234
+ def parse_action(raw: str, valid_columns: List[str]) -> CleanAction:
235
+ text = raw.strip()
236
+ if text.startswith("```"):
237
+ lines = text.split("\n")
238
+ inner = lines[1:-1] if lines[-1].strip().startswith("```") else lines[1:]
239
+ text = "\n".join(inner).strip()
240
 
241
+ m = re.search(r"\{[^{}]*\}", text, re.DOTALL)
242
+ if not m:
243
+ return CleanAction(command="DONE")
244
 
245
+ try:
246
+ data: Dict[str, Any] = json.loads(m.group())
247
+ except json.JSONDecodeError:
248
+ return CleanAction(command="DONE")
249
 
250
+ action_raw = str(data.get("action", "DONE")).strip().upper().replace(" ", "_")
 
251
 
252
+ if action_raw == "DONE":
253
+ return CleanAction(command="DONE")
254
 
255
+ command = COMMAND_MAP.get(action_raw)
256
+ if command is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  return CleanAction(command="DONE")
258
 
259
+ # Validate column name (case-sensitive, with case-insensitive fallback)
260
+ column = data.get("column")
261
+ if column is not None and valid_columns:
262
+ if column not in valid_columns:
263
+ col_lower = {c.lower(): c for c in valid_columns}
264
+ column = col_lower.get(str(column).lower()) # None if no match
265
+
266
+ # strategy β†’ fill_strategy
267
+ fill_strategy = data.get("strategy") or data.get("fill_strategy")
268
+ if fill_strategy and str(fill_strategy).lower() not in VALID_STRATEGIES:
269
+ fill_strategy = "median"
270
+
271
+ # row β†’ row_index
272
+ row_raw = data.get("row") if data.get("row") is not None else data.get("row_index")
273
+ row_index = None
274
+ if row_raw is not None:
275
+ try:
276
+ row_index = int(row_raw)
277
+ except (TypeError, ValueError):
278
+ pass
279
+
280
+ value = data.get("value")
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  try:
283
+ return CleanAction(
284
+ command = command,
285
+ column = column,
286
+ fill_strategy = fill_strategy,
287
+ row_index = row_index,
288
+ value = str(value) if value is not None else None,
 
 
 
289
  )
290
+ except Exception:
291
+ return CleanAction(command="DONE")
292
+
293
+
294
+ def call_llm(client: OpenAI, messages: list) -> str:
295
+ response = client.chat.completions.create(
296
+ model = MODEL_NAME,
297
+ messages = messages,
298
+ max_tokens = 100,
299
+ temperature = 0.0,
300
+ )
301
+ return (response.choices[0].message.content or "").strip()
302
+
303
+
304
+ # ── Episode runner ─────────────────────────────────────────────────────────────
305
+
306
+ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
307
+ max_steps = STEP_LIMITS[task_id]
308
+ threshold = DONE_THRESHOLD[task_id]
309
+ rewards: List[float] = []
310
+ steps_taken = 0
311
+ score = 0.0
312
+ success = False
313
+ history: List[str] = []
314
 
315
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
316
 
317
  try:
318
+ result = await env.reset(task_id=task_id)
319
+ obs = result.observation
320
+
321
+ valid_columns: List[str] = list(getattr(obs, "column_status", {}).keys())
322
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
323
 
324
  for step in range(1, max_steps + 1):
325
+ if obs.done:
326
  break
327
 
328
+ steps_taken = step
329
+ messages.append({"role": "user", "content": build_user_prompt(obs, history)})
330
+
331
+ try:
332
+ raw = call_llm(client, messages)
333
+ action = parse_action(raw, valid_columns)
334
+ messages.append({"role": "assistant", "content": raw})
335
+ except Exception as exc:
336
+ log_step(step, "DONE", 0.00, True, str(exc)[:120])
337
+ rewards.append(0.0)
338
+ break
339
+
340
+ # Keep system + last 10 turns inside free-tier context limit
341
+ if len(messages) > 21:
342
+ messages = [messages[0]] + messages[-20:]
343
 
344
  result = await env.step(action)
345
  obs = result.observation
346
 
347
+ if getattr(obs, "column_status", {}):
348
+ valid_columns = list(obs.column_status.keys())
 
 
 
349
 
350
+ reward = result.reward or 0.0
351
  rewards.append(reward)
352
+ score = obs.current_score
 
 
 
 
 
 
 
 
 
 
353
 
354
  log_step(
355
+ step = step,
356
+ action = action.command,
357
+ reward = reward,
358
+ done = obs.done,
359
+ error = obs.last_action_error,
360
  )
361
 
362
+ parts = [f"step {step}: {action.command}"]
363
+ if action.column: parts.append(f"col={action.column}")
364
+ if action.fill_strategy: parts.append(f"strategy={action.fill_strategy}")
365
+ parts.append(f"score={score:.4f}")
366
+ if obs.last_action_error:
367
+ parts.append(f"[BLOCKED: {obs.last_action_error[:60]}]")
368
+ history.append(" ".join(parts))
369
+
370
+ if obs.done or score >= threshold:
371
  break
372
 
 
373
  success = score >= threshold
374
 
375
+ except Exception as episode_err:
376
+ # Catch-all so [END] is always emitted even if the episode crashes
377
+ print(f"[DEBUG] Episode error: {episode_err}", flush=True)
378
+ log_end(success=False, steps=steps_taken, score=score, rewards=rewards)
379
+ return {"task_id": task_id, "score": score, "reward": sum(rewards),
380
+ "steps": steps_taken, "success": False}
381
+
382
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
383
+ return {"task_id": task_id, "score": score, "reward": sum(rewards),
384
+ "steps": steps_taken, "success": success}
385
 
 
 
 
 
 
 
 
386
 
387
+ # ── Entry point ────────────────────────────────────────────────────────────────
388
 
389
  async def main() -> None:
390
+ if not HF_TOKEN:
391
+ print(
392
+ "ERROR: HF_TOKEN is not set.\n"
393
+ "1. Go to https://huggingface.co/settings/tokens\n"
394
+ "2. Create a Read token and copy it\n"
395
+ "3. Set it: $env:HF_TOKEN='hf_xxxxxxxxxxxx' (PowerShell)\n"
396
+ " export HF_TOKEN='hf_xxxxxxxxxxxx' (bash)\n"
397
+ "4. Run: python inference.py",
398
+ file=sys.stderr,
399
+ )
400
+ sys.exit(1)
401
+
402
+ print(f"API_BASE_URL : {API_BASE_URL}", flush=True)
403
+ print(f"MODEL_NAME : {MODEL_NAME}", flush=True)
404
+ print(f"ENV_BASE_URL : {ENV_BASE_URL}", flush=True)
405
+ print("", flush=True)
406
 
407
+ llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
408
 
409
+ # Always connect via URL β€” no Docker on the evaluator machine
410
+ env = DataCleaningEnv(base_url=ENV_BASE_URL)
411
+ await env.connect()
 
 
412
 
413
  results = []
414
  try:
415
+ for task_id in TASK_IDS:
416
+ summary = await run_episode(env, llm, task_id)
417
  results.append(summary)
418
+ print("", flush=True)
419
  finally:
420
  try:
421
  await env.close()
422
+ except Exception:
423
+ pass
424
 
425
+ print("=" * 56, flush=True)
426
+ print(f"{'Task':<10} {'Score':>7} {'Reward':>9} {'Steps':>6} {'Pass':>5}")
427
+ print("-" * 56, flush=True)
 
428
  for r in results:
429
+ print(
430
+ f"{r['task_id']:<10} {r['score']:>7.4f} {r['reward']:>9.4f} "
431
+ f"{r['steps']:>6} {'YES' if r['success'] else 'NO':>4}",
432
+ flush=True,
433
+ )
434
+ print("=" * 56, flush=True)
435
 
436
 
437
  if __name__ == "__main__":