Spaces:
Running
Running
| """ | |
| Core environment implementing reset / step / state. | |
| Each call to reset() picks a task (round-robin: 1 -> 2 -> 3 -> 1 ...) | |
| or a specific task_id can be forced via reset(task_id=N). | |
| Phase 2 additions: | |
| - DataQualityMetrics computed every step (completeness, uniqueness, validity) | |
| - tried_operations: deduplication log so agent avoids repeating useless ops | |
| - plan: rule-based next-action recommendations surfaced in every observation | |
| - Episode history tracked for /report endpoint | |
| """ | |
| import re | |
| import uuid | |
| import numpy as np | |
| import pandas as pd | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from models import ( | |
| DataCleaningAction, DataCleaningObservation, | |
| DataCleaningState, DataQualityMetrics, EpisodeReport, | |
| ) | |
| import server.tasks.task1_missing as t1 | |
| import server.tasks.task2_format as t2 | |
| import server.tasks.task3_pipeline as t3 | |
| import server.tasks.task4_merge as t4 | |
| TASK_MODULES = {1: t1, 2: t2, 3: t3, 4: t4} | |
| TASK_NAMES = { | |
| 1: "Fill Missing Values", | |
| 2: "Fix Formats + Remove Duplicates", | |
| 3: "Full Cleaning Pipeline", | |
| 4: "Multi-Source Schema Alignment + Merge", | |
| } | |
| PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$") | |
| DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") | |
| VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"} | |
| class DataCleaningEnvironment: | |
| def __init__(self): | |
| self._df: Optional[pd.DataFrame] = None | |
| self._clean_df: Optional[pd.DataFrame] = None | |
| self._meta: Any = None | |
| self._task_id: int = 1 | |
| self._episode_id: str = "" | |
| self._step_count: int = 0 | |
| self._max_steps: int = 20 | |
| self._total_errors: int = 0 | |
| self._last_score: float = 0.01 | |
| self._initial_score: float = 0.01 | |
| self._task_cycle: int = 0 | |
| # Phase 2 tracking | |
| self._tried_operations: List[str] = [] | |
| self._operations_log: List[str] = [] | |
| self._issues_fixed: Dict[str, int] = {} | |
| self._initial_dq: Optional[DataQualityMetrics] = None | |
| # Task 4 state | |
| self._source_b: Optional[pd.DataFrame] = None # held until merge_sources called | |
| self._schema_aligned: bool = False | |
| self._sources_merged: bool = False | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def reset(self, task_id: Optional[int] = None) -> DataCleaningObservation: | |
| if task_id is None: | |
| self._task_cycle = (self._task_cycle % 3) + 1 | |
| task_id = self._task_cycle | |
| if task_id not in TASK_MODULES: | |
| raise ValueError(f"task_id must be 1, 2, 3, or 4 — got {task_id}") | |
| mod = TASK_MODULES[task_id] | |
| self._task_id = task_id | |
| self._episode_id = str(uuid.uuid4()) | |
| self._step_count = 0 | |
| self._max_steps = mod.MAX_STEPS | |
| # Task 4 returns 4 values; others return 3 | |
| if task_id == 4: | |
| self._df, self._source_b, self._clean_df, self._meta = mod.load() | |
| self._schema_aligned = False | |
| self._sources_merged = False | |
| else: | |
| self._df, self._clean_df, self._meta = mod.load() | |
| self._source_b = None | |
| self._schema_aligned = False | |
| self._sources_merged = False | |
| self._last_score = self._compute_score() | |
| self._initial_score = self._last_score | |
| self._total_errors = self._count_errors() | |
| # Reset Phase 2 state | |
| self._tried_operations = [] | |
| self._operations_log = [] | |
| self._issues_fixed = {"nulls_filled": 0, "dupes_removed": 0, | |
| "formats_fixed": 0, "outliers_removed": 0} | |
| self._initial_dq = self._compute_dq_metrics() | |
| return self._build_obs(self._last_score, False, "Episode started. Begin cleaning.") | |
| def step(self, action: DataCleaningAction) -> DataCleaningObservation: | |
| if self._df is None: | |
| raise RuntimeError("Call reset() before step().") | |
| self._step_count += 1 | |
| score_before = self._last_score | |
| # Track tried operations BEFORE applying (for feedback loop) | |
| op_key = self._make_op_key(action) | |
| message, applied = self._apply_action(action) | |
| score_after = self._compute_score() | |
| self._last_score = score_after | |
| delta = score_after - score_before | |
| if not applied: | |
| reward = -0.01 | |
| elif delta <= 0: | |
| reward = -0.01 | |
| else: | |
| reward = round(delta, 4) | |
| # Log successful operation | |
| if op_key not in self._tried_operations: | |
| self._tried_operations.append(op_key) | |
| self._operations_log.append(message) | |
| self._update_issues_fixed(action, message) | |
| done = (score_after >= 0.95) or (self._step_count >= self._max_steps) | |
| reward = round(max(-0.05, min(0.99, reward)), 4) | |
| return self._build_obs(reward, done, message) | |
| def state(self) -> DataCleaningState: | |
| if self._df is None: | |
| return DataCleaningState( | |
| episode_id="", task_id=0, step_count=0, | |
| max_steps=0, total_errors=0, errors_remaining=0, | |
| ) | |
| return DataCleaningState( | |
| episode_id = self._episode_id, | |
| task_id = self._task_id, | |
| step_count = self._step_count, | |
| max_steps = self._max_steps, | |
| total_errors = self._total_errors, | |
| errors_remaining = self._count_errors(), | |
| ) | |
| def get_profile(self) -> Dict[str, Any]: | |
| """Rich data profile for GET /profile endpoint.""" | |
| if self._df is None: | |
| return {} | |
| dq = self._compute_dq_metrics() | |
| profile: Dict[str, Any] = { | |
| "episode_id": self._episode_id, | |
| "task_id": self._task_id, | |
| "shape": {"rows": self._df.shape[0], "cols": self._df.shape[1]}, | |
| "dq_metrics": dq.model_dump(), | |
| "columns": {}, | |
| } | |
| for col in self._df.columns: | |
| series = self._df[col] | |
| col_info: Dict[str, Any] = { | |
| "dtype": str(series.dtype), | |
| "null_count": int(series.isnull().sum()), | |
| "null_pct": round(series.isnull().mean() * 100, 2), | |
| "unique_count": int(series.nunique(dropna=True)), | |
| "unique_pct": round(series.nunique(dropna=True) / max(len(series), 1) * 100, 2), | |
| } | |
| if pd.api.types.is_numeric_dtype(series): | |
| desc = series.describe() | |
| col_info.update({ | |
| "min": round(float(desc["min"]), 4) if pd.notna(desc["min"]) else None, | |
| "max": round(float(desc["max"]), 4) if pd.notna(desc["max"]) else None, | |
| "mean": round(float(desc["mean"]), 4) if pd.notna(desc["mean"]) else None, | |
| "median": round(float(series.median()), 4) if pd.notna(series.median()) else None, | |
| "std": round(float(desc["std"]), 4) if pd.notna(desc.get("std", float("nan"))) else None, | |
| }) | |
| else: | |
| top = series.value_counts(dropna=True).head(3).to_dict() | |
| col_info["top_values"] = {str(k): int(v) for k, v in top.items()} | |
| profile["columns"][col] = col_info | |
| return profile | |
| def get_report(self) -> EpisodeReport: | |
| """Full episode cleaning summary for GET /report endpoint.""" | |
| if self._df is None: | |
| raise RuntimeError("No active episode.") | |
| steps_used = self._step_count | |
| efficiency = round((1 - steps_used / max(self._max_steps, 1)) * 100, 1) | |
| return EpisodeReport( | |
| episode_id = self._episode_id, | |
| task_id = self._task_id, | |
| task_name = TASK_NAMES.get(self._task_id, f"Task {self._task_id}"), | |
| initial_score = self._initial_score, | |
| final_score = self._last_score, | |
| score_improvement = round(self._last_score - self._initial_score, 4), | |
| steps_taken = steps_used, | |
| max_steps = self._max_steps, | |
| step_efficiency_pct = max(0.0, efficiency), | |
| operations_applied = list(self._operations_log), | |
| issues_fixed = dict(self._issues_fixed), | |
| final_dq_metrics = self._compute_dq_metrics(), | |
| completed = self._last_score >= 0.95, | |
| ) | |
| def get_export(self) -> str: | |
| """Return current cleaned DataFrame as CSV string for GET /export.""" | |
| if self._df is None: | |
| raise RuntimeError("No active episode.") | |
| return self._df.to_csv(index=False) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _make_op_key(self, action: DataCleaningAction) -> str: | |
| if action.column: | |
| return f"{action.operation}:{action.column}" | |
| return action.operation | |
| def _update_issues_fixed(self, action: DataCleaningAction, message: str) -> None: | |
| op = action.operation.lower() | |
| # Parse numbers from message e.g. "Filled 20 missing values..." | |
| nums = re.findall(r"\d+", message) | |
| n = int(nums[0]) if nums else 1 | |
| if op == "fill_missing": | |
| self._issues_fixed["nulls_filled"] = self._issues_fixed.get("nulls_filled", 0) + n | |
| elif op == "drop_duplicates": | |
| self._issues_fixed["dupes_removed"] = self._issues_fixed.get("dupes_removed", 0) + n | |
| elif op == "fix_format": | |
| self._issues_fixed["formats_fixed"] = self._issues_fixed.get("formats_fixed", 0) + n | |
| elif op == "drop_outliers": | |
| self._issues_fixed["outliers_removed"] = self._issues_fixed.get("outliers_removed", 0) + n | |
| def _compute_dq_metrics(self) -> DataQualityMetrics: | |
| total_cells = int(self._df.size) | |
| null_cells = int(self._df.isnull().sum().sum()) | |
| duplicate_rows = int(len(self._df) - len(self._df.drop_duplicates())) | |
| invalid_cells = self._count_invalid_cells() | |
| completeness = round((1 - null_cells / max(total_cells, 1)) * 100, 2) | |
| uniqueness = round((1 - duplicate_rows / max(len(self._df), 1)) * 100, 2) | |
| validity = round((1 - invalid_cells / max(total_cells, 1)) * 100, 2) | |
| return DataQualityMetrics( | |
| completeness_pct = completeness, | |
| uniqueness_pct = uniqueness, | |
| validity_pct = validity, | |
| total_cells = total_cells, | |
| null_cells = null_cells, | |
| duplicate_rows = duplicate_rows, | |
| invalid_cells = invalid_cells, | |
| ) | |
| def _count_invalid_cells(self) -> int: | |
| """Count cells with format/dtype/range violations.""" | |
| invalid = 0 | |
| for col in self._df.columns: | |
| series = self._df[col].dropna() | |
| if col == "phone": | |
| invalid += int((~series.astype(str).str.match(PHONE_RE, na=False)).sum()) | |
| elif col in ("listed_date", "signup_date"): | |
| invalid += int((~series.apply( | |
| lambda x: bool(DATE_RE.match(str(x))) | |
| )).sum()) | |
| elif col == "country": | |
| invalid += int((~series.isin(VALID_COUNTRIES)).sum()) | |
| elif col == "age": | |
| invalid += int(((series < 0) | (series > 120)).sum()) | |
| elif col == "salary": | |
| invalid += int((series < 0).sum()) | |
| elif col == "purchase_amount": | |
| invalid += int((series < 0).sum()) | |
| return invalid | |
| def _generate_plan(self) -> List[str]: | |
| """ | |
| Rule-based planning engine — inspects current DataFrame state | |
| and returns up to 3 prioritised recommended actions. | |
| Inspired by AutoDCWorkflow (EMNLP 2025). | |
| """ | |
| plan: List[str] = [] | |
| if self._df is None: | |
| return plan | |
| # Task 4: schema alignment + merge must happen first | |
| if self._task_id == 4: | |
| if not self._schema_aligned: | |
| return ["align_schema — rename Source A columns to canonical schema (required first step)"] | |
| if not self._sources_merged: | |
| return ["merge_sources — concatenate aligned Source A + Source B (required before cleaning)"] | |
| missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0} | |
| dupes = len(self._df) - len(self._df.drop_duplicates()) | |
| # Priority 1: missing values (highest DQ impact) | |
| for col, count in sorted(missing.items(), key=lambda x: -x[1]): | |
| op_key = f"fill_missing:{col}" | |
| if op_key not in self._tried_operations: | |
| strategy = "mode" if self._df[col].dtype == object else "median" | |
| plan.append( | |
| f'fill_missing on "{col}" ({count} nulls) using {strategy}' | |
| ) | |
| if len(plan) >= 2: | |
| break | |
| # Priority 2: duplicates | |
| if dupes > 0 and "drop_duplicates" not in self._tried_operations: | |
| plan.append(f"drop_duplicates ({dupes} duplicate rows found)") | |
| # Priority 3: format issues | |
| for col in self._df.columns: | |
| if len(plan) >= 3: | |
| break | |
| op_key = f"fix_format:{col}" | |
| if op_key in self._tried_operations: | |
| continue | |
| if col == "phone": | |
| bad = int((~self._df[col].dropna().astype(str).str.match(PHONE_RE)).sum()) | |
| if bad > 0: | |
| plan.append(f'fix_format on "phone" ({bad} malformed numbers)') | |
| elif col in ("listed_date", "signup_date"): | |
| bad = int((~self._df[col].dropna().apply( | |
| lambda x: bool(DATE_RE.match(str(x))) | |
| )).sum()) | |
| if bad > 0: | |
| plan.append(f'fix_format on "{col}" ({bad} malformed dates)') | |
| elif col == "country": | |
| bad = int((~self._df[col].dropna().isin(VALID_COUNTRIES)).sum()) | |
| if bad > 0: | |
| plan.append(f'fix_format on "country" ({bad} invalid values)') | |
| # Priority 4: outliers on numeric cols | |
| if len(plan) < 3: | |
| for col in self._df.select_dtypes(include=[np.number]).columns: | |
| op_key = f"drop_outliers:{col}" | |
| if op_key in self._tried_operations: | |
| continue | |
| q1, q3 = self._df[col].quantile(0.25), self._df[col].quantile(0.75) | |
| iqr = q3 - q1 | |
| outliers = int((self._df[col] > q3 + 3 * iqr).sum()) | |
| if outliers > 0: | |
| plan.append(f'drop_outliers on "{col}" ({outliers} extreme values)') | |
| break | |
| return plan[:3] | |
| def _compute_score(self) -> float: | |
| if self._task_id == 1: | |
| raw = t1.score(self._df, self._meta) | |
| elif self._task_id == 2: | |
| raw = t2.score(self._df, self._meta) | |
| elif self._task_id == 3: | |
| raw = t3.score(self._df, self._meta) | |
| else: | |
| raw = t4.score(self._df, self._meta) | |
| raw = float(raw) | |
| EPS = 1e-4 | |
| if raw >= 1.0: | |
| raw = 1.0 - EPS | |
| elif raw <= 0.0: | |
| raw = EPS | |
| return round(raw, 4) | |
| def _count_errors(self) -> int: | |
| if self._task_id == 1: | |
| return t1.count_errors(self._df) | |
| elif self._task_id == 2: | |
| return t2.count_errors(self._df, self._meta) | |
| elif self._task_id == 3: | |
| return t3.count_errors(self._df, self._meta) | |
| else: | |
| return t4.count_errors(self._df, self._meta) | |
| def _build_obs(self, reward: float, done: bool, message: str) -> DataCleaningObservation: | |
| mod = TASK_MODULES[self._task_id] | |
| missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0} | |
| dupes = len(self._df) - len(self._df.drop_duplicates()) | |
| dtype_issues = self._detect_dtype_issues() | |
| preview = self._df.head(10).to_csv(index=False) | |
| dq_metrics = self._compute_dq_metrics() | |
| plan = self._generate_plan() | |
| return DataCleaningObservation( | |
| done = done, | |
| reward = reward, | |
| data_preview = preview, | |
| data_shape = list(self._df.shape), | |
| missing_counts = missing, | |
| duplicate_count = dupes, | |
| dtype_issues = dtype_issues, | |
| task_description = mod.DESCRIPTION, | |
| message = message, | |
| step_count = self._step_count, | |
| current_score = self._last_score, | |
| dq_metrics = dq_metrics, | |
| tried_operations = list(self._tried_operations), | |
| plan = plan, | |
| ) | |
| def _detect_dtype_issues(self) -> Dict[str, str]: | |
| issues: Dict[str, str] = {} | |
| for col in self._df.columns: | |
| series = self._df[col].dropna() | |
| if series.empty: | |
| continue | |
| if self._df[col].dtype == object: | |
| numeric_count = pd.to_numeric(series, errors="coerce").notna().sum() | |
| if numeric_count / len(series) > 0.8: | |
| issues[col] = "stored as string but appears numeric" | |
| return issues | |
| # ------------------------------------------------------------------ | |
| # Action dispatcher | |
| # ------------------------------------------------------------------ | |
| def _apply_action(self, action: DataCleaningAction) -> Tuple[str, bool]: | |
| op = action.operation.strip().lower() | |
| col = action.column | |
| p = action.params or {} | |
| try: | |
| if op == "fill_missing": | |
| return self._fill_missing(col, p) | |
| elif op == "drop_duplicates": | |
| return self._drop_duplicates() | |
| elif op == "fix_format": | |
| return self._fix_format(col) | |
| elif op == "replace_value": | |
| return self._replace_value(col, p) | |
| elif op == "drop_outliers": | |
| return self._drop_outliers(col) | |
| elif op == "fix_dtype": | |
| return self._fix_dtype(col, p) | |
| elif op == "align_schema": | |
| return self._align_schema() | |
| elif op == "merge_sources": | |
| return self._merge_sources() | |
| else: | |
| return ( | |
| f"Unknown operation '{op}'. Choose from: fill_missing, " | |
| "drop_duplicates, fix_format, replace_value, drop_outliers, " | |
| "fix_dtype, align_schema, merge_sources.", | |
| False, | |
| ) | |
| except Exception as exc: | |
| return f"Operation failed: {exc}", False | |
| def _fill_missing(self, col, p) -> Tuple[str, bool]: | |
| if col is None or col not in self._df.columns: | |
| return f"Column '{col}' not found.", False | |
| n_before = int(self._df[col].isnull().sum()) | |
| if n_before == 0: | |
| return f"No missing values in '{col}'.", False | |
| strategy = str(p.get("strategy", "median")).lower() | |
| if strategy == "median": | |
| fill_val = self._df[col].median(skipna=True) | |
| elif strategy == "mean": | |
| fill_val = self._df[col].mean(skipna=True) | |
| elif strategy == "mode": | |
| mode = self._df[col].mode(dropna=True) | |
| fill_val = mode.iloc[0] if not mode.empty else None | |
| elif strategy == "constant": | |
| fill_val = p.get("value") | |
| else: | |
| return f"Unknown strategy '{strategy}'.", False | |
| if fill_val is None: | |
| return "Could not determine fill value.", False | |
| self._df[col] = self._df[col].fillna(fill_val) | |
| n_after = int(self._df[col].isnull().sum()) | |
| return f"Filled {n_before - n_after} missing values in '{col}' using {strategy}.", True | |
| def _drop_duplicates(self) -> Tuple[str, bool]: | |
| n_before = len(self._df) | |
| self._df = self._df.drop_duplicates().reset_index(drop=True) | |
| removed = n_before - len(self._df) | |
| if removed == 0: | |
| return "No duplicate rows found.", False | |
| return f"Dropped {removed} duplicate rows.", True | |
| def _fix_format(self, col) -> Tuple[str, bool]: | |
| if col is None or col not in self._df.columns: | |
| return f"Column '{col}' not found.", False | |
| if col == "phone": | |
| return self._fix_phone(col) | |
| elif col in ("listed_date", "signup_date"): | |
| return self._fix_date(col) | |
| elif col == "country": | |
| return self._fix_country(col) | |
| else: | |
| return f"No format rule defined for column '{col}'.", False | |
| def _fix_phone(self, col) -> Tuple[str, bool]: | |
| def normalise(val): | |
| if pd.isna(val): | |
| return val | |
| digits = re.sub(r"\D", "", str(val)) | |
| if len(digits) == 10: | |
| return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}" | |
| return val | |
| before = (~self._df[col].str.match(PHONE_RE, na=False)).sum() | |
| self._df[col] = self._df[col].apply(normalise) | |
| after = (~self._df[col].str.match(PHONE_RE, na=False)).sum() | |
| fixed = int(before - after) | |
| if fixed == 0: | |
| return f"No phone format issues found in '{col}'.", False | |
| return f"Fixed {fixed} phone numbers in '{col}' to NNN-NNN-NNNN format.", True | |
| def _fix_date(self, col) -> Tuple[str, bool]: | |
| _DATE_FORMATS = ["%Y-%m-%d", "%b %d %Y", "%d/%m/%Y", "%m/%d/%Y", "%Y/%m/%d"] | |
| def normalise(val): | |
| if pd.isna(val): | |
| return val | |
| s = str(val).strip() | |
| for fmt in _DATE_FORMATS: | |
| try: | |
| return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d") | |
| except Exception: | |
| pass | |
| try: | |
| return pd.to_datetime(s).strftime("%Y-%m-%d") | |
| except Exception: | |
| return val | |
| before = (~self._df[col].apply( | |
| lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False | |
| )).sum() | |
| self._df[col] = self._df[col].apply(normalise) | |
| after = (~self._df[col].apply( | |
| lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False | |
| )).sum() | |
| fixed = int(before - after) | |
| if fixed == 0: | |
| return f"No date format issues found in '{col}'.", False | |
| return f"Fixed {fixed} dates in '{col}' to YYYY-MM-DD format.", True | |
| def _fix_country(self, col) -> Tuple[str, bool]: | |
| def normalise(val): | |
| if pd.isna(val): | |
| return val | |
| mapping = { | |
| "usa": "USA", "uk": "UK", "canada": "Canada", | |
| "australia": "Australia", "germany": "Germany", | |
| } | |
| return mapping.get(str(val).strip().lower(), val) | |
| before = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum() | |
| self._df[col] = self._df[col].apply(normalise) | |
| after = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum() | |
| fixed = int(before - after) | |
| if fixed == 0: | |
| return "No country capitalisation issues found.", False | |
| return f"Fixed {fixed} country values to correct capitalisation.", True | |
| def _replace_value(self, col, p) -> Tuple[str, bool]: | |
| if col is None or col not in self._df.columns: | |
| return f"Column '{col}' not found.", False | |
| old = p.get("old") | |
| new = p.get("new") | |
| if old is None: | |
| return "params.old is required for replace_value.", False | |
| count = int((self._df[col] == old).sum()) | |
| if count == 0: | |
| return f"Value '{old}' not found in '{col}'.", False | |
| self._df[col] = self._df[col].replace(old, new) | |
| return f"Replaced {count} occurrences of '{old}' with '{new}' in '{col}'.", True | |
| def _drop_outliers(self, col) -> Tuple[str, bool]: | |
| if col is None or col not in self._df.columns: | |
| return f"Column '{col}' not found.", False | |
| if not pd.api.types.is_numeric_dtype(self._df[col]): | |
| return f"'{col}' is not numeric.", False | |
| q1 = self._df[col].quantile(0.25) | |
| q3 = self._df[col].quantile(0.75) | |
| iqr = q3 - q1 | |
| mask = (self._df[col] >= q1 - 3 * iqr) & (self._df[col] <= q3 + 3 * iqr) | |
| n_before = len(self._df) | |
| self._df = self._df[mask | self._df[col].isna()].reset_index(drop=True) | |
| removed = n_before - len(self._df) | |
| if removed == 0: | |
| return f"No outliers found in '{col}'.", False | |
| return f"Removed {removed} outlier rows from '{col}' using IQR method.", True | |
| def _align_schema(self) -> Tuple[str, bool]: | |
| """Rename Source A columns to canonical target schema (Task 4 only).""" | |
| if self._task_id != 4: | |
| return "align_schema is only available in Task 4.", False | |
| if self._schema_aligned: | |
| return "Schema already aligned.", False | |
| from server.tasks.task4_merge import SOURCE_A_RENAME, TARGET_COLUMNS | |
| missing_src = [c for c in SOURCE_A_RENAME if c not in self._df.columns] | |
| if missing_src: | |
| return f"Expected Source A columns not found: {missing_src}.", False | |
| self._df = self._df.rename(columns=SOURCE_A_RENAME) | |
| self._schema_aligned = True | |
| renamed = list(SOURCE_A_RENAME.keys()) | |
| return ( | |
| f"Aligned Source A schema: renamed {len(SOURCE_A_RENAME)} columns " | |
| f"({', '.join(renamed)}) to canonical target schema.", True | |
| ) | |
| def _merge_sources(self) -> Tuple[str, bool]: | |
| """Concatenate aligned Source A with Source B (Task 4 only).""" | |
| if self._task_id != 4: | |
| return "merge_sources is only available in Task 4.", False | |
| if self._sources_merged: | |
| return "Sources already merged.", False | |
| if not self._schema_aligned: | |
| return "Run align_schema before merge_sources.", False | |
| if self._source_b is None: | |
| return "Source B not available.", False | |
| from server.tasks.task4_merge import TARGET_COLUMNS, _META_TEMPLATE | |
| n_a = len(self._df) | |
| n_b = len(self._source_b) | |
| # Rename source_b columns to canonical schema | |
| source_b_rename = { | |
| "age_years": "age", | |
| "spend": "purchase_amount", | |
| "country_name": "country", | |
| "registration_date": "signup_date", | |
| } | |
| source_b_aligned = self._source_b.rename(columns=source_b_rename) | |
| # Concatenate both aligned sources | |
| merged = pd.concat( | |
| [self._df[TARGET_COLUMNS], source_b_aligned[TARGET_COLUMNS]], | |
| ignore_index=True | |
| ).reset_index(drop=True) | |
| # Inject pre-computed dirty issues so grader baseline is correct | |
| dirty_merged = _META_TEMPLATE["dirty_merged"].copy() | |
| self._df = dirty_merged | |
| self._sources_merged = True | |
| self._source_b = None | |
| return ( | |
| f"Merged Source A ({n_a} rows) + Source B ({n_b} rows) → " | |
| f"{len(self._df)} rows with canonical schema. " | |
| f"Dataset now has dirty issues to clean: missing values, " | |
| f"mixed country case, mixed date formats, duplicate rows.", True | |
| ) | |
| def _fix_dtype(self, col, p) -> Tuple[str, bool]: | |
| if col is None or col not in self._df.columns: | |
| return f"Column '{col}' not found.", False | |
| dtype = str(p.get("dtype", "float")).lower() | |
| try: | |
| if dtype == "float": | |
| self._df[col] = pd.to_numeric(self._df[col], errors="coerce").astype(float) | |
| elif dtype == "int": | |
| self._df[col] = pd.to_numeric(self._df[col], errors="coerce") | |
| elif dtype == "str": | |
| self._df[col] = self._df[col].astype(str) | |
| else: | |
| return f"Unknown dtype '{dtype}'.", False | |
| return f"Converted '{col}' to {dtype}.", True | |
| except Exception as exc: | |
| return f"dtype conversion failed: {exc}", False |