Medica_DecisionSupportAI / scenario_engine.py
VED-AGI-1's picture
Update scenario_engine.py
eb5677d verified
# scenario_engine.py
from __future__ import annotations
from typing import Dict, List, Any, Tuple, Optional, Iterable
import re, math, ast
import numpy as np
import pandas as pd
# ========= Robust role/column resolver (safe with pandas.Index) =========
try:
# If you have an external, richer resolver, we will use it automatically.
from column_resolver import resolve_one as _ext_resolve_one, resolve_cols as _ext_resolve_cols # type: ignore
_HAS_EXT_RESOLVER = True
except Exception:
_HAS_EXT_RESOLVER = False
_ROLE_SYNONYMS_FALLBACK = {
"facility": ["facility", "hospital", "centre", "center", "clinic", "site", "provider",
"settlement", "community", "location"],
"community": ["community", "settlement", "reserve", "town", "village", "city", "region", "area"],
"zone": ["zone", "region", "district", "area", "healthzone"],
"specialty": ["specialty", "programme", "program", "service", "discipline", "department"],
"period": ["period", "quarter", "year", "month", "time", "fiscal", "date"],
"city": ["city", "town", "village"],
"lat": ["latitude", "lat"],
"lon": ["longitude", "lon", "lng"],
"population": ["population", "members", "residents", "census"],
"prevalence": ["prevalence", "rate", "risk", "pct", "percentage"],
"volume": ["count", "visits", "clients", "volume", "n", "cases"],
"cost": ["cost", "expense", "spend", "budget", "perclient", "startup"],
"capacity": ["capacity", "throughput", "slots", "dailycapacity", "clientsperday"],
}
def _canon(s: str) -> str:
return re.sub(r"[^a-z0-9]+", "", (s or "").lower())
def _to_list(x: Iterable | None) -> List:
if x is None:
return []
try:
return list(x)
except Exception:
return [x]
def resolve_one(want: str, columns: Iterable[str]) -> Optional[str]:
"""Return best matching column for a semantic role or exact header. Safe for pandas.Index."""
cols = _to_list(columns)
if _HAS_EXT_RESOLVER:
try:
return _ext_resolve_one(want, cols)
except Exception:
pass
if not cols:
return None
wcanon = _canon(want)
if not wcanon:
return None
canon_cols = { _canon(c): c for c in cols if isinstance(c, str) }
if wcanon in canon_cols:
return canon_cols[wcanon]
syns = _ROLE_SYNONYMS_FALLBACK.get((want or "").lower(), [])
syns_canon = [_canon(s) for s in syns]
best, score = None, -1
for c in cols:
if not isinstance(c, str):
continue
cc = _canon(c)
sc = 0
if wcanon and (cc == wcanon or cc.startswith(wcanon) or wcanon in cc):
sc += 3
for s in syns_canon:
if not s:
continue
if cc == s:
sc += 5
elif cc.startswith(s):
sc += 3
elif s in cc:
sc += 2
if sc > score:
best, score = c, sc
return best if score >= 2 else None
def resolve_cols(requested: Iterable[str], columns: Iterable[str]) -> List[str]:
"""Resolve a list of roles/headers to existing columns, uniquely. Safe for pandas.Index."""
reqs = _to_list(requested)
cols = _to_list(columns)
if _HAS_EXT_RESOLVER:
try:
return _ext_resolve_cols(reqs, cols)
except Exception:
pass
out, seen = [], set()
for r in reqs:
col = resolve_one(r, cols)
if col and col not in seen:
out.append(col)
seen.add(col)
return out
# ========= Safe expression evaluation (filters/derivations) =========
_ALLOWED_FUNCS = {
"abs": abs, "round": round,
"sqrt": np.sqrt, "log": np.log, "exp": np.exp,
"min": np.minimum, "max": np.maximum,
"mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum,
"count": lambda x: np.size(x),
"p50": lambda x: np.percentile(x, 50),
"p75": lambda x: np.percentile(x, 75),
"p90": lambda x: np.percentile(x, 90),
"p95": lambda x: np.percentile(x, 95),
}
class _SafeExpr(ast.NodeTransformer):
def __init__(self, allowed): self.allowed = allowed
def visit_Name(self, node):
if node.id not in self.allowed and node.id not in ("True","False","None"):
raise ValueError(f"Unknown name: {node.id}")
return node
def visit_Attribute(self, node):
raise ValueError("Attribute access is not allowed")
def visit_Call(self, node):
if not isinstance(node.func, ast.Name):
raise ValueError("Only simple function calls are allowed")
if node.func.id not in _ALLOWED_FUNCS:
raise ValueError(f"Function not allowed: {node.func.id}")
return self.generic_visit(node)
def generic_visit(self, node):
allowed = (
ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name,
ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div,
ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
ast.USub, ast.UAdd
)
if not isinstance(node, allowed):
raise ValueError(f"Unsupported syntax: {type(node).__name__}")
return super().generic_visit(node)
def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series:
names = set(df.columns) | {"True","False","None"}
tree = ast.parse(expr, mode="eval")
_SafeExpr(names).visit(tree)
code = compile(tree, "<expr>", "eval")
env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
val = eval(code, {"__builtins__": {}}, env)
if isinstance(val, (pd.Series, np.ndarray, list)):
return pd.Series(val, index=df.index)
if isinstance(val, (bool, np.bool_)):
return pd.Series([val] * len(df), index=df.index)
raise ValueError("Expression must yield a vector or boolean")
# ========= Helpers =========
def _as_df(v: Any) -> Optional[pd.DataFrame]:
if isinstance(v, pd.DataFrame):
return v
if isinstance(v, list):
return pd.DataFrame(v) if v else pd.DataFrame()
if isinstance(v, dict):
flat = all(isinstance(val, (int,float,str,bool,type(None))) for val in v.values())
return pd.DataFrame([v]) if flat else pd.DataFrame()
return None
def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]:
if key and key in datasets:
v = datasets[key]
else:
v = next((vv for vv in datasets.values() if vv is not None), None)
return _as_df(v) if v is not None else None
def _auto_group_cols(df: pd.DataFrame) -> List[str]:
prefs = ["facility","community","settlement","provider","zone","region","district","specialty","program","service","city"]
for p in prefs:
col = resolve_one(p, _to_list(df.columns))
if col:
return [col]
obj_cols = [c for c in df.columns if df[c].dtype == "object"]
return obj_cols[:1] if obj_cols else []
def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
"""
"mean(wait_days), p90(wait_days), count(*)" -> [("mean_wait_days","mean(wait_days)"), ...]
bare token "wait_days" becomes mean(wait_days)
"""
if not spec:
return []
out: List[Tuple[str,str]] = []
for it in [x.strip() for x in spec.split(",") if x.strip()]:
if it.lower() in ("count", "count(*)"):
out.append(("count_*", "count(*)")); continue
m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', it)
if not m:
arg = it
out.append((f"mean_{arg}", f"mean({arg})"))
continue
func, arg = m.group(1).lower(), m.group(2).strip()
out.append((f"{func}_{arg}", f"{func}({arg})"))
return out
def _apply_agg_call(df: pd.DataFrame, call: str):
call = call.strip().lower()
if call in ("count", "count(*)"):
return int(len(df))
m = re.match(r'([a-z_][a-z0-9_]*)\(([^)]+)\)', call)
if not m:
arg = call
if arg not in df.columns: return None
col = pd.to_numeric(df[arg], errors="coerce").dropna()
return float(col.mean()) if len(col) else float("nan")
func, arg = m.group(1), m.group(2).strip()
if arg not in df.columns:
return None
col = pd.to_numeric(df[arg], errors="coerce").dropna()
if not len(col):
return float("nan")
if func in ("avg","mean"): return float(col.mean())
if func == "median": return float(np.median(col))
if func == "sum": return float(col.sum())
if func in ("min","max"): return float(getattr(np, func)(col))
if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:])))
return None
def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame:
m = _eval_series_expr(expr, df)
return df.loc[m.astype(bool)].copy()
def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
# supports "newcol = expr, other = expr2"
parts = re.split(r'[;,]\s*', spec)
for p in parts:
if "=" in p:
col, expr = p.split("=", 1)
df[col.strip()] = _eval_series_expr(expr.strip(), df)
return df
def _render_table(df: pd.DataFrame) -> str:
if df is None or df.empty:
return "_No rows._"
dff = df.copy()
for c in dff.columns:
if pd.api.types.is_float_dtype(dff[c]) or pd.api.types.is_integer_dtype(dff[c]):
dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}")
header = "| " + " | ".join(map(str, dff.columns)) + " |"
sep = "|" + "|".join(["---"] * len(dff.columns)) + "|"
rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
return "\n".join([header, sep, *rows])
def _small_n_flags(df: pd.DataFrame, count_col: Optional[str] = None, threshold: int = 5) -> Optional[pd.Series]:
if df is None or df.empty:
return None
if count_col and count_col in df.columns:
return df[count_col].apply(lambda n: " (interpret cautiously: small n)" if pd.notnull(n) and float(n) < threshold else "")
return None
def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]:
notes = []
for c in metric_cols:
if c in df.columns:
miss = df[c].isna().mean()
if miss > 0:
notes.append(f"{c}: missing {miss:.1%}")
return notes
# ========= Scenario Engine =========
class ScenarioEngine:
"""
Execute a ScenarioPlan (or dict) consisting of tasks that specify:
- data_key: name of dataset in `datasets`
- filter: boolean/vectorized expression (safe-eval)
- derive: "new = expr, ..."
- group_by: list of roles/column names (resolved dynamically)
- agg: "mean(col), p90(col), count(*)" (bare 'col' => mean(col))
- sort_by / sort_dir
- top
- fields: project/alias output columns by role/name (resolved dynamically)
Returns markdown with:
- task section
- table output
- Assumptions & Mappings
- Data Quality notes
"""
@staticmethod
def _group_agg(df: pd.DataFrame,
group_by: Optional[List[str]],
agg_spec: Optional[str],
mapping_log: List[str]) -> pd.DataFrame:
# Resolve grouping to existing columns; tolerate roles or wrong names
if group_by:
gcols = resolve_cols(group_by, _to_list(df.columns))
for want in (group_by or []):
got = resolve_one(want, _to_list(df.columns))
mapping_log.append(f"group_by: {want}{got if got else '(unresolved)'}")
else:
gcols = _auto_group_cols(df)
if gcols:
mapping_log.append(f"group_by: (auto) → {gcols[0]}")
else:
mapping_log.append("group_by: (auto) → (none)")
aggs = _parse_aggs(agg_spec or "")
# No grouping & no agg => just preview a slice
if not gcols:
if not aggs:
return df.head(50).copy()
rec = { out_col: _apply_agg_call(df, call) for out_col, call in aggs }
return pd.DataFrame([rec])
if not aggs:
# default: mean of numeric cols + count(*)
num_cols = list(df.select_dtypes(include="number").columns)
gb = df.groupby(gcols, dropna=False)
if not num_cols:
out = gb.size().reset_index(name="count_*")
return out.sort_values("count_*", ascending=False)
out = gb[num_cols].mean(numeric_only=True)
out["count_*"] = gb.size()
return out.reset_index()
# Apply requested aggs
rows = []
gb = df.groupby(gcols, dropna=False)
for keys, g in gb:
if not isinstance(keys, tuple): keys = (keys,)
rec = { gcols[i]: keys[i] for i in range(len(gcols)) }
for out_col, call in aggs:
rec[out_col] = _apply_agg_call(g, call)
rows.append(rec)
return pd.DataFrame(rows)
@staticmethod
def _project_fields(out_df: pd.DataFrame,
fields: Optional[List[str]],
mapping_log: List[str]) -> pd.DataFrame:
if not isinstance(out_df, pd.DataFrame) or out_df.empty or not fields:
return out_df
cols = resolve_cols(fields, _to_list(out_df.columns))
for want in fields:
got = resolve_one(want, _to_list(out_df.columns))
mapping_log.append(f"field: {want}{got if got else '(unresolved)'}")
if cols:
return out_df[cols]
return out_df
@staticmethod
def _data_quality_notes(out_df: pd.DataFrame) -> List[str]:
notes: List[str] = []
if out_df is None or out_df.empty:
return notes
# small-n flag if a count column exists
cnt_col = None
for c in out_df.columns:
if c.lower() in ("count", "count_*", "n", "records"):
cnt_col = c; break
sn = _small_n_flags(out_df, count_col=cnt_col, threshold=5)
if sn is not None and sn.any():
n_small = (sn != "").sum()
if n_small > 0:
notes.append(f"{n_small} row(s) flagged as small-n (interpret cautiously).")
# missingness for numeric columns
metric_cols = [c for c in out_df.columns if pd.api.types.is_numeric_dtype(out_df[c])]
notes.extend(_missingness(out_df, metric_cols))
return notes
@staticmethod
def _exec_task(t: Any, datasets: Dict[str, Any]) -> str:
title = getattr(t, "title", None) or (isinstance(t, dict) and t.get("title")) or "Task"
section_lines: List[str] = [f"## {title}\n"]
data_key = getattr(t, "data_key", None) or (isinstance(t, dict) and t.get("data_key"))
df = _get_df(datasets, data_key)
if df is None or df.empty:
section_lines.append("_No matching data for this task._")
return "\n".join(section_lines)
# Filter(s)
t_filter = getattr(t, "filter", None) or (isinstance(t, dict) and t.get("filter"))
if t_filter:
try:
df = _apply_filter(df, t_filter)
except Exception as e:
section_lines.append(f"_Warning: filter ignored ({e})._")
# Derive(s)
t_derive = getattr(t, "derive", None) or (isinstance(t, dict) and t.get("derive"))
if t_derive:
for d in (t_derive if isinstance(t_derive, (list, tuple)) else [t_derive]):
try:
df = _apply_derive(df, d)
except Exception as e:
section_lines.append(f"_Warning: derive ignored ({e})._")
# Group/Agg
t_group_by = getattr(t, "group_by", None) or (isinstance(t, dict) and t.get("group_by"))
if isinstance(t_group_by, str):
t_group_by = [t_group_by]
t_agg = getattr(t, "agg", None) or (isinstance(t, dict) and t.get("agg"))
agg_spec = ", ".join(t_agg) if isinstance(t_agg, list) else (t_agg or None)
mapping_log: List[str] = []
out_df = ScenarioEngine._group_agg(df, t_group_by, agg_spec, mapping_log)
# Sort / Top
t_sort_by = getattr(t, "sort_by", None) or (isinstance(t, dict) and t.get("sort_by"))
t_sort_dir = (getattr(t, "sort_dir", None) or (isinstance(t, dict) and t.get("sort_dir")) or "desc").lower()
if isinstance(out_df, pd.DataFrame) and t_sort_by and t_sort_by in out_df.columns:
out_df = out_df.sort_values(t_sort_by, ascending=(t_sort_dir=="asc"))
t_top = getattr(t, "top", None) or (isinstance(t, dict) and t.get("top"))
if isinstance(out_df, pd.DataFrame) and isinstance(t_top, int) and t_top > 0:
out_df = out_df.head(t_top)
# Field projection
t_fields = getattr(t, "fields", None) or (isinstance(t, dict) and t.get("fields"))
if isinstance(t_fields, str):
t_fields = [t_fields]
out_df = ScenarioEngine._project_fields(out_df, t_fields, mapping_log)
# Render
section_lines.append(_render_table(out_df))
# Assumptions & Mappings
if mapping_log:
section_lines.append("\n**Assumptions & Mappings**")
for line in mapping_log:
section_lines.append(f"- {line}")
# Data quality
dq = ScenarioEngine._data_quality_notes(out_df)
if dq:
section_lines.append("\n**Data Quality Notes**")
for n in dq:
section_lines.append(f"- {n}")
return "\n".join(section_lines)
@staticmethod
def execute_plan(plan: Any, datasets: Dict[str, Any]) -> str:
"""
plan: object or dict with `tasks: List[Task]`
Each Task can have: title, data_key, filter, derive, group_by, agg, sort_by, sort_dir, top, fields
"""
sections: List[str] = ["# Scenario Output\n"]
tasks = getattr(plan, "tasks", None) or (isinstance(plan, dict) and plan.get("tasks")) or []
for t in tasks:
sections.append(ScenarioEngine._exec_task(t, datasets))
return "\n".join(sections).strip()