Spaces:
Sleeping
Sleeping
| # scenario_engine.py | |
| from __future__ import annotations | |
| from typing import Dict, List, Any, Tuple, Optional, Iterable | |
| import re, math, ast | |
| import numpy as np | |
| import pandas as pd | |
| # ========= Robust role/column resolver (safe with pandas.Index) ========= | |
| try: | |
| # If you have an external, richer resolver, we will use it automatically. | |
| from column_resolver import resolve_one as _ext_resolve_one, resolve_cols as _ext_resolve_cols # type: ignore | |
| _HAS_EXT_RESOLVER = True | |
| except Exception: | |
| _HAS_EXT_RESOLVER = False | |
| _ROLE_SYNONYMS_FALLBACK = { | |
| "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "provider", | |
| "settlement", "community", "location"], | |
| "community": ["community", "settlement", "reserve", "town", "village", "city", "region", "area"], | |
| "zone": ["zone", "region", "district", "area", "healthzone"], | |
| "specialty": ["specialty", "programme", "program", "service", "discipline", "department"], | |
| "period": ["period", "quarter", "year", "month", "time", "fiscal", "date"], | |
| "city": ["city", "town", "village"], | |
| "lat": ["latitude", "lat"], | |
| "lon": ["longitude", "lon", "lng"], | |
| "population": ["population", "members", "residents", "census"], | |
| "prevalence": ["prevalence", "rate", "risk", "pct", "percentage"], | |
| "volume": ["count", "visits", "clients", "volume", "n", "cases"], | |
| "cost": ["cost", "expense", "spend", "budget", "perclient", "startup"], | |
| "capacity": ["capacity", "throughput", "slots", "dailycapacity", "clientsperday"], | |
| } | |
| def _canon(s: str) -> str: | |
| return re.sub(r"[^a-z0-9]+", "", (s or "").lower()) | |
| def _to_list(x: Iterable | None) -> List: | |
| if x is None: | |
| return [] | |
| try: | |
| return list(x) | |
| except Exception: | |
| return [x] | |
| def resolve_one(want: str, columns: Iterable[str]) -> Optional[str]: | |
| """Return best matching column for a semantic role or exact header. Safe for pandas.Index.""" | |
| cols = _to_list(columns) | |
| if _HAS_EXT_RESOLVER: | |
| try: | |
| return _ext_resolve_one(want, cols) | |
| except Exception: | |
| pass | |
| if not cols: | |
| return None | |
| wcanon = _canon(want) | |
| if not wcanon: | |
| return None | |
| canon_cols = { _canon(c): c for c in cols if isinstance(c, str) } | |
| if wcanon in canon_cols: | |
| return canon_cols[wcanon] | |
| syns = _ROLE_SYNONYMS_FALLBACK.get((want or "").lower(), []) | |
| syns_canon = [_canon(s) for s in syns] | |
| best, score = None, -1 | |
| for c in cols: | |
| if not isinstance(c, str): | |
| continue | |
| cc = _canon(c) | |
| sc = 0 | |
| if wcanon and (cc == wcanon or cc.startswith(wcanon) or wcanon in cc): | |
| sc += 3 | |
| for s in syns_canon: | |
| if not s: | |
| continue | |
| if cc == s: | |
| sc += 5 | |
| elif cc.startswith(s): | |
| sc += 3 | |
| elif s in cc: | |
| sc += 2 | |
| if sc > score: | |
| best, score = c, sc | |
| return best if score >= 2 else None | |
| def resolve_cols(requested: Iterable[str], columns: Iterable[str]) -> List[str]: | |
| """Resolve a list of roles/headers to existing columns, uniquely. Safe for pandas.Index.""" | |
| reqs = _to_list(requested) | |
| cols = _to_list(columns) | |
| if _HAS_EXT_RESOLVER: | |
| try: | |
| return _ext_resolve_cols(reqs, cols) | |
| except Exception: | |
| pass | |
| out, seen = [], set() | |
| for r in reqs: | |
| col = resolve_one(r, cols) | |
| if col and col not in seen: | |
| out.append(col) | |
| seen.add(col) | |
| return out | |
| # ========= Safe expression evaluation (filters/derivations) ========= | |
| _ALLOWED_FUNCS = { | |
| "abs": abs, "round": round, | |
| "sqrt": np.sqrt, "log": np.log, "exp": np.exp, | |
| "min": np.minimum, "max": np.maximum, | |
| "mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum, | |
| "count": lambda x: np.size(x), | |
| "p50": lambda x: np.percentile(x, 50), | |
| "p75": lambda x: np.percentile(x, 75), | |
| "p90": lambda x: np.percentile(x, 90), | |
| "p95": lambda x: np.percentile(x, 95), | |
| } | |
| class _SafeExpr(ast.NodeTransformer): | |
| def __init__(self, allowed): self.allowed = allowed | |
| def visit_Name(self, node): | |
| if node.id not in self.allowed and node.id not in ("True","False","None"): | |
| raise ValueError(f"Unknown name: {node.id}") | |
| return node | |
| def visit_Attribute(self, node): | |
| raise ValueError("Attribute access is not allowed") | |
| def visit_Call(self, node): | |
| if not isinstance(node.func, ast.Name): | |
| raise ValueError("Only simple function calls are allowed") | |
| if node.func.id not in _ALLOWED_FUNCS: | |
| raise ValueError(f"Function not allowed: {node.func.id}") | |
| return self.generic_visit(node) | |
| def generic_visit(self, node): | |
| allowed = ( | |
| ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name, | |
| ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div, | |
| ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, | |
| ast.USub, ast.UAdd | |
| ) | |
| if not isinstance(node, allowed): | |
| raise ValueError(f"Unsupported syntax: {type(node).__name__}") | |
| return super().generic_visit(node) | |
| def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series: | |
| names = set(df.columns) | {"True","False","None"} | |
| tree = ast.parse(expr, mode="eval") | |
| _SafeExpr(names).visit(tree) | |
| code = compile(tree, "<expr>", "eval") | |
| env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS} | |
| val = eval(code, {"__builtins__": {}}, env) | |
| if isinstance(val, (pd.Series, np.ndarray, list)): | |
| return pd.Series(val, index=df.index) | |
| if isinstance(val, (bool, np.bool_)): | |
| return pd.Series([val] * len(df), index=df.index) | |
| raise ValueError("Expression must yield a vector or boolean") | |
| # ========= Helpers ========= | |
| def _as_df(v: Any) -> Optional[pd.DataFrame]: | |
| if isinstance(v, pd.DataFrame): | |
| return v | |
| if isinstance(v, list): | |
| return pd.DataFrame(v) if v else pd.DataFrame() | |
| if isinstance(v, dict): | |
| flat = all(isinstance(val, (int,float,str,bool,type(None))) for val in v.values()) | |
| return pd.DataFrame([v]) if flat else pd.DataFrame() | |
| return None | |
| def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]: | |
| if key and key in datasets: | |
| v = datasets[key] | |
| else: | |
| v = next((vv for vv in datasets.values() if vv is not None), None) | |
| return _as_df(v) if v is not None else None | |
| def _auto_group_cols(df: pd.DataFrame) -> List[str]: | |
| prefs = ["facility","community","settlement","provider","zone","region","district","specialty","program","service","city"] | |
| for p in prefs: | |
| col = resolve_one(p, _to_list(df.columns)) | |
| if col: | |
| return [col] | |
| obj_cols = [c for c in df.columns if df[c].dtype == "object"] | |
| return obj_cols[:1] if obj_cols else [] | |
| def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]: | |
| """ | |
| "mean(wait_days), p90(wait_days), count(*)" -> [("mean_wait_days","mean(wait_days)"), ...] | |
| bare token "wait_days" becomes mean(wait_days) | |
| """ | |
| if not spec: | |
| return [] | |
| out: List[Tuple[str,str]] = [] | |
| for it in [x.strip() for x in spec.split(",") if x.strip()]: | |
| if it.lower() in ("count", "count(*)"): | |
| out.append(("count_*", "count(*)")); continue | |
| m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', it) | |
| if not m: | |
| arg = it | |
| out.append((f"mean_{arg}", f"mean({arg})")) | |
| continue | |
| func, arg = m.group(1).lower(), m.group(2).strip() | |
| out.append((f"{func}_{arg}", f"{func}({arg})")) | |
| return out | |
| def _apply_agg_call(df: pd.DataFrame, call: str): | |
| call = call.strip().lower() | |
| if call in ("count", "count(*)"): | |
| return int(len(df)) | |
| m = re.match(r'([a-z_][a-z0-9_]*)\(([^)]+)\)', call) | |
| if not m: | |
| arg = call | |
| if arg not in df.columns: return None | |
| col = pd.to_numeric(df[arg], errors="coerce").dropna() | |
| return float(col.mean()) if len(col) else float("nan") | |
| func, arg = m.group(1), m.group(2).strip() | |
| if arg not in df.columns: | |
| return None | |
| col = pd.to_numeric(df[arg], errors="coerce").dropna() | |
| if not len(col): | |
| return float("nan") | |
| if func in ("avg","mean"): return float(col.mean()) | |
| if func == "median": return float(np.median(col)) | |
| if func == "sum": return float(col.sum()) | |
| if func in ("min","max"): return float(getattr(np, func)(col)) | |
| if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) | |
| return None | |
| def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame: | |
| m = _eval_series_expr(expr, df) | |
| return df.loc[m.astype(bool)].copy() | |
| def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame: | |
| # supports "newcol = expr, other = expr2" | |
| parts = re.split(r'[;,]\s*', spec) | |
| for p in parts: | |
| if "=" in p: | |
| col, expr = p.split("=", 1) | |
| df[col.strip()] = _eval_series_expr(expr.strip(), df) | |
| return df | |
| def _render_table(df: pd.DataFrame) -> str: | |
| if df is None or df.empty: | |
| return "_No rows._" | |
| dff = df.copy() | |
| for c in dff.columns: | |
| if pd.api.types.is_float_dtype(dff[c]) or pd.api.types.is_integer_dtype(dff[c]): | |
| dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}") | |
| header = "| " + " | ".join(map(str, dff.columns)) + " |" | |
| sep = "|" + "|".join(["---"] * len(dff.columns)) + "|" | |
| rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()] | |
| return "\n".join([header, sep, *rows]) | |
| def _small_n_flags(df: pd.DataFrame, count_col: Optional[str] = None, threshold: int = 5) -> Optional[pd.Series]: | |
| if df is None or df.empty: | |
| return None | |
| if count_col and count_col in df.columns: | |
| return df[count_col].apply(lambda n: " (interpret cautiously: small n)" if pd.notnull(n) and float(n) < threshold else "") | |
| return None | |
| def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]: | |
| notes = [] | |
| for c in metric_cols: | |
| if c in df.columns: | |
| miss = df[c].isna().mean() | |
| if miss > 0: | |
| notes.append(f"{c}: missing {miss:.1%}") | |
| return notes | |
| # ========= Scenario Engine ========= | |
| class ScenarioEngine: | |
| """ | |
| Execute a ScenarioPlan (or dict) consisting of tasks that specify: | |
| - data_key: name of dataset in `datasets` | |
| - filter: boolean/vectorized expression (safe-eval) | |
| - derive: "new = expr, ..." | |
| - group_by: list of roles/column names (resolved dynamically) | |
| - agg: "mean(col), p90(col), count(*)" (bare 'col' => mean(col)) | |
| - sort_by / sort_dir | |
| - top | |
| - fields: project/alias output columns by role/name (resolved dynamically) | |
| Returns markdown with: | |
| - task section | |
| - table output | |
| - Assumptions & Mappings | |
| - Data Quality notes | |
| """ | |
| def _group_agg(df: pd.DataFrame, | |
| group_by: Optional[List[str]], | |
| agg_spec: Optional[str], | |
| mapping_log: List[str]) -> pd.DataFrame: | |
| # Resolve grouping to existing columns; tolerate roles or wrong names | |
| if group_by: | |
| gcols = resolve_cols(group_by, _to_list(df.columns)) | |
| for want in (group_by or []): | |
| got = resolve_one(want, _to_list(df.columns)) | |
| mapping_log.append(f"group_by: {want} → {got if got else '(unresolved)'}") | |
| else: | |
| gcols = _auto_group_cols(df) | |
| if gcols: | |
| mapping_log.append(f"group_by: (auto) → {gcols[0]}") | |
| else: | |
| mapping_log.append("group_by: (auto) → (none)") | |
| aggs = _parse_aggs(agg_spec or "") | |
| # No grouping & no agg => just preview a slice | |
| if not gcols: | |
| if not aggs: | |
| return df.head(50).copy() | |
| rec = { out_col: _apply_agg_call(df, call) for out_col, call in aggs } | |
| return pd.DataFrame([rec]) | |
| if not aggs: | |
| # default: mean of numeric cols + count(*) | |
| num_cols = list(df.select_dtypes(include="number").columns) | |
| gb = df.groupby(gcols, dropna=False) | |
| if not num_cols: | |
| out = gb.size().reset_index(name="count_*") | |
| return out.sort_values("count_*", ascending=False) | |
| out = gb[num_cols].mean(numeric_only=True) | |
| out["count_*"] = gb.size() | |
| return out.reset_index() | |
| # Apply requested aggs | |
| rows = [] | |
| gb = df.groupby(gcols, dropna=False) | |
| for keys, g in gb: | |
| if not isinstance(keys, tuple): keys = (keys,) | |
| rec = { gcols[i]: keys[i] for i in range(len(gcols)) } | |
| for out_col, call in aggs: | |
| rec[out_col] = _apply_agg_call(g, call) | |
| rows.append(rec) | |
| return pd.DataFrame(rows) | |
| def _project_fields(out_df: pd.DataFrame, | |
| fields: Optional[List[str]], | |
| mapping_log: List[str]) -> pd.DataFrame: | |
| if not isinstance(out_df, pd.DataFrame) or out_df.empty or not fields: | |
| return out_df | |
| cols = resolve_cols(fields, _to_list(out_df.columns)) | |
| for want in fields: | |
| got = resolve_one(want, _to_list(out_df.columns)) | |
| mapping_log.append(f"field: {want} → {got if got else '(unresolved)'}") | |
| if cols: | |
| return out_df[cols] | |
| return out_df | |
| def _data_quality_notes(out_df: pd.DataFrame) -> List[str]: | |
| notes: List[str] = [] | |
| if out_df is None or out_df.empty: | |
| return notes | |
| # small-n flag if a count column exists | |
| cnt_col = None | |
| for c in out_df.columns: | |
| if c.lower() in ("count", "count_*", "n", "records"): | |
| cnt_col = c; break | |
| sn = _small_n_flags(out_df, count_col=cnt_col, threshold=5) | |
| if sn is not None and sn.any(): | |
| n_small = (sn != "").sum() | |
| if n_small > 0: | |
| notes.append(f"{n_small} row(s) flagged as small-n (interpret cautiously).") | |
| # missingness for numeric columns | |
| metric_cols = [c for c in out_df.columns if pd.api.types.is_numeric_dtype(out_df[c])] | |
| notes.extend(_missingness(out_df, metric_cols)) | |
| return notes | |
| def _exec_task(t: Any, datasets: Dict[str, Any]) -> str: | |
| title = getattr(t, "title", None) or (isinstance(t, dict) and t.get("title")) or "Task" | |
| section_lines: List[str] = [f"## {title}\n"] | |
| data_key = getattr(t, "data_key", None) or (isinstance(t, dict) and t.get("data_key")) | |
| df = _get_df(datasets, data_key) | |
| if df is None or df.empty: | |
| section_lines.append("_No matching data for this task._") | |
| return "\n".join(section_lines) | |
| # Filter(s) | |
| t_filter = getattr(t, "filter", None) or (isinstance(t, dict) and t.get("filter")) | |
| if t_filter: | |
| try: | |
| df = _apply_filter(df, t_filter) | |
| except Exception as e: | |
| section_lines.append(f"_Warning: filter ignored ({e})._") | |
| # Derive(s) | |
| t_derive = getattr(t, "derive", None) or (isinstance(t, dict) and t.get("derive")) | |
| if t_derive: | |
| for d in (t_derive if isinstance(t_derive, (list, tuple)) else [t_derive]): | |
| try: | |
| df = _apply_derive(df, d) | |
| except Exception as e: | |
| section_lines.append(f"_Warning: derive ignored ({e})._") | |
| # Group/Agg | |
| t_group_by = getattr(t, "group_by", None) or (isinstance(t, dict) and t.get("group_by")) | |
| if isinstance(t_group_by, str): | |
| t_group_by = [t_group_by] | |
| t_agg = getattr(t, "agg", None) or (isinstance(t, dict) and t.get("agg")) | |
| agg_spec = ", ".join(t_agg) if isinstance(t_agg, list) else (t_agg or None) | |
| mapping_log: List[str] = [] | |
| out_df = ScenarioEngine._group_agg(df, t_group_by, agg_spec, mapping_log) | |
| # Sort / Top | |
| t_sort_by = getattr(t, "sort_by", None) or (isinstance(t, dict) and t.get("sort_by")) | |
| t_sort_dir = (getattr(t, "sort_dir", None) or (isinstance(t, dict) and t.get("sort_dir")) or "desc").lower() | |
| if isinstance(out_df, pd.DataFrame) and t_sort_by and t_sort_by in out_df.columns: | |
| out_df = out_df.sort_values(t_sort_by, ascending=(t_sort_dir=="asc")) | |
| t_top = getattr(t, "top", None) or (isinstance(t, dict) and t.get("top")) | |
| if isinstance(out_df, pd.DataFrame) and isinstance(t_top, int) and t_top > 0: | |
| out_df = out_df.head(t_top) | |
| # Field projection | |
| t_fields = getattr(t, "fields", None) or (isinstance(t, dict) and t.get("fields")) | |
| if isinstance(t_fields, str): | |
| t_fields = [t_fields] | |
| out_df = ScenarioEngine._project_fields(out_df, t_fields, mapping_log) | |
| # Render | |
| section_lines.append(_render_table(out_df)) | |
| # Assumptions & Mappings | |
| if mapping_log: | |
| section_lines.append("\n**Assumptions & Mappings**") | |
| for line in mapping_log: | |
| section_lines.append(f"- {line}") | |
| # Data quality | |
| dq = ScenarioEngine._data_quality_notes(out_df) | |
| if dq: | |
| section_lines.append("\n**Data Quality Notes**") | |
| for n in dq: | |
| section_lines.append(f"- {n}") | |
| return "\n".join(section_lines) | |
| def execute_plan(plan: Any, datasets: Dict[str, Any]) -> str: | |
| """ | |
| plan: object or dict with `tasks: List[Task]` | |
| Each Task can have: title, data_key, filter, derive, group_by, agg, sort_by, sort_dir, top, fields | |
| """ | |
| sections: List[str] = ["# Scenario Output\n"] | |
| tasks = getattr(plan, "tasks", None) or (isinstance(plan, dict) and plan.get("tasks")) or [] | |
| for t in tasks: | |
| sections.append(ScenarioEngine._exec_task(t, datasets)) | |
| return "\n".join(sections).strip() | |