CodeKnightDebjit commited on
Commit
73f852c
Β·
verified Β·
1 Parent(s): eee232c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +298 -221
inference.py CHANGED
@@ -1,17 +1,26 @@
1
  """
2
  inference.py
3
  ------------
4
- Official submission inference script for the Data Cleaning Pipeline environment.
5
-
6
- Environment variables (all free β€” no paid API):
7
- API_BASE_URL LLM endpoint. Default: HuggingFace free router.
8
- MODEL_NAME Model to use. Default: Qwen/Qwen2.5-72B-Instruct (free).
9
- HF_TOKEN Your free HuggingFace token (hf_...).
10
- LOCAL_IMAGE_NAME Docker image name if using from_docker_image() β€” leave
11
- unset to use ENV_BASE_URL instead.
12
- ENV_BASE_URL Server URL. Default: http://localhost:8000
13
-
14
- STDOUT FORMAT (evaluator parses these lines β€” do not modify):
 
 
 
 
 
 
 
 
 
15
  [START] task=<n> env=<benchmark> model=<model>
16
  [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
17
  [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
@@ -35,277 +44,334 @@ except ImportError:
35
  from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
36
 
37
 
38
- # ── Configuration β€” all defaults are FREE ─────────────────────────────────────
39
 
40
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
41
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
42
- HF_TOKEN = os.getenv("HF_TOKEN", "")
43
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
44
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
45
 
46
- BENCHMARK = "data_cleaning_env"
47
- TASK_IDS = ["easy", "medium", "hard"]
48
- STEP_LIMITS = {"easy": 25, "medium": 50, "hard": 80}
49
 
50
 
51
- # ── System prompt (expert data cleaning agent) ────────────────────────────────
52
 
53
- SYSTEM_PROMPT = """You are an expert data cleaning agent operating in a structured environment.
54
- Your goal is to transform the dataset into a fully clean state using the minimum number of steps.
55
- You are given:
56
- 1. Column schema (with data types)
 
 
 
57
  2. Column status:
58
  - missing values count
59
- - whether the column is standardized
60
- 3. Remaining issues (global view)
61
- 4. Previous actions
62
-
63
- ---
64
- ## STRICT RULES
65
-
66
- ### 1. DO NOT terminate early
67
- You MUST NOT output DONE unless ALL of the following are true:
68
- - No missing values remain in any column
69
  - All columns are standardized
70
- - No formatting issues remain
71
- - No invalid values exist
72
- If ANY issue remains β†’ continue acting.
73
-
74
- ---
75
- ### 2. Prioritize column-level fixes
76
- Prefer:
77
- - FILL_MISSING (for missing values)
78
- - STANDARDIZE_COL (for formatting / normalization)
79
- Avoid:
80
- - SET_VALUE (only use for isolated anomalies)
81
- NEVER fix an entire column using repeated SET_VALUE.
82
-
83
- ---
84
- ### 3. Use correct strategies based on column type
85
- - Numeric columns β†’ mean or median
86
- - Categorical columns β†’ mode
87
- - Datetime columns β†’ STANDARDIZE_COL (not SET_VALUE unless single anomaly)
88
-
89
- ---
90
- ### 4. Do not repeat work
91
- - Do NOT standardize a column more than once unless state changed
 
 
 
 
 
92
  - Do NOT fill missing if missing = 0
93
 
94
- ---
95
- ### 5. Always reason about global completion
96
- Before choosing DONE, check:
97
- - column_status
98
- - remaining_issues
99
- If any column has:
100
- - missing > 0
101
- - standardized = false
102
- β†’ DO NOT choose DONE
103
-
104
- ---
105
  ## DECISION PROCESS (MANDATORY)
106
  At each step:
107
- 1. Identify remaining issues
108
- 2. Select the MOST impactful action
109
- 3. Prefer actions that resolve entire columns
110
- 4. Avoid redundant or low-value actions
 
 
 
111
 
112
- ---
113
- ## OUTPUT FORMAT
114
- Return ONLY a valid JSON action β€” no explanation, no markdown fences:
115
 
116
- For column-level fixes:
117
- {"action": "FILL_MISSING", "column": "<col>", "strategy": "<mean|median|mode>"}
118
- {"action": "STANDARDIZE_COL", "column": "<col>"}
119
 
120
- For isolated cell fixes:
121
- {"action": "SET_VALUE", "column": "<col>", "row": <int>, "value": "<str>"}
122
 
