causalscience commited on
Commit
9bf8127
·
verified ·
1 Parent(s): 3551cf7

Aug 2025 Bug Fixes

Browse files
Files changed (1) hide show
  1. models/propensity.py +736 -127
models/propensity.py CHANGED
@@ -1,127 +1,736 @@
1
- # causalscience/models/propensity.py
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from sklearn.linear_model import LogisticRegression
6
- import seaborn as sns
7
- import matplotlib.pyplot as plt
8
- from io import BytesIO
9
- from PIL import Image
10
- from scipy.special import logit # MODIFIED: import for logit transform
11
-
12
- # no imports from models.propensity here! # MODIFIED
13
-
14
- from utils.helpers import detect_column_type
15
- from utils.plotting import calculate_standardized_differences, love_plot
16
-
17
-
18
- def fit_propensity_score(
19
- df,
20
- treatment_col,
21
- feature_cols=None,
22
- C=1e6,
23
- max_iter=1000
24
- ):
25
- """
26
- Fit a logistic regression model to estimate propensity scores.
27
-
28
- # MODIFIED: C (regularization strength) and max_iter are now configurable parameters.
29
- """
30
- if feature_cols is None:
31
- feature_cols = [col for col in df.columns if col != treatment_col]
32
-
33
- X = df[feature_cols].copy()
34
- for col in X.columns:
35
- if not pd.api.types.is_numeric_dtype(X[col]):
36
- try:
37
- X[col] = pd.to_numeric(X[col], errors='coerce')
38
- except:
39
- X = X.drop(columns=[col])
40
- if feature_cols:
41
- feature_cols.remove(col)
42
-
43
- y = df[treatment_col].astype(int)
44
-
45
- model = LogisticRegression(C=C, max_iter=max_iter)
46
- model.fit(X, y)
47
-
48
- df_scores = df.copy()
49
- df_scores['propensity_score'] = model.predict_proba(X)[:, 1]
50
- return model, df_scores
51
-
52
-
53
- def match_with_caliper(
54
- df,
55
- treatment_col,
56
- caliper_width=0.2,
57
- with_replacement=True,
58
- use_logit_caliper=False # MODIFIED: option to compute caliper on logit scale
59
- ):
60
- """
61
- Perform 1:1 nearest neighbor matching on propensity scores within a caliper.
62
- """
63
- # select scale for caliper
64
- if use_logit_caliper:
65
- scores = logit(df['propensity_score']) # MODIFIED
66
- else:
67
- scores = df['propensity_score']
68
-
69
- ps_std = scores.std()
70
- caliper = caliper_width * ps_std
71
-
72
- treated = df[df[treatment_col] == 1].copy()
73
- control = df[df[treatment_col] == 0].copy()
74
-
75
- matches = []
76
- pair_id = 0
77
- for idx, row in treated.iterrows():
78
- if use_logit_caliper:
79
- tgt = logit(row['propensity_score']) # MODIFIED
80
- diffs = (logit(control['propensity_score']) - tgt).abs() # MODIFIED
81
- else:
82
- diffs = (control['propensity_score'] - row['propensity_score']).abs()
83
-
84
- eligible = control[diffs <= caliper]
85
- if eligible.empty:
86
- continue
87
-
88
- best_idx = diffs[diffs <= caliper].idxmin() # MODIFIED
89
- match = control.loc[best_idx].copy()
90
- match['pair_id'] = pair_id
91
- treated.at[idx, 'pair_id'] = pair_id
92
- matches.append(match)
93
-
94
- pair_id += 1
95
- if not with_replacement:
96
- control = control.drop(best_idx)
97
-
98
- matched_controls = pd.DataFrame(matches)
99
- matched_df = pd.concat([
100
- treated.dropna(subset=['pair_id']),
101
- matched_controls
102
- ], ignore_index=True)
103
- return matched_df
104
-
105
-
106
- def assess_balance(
107
- df,
108
- matched_df,
109
- treatment_col,
110
- covariate_cols,
111
- threshold=0.1
112
- ):
113
- """
114
- Compute standardized differences and create a Love plot.
115
- """
116
- covariates = covariate_cols # MODIFIED: use explicit covariates list
117
-
118
- std_unadj = calculate_standardized_differences(df, covariates, treatment_col)
119
- std_matched = calculate_standardized_differences(matched_df, covariates, treatment_col)
120
-
121
- love_img = love_plot(
122
- [std_unadj, std_matched],
123
- labels=['Unadjusted', 'Matched'],
124
- threshold=threshold,
125
- abs_val=True
126
- )
127
- return love_img, std_unadj, std_matched
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Runtime-safe installs
3
+ try:
4
+ import numpy # noqa
5
+ import pandas # noqa
6
+ import sklearn # noqa
7
+ import matplotlib # noqa
8
+ import PIL # noqa
9
+ except Exception:
10
+ import sys, subprocess
11
+ subprocess.run(
12
+ [sys.executable, "-m", "pip", "install", "-q",
13
+ "numpy", "pandas", "scikit-learn", "matplotlib", "pillow"],
14
+ check=False
15
+ )
16
+
17
+ # models/propensity.py
18
+ from dataclasses import dataclass
19
+ from typing import List, Tuple, Optional, Union, Dict
20
+
21
+ import io
22
+ import numpy as np
23
+ import pandas as pd
24
+ from PIL import Image
25
+
26
+ import matplotlib
27
+ matplotlib.use("Agg")
28
+ import matplotlib.pyplot as plt
29
+
30
+ from sklearn.linear_model import LogisticRegression
31
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
32
+ from sklearn.compose import ColumnTransformer
33
+ from sklearn.pipeline import Pipeline
34
+
35
+ # -----------------------------
36
+ # Helpers
37
+ # -----------------------------
38
+
39
+ def _ensure_binary(series: pd.Series) -> pd.Series:
40
+ s = series.copy()
41
+ if s.dtype == bool:
42
+ return s.astype(int)
43
+ if s.dtype == object:
44
+ mapping = {"t":1,"true":1,"yes":1,"y":1,"1":1,"f":0,"false":0,"no":0,"n":0,"0":0}
45
+ m = s.astype(str).str.strip().str.lower().map(mapping).astype("Int64")
46
+ if m.isna().any():
47
+ try:
48
+ sn = pd.to_numeric(s, errors="raise")
49
+ if set(pd.unique(sn)) <= {0,1}:
50
+ return sn.astype(int)
51
+ except Exception:
52
+ pass
53
+ raise ValueError("Treatment column cannot be coerced to binary 0/1.")
54
+ return m.astype(int)
55
+ uniq = set(pd.unique(s.dropna()))
56
+ if uniq <= {0,1} or uniq <= {0.0,1.0}:
57
+ return s.astype(int)
58
+ raise ValueError("Treatment column must contain values that map to 0/1.")
59
+
60
+ def _select_features(df: pd.DataFrame, feature_cols: List[str], outcome_col: Optional[str]) -> List[str]:
61
+ cols = [c for c in feature_cols if c in df.columns]
62
+ if outcome_col and outcome_col in df.columns and outcome_col not in cols:
63
+ cols.append(outcome_col) # include outcome for balance view only
64
+ return cols
65
+
66
+ def _split_num_cat(df: pd.DataFrame, cols: List[str]) -> Tuple[List[str], List[str]]:
67
+ num, cat = [], []
68
+ for c in cols:
69
+ (num if pd.api.types.is_numeric_dtype(df[c]) else cat).append(c)
70
+ return num, cat
71
+
72
+ # -----------------------------
73
+ # Propensity model
74
+ # -----------------------------
75
+
76
+ def _fit_propensity(df: pd.DataFrame, treatment_col: str, features: List[str]) -> Tuple[np.ndarray, Pipeline]:
77
+ y = _ensure_binary(df[treatment_col])
78
+ num, cat = _split_num_cat(df, features)
79
+ transformers = []
80
+ if num:
81
+ transformers.append(("num", StandardScaler(), num))
82
+ if cat:
83
+ try:
84
+ ohe = OneHotEncoder(handle_unknown="ignore", drop="first", sparse_output=False) # sklearn>=1.2
85
+ except TypeError:
86
+ ohe = OneHotEncoder(handle_unknown="ignore", drop="first", sparse=False)
87
+ transformers.append(("cat", ohe, cat))
88
+ pre = ColumnTransformer(transformers, remainder="drop", verbose_feature_names_out=False)
89
+ clf = LogisticRegression(max_iter=2000, solver="liblinear")
90
+ pipe = Pipeline([("pre", pre), ("clf", clf)])
91
+ pipe.fit(df[features], y.values)
92
+ ps = pipe.predict_proba(df[features])[:, 1]
93
+ return ps, pipe
94
+
95
+ # -----------------------------
96
+ # Matching
97
+ # -----------------------------
98
+
99
+ @dataclass
100
+ class MatchSummary:
101
+ method: str
102
+ treated_rows: int
103
+ control_rows: int
104
+ unique_controls: int
105
+ min_controls: int
106
+ max_controls: int
107
+ replacement: bool
108
+ caliper: Optional[float]
109
+ n_strata: Optional[int] = None
110
+ caliper_dropped_treated: int = 0 # diagnostic
111
+
112
+ def _greedy_nearest(
113
+ df: pd.DataFrame,
114
+ ps_col: str,
115
+ treatment_col: str,
116
+ min_controls: int,
117
+ max_controls: int,
118
+ replacement: bool,
119
+ caliper: Optional[float] = None,
120
+ ) -> Tuple[pd.DataFrame, MatchSummary]:
121
+ work = df.copy()
122
+ work["__rowid__"] = np.arange(len(work))
123
+ treated = work[work[treatment_col] == 1]
124
+ control = work[work[treatment_col] == 0].copy()
125
+ if control.empty or treated.empty:
126
+ return pd.DataFrame(), MatchSummary("nearest", 0, 0, 0, min_controls, max_controls, replacement, caliper)
127
+
128
+ used = set()
129
+ pairs: List[Tuple[int, int]] = []
130
+ dropped_due_to_caliper = 0
131
+
132
+ for _, t in treated.iterrows():
133
+ diffs = control.copy()
134
+ diffs["__dist__"] = (diffs[ps_col] - t[ps_col]).abs()
135
+ if caliper is not None and caliper >= 0:
136
+ diffs = diffs[diffs["__dist__"] <= caliper]
137
+ if not replacement:
138
+ diffs = diffs[~diffs["__rowid__"].isin(used)]
139
+ diffs = diffs.sort_values("__dist__", ascending=True).head(max_controls)
140
+ if len(diffs) < min_controls:
141
+ if caliper is not None and caliper >= 0:
142
+ dropped_due_to_caliper += 1
143
+ continue
144
+ for _, c in diffs.iterrows():
145
+ pairs.append((int(t["__rowid__"]), int(c["__rowid__"])))
146
+ if not replacement:
147
+ used.add(int(c["__rowid__"]))
148
+
149
+ if not pairs:
150
+ return pd.DataFrame(), MatchSummary("nearest", 0, 0, 0, min_controls, max_controls, replacement, caliper, caliper_dropped_treated=dropped_due_to_caliper)
151
+
152
+ idx_t = [p[0] for p in pairs]
153
+ idx_c = [p[1] for p in pairs]
154
+ wsi = work.set_index("__rowid__")
155
+
156
+ # Build a UNIQUE mapping of treated rowid -> group id (in first-seen order)
157
+ # This avoids Pandas "Reindexing only valid with uniquely valued Index objects" during map().
158
+ treated_seq = idx_t # sequence with duplicates, one per pair
159
+ seen = set()
160
+ unique_treated_in_order = []
161
+ for t_id in treated_seq:
162
+ if t_id not in seen:
163
+ unique_treated_in_order.append(t_id)
164
+ seen.add(t_id)
165
+ group_map = {t_id: gid for gid, t_id in enumerate(unique_treated_in_order)}
166
+
167
+ mt = wsi.loc[idx_t].copy()
168
+ mt["__role__"] = "treated"
169
+ mt["__match_group__"] = [group_map[t_id] for t_id in mt.index]
170
+
171
+ mc = wsi.loc[idx_c].copy()
172
+ mc["__role__"] = "control"
173
+ # align groups to each pair order (idx_t and idx_c are parallel)
174
+ mc["__match_group__"] = [group_map[t_id] for t_id in idx_t]
175
+
176
+ matched_stack = pd.concat([mt, mc], ignore_index=True)
177
+
178
+ summary = MatchSummary(
179
+ method="nearest",
180
+ treated_rows=mt.shape[0],
181
+ control_rows=mc.shape[0],
182
+ unique_controls=len(set(idx_c)),
183
+ min_controls=min_controls, max_controls=max_controls,
184
+ replacement=replacement, caliper=caliper,
185
+ n_strata=None,
186
+ caliper_dropped_treated=dropped_due_to_caliper,
187
+ )
188
+ return matched_stack, summary
189
+
190
+ def _caliper_matching(
191
+ df: pd.DataFrame,
192
+ ps_col: str,
193
+ treatment_col: str,
194
+ min_controls: int,
195
+ max_controls: int,
196
+ replacement: bool,
197
+ caliper: float,
198
+ ) -> Tuple[pd.DataFrame, MatchSummary]:
199
+ if caliper is None or caliper < 0:
200
+ raise ValueError("`caliper` must be a non-negative float for caliper matching.")
201
+ matched, base_summary = _greedy_nearest(
202
+ df, ps_col, treatment_col, min_controls, max_controls, replacement, caliper=caliper
203
+ )
204
+ summary = MatchSummary(
205
+ method="caliper", # MODIFIED: correct method label
206
+ treated_rows=base_summary.treated_rows,
207
+ control_rows=base_summary.control_rows,
208
+ unique_controls=base_summary.unique_controls,
209
+ min_controls=min_controls, max_controls=max_controls,
210
+ replacement=replacement, caliper=caliper,
211
+ n_strata=None,
212
+ caliper_dropped_treated=base_summary.caliper_dropped_treated,
213
+ )
214
+ return matched, summary
215
+
216
+ def _stratification(
217
+ df: pd.DataFrame,
218
+ ps_col: str,
219
+ treatment_col: str,
220
+ n_strata: int,
221
+ ) -> Tuple[pd.DataFrame, MatchSummary]:
222
+ # MODIFIED: Stratification implemented with ATT weights per stratum
223
+ work = df.copy()
224
+ work["__stratum__"] = pd.qcut(work[ps_col], q=n_strata, labels=False, duplicates="drop")
225
+ work["__role__"] = work[treatment_col].apply(lambda x: "treated" if int(x) == 1 else "control")
226
+ work["__match_group__"] = work["__stratum__"]
227
+
228
+ work["__weight__"] = 1.0
229
+ for s in sorted(work["__stratum__"].dropna().unique()):
230
+ sm = work["__stratum__"] == s
231
+ n_treated = int((work.loc[sm, treatment_col] == 1).sum())
232
+ n_control = int((work.loc[sm, treatment_col] == 0).sum())
233
+ if n_treated > 0 and n_control > 0:
234
+ work.loc[sm & (work[treatment_col] == 1), "__weight__"] = 1.0
235
+ work.loc[sm & (work[treatment_col] == 0), "__weight__"] = n_treated / n_control
236
+ else:
237
+ work.loc[sm, "__weight__"] = 1.0 # MODIFIED: if unbalanced stratum, keep neutral weights
238
+
239
+ # Diagnostics — which strata contain both treated and control
240
+ strata_balance = work.groupby("__stratum__")[treatment_col].agg(["sum", "count"])
241
+ balanced_strata = strata_balance[(strata_balance["sum"] > 0) & (strata_balance["sum"] < strata_balance["count"])].index
242
+ work["__balanced_stratum__"] = work["__stratum__"].isin(balanced_strata)
243
+
244
+ treated_count = int((work[treatment_col] == 1).sum())
245
+ control_count = int((work[treatment_col] == 0).sum())
246
+
247
+ summary = MatchSummary(
248
+ method="stratification",
249
+ treated_rows=treated_count,
250
+ control_rows=control_count,
251
+ unique_controls=control_count, # all controls retained
252
+ min_controls=0,
253
+ max_controls=0,
254
+ replacement=True,
255
+ caliper=None,
256
+ n_strata=n_strata,
257
+ )
258
+ return work, summary
259
+
260
+ # -----------------------------
261
+ # Balance + plotting
262
+ # -----------------------------
263
+
264
+ def _standardized_mean_differences(df: pd.DataFrame, treatment_col: str, covariates: List[str]) -> pd.DataFrame:
265
+ # MODIFIED: support optional weighting via "__weight__" if present (for stratification)
266
+ if df is None or len(df) == 0:
267
+ return pd.DataFrame(columns=["variable", "smd", "abs_smd"])
268
+ out = []
269
+ tmask = df[treatment_col] == 1
270
+ cmask = df[treatment_col] == 0
271
+
272
+ has_w = "__weight__" in df.columns
273
+ wt = df["__weight__"] if has_w else None
274
+
275
+ for v in covariates:
276
+ if v not in df.columns: # guard for missing cols
277
+ continue
278
+ a = pd.to_numeric(df.loc[tmask, v], errors="coerce")
279
+ b = pd.to_numeric(df.loc[cmask, v], errors="coerce")
280
+ if has_w:
281
+ wa = pd.to_numeric(wt.loc[tmask], errors="coerce")
282
+ wb = pd.to_numeric(wt.loc[cmask], errors="coerce")
283
+ # Drop NaNs aligned
284
+ am = a.notna() & wa.notna()
285
+ bm = b.notna() & wb.notna()
286
+ a, wa = a[am], wa[am]
287
+ b, wb = b[bm], wb[bm]
288
+ def wmean(x, w):
289
+ sw = float(w.sum())
290
+ return np.nan if sw == 0 else float(np.sum(w * x) / sw)
291
+ def wvar(x, w, mean):
292
+ sw = float(w.sum())
293
+ return np.nan if sw == 0 else float(np.sum(w * (x - mean) ** 2) / sw)
294
+ ma = wmean(a, wa); mb = wmean(b, wb)
295
+ va = wvar(a, wa, ma); vb = wvar(b, wb, mb)
296
+ else:
297
+ ma, mb = a.mean(), b.mean()
298
+ va, vb = a.var(ddof=1), b.var(ddof=1)
299
+ denom = np.sqrt(np.nanmean([va, vb])) if not (np.isnan(va) and np.isnan(vb)) else np.nan
300
+ smd = np.nan if (denom == 0 or np.isnan(denom)) else (ma - mb) / float(denom)
301
+ out.append((v, smd, abs(smd) if pd.notna(smd) else np.nan))
302
+ return pd.DataFrame(out, columns=["variable", "smd", "abs_smd"])
303
+
304
+ def _plot_love_before_after(
305
+ smd_pre: pd.DataFrame,
306
+ smd_post: pd.DataFrame,
307
+ title: str,
308
+ *,
309
+ empty_msg: Optional[str] = None,
310
+ xmax: Optional[float] = None,
311
+ fixed_order: Optional[List[str]] = None
312
+ ) -> Image.Image:
313
+ # Reconstructed plotting helper (equivalent to prior version)
314
+ def frame(df: Optional[pd.DataFrame], key: str) -> pd.DataFrame:
315
+ if df is None or df.empty:
316
+ return pd.DataFrame(columns=["variable", key])
317
+ return df[["variable", "abs_smd"]].rename(columns={"abs_smd": key})
318
+
319
+ a = frame(smd_pre, "before")
320
+ b = frame(smd_post, "after")
321
+ m = pd.merge(a, b, on="variable", how="outer")
322
+ m = m[~(m["before"].isna() & m["after"].isna())]
323
+ if m.empty:
324
+ fig, ax = plt.subplots(figsize=(6.5, 3.0))
325
+ ax.text(0.5, 0.5, empty_msg or "No balance data to plot.", ha="center", va="center", transform=ax.transAxes)
326
+ ax.axis("off")
327
+ buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
328
+ plt.close(fig); buf.seek(0)
329
+ return Image.open(buf)
330
+
331
+ # MODIFIED: stable ordering across methods
332
+ if fixed_order:
333
+ cat = pd.Categorical(m["variable"], categories=fixed_order, ordered=True)
334
+ m = m.assign(_ord=cat).sort_values("_ord").drop(columns=["_ord"])
335
+ else:
336
+ m["_sort"] = m["before"].fillna(-np.inf)
337
+ m.sort_values(["_sort"], ascending=[False], inplace=True)
338
+ m.drop(columns=["_sort"], inplace=True)
339
+
340
+ y = np.arange(len(m))
341
+ fig, ax = plt.subplots(figsize=(7.5, max(3.0, 0.6 * len(m))))
342
+
343
+ for i, row in m.reset_index(drop=True).iterrows():
344
+ bi, ai = row["before"], row["after"]
345
+ if pd.notna(bi) and pd.notna(ai):
346
+ ax.plot([bi, ai], [i, i], linewidth=1)
347
+
348
+ ax.scatter(m["before"], y, label="Before", zorder=3)
349
+ ax.scatter(m["after"], y, label="After", zorder=3, marker="s")
350
+
351
+ ax.set_yticks(y)
352
+ ax.set_yticklabels(m["variable"].tolist())
353
+ ax.invert_yaxis()
354
+ ax.set_xlabel("|SMD|")
355
+ ax.set_title(title)
356
+ ax.axvline(0.10, linestyle="--")
357
+ ax.grid(axis="x", linestyle=":", alpha=0.4)
358
+
359
+ if xmax is not None:
360
+ ax.set_xlim(0, xmax)
361
+
362
+ ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
363
+ fig.tight_layout(rect=[0.0, 0.0, 0.82, 1.0])
364
+
365
+ buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
366
+ plt.close(fig); buf.seek(0)
367
+ return Image.open(buf)
368
+
369
+ # -----------------------------
370
+ # Public pipeline (legacy)
371
+ # -----------------------------
372
+
373
+ def run_propensity_analysis(
374
+ data: Union[pd.DataFrame, str],
375
+ treatment_col: str,
376
+ feature_cols: List[str],
377
+ outcome_col: str = "",
378
+ matching_method: str = "nearest",
379
+ caliper: Optional[float] = None,
380
+ min_controls: int = 1,
381
+ max_controls: int = 1,
382
+ replacement: bool = True,
383
+ n_strata: int = 5,
384
+ include_balance: bool = True,
385
+ ) -> Tuple[str, Optional[Image.Image]]:
386
+ try:
387
+ # Load data
388
+ if isinstance(data, str):
389
+ if data.lower().endswith(".csv"):
390
+ df = pd.read_csv(data)
391
+ else:
392
+ raise ValueError("Only CSV paths are supported when passing a string to `data`.")
393
+ elif isinstance(data, pd.DataFrame):
394
+ df = data.copy()
395
+ else:
396
+ raise ValueError("`data` must be a pandas DataFrame or a CSV file path.")
397
+
398
+ if treatment_col not in df.columns:
399
+ raise ValueError(f"Treatment column '{treatment_col}' not found in data.")
400
+
401
+ covariates_used = _select_features(df, feature_cols, outcome_col if outcome_col else None)
402
+
403
+ # Fit propensity model
404
+ df[treatment_col] = _ensure_binary(df[treatment_col])
405
+ ps, _ = _fit_propensity(df, treatment_col, covariates_used)
406
+ df["__ps__"] = ps
407
+
408
+ # Matching
409
+ method = (matching_method or "nearest").lower()
410
+ matched = pd.DataFrame()
411
+ summary: Optional[MatchSummary] = None
412
+
413
+ if method == "nearest":
414
+ matched, summary = _greedy_nearest(
415
+ df, "__ps__", treatment_col, min_controls, max_controls, replacement, caliper=None # MODIFIED: force pure nearest
416
+ )
417
+ elif method == "caliper":
418
+ if caliper is None or caliper < 0:
419
+ raise ValueError("Caliper matching requires a non-negative `caliper`.")
420
+ matched, summary = _caliper_matching(
421
+ df, "__ps__", treatment_col, min_controls, max_controls, replacement, caliper
422
+ )
423
+ elif method == "stratification":
424
+ matched, summary = _stratification(df, "__ps__", treatment_col, n_strata)
425
+ else:
426
+ raise ValueError("matching_method must be one of {'nearest','caliper','stratification'}.")
427
+
428
+ # Build report
429
+ report_lines = [
430
+ f"Matching Method: {summary.method if summary else method}",
431
+ f"Knobs -> min_controls={summary.min_controls if summary else min_controls}, "
432
+ f"max_controls={summary.max_controls if summary else max_controls}, "
433
+ f"replacement={summary.replacement if summary else replacement}, "
434
+ f"caliper={summary.caliper if summary else caliper}, "
435
+ f"n_strata={summary.n_strata if summary else (n_strata if method=='stratification' else None)}",
436
+ f"Matched counts -> treated_rows={summary.treated_rows if summary else 0}, "
437
+ f"control_rows={summary.control_rows if summary else 0}, "
438
+ f"unique_controls={summary.unique_controls if summary else 0}",
439
+ ]
440
+
441
+ if summary and summary.caliper is not None:
442
+ binding = bool(summary.caliper_dropped_treated > 0)
443
+ report_lines.append(f"caliper_dropped_treated={summary.caliper_dropped_treated}")
444
+ report_lines.append(f"caliper_binding={'True' if binding else 'False'}")
445
+
446
+ love_img = None
447
+ if include_balance:
448
+ smd_pre = _standardized_mean_differences(df, treatment_col, covariates_used)
449
+ smd_post = _standardized_mean_differences(matched, treatment_col, covariates_used) if not matched.empty else pd.DataFrame()
450
+
451
+ # MODIFIED: compute unified x-axis from PRE (identical across methods) + small headroom
452
+ if not smd_pre.empty and smd_pre["abs_smd"].notna().any():
453
+ max_pre = float(np.nanmax(smd_pre["abs_smd"]))
454
+ xmax = max(0.10, max_pre) * 1.1
455
+ else:
456
+ xmax = 0.5 # safe fallback
457
+
458
+ # MODIFIED: fixed order = by PRE imbalance ensures all methods align
459
+ fixed_order = smd_pre.sort_values("abs_smd", ascending=False)["variable"].tolist()
460
+
461
+ love_img = _plot_love_before_after(
462
+ smd_pre, smd_post,
463
+ title=f"Love Plot — {summary.method.title() if summary else method.title()} Matching",
464
+ empty_msg="No matched sample to assess." if matched.empty else None,
465
+ xmax=xmax, # MODIFIED
466
+ fixed_order=fixed_order # MODIFIED
467
+ )
468
+
469
+ preview = (smd_post if not smd_post.empty else smd_pre).sort_values("abs_smd", ascending=False).head(10)
470
+ report_lines.append("\nBalance (|SMD|) summary (first 10):")
471
+ for _, r in preview.iterrows():
472
+ val = (np.round(r["abs_smd"], 4) if pd.notna(r["abs_smd"]) else "NaN")
473
+ report_lines.append(f" {r['variable']}: {val}")
474
+
475
+ return "\n".join(report_lines), love_img
476
+
477
+ except Exception as e:
478
+ return f"An unexpected error occurred: {e}", None
479
+
480
+ # -----------------------------
481
+ # MODIFIED: Data export helpers for v2 (Edges & Units)
482
+ # -----------------------------
483
+
484
+ _EDGES_COLUMNS = [
485
+ "method", "group_id", "treated_unit_id", "control_unit_id",
486
+ "neighbor_rank", "distance", "edge_weight",
487
+ "replacement", "min_controls", "max_controls", "caliper", "n_strata"
488
+ ] # MODIFIED
489
+
490
+ def _build_edges_units_nearest_caliper( # MODIFIED: new helper builds export frames
491
+ df: pd.DataFrame,
492
+ ps_col: str,
493
+ treatment_col: str,
494
+ min_controls: int,
495
+ max_controls: int,
496
+ replacement: bool,
497
+ caliper: Optional[float],
498
+ method: str,
499
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, MatchSummary]:
500
+ work = df.copy()
501
+ work["__unit_id__"] = np.arange(len(work)) # MODIFIED: stable integer id derived from row order
502
+ treated = work[work[treatment_col] == 1].copy()
503
+ control = work[work[treatment_col] == 0].copy()
504
+
505
+ used = set()
506
+ edges_records: List[Dict] = []
507
+ dropped_due_to_caliper = 0
508
+
509
+ # Build edges per treated → top-K nearest controls (respect replacement & optional caliper)
510
+ for _, t in treated.iterrows():
511
+ diffs = control.copy()
512
+ diffs["__dist__"] = (diffs[ps_col] - t[ps_col]).abs()
513
+ if caliper is not None and caliper >= 0:
514
+ diffs = diffs[diffs["__dist__"] <= caliper]
515
+ if not replacement:
516
+ diffs = diffs[~diffs["__unit_id__"].isin(used)]
517
+ diffs = diffs.sort_values("__dist__", ascending=True).head(max_controls)
518
+
519
+ if len(diffs) < min_controls:
520
+ if caliper is not None and caliper >= 0:
521
+ dropped_due_to_caliper += 1
522
+ continue
523
+
524
+ # neighbor_rank assigned in sorted order
525
+ for rank, (_, crow) in enumerate(diffs.iterrows(), start=1):
526
+ edges_records.append({
527
+ "method": method,
528
+ "group_id": int(t["__unit_id__"]),
529
+ "treated_unit_id": int(t["__unit_id__"]),
530
+ "control_unit_id": int(crow["__unit_id__"]),
531
+ "neighbor_rank": int(rank),
532
+ "distance": float(crow["__dist__"]),
533
+ # edge_weight filled later after we know k-per-group
534
+ "edge_weight": np.nan,
535
+ "replacement": bool(replacement),
536
+ "min_controls": int(min_controls),
537
+ "max_controls": int(max_controls),
538
+ "caliper": (float(caliper) if caliper is not None else np.nan),
539
+ "n_strata": np.nan,
540
+ })
541
+ if not replacement:
542
+ used.add(int(crow["__unit_id__"]))
543
+
544
+ if not edges_records:
545
+ # Empty frames with proper schema
546
+ edges_df = pd.DataFrame(columns=_EDGES_COLUMNS)
547
+ units_df = pd.DataFrame(columns=["unit_id", "role", "ps", "treatment", "group_id", "method",
548
+ "replacement", "min_controls", "max_controls", "caliper", "n_strata"])
549
+ summary = MatchSummary(method=method, treated_rows=0, control_rows=0, unique_controls=0,
550
+ min_controls=min_controls, max_controls=max_controls,
551
+ replacement=replacement, caliper=caliper, n_strata=None,
552
+ caliper_dropped_treated=dropped_due_to_caliper)
553
+ return units_df, edges_df, summary
554
+
555
+ edges_df = pd.DataFrame.from_records(edges_records)[_EDGES_COLUMNS]
556
+
557
+ # Fill edge_weight = 1/k within each treated group (synthetic control equal weights)
558
+ sizes = edges_df.groupby("group_id")["control_unit_id"].transform("count")
559
+ edges_df["edge_weight"] = 1.0 / sizes
560
+
561
+ included_groups = edges_df["group_id"].unique()
562
+ # Treated rows (one per group)
563
+ tre = work[work["__unit_id__"].isin(included_groups)].copy()
564
+ tre_df = tre.assign(
565
+ unit_id=tre["__unit_id__"].astype(int),
566
+ role="treated",
567
+ ps=tre[ps_col].astype(float),
568
+ treatment=1,
569
+ group_id=tre["__unit_id__"].astype(int),
570
+ method=method,
571
+ replacement=bool(replacement),
572
+ min_controls=int(min_controls),
573
+ max_controls=int(max_controls),
574
+ caliper=(float(caliper) if caliper is not None else np.nan),
575
+ n_strata=np.nan,
576
+ )
577
+
578
+ # Controls (one row per edge, allows replacement across groups)
579
+ ctrl_rows = []
580
+ for _, e in edges_df.iterrows():
581
+ c = work.loc[work["__unit_id__"] == e["control_unit_id"]].iloc[0]
582
+ ctrl_rows.append({
583
+ **{col: c[col] for col in work.columns}, # original columns
584
+ "unit_id": int(c["__unit_id__"]),
585
+ "role": "control",
586
+ "ps": float(c[ps_col]),
587
+ "treatment": 0,
588
+ "group_id": int(e["group_id"]),
589
+ "method": method,
590
+ "replacement": bool(replacement),
591
+ "min_controls": int(min_controls),
592
+ "max_controls": int(max_controls),
593
+ "caliper": (float(caliper) if caliper is not None else np.nan),
594
+ "n_strata": np.nan,
595
+ })
596
+ ctrl_df = pd.DataFrame(ctrl_rows) if ctrl_rows else pd.DataFrame(columns=list(tre_df.columns))
597
+
598
+ base_cols = [c for c in work.columns if not c.startswith("__")]
599
+ export_cols = ["unit_id", "role", "ps", "treatment", "group_id", "method",
600
+ "replacement", "min_controls", "max_controls", "caliper", "n_strata"]
601
+ tre_df = tre_df[base_cols + export_cols]
602
+ if not ctrl_df.empty:
603
+ ctrl_df = ctrl_df[base_cols + export_cols]
604
+ units_df = pd.concat([tre_df, ctrl_df], ignore_index=True)
605
+
606
+ summary = MatchSummary(
607
+ method=method,
608
+ treated_rows=tre_df.shape[0],
609
+ control_rows=ctrl_df.shape[0],
610
+ unique_controls=int(edges_df["control_unit_id"].nunique()),
611
+ min_controls=min_controls,
612
+ max_controls=max_controls,
613
+ replacement=replacement,
614
+ caliper=caliper,
615
+ n_strata=None,
616
+ caliper_dropped_treated=dropped_due_to_caliper,
617
+ )
618
+ return units_df, edges_df, summary
619
+
620
+ def _build_units_stratification( # MODIFIED: new helper for stratification exports
621
+ df: pd.DataFrame,
622
+ ps_col: str,
623
+ treatment_col: str,
624
+ n_strata: int,
625
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, MatchSummary]:
626
+ work = df.copy()
627
+ work["__unit_id__"] = np.arange(len(work))
628
+ strat_df, summary = _stratification(work, ps_col, treatment_col, n_strata)
629
+
630
+ base_cols = [c for c in strat_df.columns if not c.startswith("__")] + ["__unit_id__"]
631
+ units = strat_df.copy()
632
+ units = units.assign(
633
+ unit_id=units["__unit_id__"].astype(int),
634
+ role=units["__role__"],
635
+ ps=units[ps_col].astype(float),
636
+ treatment=units[treatment_col].astype(int),
637
+ group_id=units["__stratum__"].astype(int),
638
+ method="stratification",
639
+ replacement=True,
640
+ min_controls=0,
641
+ max_controls=0,
642
+ caliper=np.nan,
643
+ n_strata=int(n_strata),
644
+ weight=units["__weight__"].astype(float),
645
+ balanced_stratum=units["__balanced_stratum__"].astype(bool),
646
+ stratum=units["__stratum__"].astype(int),
647
+ )
648
+
649
+ # Order and drop helpers
650
+ keep_export = [c for c in base_cols if c != "__unit_id__"] + [
651
+ "unit_id", "role", "ps", "treatment", "group_id", "method",
652
+ "replacement", "min_controls", "max_controls", "caliper", "n_strata",
653
+ "weight", "balanced_stratum", "stratum"
654
+ ]
655
+ units_df = units[keep_export]
656
+ edges_df = pd.DataFrame(columns=_EDGES_COLUMNS) # no combinatorial pairings for stratification
657
+ return units_df, edges_df, summary
658
+
659
+ # -----------------------------
660
+ # API returning exportable DataFrames
661
+ # -----------------------------
662
+
663
+ def run_propensity_analysis_v2( # MODIFIED: new function; keeps legacy API intact
664
+ data: Union[pd.DataFrame, str],
665
+ treatment_col: str,
666
+ feature_cols: List[str],
667
+ outcome_col: str = "",
668
+ matching_method: str = "nearest",
669
+ caliper: Optional[float] = None,
670
+ min_controls: int = 1,
671
+ max_controls: int = 1,
672
+ replacement: bool = True,
673
+ n_strata: int = 5,
674
+ include_balance: bool = True,
675
+ return_dataframes: bool = True,
676
+ ) -> Tuple[str, Optional[Image.Image], Optional[pd.DataFrame], Optional[pd.DataFrame]]:
677
+ """
678
+ Returns:
679
+ report (str), love_plot (PIL.Image or None), units_df (or None), edges_df (or None)
680
+ """
681
+ # Load data (same behavior as legacy)
682
+ if isinstance(data, str):
683
+ if data.lower().endswith(".csv"):
684
+ df = pd.read_csv(data)
685
+ else:
686
+ raise ValueError("Only CSV paths are supported when passing a string to `data`.")
687
+ elif isinstance(data, pd.DataFrame):
688
+ df = data.copy()
689
+ else:
690
+ raise ValueError("`data` must be a pandas DataFrame or a CSV file path.")
691
+
692
+ if treatment_col not in df.columns:
693
+ raise ValueError(f"Treatment column '{treatment_col}' not found in data.")
694
+
695
+ # Prepare covariates and PS
696
+ covariates_used = _select_features(df, feature_cols, outcome_col if outcome_col else None)
697
+ df[treatment_col] = _ensure_binary(df[treatment_col])
698
+ ps, _ = _fit_propensity(df, treatment_col, covariates_used)
699
+ df["__ps__"] = ps
700
+
701
+ method = (matching_method or "nearest").lower()
702
+ units_df: Optional[pd.DataFrame] = None
703
+ edges_df: Optional[pd.DataFrame] = None
704
+
705
+ # Build export frames by method (correct multi-match semantics)
706
+ if method == "nearest":
707
+ units_df, edges_df, _ = _build_edges_units_nearest_caliper(
708
+ df, "__ps__", treatment_col, min_controls, max_controls, replacement, caliper=None, method="nearest"
709
+ )
710
+ elif method == "caliper":
711
+ if caliper is None or caliper < 0:
712
+ raise ValueError("Caliper matching requires a non-negative `caliper`.")
713
+ units_df, edges_df, _ = _build_edges_units_nearest_caliper(
714
+ df, "__ps__", treatment_col, min_controls, max_controls, replacement, caliper=caliper, method="caliper"
715
+ )
716
+ elif method == "stratification":
717
+ units_df, edges_df, _ = _build_units_stratification(df, "__ps__", treatment_col, n_strata=n_strata)
718
+ else:
719
+ raise ValueError("matching_method must be one of {'nearest','caliper','stratification'}.")
720
+
721
+ # Produce report + plot using the legacy function to preserve visuals/diagnostics
722
+ report, love_img = run_propensity_analysis(
723
+ data=df, # already a DataFrame with __ps__
724
+ treatment_col=treatment_col,
725
+ feature_cols=feature_cols or [],
726
+ outcome_col=outcome_col or "",
727
+ matching_method=method,
728
+ caliper=(caliper if method == "caliper" else None),
729
+ min_controls=min_controls if method in ("nearest", "caliper") else 0,
730
+ max_controls=max_controls if method in ("nearest", "caliper") else 0,
731
+ replacement=replacement if method in ("nearest", "caliper") else True,
732
+ n_strata=n_strata if method == "stratification" else 5,
733
+ include_balance=include_balance,
734
+ )
735
+
736
+ return report, love_img, (units_df if return_dataframes else None), (edges_df if return_dataframes else None)