OpenEnv_hack / server /environment.py
srishtichugh's picture
add ui
40fcf49
"""
Core environment implementing reset / step / state.
Each call to reset() picks a task (round-robin: 1 -> 2 -> 3 -> 1 ...)
or a specific task_id can be forced via reset(task_id=N).
Phase 2 additions:
- DataQualityMetrics computed every step (completeness, uniqueness, validity)
- tried_operations: deduplication log so agent avoids repeating useless ops
- plan: rule-based next-action recommendations surfaced in every observation
- Episode history tracked for /report endpoint
"""
import re
import uuid
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Optional, Tuple
from models import (
DataCleaningAction, DataCleaningObservation,
DataCleaningState, DataQualityMetrics, EpisodeReport,
)
import server.tasks.task1_missing as t1
import server.tasks.task2_format as t2
import server.tasks.task3_pipeline as t3
import server.tasks.task4_merge as t4
TASK_MODULES = {1: t1, 2: t2, 3: t3, 4: t4}
TASK_NAMES = {
1: "Fill Missing Values",
2: "Fix Formats + Remove Duplicates",
3: "Full Cleaning Pipeline",
4: "Multi-Source Schema Alignment + Merge",
}
PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$")
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"}
class DataCleaningEnvironment:
def __init__(self):
self._df: Optional[pd.DataFrame] = None
self._clean_df: Optional[pd.DataFrame] = None
self._meta: Any = None
self._task_id: int = 1
self._episode_id: str = ""
self._step_count: int = 0
self._max_steps: int = 20
self._total_errors: int = 0
self._last_score: float = 0.01
self._initial_score: float = 0.01
self._task_cycle: int = 0
# Phase 2 tracking
self._tried_operations: List[str] = []
self._operations_log: List[str] = []
self._issues_fixed: Dict[str, int] = {}
self._initial_dq: Optional[DataQualityMetrics] = None
# Task 4 state
self._source_b: Optional[pd.DataFrame] = None # held until merge_sources called
self._schema_aligned: bool = False
self._sources_merged: bool = False
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(self, task_id: Optional[int] = None) -> DataCleaningObservation:
if task_id is None:
self._task_cycle = (self._task_cycle % 3) + 1
task_id = self._task_cycle
if task_id not in TASK_MODULES:
raise ValueError(f"task_id must be 1, 2, 3, or 4 — got {task_id}")
mod = TASK_MODULES[task_id]
self._task_id = task_id
self._episode_id = str(uuid.uuid4())
self._step_count = 0
self._max_steps = mod.MAX_STEPS
# Task 4 returns 4 values; others return 3
if task_id == 4:
self._df, self._source_b, self._clean_df, self._meta = mod.load()
self._schema_aligned = False
self._sources_merged = False
else:
self._df, self._clean_df, self._meta = mod.load()
self._source_b = None
self._schema_aligned = False
self._sources_merged = False
self._last_score = self._compute_score()
self._initial_score = self._last_score
self._total_errors = self._count_errors()
# Reset Phase 2 state
self._tried_operations = []
self._operations_log = []
self._issues_fixed = {"nulls_filled": 0, "dupes_removed": 0,
"formats_fixed": 0, "outliers_removed": 0}
self._initial_dq = self._compute_dq_metrics()
return self._build_obs(self._last_score, False, "Episode started. Begin cleaning.")
def step(self, action: DataCleaningAction) -> DataCleaningObservation:
if self._df is None:
raise RuntimeError("Call reset() before step().")
self._step_count += 1
score_before = self._last_score
# Track tried operations BEFORE applying (for feedback loop)
op_key = self._make_op_key(action)
message, applied = self._apply_action(action)
score_after = self._compute_score()
self._last_score = score_after
delta = score_after - score_before
if not applied:
reward = -0.01
elif delta <= 0:
reward = -0.01
else:
reward = round(delta, 4)
# Log successful operation
if op_key not in self._tried_operations:
self._tried_operations.append(op_key)
self._operations_log.append(message)
self._update_issues_fixed(action, message)
done = (score_after >= 0.95) or (self._step_count >= self._max_steps)
reward = round(max(-0.05, min(0.99, reward)), 4)
return self._build_obs(reward, done, message)
def state(self) -> DataCleaningState:
if self._df is None:
return DataCleaningState(
episode_id="", task_id=0, step_count=0,
max_steps=0, total_errors=0, errors_remaining=0,
)
return DataCleaningState(
episode_id = self._episode_id,
task_id = self._task_id,
step_count = self._step_count,
max_steps = self._max_steps,
total_errors = self._total_errors,
errors_remaining = self._count_errors(),
)
def get_profile(self) -> Dict[str, Any]:
"""Rich data profile for GET /profile endpoint."""
if self._df is None:
return {}
dq = self._compute_dq_metrics()
profile: Dict[str, Any] = {
"episode_id": self._episode_id,
"task_id": self._task_id,
"shape": {"rows": self._df.shape[0], "cols": self._df.shape[1]},
"dq_metrics": dq.model_dump(),
"columns": {},
}
for col in self._df.columns:
series = self._df[col]
col_info: Dict[str, Any] = {
"dtype": str(series.dtype),
"null_count": int(series.isnull().sum()),
"null_pct": round(series.isnull().mean() * 100, 2),
"unique_count": int(series.nunique(dropna=True)),
"unique_pct": round(series.nunique(dropna=True) / max(len(series), 1) * 100, 2),
}
if pd.api.types.is_numeric_dtype(series):
desc = series.describe()
col_info.update({
"min": round(float(desc["min"]), 4) if pd.notna(desc["min"]) else None,
"max": round(float(desc["max"]), 4) if pd.notna(desc["max"]) else None,
"mean": round(float(desc["mean"]), 4) if pd.notna(desc["mean"]) else None,
"median": round(float(series.median()), 4) if pd.notna(series.median()) else None,
"std": round(float(desc["std"]), 4) if pd.notna(desc.get("std", float("nan"))) else None,
})
else:
top = series.value_counts(dropna=True).head(3).to_dict()
col_info["top_values"] = {str(k): int(v) for k, v in top.items()}
profile["columns"][col] = col_info
return profile
def get_report(self) -> EpisodeReport:
"""Full episode cleaning summary for GET /report endpoint."""
if self._df is None:
raise RuntimeError("No active episode.")
steps_used = self._step_count
efficiency = round((1 - steps_used / max(self._max_steps, 1)) * 100, 1)
return EpisodeReport(
episode_id = self._episode_id,
task_id = self._task_id,
task_name = TASK_NAMES.get(self._task_id, f"Task {self._task_id}"),
initial_score = self._initial_score,
final_score = self._last_score,
score_improvement = round(self._last_score - self._initial_score, 4),
steps_taken = steps_used,
max_steps = self._max_steps,
step_efficiency_pct = max(0.0, efficiency),
operations_applied = list(self._operations_log),
issues_fixed = dict(self._issues_fixed),
final_dq_metrics = self._compute_dq_metrics(),
completed = self._last_score >= 0.95,
)
def get_export(self) -> str:
"""Return current cleaned DataFrame as CSV string for GET /export."""
if self._df is None:
raise RuntimeError("No active episode.")
return self._df.to_csv(index=False)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _make_op_key(self, action: DataCleaningAction) -> str:
if action.column:
return f"{action.operation}:{action.column}"
return action.operation
def _update_issues_fixed(self, action: DataCleaningAction, message: str) -> None:
op = action.operation.lower()
# Parse numbers from message e.g. "Filled 20 missing values..."
nums = re.findall(r"\d+", message)
n = int(nums[0]) if nums else 1
if op == "fill_missing":
self._issues_fixed["nulls_filled"] = self._issues_fixed.get("nulls_filled", 0) + n
elif op == "drop_duplicates":
self._issues_fixed["dupes_removed"] = self._issues_fixed.get("dupes_removed", 0) + n
elif op == "fix_format":
self._issues_fixed["formats_fixed"] = self._issues_fixed.get("formats_fixed", 0) + n
elif op == "drop_outliers":
self._issues_fixed["outliers_removed"] = self._issues_fixed.get("outliers_removed", 0) + n
def _compute_dq_metrics(self) -> DataQualityMetrics:
total_cells = int(self._df.size)
null_cells = int(self._df.isnull().sum().sum())
duplicate_rows = int(len(self._df) - len(self._df.drop_duplicates()))
invalid_cells = self._count_invalid_cells()
completeness = round((1 - null_cells / max(total_cells, 1)) * 100, 2)
uniqueness = round((1 - duplicate_rows / max(len(self._df), 1)) * 100, 2)
validity = round((1 - invalid_cells / max(total_cells, 1)) * 100, 2)
return DataQualityMetrics(
completeness_pct = completeness,
uniqueness_pct = uniqueness,
validity_pct = validity,
total_cells = total_cells,
null_cells = null_cells,
duplicate_rows = duplicate_rows,
invalid_cells = invalid_cells,
)
def _count_invalid_cells(self) -> int:
"""Count cells with format/dtype/range violations."""
invalid = 0
for col in self._df.columns:
series = self._df[col].dropna()
if col == "phone":
invalid += int((~series.astype(str).str.match(PHONE_RE, na=False)).sum())
elif col in ("listed_date", "signup_date"):
invalid += int((~series.apply(
lambda x: bool(DATE_RE.match(str(x)))
)).sum())
elif col == "country":
invalid += int((~series.isin(VALID_COUNTRIES)).sum())
elif col == "age":
invalid += int(((series < 0) | (series > 120)).sum())
elif col == "salary":
invalid += int((series < 0).sum())
elif col == "purchase_amount":
invalid += int((series < 0).sum())
return invalid
def _generate_plan(self) -> List[str]:
"""
Rule-based planning engine — inspects current DataFrame state
and returns up to 3 prioritised recommended actions.
Inspired by AutoDCWorkflow (EMNLP 2025).
"""
plan: List[str] = []
if self._df is None:
return plan
# Task 4: schema alignment + merge must happen first
if self._task_id == 4:
if not self._schema_aligned:
return ["align_schema — rename Source A columns to canonical schema (required first step)"]
if not self._sources_merged:
return ["merge_sources — concatenate aligned Source A + Source B (required before cleaning)"]
missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0}
dupes = len(self._df) - len(self._df.drop_duplicates())
# Priority 1: missing values (highest DQ impact)
for col, count in sorted(missing.items(), key=lambda x: -x[1]):
op_key = f"fill_missing:{col}"
if op_key not in self._tried_operations:
strategy = "mode" if self._df[col].dtype == object else "median"
plan.append(
f'fill_missing on "{col}" ({count} nulls) using {strategy}'
)
if len(plan) >= 2:
break
# Priority 2: duplicates
if dupes > 0 and "drop_duplicates" not in self._tried_operations:
plan.append(f"drop_duplicates ({dupes} duplicate rows found)")
# Priority 3: format issues
for col in self._df.columns:
if len(plan) >= 3:
break
op_key = f"fix_format:{col}"
if op_key in self._tried_operations:
continue
if col == "phone":
bad = int((~self._df[col].dropna().astype(str).str.match(PHONE_RE)).sum())
if bad > 0:
plan.append(f'fix_format on "phone" ({bad} malformed numbers)')
elif col in ("listed_date", "signup_date"):
bad = int((~self._df[col].dropna().apply(
lambda x: bool(DATE_RE.match(str(x)))
)).sum())
if bad > 0:
plan.append(f'fix_format on "{col}" ({bad} malformed dates)')
elif col == "country":
bad = int((~self._df[col].dropna().isin(VALID_COUNTRIES)).sum())
if bad > 0:
plan.append(f'fix_format on "country" ({bad} invalid values)')
# Priority 4: outliers on numeric cols
if len(plan) < 3:
for col in self._df.select_dtypes(include=[np.number]).columns:
op_key = f"drop_outliers:{col}"
if op_key in self._tried_operations:
continue
q1, q3 = self._df[col].quantile(0.25), self._df[col].quantile(0.75)
iqr = q3 - q1
outliers = int((self._df[col] > q3 + 3 * iqr).sum())
if outliers > 0:
plan.append(f'drop_outliers on "{col}" ({outliers} extreme values)')
break
return plan[:3]
def _compute_score(self) -> float:
if self._task_id == 1:
raw = t1.score(self._df, self._meta)
elif self._task_id == 2:
raw = t2.score(self._df, self._meta)
elif self._task_id == 3:
raw = t3.score(self._df, self._meta)
else:
raw = t4.score(self._df, self._meta)
raw = float(raw)
EPS = 1e-4
if raw >= 1.0:
raw = 1.0 - EPS
elif raw <= 0.0:
raw = EPS
return round(raw, 4)
def _count_errors(self) -> int:
if self._task_id == 1:
return t1.count_errors(self._df)
elif self._task_id == 2:
return t2.count_errors(self._df, self._meta)
elif self._task_id == 3:
return t3.count_errors(self._df, self._meta)
else:
return t4.count_errors(self._df, self._meta)
def _build_obs(self, reward: float, done: bool, message: str) -> DataCleaningObservation:
mod = TASK_MODULES[self._task_id]
missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0}
dupes = len(self._df) - len(self._df.drop_duplicates())
dtype_issues = self._detect_dtype_issues()
preview = self._df.head(10).to_csv(index=False)
dq_metrics = self._compute_dq_metrics()
plan = self._generate_plan()
return DataCleaningObservation(
done = done,
reward = reward,
data_preview = preview,
data_shape = list(self._df.shape),
missing_counts = missing,
duplicate_count = dupes,
dtype_issues = dtype_issues,
task_description = mod.DESCRIPTION,
message = message,
step_count = self._step_count,
current_score = self._last_score,
dq_metrics = dq_metrics,
tried_operations = list(self._tried_operations),
plan = plan,
)
def _detect_dtype_issues(self) -> Dict[str, str]:
issues: Dict[str, str] = {}
for col in self._df.columns:
series = self._df[col].dropna()
if series.empty:
continue
if self._df[col].dtype == object:
numeric_count = pd.to_numeric(series, errors="coerce").notna().sum()
if numeric_count / len(series) > 0.8:
issues[col] = "stored as string but appears numeric"
return issues
# ------------------------------------------------------------------
# Action dispatcher
# ------------------------------------------------------------------
def _apply_action(self, action: DataCleaningAction) -> Tuple[str, bool]:
op = action.operation.strip().lower()
col = action.column
p = action.params or {}
try:
if op == "fill_missing":
return self._fill_missing(col, p)
elif op == "drop_duplicates":
return self._drop_duplicates()
elif op == "fix_format":
return self._fix_format(col)
elif op == "replace_value":
return self._replace_value(col, p)
elif op == "drop_outliers":
return self._drop_outliers(col)
elif op == "fix_dtype":
return self._fix_dtype(col, p)
elif op == "align_schema":
return self._align_schema()
elif op == "merge_sources":
return self._merge_sources()
else:
return (
f"Unknown operation '{op}'. Choose from: fill_missing, "
"drop_duplicates, fix_format, replace_value, drop_outliers, "
"fix_dtype, align_schema, merge_sources.",
False,
)
except Exception as exc:
return f"Operation failed: {exc}", False
def _fill_missing(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
n_before = int(self._df[col].isnull().sum())
if n_before == 0:
return f"No missing values in '{col}'.", False
strategy = str(p.get("strategy", "median")).lower()
if strategy == "median":
fill_val = self._df[col].median(skipna=True)
elif strategy == "mean":
fill_val = self._df[col].mean(skipna=True)
elif strategy == "mode":
mode = self._df[col].mode(dropna=True)
fill_val = mode.iloc[0] if not mode.empty else None
elif strategy == "constant":
fill_val = p.get("value")
else:
return f"Unknown strategy '{strategy}'.", False
if fill_val is None:
return "Could not determine fill value.", False
self._df[col] = self._df[col].fillna(fill_val)
n_after = int(self._df[col].isnull().sum())
return f"Filled {n_before - n_after} missing values in '{col}' using {strategy}.", True
def _drop_duplicates(self) -> Tuple[str, bool]:
n_before = len(self._df)
self._df = self._df.drop_duplicates().reset_index(drop=True)
removed = n_before - len(self._df)
if removed == 0:
return "No duplicate rows found.", False
return f"Dropped {removed} duplicate rows.", True
def _fix_format(self, col) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
if col == "phone":
return self._fix_phone(col)
elif col in ("listed_date", "signup_date"):
return self._fix_date(col)
elif col == "country":
return self._fix_country(col)
else:
return f"No format rule defined for column '{col}'.", False
def _fix_phone(self, col) -> Tuple[str, bool]:
def normalise(val):
if pd.isna(val):
return val
digits = re.sub(r"\D", "", str(val))
if len(digits) == 10:
return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}"
return val
before = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
fixed = int(before - after)
if fixed == 0:
return f"No phone format issues found in '{col}'.", False
return f"Fixed {fixed} phone numbers in '{col}' to NNN-NNN-NNNN format.", True
def _fix_date(self, col) -> Tuple[str, bool]:
_DATE_FORMATS = ["%Y-%m-%d", "%b %d %Y", "%d/%m/%Y", "%m/%d/%Y", "%Y/%m/%d"]
def normalise(val):
if pd.isna(val):
return val
s = str(val).strip()
for fmt in _DATE_FORMATS:
try:
return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d")
except Exception:
pass
try:
return pd.to_datetime(s).strftime("%Y-%m-%d")
except Exception:
return val
before = (~self._df[col].apply(
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
)).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].apply(
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
)).sum()
fixed = int(before - after)
if fixed == 0:
return f"No date format issues found in '{col}'.", False
return f"Fixed {fixed} dates in '{col}' to YYYY-MM-DD format.", True
def _fix_country(self, col) -> Tuple[str, bool]:
def normalise(val):
if pd.isna(val):
return val
mapping = {
"usa": "USA", "uk": "UK", "canada": "Canada",
"australia": "Australia", "germany": "Germany",
}
return mapping.get(str(val).strip().lower(), val)
before = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
fixed = int(before - after)
if fixed == 0:
return "No country capitalisation issues found.", False
return f"Fixed {fixed} country values to correct capitalisation.", True
def _replace_value(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
old = p.get("old")
new = p.get("new")
if old is None:
return "params.old is required for replace_value.", False
count = int((self._df[col] == old).sum())
if count == 0:
return f"Value '{old}' not found in '{col}'.", False
self._df[col] = self._df[col].replace(old, new)
return f"Replaced {count} occurrences of '{old}' with '{new}' in '{col}'.", True
def _drop_outliers(self, col) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
if not pd.api.types.is_numeric_dtype(self._df[col]):
return f"'{col}' is not numeric.", False
q1 = self._df[col].quantile(0.25)
q3 = self._df[col].quantile(0.75)
iqr = q3 - q1
mask = (self._df[col] >= q1 - 3 * iqr) & (self._df[col] <= q3 + 3 * iqr)
n_before = len(self._df)
self._df = self._df[mask | self._df[col].isna()].reset_index(drop=True)
removed = n_before - len(self._df)
if removed == 0:
return f"No outliers found in '{col}'.", False
return f"Removed {removed} outlier rows from '{col}' using IQR method.", True
def _align_schema(self) -> Tuple[str, bool]:
"""Rename Source A columns to canonical target schema (Task 4 only)."""
if self._task_id != 4:
return "align_schema is only available in Task 4.", False
if self._schema_aligned:
return "Schema already aligned.", False
from server.tasks.task4_merge import SOURCE_A_RENAME, TARGET_COLUMNS
missing_src = [c for c in SOURCE_A_RENAME if c not in self._df.columns]
if missing_src:
return f"Expected Source A columns not found: {missing_src}.", False
self._df = self._df.rename(columns=SOURCE_A_RENAME)
self._schema_aligned = True
renamed = list(SOURCE_A_RENAME.keys())
return (
f"Aligned Source A schema: renamed {len(SOURCE_A_RENAME)} columns "
f"({', '.join(renamed)}) to canonical target schema.", True
)
def _merge_sources(self) -> Tuple[str, bool]:
"""Concatenate aligned Source A with Source B (Task 4 only)."""
if self._task_id != 4:
return "merge_sources is only available in Task 4.", False
if self._sources_merged:
return "Sources already merged.", False
if not self._schema_aligned:
return "Run align_schema before merge_sources.", False
if self._source_b is None:
return "Source B not available.", False
from server.tasks.task4_merge import TARGET_COLUMNS, _META_TEMPLATE
n_a = len(self._df)
n_b = len(self._source_b)
# Rename source_b columns to canonical schema
source_b_rename = {
"age_years": "age",
"spend": "purchase_amount",
"country_name": "country",
"registration_date": "signup_date",
}
source_b_aligned = self._source_b.rename(columns=source_b_rename)
# Concatenate both aligned sources
merged = pd.concat(
[self._df[TARGET_COLUMNS], source_b_aligned[TARGET_COLUMNS]],
ignore_index=True
).reset_index(drop=True)
# Inject pre-computed dirty issues so grader baseline is correct
dirty_merged = _META_TEMPLATE["dirty_merged"].copy()
self._df = dirty_merged
self._sources_merged = True
self._source_b = None
return (
f"Merged Source A ({n_a} rows) + Source B ({n_b} rows) → "
f"{len(self._df)} rows with canonical schema. "
f"Dataset now has dirty issues to clean: missing values, "
f"mixed country case, mixed date formats, duplicate rows.", True
)
def _fix_dtype(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
dtype = str(p.get("dtype", "float")).lower()
try:
if dtype == "float":
self._df[col] = pd.to_numeric(self._df[col], errors="coerce").astype(float)
elif dtype == "int":
self._df[col] = pd.to_numeric(self._df[col], errors="coerce")
elif dtype == "str":
self._df[col] = self._df[col].astype(str)
else:
return f"Unknown dtype '{dtype}'.", False
return f"Converted '{col}' to {dtype}.", True
except Exception as exc:
return f"dtype conversion failed: {exc}", False