123
- For outlier rows:
124
  {"action": "DROP_ROW", "row": <int>}
125
 
126
- When everything is clean:
127
  {"action": "DONE"}
128
 
129
- ---
130
- ## OBJECTIVE
131
- Minimize: number of steps, redundant operations, row-level edits.
132
- Maximize: completeness, correctness, efficiency.
133
-
134
- You will be penalized for: premature DONE, repeated actions, unnecessary SET_VALUE usage.
135
-
136
- Think step-by-step internally, but ONLY output the final JSON action."""
 
 
 
 
137
 
138
 
139
- # ── Official log format ────────────────────────────────────────────────────────
140
 
141
  def log_start(task: str, env: str, model: str) -> None:
142
  print(f"[START] task={task} env={env} model={model}", flush=True)
143
 
144
 
145
- def log_step(step: int, action: str, reward: float, done: bool,
146
- error: Optional[str]) -> None:
147
- error_val = error if error else "null"
148
  print(
149
- f"[STEP] step={step} action={action[:80].replace(chr(10),' ')} "
150
- f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
 
151
  flush=True,
152
  )
153
 
154
 
155
  def log_end(success: bool, steps: int, score: float,
156
  rewards: List[float]) -> None:
157
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
158
  print(
159
  f"[END] success={str(success).lower()} steps={steps} "
160
- f"score={score:.2f} rewards={rewards_str}",
161
  flush=True,
162
  )
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # ── Prompt builder ────────────────────────────────────────────────────────────
166
 
167
- def _format_column_status(column_status: Dict[str, Any]) -> str:
168
- """Render column_status as a compact, agent-readable block."""
169
- if not column_status:
170
- return " (not available)"
171
- lines = []
172
- for col, status in column_status.items():
173
- missing = status.get("missing", 0)
174
- standardized = status.get("standardized", True)
175
- issues = status.get("issues", [])
176
- flag = "OK" if missing == 0 and standardized else "NEEDS_FIX"
177
- issue_str = ", ".join(issues) if issues else ""
178
- lines.append(
179
- f" {col:<22} missing={missing} standardized={str(standardized).lower()}"
180
- + (f" issues=[{issue_str}]" if issue_str else "")
181
- + f" β†’ {flag}"
182
- )
183
- return "\n".join(lines)
 
 
 
 
 
 
 
 
184
 
185
 
186
  def build_user_prompt(obs, history: List[str]) -> str:
187
  rows = obs.dirty_csv.strip().split("\n")
188
  header = rows[0] if rows else ""
189
- preview = "\n".join(rows[:25])
190
- truncated = len(rows) > 25
191
-
192
- col_status_block = _format_column_status(
193
- getattr(obs, "column_status", {})
194
- )
195
-
196
- history_block = (
197
- "\n".join(f" {h}" for h in history[-6:]) if history else " (none yet)"
198
- )
199
 
200
- # Count truly broken columns
201
- col_status = getattr(obs, "column_status", {})
202
  broken = [
203
  c for c, s in col_status.items()
204
  if s.get("missing", 0) > 0 or not s.get("standardized", True)
205
  ]
206
 
207
- return f"""## Current State
208
- Task: {obs.task_id}
209
- Step: {obs.step_number}/{obs.max_steps}
210
- Score: {obs.current_score:.4f} (need {DONE_THRESHOLD[obs.task_id]:.2f} for success)
211
- Issues remaining: {obs.issues_remaining}
212
- Broken columns: {len(broken)} β†’ {broken[:8]}
213
-
214
- ## Schema hint
215
- {obs.schema_hint}
216
-
217
- ## Column status
218
- {col_status_block}
219
-
220
- ## CSV columns
221
- {header}
222
 
223
- ## CSV preview{' (first 25 rows)' if truncated else ''}
224
- {preview}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- ## Previous actions (last 6)
227
- {history_block}
228
 
229
- ## Your task
230
- Select the single most impactful action to bring broken columns to clean state.
231
- Check column_status β€” if all columns show missing=0 and standardized=true β†’ output DONE.
232
- Otherwise β†’ pick the highest-impact fix.
233
- Return ONLY valid JSON, no markdown."""
234
 
 
 
 
 
 
 
 
 
 
235
 
236
- # ── Action parsing ─────────────────────────────────────────────────────────────
237
- # The system prompt uses {action, column, strategy, row, value}.
238
- # CleanAction uses {command, column, fill_strategy, row_index, value}.
239
- # This function bridges the two.
240
 
241
  def parse_action(raw: str) -> CleanAction:
242
  text = raw.strip()
243
 
244
- # Strip markdown fences if model wraps output
245
  if text.startswith("```"):
246
  lines = text.split("\n")
247
  inner = lines[1:-1] if lines[-1].strip().startswith("```") else lines[1:]
248
  text = "\n".join(inner).strip()
249
 
250
- # Extract first {...} block
251
- match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
252
- if not match:
253
  return CleanAction(command="DONE")
254
 
255
  try:
256
- data: Dict[str, Any] = json.loads(match.group())
257
  except json.JSONDecodeError:
258
  return CleanAction(command="DONE")
259
 
260
- # ── Field mapping: prompt format β†’ CleanAction format ─────────────────
261
- action_name: str = str(data.get("action", "DONE")).upper().replace(" ", "_")
262
-
263
- if action_name == "DONE":
264
  return CleanAction(command="DONE")
265
-
266
- # Normalise command name (prompt may say FILL_MISSING, STANDARDIZE_COL, etc.)
267
- command_map = {
268
- "FILL_MISSING": "FILL_MISSING",
269
- "STANDARDIZE_COL": "STANDARDIZE_COL",
270
- "STANDARDIZE": "STANDARDIZE_COL",
271
- "SET_VALUE": "SET_VALUE",
272
- "DROP_ROW": "DROP_ROW",
273
- "DROP": "DROP_ROW",
274
- }
275
- command = command_map.get(action_name)
276
- if command is None:
277
  return CleanAction(command="DONE")
278
 
279
  column = data.get("column")
280
- # "strategy" in prompt β†’ "fill_strategy" in CleanAction
281
  fill_strategy = data.get("strategy") or data.get("fill_strategy")
282
- # "row" in prompt β†’ "row_index" in CleanAction
283
- row_index = data.get("row") if data.get("row") is not None else data.get("row_index")
284
  value = data.get("value")
285
 
286
  try:
287
  return CleanAction(
288
- command=command,
289
- column=column,
290
- fill_strategy=fill_strategy,
291
- row_index=int(row_index) if row_index is not None else None,
292
- value=str(value) if value is not None else None,
293
  )
294
  except Exception:
295
  return CleanAction(command="DONE")
296
 
297
 
298
- def call_llm(client: OpenAI, messages: list) -> str:
299
- response = client.chat.completions.create(
300
- model=MODEL_NAME,
301
- messages=messages,
302
- max_tokens=150,
303
- temperature=0.0, # deterministic β€” the prompt is already very directive
 
 
 
 
 
 
304
  )
305
  return (response.choices[0].message.content or "").strip()
306
 
307
 
308
- # ── Episode runner ─────────────────────────────────────────────────────────────
309
 
310
  async def run_episode(env, client: OpenAI, task_id: str) -> dict:
311
  max_steps = STEP_LIMITS[task_id]
@@ -328,11 +394,10 @@ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
328
  break
329
 
330
  steps_taken = step
331
- user_msg = build_user_prompt(obs, history)
332
- messages.append({"role": "user", "content": user_msg})
333
 
334
  try:
335
- raw = call_llm(client, messages)
336
  action = parse_action(raw)
337
  messages.append({"role": "assistant", "content": raw})
338
  except Exception as exc:
@@ -340,9 +405,9 @@ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
340
  rewards.append(0.0)
341
  break
342
 
343
- # Keep system + last 10 turns (5 user + 5 assistant) inside context
344
- if len(messages) > 21:
345
- messages = [messages[0]] + messages[-20:]
346
 
347
  result = await env.step(action)
348
  obs = result.observation
@@ -358,14 +423,13 @@ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
358
  error = obs.last_action_error,
359
  )
360
 
361
- # Track history for agent context
362
- err_note = f" [BLOCKED: {obs.last_action_error[:50]}]" \
363
- if obs.last_action_error else ""
364
  history.append(
365
  f"step {step}: {action.command}"
366
- + (f" col={action.column}" if action.column else "")
367
- + (f" strategy={action.fill_strategy}" if action.fill_strategy else "")
368
- + f" β†’ score={score:.4f}{err_note}"
 
369
  )
370
 
371
  if obs.done or score >= threshold:
@@ -373,48 +437,61 @@ async def run_episode(env, client: OpenAI, task_id: str) -> dict:
373
 
374
  success = score >= threshold
375
 
 
 
 
376
  finally:
377
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
378
 
379
- return {"task_id": task_id, "score": score,
380
- "reward": sum(rewards), "steps": steps_taken, "success": success}
 
 
 
 
 
381
 
382
 
383
  # ── Entry point ────────────────────────────────────────────────────────────────
384
 
385
  async def main() -> None:
386
- if not HF_TOKEN:
 
 
387
  print(
388
- "ERROR: HF_TOKEN is not set.\n"
389
- "1. Go to https://huggingface.co/settings/tokens\n"
390
- "2. Click 'New token' β†’ 'Read' access β†’ copy the hf_... token\n"
391
- "3. In PowerShell: $env:HF_TOKEN='hf_xxxxxxxxxxxx'\n"
392
- "4. Run: python inference.py",
393
  file=sys.stderr,
394
  )
395
  sys.exit(1)
396
 
397
  print(f"API_BASE_URL : {API_BASE_URL}", flush=True)
398
  print(f"MODEL_NAME : {MODEL_NAME}", flush=True)
399
- print(f"TARGET : {LOCAL_IMAGE_NAME or ENV_BASE_URL}", flush=True)
 
400
  print("", flush=True)
401
 
402
  llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
403
 
404
- if LOCAL_IMAGE_NAME:
405
- env = await DataCleaningEnv.from_docker_image(LOCAL_IMAGE_NAME)
406
- else:
407
- env = DataCleaningEnv(base_url=ENV_BASE_URL)
408
- await env.connect()
409
-
410
  results = []
411
- try:
412
- for task_id in TASK_IDS:
 
 
 
 
 
 
 
413
  summary = await run_episode(env, llm, task_id)
414
  results.append(summary)
415
- print("", flush=True)
416
- finally:
417
- await env.close()
 
 
 
418
 
419
  print("=" * 56, flush=True)
420
  print(f"{'Task':<10} {'Score':>7} {'Reward':>9} {'Steps':>6} {'Pass':>5}")
 
1
  """
2
  inference.py
3
  ------------
4
+ Data Cleaning Pipeline β€” submission inference script.
5
+
6
+ Supports:
7
+ β€’ Ollama local llama3 (DEFAULT β€” no API key needed)
8
+ β€’ Groq free cloud API
9
+ β€’ Any OpenAI-compatible endpoint
10
+
11
+ Environment variables:
12
+ API_BASE_URL LLM endpoint. Default: http://localhost:11434/v1 (Ollama)
13
+ MODEL_NAME Model name. Default: llama3
14
+ HF_TOKEN API key. Default: "ollama" (ignored by Ollama)
15
+ LOCAL_IMAGE_NAME Docker image (leave unset to use ENV_BASE_URL)
16
+ ENV_BASE_URL Env server URL. Default: http://localhost:8000
17
+
18
+ To switch to Groq instead of Ollama:
19
+ $env:API_BASE_URL = "https://api.groq.com/openai/v1"
20
+ $env:MODEL_NAME = "llama-3.3-70b-versatile"
21
+ $env:HF_TOKEN = "gsk_xxxxxxxxxxxx"
22
+
23
+ STDOUT FORMAT (evaluator parses exactly β€” do not modify):
24
  [START] task=<n> env=<benchmark> model=<model>
25
  [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
26
  [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
 
44
  from models import CleanAction, MAX_STEPS, DONE_THRESHOLD
45
 
46
 
47
+ # ── Configuration ─────────────────────────────────────────────────────────────
48
 
49
+ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:11434/v1")
50
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama3")
51
+ HF_TOKEN = os.getenv("HF_TOKEN", "ollama")
52
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
53
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
54
 
55
+ BENCHMARK = "data_cleaning_env"
56
+ TASK_IDS = ["easy", "medium", "hard"]
57
+ STEP_LIMITS = {"easy": 40, "medium": 100, "hard": 150}
58
 
59
 
60
+ # ── System prompt (deterministic agent) ──────────────────────────────────────
61
 
62
+ SYSTEM_PROMPT = """You are a deterministic data cleaning agent.
63
+ Your task is to clean a dataset step-by-step using valid actions.
64
+ You are operating inside an environment with strict rules.
65
+ --------------------------------------------------
66
+ ## INPUT PROVIDED EACH STEP
67
+ You will receive:
68
+ 1. Column schema (LIST OF VALID COLUMN NAMES - CASE SENSITIVE)
69
  2. Column status:
70
  - missing values count
71
+ - whether standardized (true/false)
72
+ 3. Remaining issues (global state)
73
+ 4. Previous actions taken
74
+ --------------------------------------------------
75
+ ## OBJECTIVE
76
+ Fully clean the dataset with MINIMUM steps.
77
+ A dataset is CLEAN only if:
78
+ - No missing values remain
 
 
79
  - All columns are standardized
80
+ - No invalid formats exist
81
+ --------------------------------------------------
82
+ ## STRICT RULES (MUST FOLLOW)
83
+
84
+ ### 1. NEVER TERMINATE EARLY
85
+ You MUST NOT output DONE unless:
86
+ - ALL columns have missing = 0
87
+ - ALL columns have standardized = true
88
+ - remaining_issues is empty
89
+ If ANY issue remains -> DO NOT output DONE.
90
+
91
+ ### 2. USE ONLY VALID COLUMNS
92
+ - You MUST use EXACT column names from the schema list
93
+ - Column names are CASE SENSITIVE
94
+ - NEVER invent new column names
95
+
96
+ ### 3. PRIORITIZE COLUMN-LEVEL ACTIONS
97
+ Preferred actions (in order):
98
+ 1. FILL_MISSING - fixes entire column missing values
99
+ 2. STANDARDIZE_COL - fixes formatting for entire column
100
+ 3. SET_VALUE - only for a single isolated bad cell
101
+ 4. DROP_ROW - only for truly corrupt/outlier rows
102
+ NEVER fix a full column using repeated SET_VALUE.
103
+
104
+ ### 4. DO NOT REPEAT ACTIONS
105
+ - Do NOT apply the same action to the same column twice
106
+ - Do NOT standardize an already standardized column
107
  - Do NOT fill missing if missing = 0
108
 
109
+ ### 5. CHOOSE THE CORRECT FILL STRATEGY
110
+ - Numeric columns (float/int): use "median" or "mean"
111
+ - Categorical/string columns: use "mode"
112
+ - NEVER use "mean" or "median" on a categorical column
113
+
114
+ ### 6. ALWAYS THINK GLOBALLY
115
+ Before choosing an action:
116
+ - Review ALL columns in column_status
117
+ - Pick the single action that fixes the largest remaining issue
118
+ --------------------------------------------------
 
119
  ## DECISION PROCESS (MANDATORY)
120
  At each step:
121
+ 1. Read column_status carefully
122
+ 2. Find columns where missing > 0 OR standardized = false
123
+ 3. If none exist AND remaining_issues is empty -> output DONE
124
+ 4. Otherwise, pick the ONE most impactful action
125
+ --------------------------------------------------
126
+ ## OUTPUT FORMAT - STRICT JSON ONLY
127
+ Return ONLY a single JSON object. No explanation. No markdown. No backticks.
128
 
129
+ Fill missing values:
130
+ {"action": "FILL_MISSING", "column": "<exact_col_name>", "strategy": "<mean|median|mode>"}
 
131
 
132
+ Standardize a column:
133
+ {"action": "STANDARDIZE_COL", "column": "<exact_col_name>"}
 
134
 
135
+ Fix one cell:
136
+ {"action": "SET_VALUE", "column": "<exact_col_name>", "row": <int>, "value": "<str>"}
137
 
138
+ Drop a bad row:
139
  {"action": "DROP_ROW", "row": <int>}
140
 
141
+ Signal completion:
142
  {"action": "DONE"}
143
 
144
+ --------------------------------------------------
145
+ ## FAILURE CONDITIONS (YOU WILL BE PENALIZED FOR):
146
+ - Outputting DONE when issues remain
147
+ - Using a column name not in the schema
148
+ - Repeating the same action on the same column
149
+ - Using SET_VALUE to fix an entire column
150
+ - Using mean/median on a categorical column
151
+ - Using mode on a numeric column
152
+ --------------------------------------------------
153
+ ## FINAL GOAL
154
+ Be efficient, precise, and minimal.
155
+ Every step must move the dataset closer to a fully clean state."""
156
 
157
 
158
+ # ── Official log helpers ──────────────────────────────────────────────────────
159
 
160
  def log_start(task: str, env: str, model: str) -> None:
161
  print(f"[START] task={task} env={env} model={model}", flush=True)
162
 
163
 
164
+ def log_step(step: int, action: str, reward: float,
165
+ done: bool, error: Optional[str]) -> None:
 
166
  print(
167
+ f"[STEP] step={step} action={action[:80].replace(chr(10), ' ')} "
168
+ f"reward={reward:.2f} done={str(done).lower()} "
169
+ f"error={error if error else 'null'}",
170
  flush=True,
171
  )
172
 
173
 
174
  def log_end(success: bool, steps: int, score: float,
175
  rewards: List[float]) -> None:
 
176
  print(
177
  f"[END] success={str(success).lower()} steps={steps} "
178
+ f"score={score:.4f} rewards={','.join(f'{r:.2f}' for r in rewards)}",
179
  flush=True,
180
  )
181
 
182
 
183
+ # ── Column type hints (used to suggest fill strategies) ──────────────────────
184
+
185
+ _COL_TYPES: Dict[str, Dict[str, str]] = {
186
+ "easy": {
187
+ "order_id": "numeric",
188
+ "customer": "categorical",
189
+ "product": "categorical",
190
+ "category": "categorical",
191
+ "price": "numeric",
192
+ "quantity": "numeric",
193
+ "order_date": "datetime",
194
+ "region": "categorical",
195
+ },
196
+ "medium": {
197
+ "tx_id": "numeric",
198
+ "customer_id": "numeric",
199
+ "amount": "numeric",
200
+ "tx_date": "datetime",
201
+ "category": "categorical",
202
+ "country": "categorical",
203
+ "status": "categorical",
204
+ },
205
+ "hard": {
206
+ "record_id": "numeric", "id": "numeric", "RecordID": "numeric",
207
+ "customer_id": "numeric", "cust_id": "numeric", "CustomerID": "numeric",
208
+ "full_name": "categorical","name": "categorical","CustomerName":"categorical",
209
+ "email": "categorical","email_address": "categorical","Email": "categorical",
210
+ "amount": "numeric", "sale_amount": "numeric", "Amount": "numeric",
211
+ "currency": "categorical","ccy": "categorical","Currency": "categorical",
212
+ "purchase_date": "datetime", "date": "datetime", "PurchaseDate":"datetime",
213
+ "product_name": "categorical","item": "categorical","ProductName": "categorical",
214
+ "region": "categorical","territory": "categorical","area": "categorical",
215
+ "contact_email": "categorical","value": "numeric", "product": "categorical",
216
+ },
217
+ }
218
+
219
+
220
+ def _strategy_hint(task_id: str, col: str) -> str:
221
+ col_type = _COL_TYPES.get(task_id, {}).get(col, "unknown")
222
+ if col_type == "numeric":
223
+ return "median"
224
+ if col_type in ("categorical", "datetime"):
225
+ return "mode"
226
+ return "median"
227
+
228
+
229
  # ── Prompt builder ────────────────────────────────────────────────────────────
230
 
231
+ def _column_status_block(obs, task_id: str) -> str:
232
+ col_status: Dict[str, Any] = getattr(obs, "column_status", {}) or {}
233
+
234
+ if col_status:
235
+ lines = []
236
+ for col, status in col_status.items():
237
+ missing = status.get("missing", 0)
238
+ standardized = status.get("standardized", True)
239
+ hint = _strategy_hint(task_id, col)
240
+ flag = "OK" if (missing == 0 and standardized) else "NEEDS_FIX"
241
+ lines.append(
242
+ f" {col:<22} missing={missing:<4} "
243
+ f"standardized={str(standardized).lower():<5} "
244
+ f"fill_strategy={hint:<7} [{flag}]"
245
+ )
246
+ return "\n".join(lines)
247
+
248
+ # Fallback: derive columns from CSV header
249
+ rows = obs.dirty_csv.strip().split("\n")
250
+ header = rows[0] if rows else ""
251
+ cols = [c.strip() for c in header.split(",")]
252
+ return "\n".join(
253
+ f" {col:<22} (status unknown) fill_strategy={_strategy_hint(task_id, col)}"
254
+ for col in cols
255
+ )
256
 
257
 
258
  def build_user_prompt(obs, history: List[str]) -> str:
259
  rows = obs.dirty_csv.strip().split("\n")
260
  header = rows[0] if rows else ""
261
+ data_rows = rows[1:]
262
+ preview = "\n".join([header] + data_rows[:10])
263
+ truncated = len(data_rows) > 10
 
 
 
 
 
 
 
264
 
265
+ col_status: Dict[str, Any] = getattr(obs, "column_status", {}) or {}
 
266
  broken = [
267
  c for c, s in col_status.items()
268
  if s.get("missing", 0) > 0 or not s.get("standardized", True)
269
  ]
270
 
271
+ history_block = (
272
+ "\n".join(f" {h}" for h in history[-6:])
273
+ if history else " (none yet)"
274
+ )
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ return (
277
+ f"--------------------------------------------------\n"
278
+ f"## STEP {obs.step_number}/{obs.max_steps}\n"
279
+ f"Score: {obs.current_score:.4f} "
280
+ f"(need >= {DONE_THRESHOLD[obs.task_id]:.2f} to pass)\n"
281
+ f"Issues remaining: {obs.issues_remaining}\n"
282
+ f"Broken columns: {len(broken)} -> {broken[:10] if broken else 'NONE β€” consider DONE'}\n"
283
+ f"\n## SCHEMA HINT\n{obs.schema_hint}\n"
284
+ f"\n## VALID COLUMN NAMES (CASE SENSITIVE β€” copy exactly)\n{header}\n"
285
+ f"\n## COLUMN STATUS (read carefully before acting)\n"
286
+ f"{_column_status_block(obs, obs.task_id)}\n"
287
+ f"\n## CSV PREVIEW"
288
+ f"{' (first 10 of ' + str(len(data_rows)) + ' rows)' if truncated else ''}\n"
289
+ f"{preview}\n"
290
+ f"\n## PREVIOUS ACTIONS (last 6)\n{history_block}\n"
291
+ f"\n--------------------------------------------------\n"
292
+ f"## DECISION CHECKLIST\n"
293
+ f"1. Any column with missing > 0? -> FILL_MISSING (use strategy from column status)\n"
294
+ f"2. Any column with standardized=false? -> STANDARDIZE_COL\n"
295
+ f"3. Isolated bad cell visible in CSV? -> SET_VALUE\n"
296
+ f"4. Clearly corrupt/outlier row? -> DROP_ROW\n"
297
+ f"5. ALL missing=0, ALL standardized=true, issues=0? -> DONE\n"
298
+ f"\nOutput ONE JSON action (no markdown, no explanation):"
299
+ )
300
 
 
 
301
 
302
+ # ── Action parser ─────────────────────────────────────────────────────────────
303
+ # Bridges {action, column, strategy, row, value} -> CleanAction
 
 
 
304
 
305
+ _COMMAND_MAP = {
306
+ "FILL_MISSING": "FILL_MISSING",
307
+ "STANDARDIZE_COL": "STANDARDIZE_COL",
308
+ "STANDARDIZE": "STANDARDIZE_COL",
309
+ "SET_VALUE": "SET_VALUE",
310
+ "DROP_ROW": "DROP_ROW",
311
+ "DROP": "DROP_ROW",
312
+ "DONE": "DONE",
313
+ }
314
 
 
 
 
 
315
 
316
  def parse_action(raw: str) -> CleanAction:
317
  text = raw.strip()
318
 
319
+ # Strip markdown fences
320
  if text.startswith("```"):
321
  lines = text.split("\n")
322
  inner = lines[1:-1] if lines[-1].strip().startswith("```") else lines[1:]
323
  text = "\n".join(inner).strip()
324
 
325
+ m = re.search(r"\{[^{}]*\}", text, re.DOTALL)
326
+ if not m:
 
327
  return CleanAction(command="DONE")
328
 
329
  try:
330
+ data: Dict[str, Any] = json.loads(m.group())
331
  except json.JSONDecodeError:
332
  return CleanAction(command="DONE")
333
 
334
+ raw_cmd = str(data.get("action", "DONE")).upper().strip().replace(" ", "_")
335
+ command = _COMMAND_MAP.get(raw_cmd)
336
+ if not command:
 
337
  return CleanAction(command="DONE")
338
+ if command == "DONE":
 
 
 
 
 
 
 
 
 
 
 
339
  return CleanAction(command="DONE")
340
 
341
  column = data.get("column")
 
342
  fill_strategy = data.get("strategy") or data.get("fill_strategy")
343
+ row_raw = data.get("row") if data.get("row") is not None else data.get("row_index")
 
344
  value = data.get("value")
345
 
346
  try:
347
  return CleanAction(
348
+ command = command,
349
+ column = column,
350
+ fill_strategy = fill_strategy,
351
+ row_index = int(row_raw) if row_raw is not None else None,
352
+ value = str(value) if value is not None else None,
353
  )
354
  except Exception:
355
  return CleanAction(command="DONE")
356
 
357
 
358
+ # ── LLM call (async β€” keeps WebSocket keepalive alive) ───────────────────────
359
+
360
+ async def call_llm_async(client: OpenAI, messages: list) -> str:
361
+ loop = asyncio.get_event_loop()
362
+ response = await loop.run_in_executor(
363
+ None,
364
+ lambda: client.chat.completions.create(
365
+ model = MODEL_NAME,
366
+ messages = messages,
367
+ max_tokens = 120,
368
+ temperature = 0.0,
369
+ ),
370
  )
371
  return (response.choices[0].message.content or "").strip()
372
 
373
 
374
+ # ── Episode loop ───────────────────────────────────────────────────────────────
375
 
376
  async def run_episode(env, client: OpenAI, task_id: str) -> dict:
377
  max_steps = STEP_LIMITS[task_id]
 
394
  break
395
 
396
  steps_taken = step
397
+ messages.append({"role": "user", "content": build_user_prompt(obs, history)})
 
398
 
399
  try:
400
+ raw = await call_llm_async(client, messages)
401
  action = parse_action(raw)
402
  messages.append({"role": "assistant", "content": raw})
403
  except Exception as exc:
 
405
  rewards.append(0.0)
406
  break
407
 
408
+ # Keep system + last 3 exchanges to avoid context overflow
409
+ if len(messages) > 7:
410
+ messages = [messages[0]] + messages[-6:]
411
 
412
  result = await env.step(action)
413
  obs = result.observation
 
423
  error = obs.last_action_error,
424
  )
425
 
426
+ err_note = f" [ERR: {obs.last_action_error[:40]}]" if obs.last_action_error else ""
 
 
427
  history.append(
428
  f"step {step}: {action.command}"
429
+ + (f"({action.column}"
430
+ + (f", {action.fill_strategy})" if action.fill_strategy else ")")
431
+ if action.column else "")
432
+ + f" -> score={score:.4f}{err_note}"
433
  )
434
 
435
  if obs.done or score >= threshold:
 
437
 
438
  success = score >= threshold
439
 
440
+ except Exception as e:
441
+ print(f"[EPISODE ERROR] task={task_id} error={str(e)[:120]}", flush=True)
442
+
443
  finally:
444
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
445
 
446
+ return {
447
+ "task_id": task_id,
448
+ "score": score,
449
+ "reward": sum(rewards),
450
+ "steps": steps_taken,
451
+ "success": success,
452
+ }
453
 
454
 
455
  # ── Entry point ────────────────────────────────────────────────────────────────
456
 
457
  async def main() -> None:
458
+ is_ollama = "11434" in API_BASE_URL or "ollama" in API_BASE_URL.lower()
459
+
460
+ if not is_ollama and (not HF_TOKEN or HF_TOKEN == "ollama"):
461
  print(
462
+ "ERROR: HF_TOKEN not set for remote API.\n"
463
+ "For Groq: $env:HF_TOKEN='gsk_xxxxxxxxxxxx'\n"
464
+ "For Ollama (local): no token needed β€” defaults already set.",
 
 
465
  file=sys.stderr,
466
  )
467
  sys.exit(1)
468
 
469
  print(f"API_BASE_URL : {API_BASE_URL}", flush=True)
470
  print(f"MODEL_NAME : {MODEL_NAME}", flush=True)
471
+ print(f"BACKEND : {'Ollama (local)' if is_ollama else 'Remote API'}", flush=True)
472
+ print(f"ENV SERVER : {LOCAL_IMAGE_NAME or ENV_BASE_URL}", flush=True)
473
  print("", flush=True)
474
 
475
  llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
476
 
 
 
 
 
 
 
477
  results = []
478
+ for task_id in TASK_IDS:
479
+ # Fresh connection per task β€” prevents WebSocket keepalive timeout carryover
480
+ if LOCAL_IMAGE_NAME:
481
+ env = await DataCleaningEnv.from_docker_image(LOCAL_IMAGE_NAME)
482
+ else:
483
+ env = DataCleaningEnv(base_url=ENV_BASE_URL)
484
+ await env.connect()
485
+
486
+ try:
487
  summary = await run_episode(env, llm, task_id)
488
  results.append(summary)
489
+ finally:
490
+ try:
491
+ await env.close()
492
+ except Exception:
493
+ pass
494
+ print("", flush=True)
495
 
496
  print("=" * 56, flush=True)
497
  print(f"{'Task':<10} {'Score':>7} {'Reward':>9} {'Steps':>6} {'Pass':>5}")