causalscience commited on
Commit
3551cf7
·
verified ·
1 Parent(s): 46b4d92

Aug 25 Bug Fixes

Browse files
Files changed (1) hide show
  1. models/its.py +534 -228
models/its.py CHANGED
@@ -1,228 +1,534 @@
1
- import numpy as np
2
- import pandas as pd
3
- from io import BytesIO
4
- from PIL import Image
5
- import causalpy as cp
6
- import matplotlib.pyplot as plt
7
- from sklearn.linear_model import LinearRegression
8
- import statsmodels.api as sm # For sm.add_constant
9
- from scipy import stats
10
-
11
-
12
- def enhanced_its_analysis(file, target_col, date_col, pre_dates, post_dates, freq_input="D", control_vars=""):
13
- """
14
- Performs interrupted time series analysis using causalpy on user-provided frequency.
15
- Uses DatetimeIndex for CausalPy compatibility with Timestamp treatment_time.
16
- Ensures predicted counterfactual is 1D for statistics.
17
-
18
- Parameters:
19
- -----------
20
- file : file object
21
- The uploaded CSV file
22
- target_col : str
23
- The column name for the outcome variable
24
- date_col : str
25
- The column name for the date variable
26
- pre_dates : str
27
- Comma-separated start and end dates for pre-intervention period
28
- post_dates : str
29
- Comma-separated start and end dates for post-intervention period
30
- freq_input : str, default="D"
31
- Pandas frequency alias for the time series
32
- control_vars : str, default=""
33
- Comma-separated list of control variable column names
34
- """
35
- if file is None:
36
- return "Error: No file uploaded.", None
37
- try:
38
- df = pd.read_csv(file.name)
39
- if date_col not in df.columns:
40
- return f"Error: Date column '{date_col}' not found.", None
41
-
42
- df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
43
- df = df.dropna(subset=[date_col])
44
- if df.empty:
45
- return f"Error: No valid dates found in column '{date_col}' after parsing.", None
46
-
47
- if target_col not in df.columns:
48
- return f"Error: Column '{target_col}' not found.", None
49
-
50
- if not pd.api.types.is_numeric_dtype(df[target_col]):
51
- df[target_col] = pd.to_numeric(df[target_col], errors='coerce')
52
- df = df.dropna(subset=[target_col])
53
- if df.empty:
54
- return f"Error: No valid data in target column '{target_col}' after NA removal.", None
55
-
56
- # Process control variables
57
- control_columns = []
58
- if control_vars and control_vars.strip():
59
- control_columns = [col.strip() for col in control_vars.split(',')]
60
- # Validate control variables exist in the dataframe
61
- missing_cols = [col for col in control_columns if col not in df.columns]
62
- if missing_cols:
63
- return f"Error: Control variable column(s) not found: {', '.join(missing_cols)}", None
64
-
65
- # Convert control variables to numeric if needed
66
- for col in control_columns:
67
- if not pd.api.types.is_numeric_dtype(df[col]):
68
- df[col] = pd.to_numeric(df[col], errors='coerce')
69
- # Check if conversion resulted in all NaN values
70
- if df[col].isna().all():
71
- return f"Error: Control variable '{col}' could not be converted to numeric values.", None
72
-
73
- pre_list = [d.strip() for d in pre_dates.split(',')]
74
- post_list = [d.strip() for d in post_dates.split(',')]
75
-
76
- start_pre, end_pre = pd.to_datetime(pre_list[0], errors='coerce'), pd.to_datetime(pre_list[1], errors='coerce')
77
- start_post, end_post = pd.to_datetime(post_list[0], errors='coerce'), pd.to_datetime(post_list[1], errors='coerce')
78
-
79
- if pd.NaT in [start_pre, end_pre, start_post, end_post]:
80
- return "Error: One or more pre/post period boundary dates are invalid. Use YYYY-MM-DD.", None
81
-
82
- if start_post <= end_pre: # Basic sanity check for period ordering
83
- return f"Error: Post-intervention start date ({start_post.date()}) must be after pre-intervention end date ({end_pre.date()}).", None
84
-
85
- mask = (df[date_col] >= start_pre) & (df[date_col] <= end_post)
86
- analysis_df_filtered = df.loc[mask].copy()
87
-
88
- if analysis_df_filtered.empty:
89
- return "Error: No data in the specified overall date range (pre-start to post-end).", None
90
-
91
- analysis_df_for_cp = analysis_df_filtered.sort_values(date_col)
92
- analysis_df_for_cp = analysis_df_for_cp.set_index(date_col, drop=True)
93
-
94
- analysis_df_for_cp['time_index'] = range(len(analysis_df_for_cp))
95
- analysis_df_for_cp = analysis_df_for_cp.rename(columns={target_col: 'y'})
96
-
97
- pre_df = analysis_df_for_cp.loc[start_pre:end_pre]
98
- post_df = analysis_df_for_cp.loc[start_post:end_post]
99
-
100
- if pre_df.empty:
101
- return f"Error: Pre-intervention period ({start_pre.date()} to {end_pre.date()}) contains no data after filtering.", None
102
- if post_df.empty:
103
- return f"Error: Post-intervention period ({start_post.date()} to {end_post.date()}) contains no data after filtering.", None
104
-
105
- # Build formula with control variables if provided
106
- formula = 'y ~ 1 + time_index'
107
- if control_columns:
108
- formula += ' + ' + ' + '.join(control_columns)
109
-
110
- # Check for missing values in control variables for the analysis period
111
- if control_columns:
112
- for col in control_columns:
113
- if analysis_df_for_cp[col].isna().any():
114
- return f"Error: Control variable '{col}' contains missing values in the analysis period.", None
115
-
116
- its_model = cp.InterruptedTimeSeries(
117
- data=analysis_df_for_cp,
118
- formula=formula, # Now includes control variables if specified
119
- treatment_time=start_post,
120
- model=LinearRegression(),
121
- freq=freq_input
122
- )
123
-
124
- # For prediction, we need to prepare X_post with control variables
125
- if not control_columns:
126
- X_post = sm.add_constant(post_df[['time_index']])
127
- else:
128
- # Ensure control columns are present in post_df (they should be as post_df is a slice of analysis_df_for_cp)
129
- missing_in_post_df = [col for col in control_columns if col not in post_df.columns]
130
- if missing_in_post_df: # Should ideally not happen if logic is correct
131
- return f"Error: Control variable(s) {', '.join(missing_in_post_df)} missing from post-intervention data slice.", None
132
- X_post = sm.add_constant(post_df[['time_index'] + control_columns])
133
-
134
- pred_cf_array = its_model.model.predict(X_post)
135
-
136
- # --- FIX: Ensure pred_cf_array is 1D ---
137
- if pred_cf_array.ndim > 1:
138
- pred_cf_array = pred_cf_array.squeeze()
139
- # --- End of FIX ---
140
-
141
- pred_cf = pd.Series(pred_cf_array, index=post_df.index, name='y_fc')
142
- observed_post = post_df['y']
143
-
144
- post_mean = observed_post.mean()
145
- cf_mean = pred_cf.mean()
146
- effect = post_mean - cf_mean
147
-
148
- if len(observed_post) < 2 or len(pred_cf) < 2:
149
- post_se, cf_se, eff_se, t_stat, p_value, ci_low, ci_high = [np.nan] * 7
150
- df_t = 0
151
- else:
152
- post_se = observed_post.std(ddof=1) / np.sqrt(len(observed_post))
153
- cf_se = pred_cf.std(ddof=1) / np.sqrt(len(pred_cf))
154
- eff_se = np.sqrt(post_se**2 + cf_se**2) if (not np.isnan(post_se) and not np.isnan(cf_se)) else np.nan
155
- df_t = min(len(post_df) - 1, len(pred_cf) -1)
156
- if df_t < 1 : df_t=1
157
-
158
- if np.isnan(eff_se) or eff_se == 0:
159
- t_stat, p_value, ci_low, ci_high = [np.nan] * 4
160
- else:
161
- t_stat = effect / eff_se
162
- p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=df_t))
163
- ci_low, ci_high = stats.t.interval(0.95, df_t, loc=effect, scale=eff_se)
164
-
165
- # Enhanced report to include control variables
166
- report_lines = [
167
- f"ITS Analysis for: 'y' (originally '{target_col}')",
168
- f"Intervention at: {start_post.strftime('%Y-%m-%d')}",
169
- f"Pre-period: {start_pre.strftime('%Y-%m-%d')} to {end_pre.strftime('%Y-%m-%d')}",
170
- f"Post-period: {start_post.strftime('%Y-%m-%d')} to {end_post.strftime('%Y-%m-%d')}"
171
- ]
172
-
173
- if control_columns:
174
- report_lines.append(f"Control Variables: {', '.join(control_columns)}")
175
-
176
- report_lines.extend([
177
- "--- Average Effect Estimation ---",
178
- f"Observed post-intervention mean: {post_mean:.3f}",
179
- f"Estimated counterfactual mean: {cf_mean:.3f}",
180
- f"Estimated average effect: {effect:.3f}"
181
- ])
182
-
183
- if not np.isnan(p_value):
184
- report_lines.append(f" 95% CI: [{ci_low:.3f}, {ci_high:.3f}]")
185
- report_lines.append(f" t-statistic: {t_stat:.3f}, p-value: {p_value:.4f} (df={df_t})")
186
- else:
187
- report_lines.append(" (CI/p-value not computed due to insufficient data or variability)")
188
- report = "\n".join(report_lines)
189
-
190
- fig, ax = its_model.plot(plot_predict_all=False, plot_show_params=True)
191
- buf = BytesIO()
192
- fig.savefig(buf, format='png', bbox_inches='tight')
193
- plt.close(fig)
194
- buf.seek(0)
195
- img = Image.open(buf)
196
-
197
- return report, img
198
-
199
- except Exception as e:
200
- # import traceback # For debugging
201
- # print("--- TRACEBACK ---")
202
- # traceback.print_exc()
203
- # print("--- END TRACEBACK ---")
204
- return f"An unexpected error occurred: {str(e)}", None
205
-
206
-
207
- def run_its_analysis(file, target_col, date_col, pre_dates, post_dates, freq_input, control_vars=""):
208
- """
209
- Wrapper function for the enhanced_its_analysis function.
210
-
211
- Parameters:
212
- -----------
213
- file : file object
214
- The uploaded CSV file
215
- target_col : str
216
- The column name for the outcome variable
217
- date_col : str
218
- The column name for the date variable
219
- pre_dates : str
220
- Comma-separated start and end dates for pre-intervention period
221
- post_dates : str
222
- Comma-separated start and end dates for post-intervention period
223
- freq_input : str
224
- Pandas frequency alias for the time series
225
- control_vars : str, default=""
226
- Comma-separated list of control variable column names
227
- """
228
- return enhanced_its_analysis(file, target_col, date_col, pre_dates, post_dates, freq_input, control_vars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from PIL import Image
7
+
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.gridspec as gridspec
12
+
13
+ import causalpy as cp
14
+ import patsy
15
+ import statsmodels.api as sm
16
+ from scipy import stats
17
+ from sklearn.linear_model import LinearRegression
18
+ from statsmodels.stats.diagnostic import acorr_ljungbox
19
+ from statsmodels.tsa.stattools import acf
20
+ import statsmodels.stats.stattools as smt
21
+ import traceback
22
+
23
+ # ==================== Global knobs ====================
24
+ SEASONALITY_ENABLED = True
25
+ SEASONALITY_METHOD = "fourier" # "fourier" or "dummies"
26
+ FOURIER_WEEKLY_K = 3 # number of sin/cos pairs for short cycle
27
+ FOURIER_YEARLY_K = 5 # number of sin/cos pairs for long cycle
28
+
29
+
30
+ HAC_ENABLED: bool = True
31
+ HAC_MAXLAGS: Union[str, int] = "auto" # "auto" = plug-in; or set an int (e.g., 8 for ~two months on weekly data)
32
+ HAC_SMALL_SAMPLE_CORR: bool = True # use finite-sample correction in statsmodels
33
+
34
+
35
+ # -------------------- rendering helpers --------------------
36
+
37
+ def _fig_to_pil(fig: plt.Figure, dpi: int = 110) -> Image.Image:
38
+ """Save a figure to a PIL image with outer padding and opaque background."""
39
+ buf = BytesIO()
40
+ fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0.40, facecolor="white")
41
+ plt.close(fig)
42
+ buf.seek(0)
43
+ return Image.open(buf).convert("RGB")
44
+
45
+
46
+ def _stack_images_vertical(images: List[Image.Image], pad: int = 22, bg=(255, 255, 255)) -> Optional[Image.Image]:
47
+ if not images:
48
+ return None
49
+ max_w = max(im.width for im in images)
50
+ total_h = sum(im.height for im in images) + pad * (len(images) - 1)
51
+ out = Image.new("RGB", (max_w, total_h), bg)
52
+ y = 0
53
+ for im in images:
54
+ x = (max_w - im.width) // 2
55
+ out.paste(im, (x, y))
56
+ y += im.height + pad
57
+ return out
58
+
59
+
60
+ def _rotate_all_xticklabels(fig: plt.Figure) -> None:
61
+ """Rotate x-tick labels on every axes in a figure."""
62
+ for ax in fig.axes:
63
+ for lbl in ax.get_xticklabels():
64
+ lbl.set_rotation(45)
65
+ lbl.set_ha("right")
66
+ try:
67
+ fig.autofmt_xdate()
68
+ except Exception:
69
+ pass
70
+
71
+
72
+ # -------------------- frequency-aware seasonality (opt-in originally; default ON now) --------------------
73
+
74
+ def _add_frequency_aware_seasonality(df: pd.DataFrame, freq_input: str) -> Tuple[pd.DataFrame, List[str]]:
75
+ """
76
+ Return (df_with_terms, season_terms_list) based on freq_input.
77
+ df index must be DatetimeIndex and df must contain 'time_index'.
78
+ No side effects outside df copy. Purely pre-period features, no leakage.
79
+ """
80
+ df = df.copy()
81
+ added: List[str] = []
82
+
83
+ # Helpers to add Fourier pairs
84
+ def add_fourier(prefix: str, period: Optional[float], K: int) -> None:
85
+ if K <= 0 or period is None:
86
+ return
87
+ for k in range(1, K + 1):
88
+ s = f"{prefix}_sin_{k}"
89
+ c = f"{prefix}_cos_{k}"
90
+ df[s] = np.sin(2 * np.pi * k * df["time_index"] / period)
91
+ df[c] = np.cos(2 * np.pi * k * df["time_index"] / period)
92
+ added.extend([s, c])
93
+
94
+ f = (freq_input or "").upper()
95
+
96
+ if SEASONALITY_METHOD == "dummies":
97
+ if f == "M":
98
+ df["month"] = df.index.month
99
+ added.append("C(month)")
100
+ elif f == "Q":
101
+ df["quarter"] = df.index.quarter
102
+ added.append("C(quarter)")
103
+ elif f == "D":
104
+ df["dow"] = df.index.dayofweek
105
+ df["month"] = df.index.month
106
+ added.extend(["C(dow)", "C(month)"])
107
+ elif f == "W":
108
+ df["weekofyear"] = df.index.isocalendar().week.astype(int)
109
+ added.append("C(weekofyear)")
110
+ else:
111
+ # Unknown/other: no-op
112
+ pass
113
+
114
+ else: # "fourier" (smooth & compact)
115
+ if f == "D":
116
+ add_fourier("wk", period=7.0, K=FOURIER_WEEKLY_K) # weekly cycle
117
+ add_fourier("yr", period=365.25, K=FOURIER_YEARLY_K) # yearly cycle
118
+ elif f == "W":
119
+ add_fourier("yr", period=52.1775, K=FOURIER_YEARLY_K) # annual on weekly cadence
120
+ elif f == "M":
121
+ add_fourier("yr", period=12.0, K=FOURIER_YEARLY_K) # annual on monthly cadence
122
+ elif f == "Q":
123
+ add_fourier("yr", period=4.0, K=max(1, min(FOURIER_YEARLY_K, 2))) # annual on quarterly cadence
124
+ else:
125
+ # Fallback: do nothing if alias not recognized
126
+ pass
127
+
128
+ return df, added
129
+
130
+
131
+ # -------------------- HAC utilities --------------------
132
+
133
+ def _nw_auto_maxlags(n: int) -> int:
134
+ """
135
+ Newey–West plug-in bandwidth: floor(4 * (n/100)^(2/9)), at least 1.
136
+ """
137
+ if n <= 1:
138
+ return 1
139
+ return max(1, int(np.floor(4.0 * (n / 100.0) ** (2.0 / 9.0))))
140
+
141
+
142
+ # Bartlett-weighted (Newey–West) SE for the mean of a time series
143
+ def _nw_se_of_mean(series: pd.Series, maxlags: int) -> float:
144
+ """
145
+ Compute Newey–West standard error of the sample mean with Bartlett weights.
146
+ Var(mean) ≈ (1/n) * [γ0 + 2 * sum_{k=1..L} w_k * γ_k], w_k = 1 - k/(L+1).
147
+ γ_k are sample autocovariances at lag k with divisor n (not n-1).
148
+ """
149
+ x = np.asarray(series, dtype=float)
150
+ n = x.shape[0]
151
+ if n <= 1:
152
+ return np.nan
153
+ x = x - x.mean()
154
+ # autocovariances γ_k
155
+ gamma0 = np.dot(x, x) / n
156
+ lrvar = gamma0
157
+ L = min(maxlags, n - 1) if n > 1 else 0
158
+ for k in range(1, L + 1):
159
+ w = 1.0 - k / (L + 1.0)
160
+ cov = np.dot(x[k:], x[:-k]) / n
161
+ lrvar += 2.0 * w * cov
162
+ var_mean = lrvar / n
163
+ return float(np.sqrt(var_mean))
164
+
165
+
166
+ # -------------------- diagnostics & comparisons --------------------
167
+
168
+ def add_diagnostic_tests(sm_model, pre_data, formula, report_lines):
169
+ """
170
+ Build diagnostic figures and append textual tests to report_lines.
171
+ Returns a dict of {name: Figure}.
172
+ """
173
+ diagnostic_plots = {}
174
+
175
+ report_lines.append("\n--- Diagnostic Tests ---")
176
+ try:
177
+ residuals = sm_model.resid
178
+
179
+ # Durbin–Watson
180
+ dw = smt.durbin_watson(residuals)
181
+ report_lines.append(f"Durbin-Watson statistic: {dw:.3f}")
182
+ if dw < 1.5:
183
+ report_lines.append(" ⚠️ Positive autocorrelation detected (DW < 1.5)")
184
+ elif dw > 2.5:
185
+ report_lines.append(" ⚠️ Negative autocorrelation detected (DW > 2.5)")
186
+ else:
187
+ report_lines.append(" No significant autocorrelation (1.5 < DW < 2.5)")
188
+
189
+ # Ljung–Box
190
+ if len(residuals) > 10:
191
+ lb = acorr_ljungbox(residuals, lags=min(10, len(residuals)//4), return_df=True)
192
+ sig = lb[lb["lb_pvalue"] < 0.05]
193
+ if len(sig) > 0:
194
+ report_lines.append(f" Ljung-Box: Autocorrelation at lags {list(sig.index)}")
195
+ else:
196
+ report_lines.append(" Ljung-Box: No significant autocorrelation up to lag 10")
197
+
198
+ # ACF
199
+ if len(residuals) > 20:
200
+ fig_acf, ax = plt.subplots(figsize=(11, 6))
201
+ acf_vals = acf(residuals, nlags=min(20, len(residuals)//4))
202
+ ax.bar(range(len(acf_vals)), acf_vals, alpha=0.85)
203
+ ax.axhline(0, linewidth=0.5)
204
+ ci = 1.96/np.sqrt(len(residuals))
205
+ ax.axhline(ci, linestyle="--", alpha=0.7)
206
+ ax.axhline(-ci, linestyle="--", alpha=0.7)
207
+ ax.set_title("Autocorrelation Function (ACF) of Residuals")
208
+ ax.set_xlabel("Lag"); ax.set_ylabel("Autocorrelation")
209
+ ax.grid(True, alpha=0.3)
210
+ fig_acf.tight_layout(pad=1.2)
211
+ diagnostic_plots["acf"] = fig_acf
212
+
213
+ except Exception as e:
214
+ report_lines.append(f" Could not perform autocorrelation test: {e}")
215
+
216
+ # Model fit stats
217
+ report_lines.append("\n--- Model Fit Statistics ---")
218
+ report_lines.append(f"R-squared: {sm_model.rsquared:.3f}")
219
+ report_lines.append(f"Adjusted R-squared: {sm_model.rsquared_adj:.3f}")
220
+ report_lines.append(f"AIC: {sm_model.aic:.2f}")
221
+ report_lines.append(f"BIC: {sm_model.bic:.2f}")
222
+
223
+ # Residuals figure (6 panels) with generous spacing
224
+ try:
225
+ fig_resid = plt.figure(figsize=(13.5, 10.5), constrained_layout=False) # MODIFIED: bigger
226
+ gs = gridspec.GridSpec(3, 2, figure=fig_resid, hspace=0.85, wspace=0.55) # MODIFIED: more space
227
+
228
+ # Residuals vs Fitted
229
+ ax1 = fig_resid.add_subplot(gs[0, 0])
230
+ ax1.scatter(sm_model.fittedvalues, sm_model.resid, alpha=0.65)
231
+ ax1.axhline(0, linestyle="--", alpha=0.7)
232
+ ax1.set_title("Residuals vs Fitted Values"); ax1.set_xlabel("Fitted Values"); ax1.set_ylabel("Residuals")
233
+ ax1.grid(True, alpha=0.3)
234
+
235
+ # Normal Q–Q
236
+ ax2 = fig_resid.add_subplot(gs[0, 1])
237
+ stats.probplot(sm_model.resid, dist="norm", plot=ax2)
238
+ ax2.set_title("Normal Q-Q Plot"); ax2.grid(True, alpha=0.3)
239
+
240
+ # Histogram
241
+ ax3 = fig_resid.add_subplot(gs[1, 0])
242
+ ax3.hist(sm_model.resid, bins=20, edgecolor="black", alpha=0.75)
243
+ ax3.set_title("Histogram of Residuals"); ax3.set_xlabel("Residuals"); ax3.set_ylabel("Density")
244
+ ax3.grid(True, alpha=0.3)
245
+
246
+ # Residuals over time
247
+ ax4 = fig_resid.add_subplot(gs[1, 1])
248
+ ax4.plot(pre_data.index, sm_model.resid, marker="o", alpha=0.7)
249
+ ax4.axhline(0, linestyle="--", alpha=0.7)
250
+ ax4.set_title("Residuals Over Time"); ax4.set_xlabel("Date"); ax4.set_ylabel("Residuals")
251
+ for lbl in ax4.get_xticklabels():
252
+ lbl.set_rotation(45); lbl.set_ha("right")
253
+ ax4.grid(True, alpha=0.3)
254
+
255
+ # Scale–Location
256
+ ax5 = fig_resid.add_subplot(gs[2, 0])
257
+ std_resid = sm_model.resid / sm_model.resid.std()
258
+ ax5.scatter(sm_model.fittedvalues, np.sqrt(np.abs(std_resid)), alpha=0.65)
259
+ ax5.set_title("Scale-Location Plot"); ax5.set_xlabel("Fitted Values"); ax5.set_ylabel("√|Standardized Residuals|")
260
+ ax5.grid(True, alpha=0.3)
261
+
262
+ # Influence (Cook’s Distance)
263
+ ax6 = fig_resid.add_subplot(gs[2, 1])
264
+ try:
265
+ from statsmodels.stats.outliers_influence import OLSInfluence
266
+ infl = OLSInfluence(sm_model)
267
+ ax6.scatter(range(len(infl.cooks_distance[0])), infl.cooks_distance[0], alpha=0.65)
268
+ ax6.axhline(4/len(sm_model.resid), linestyle="--", alpha=0.7, label="4/n threshold")
269
+ ax6.legend()
270
+ except Exception:
271
+ ax6.text(0.5, 0.5, "Influence plot unavailable", ha="center", va="center")
272
+ ax6.set_title("Cook's Distance (Influence Plot)"); ax6.set_xlabel("Observation Index"); ax6.set_ylabel("Cook's Distance")
273
+ ax6.grid(True, alpha=0.3)
274
+
275
+ fig_resid.subplots_adjust(top=0.92, bottom=0.20, left=0.10, right=0.98, hspace=0.85, wspace=0.55) # MODIFIED
276
+ diagnostic_plots["residuals"] = fig_resid
277
+
278
+ except Exception as e:
279
+ report_lines.append(f" Could not create residual diagnostic plots: {e}")
280
+
281
+ return diagnostic_plots
282
+
283
+
284
+ def compare_model_specifications(pre_data: pd.DataFrame, formula_base: str,
285
+ control_columns: List[str], report_lines: List[str]):
286
+ """
287
+ Fit linear vs polynomial pre-period trends; return dict possibly
288
+ containing 'comparison_plot' Figure. Append summary lines to report.
289
+ """
290
+ report_lines.append("\n--- Model Specification Comparison ---")
291
+ try:
292
+ formula_linear = formula_base
293
+ formula_quad = 'y ~ 1 + time_index + I(time_index**2)' + ('' if not control_columns else ' + ' + ' + '.join(control_columns))
294
+ formula_cubic = 'y ~ 1 + time_index + I(time_index**2) + I(time_index**3)' + ('' if not control_columns else ' + ' + ' + '.join(control_columns))
295
+
296
+ models, formulas = {}, {'Linear': formula_linear, 'Quadratic': formula_quad, 'Cubic': formula_cubic}
297
+ for name, fml in formulas.items():
298
+ try:
299
+ y, X = patsy.dmatrices(fml, data=pre_data, return_type='dataframe')
300
+ models[name] = sm.OLS(y, X).fit()
301
+ report_lines.append(f"{name}: R² {models[name].rsquared:.3f}, AIC {models[name].aic:.1f}, BIC {models[name].bic:.1f}")
302
+ except Exception:
303
+ report_lines.append(f"{name}: could not fit")
304
+
305
+ out = {}
306
+ if 'Linear' in models:
307
+ fig, ax = plt.subplots(figsize=(11, 6))
308
+ ax.scatter(pre_data.index, pre_data['y'], alpha=0.6, label='Actual', color='black')
309
+ colors = {'Linear': 'tab:blue', 'Quadratic': 'tab:red', 'Cubic': 'tab:green'}
310
+ for name, mdl in models.items():
311
+ if hasattr(mdl, 'fittedvalues'):
312
+ ax.plot(pre_data.index, mdl.fittedvalues, label=f'{name} fit', color=colors.get(name, None), linewidth=2, alpha=0.85)
313
+ ax.set_title('Model Specification Comparison (Pre Period)')
314
+ ax.set_xlabel('Date'); ax.set_ylabel('Outcome'); ax.grid(True, alpha=0.3); ax.legend()
315
+ fig.tight_layout(pad=1.2)
316
+ out['comparison_plot'] = fig
317
+
318
+ if len(models) >= 2:
319
+ best = min(models.items(), key=lambda kv: kv[1].bic if hasattr(kv[1], 'bic') else np.inf)
320
+ report_lines.append(f"Recommended (by BIC): {best[0]}")
321
+
322
+ return out
323
+
324
+ except Exception as e:
325
+ report_lines.append(f" Could not compare model specifications: {e}")
326
+ return {}
327
+
328
+
329
+ # -------------------- analysis --------------------
330
+
331
+ def enhanced_its_analysis(file, target_col, date_col, pre_dates, post_dates, freq_input="D",
332
+ control_vars="", run_diagnostics=True, show_formulas=False):
333
+ """
334
+ ITS analysis using CausalPy with diagnostics and optional model-spec comparison.
335
+ Returns (report_text, stacked_image).
336
+ """
337
+ if file is None:
338
+ return "Error: No file uploaded.", None
339
+
340
+ try:
341
+ # Load & validate
342
+ df = pd.read_csv(file.name)
343
+ if date_col not in df.columns:
344
+ return f"Error: Date column '{date_col}' not found.", None
345
+ if target_col not in df.columns:
346
+ return f"Error: Column '{target_col}' not found.", None
347
+
348
+ df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
349
+ df = df.dropna(subset=[date_col]).sort_values(date_col).set_index(date_col)
350
+ df[target_col] = pd.to_numeric(df[target_col], errors="coerce")
351
+ df = df.dropna(subset=[target_col]).rename(columns={target_col: "y"})
352
+ df["time_index"] = np.arange(len(df), dtype=int)
353
+
354
+ # Periods
355
+ try:
356
+ pre_s, pre_e = [pd.to_datetime(s.strip(), errors="raise") for s in pre_dates.split(",")]
357
+ post_s, post_e = [pd.to_datetime(s.strip(), errors="raise") for s in post_dates.split(",")]
358
+ except Exception:
359
+ return "Error: Use 'YYYY-MM-DD,YYYY-MM-DD' for pre/post.", None
360
+ if not (pre_s <= pre_e < post_s <= post_e):
361
+ return "Error: Must satisfy pre_start <= pre_end < post_start <= post_end.", None
362
+
363
+ df = df.loc[(df.index >= pre_s) & (df.index <= post_e)].copy()
364
+ if df.empty:
365
+ return "Error: No data in the specified overall date range.", None
366
+
367
+ # Controls
368
+ control_columns: List[str] = []
369
+ if control_vars and control_vars.strip():
370
+ control_columns = [c.strip() for c in control_vars.split(",") if c.strip()]
371
+ missing = [c for c in control_columns if c not in df.columns]
372
+ if missing:
373
+ return f"Error: Control variable(s) not found: {', '.join(missing)}", None
374
+ for c in control_columns:
375
+ df[c] = pd.to_numeric(df[c], errors="coerce")
376
+ df = df.dropna(subset=control_columns)
377
+ if df.empty:
378
+ return "Error: Data empty after removing NA rows for controls.", None
379
+
380
+ season_terms: List[str] = []
381
+ if SEASONALITY_ENABLED:
382
+ df, season_terms = _add_frequency_aware_seasonality(df, freq_input)
383
+
384
+ # Formula (base + seasonal + controls)
385
+ formula = "y ~ 1 + time_index"
386
+ if season_terms:
387
+ season_rhs = " + ".join(season_terms) # 'C(...)' tokens or column names
388
+ formula += " + " + season_rhs
389
+ if control_columns:
390
+ formula += " + " + " + ".join(control_columns)
391
+
392
+ # Fit pre-period OLS for inference components
393
+ pre_df = df.loc[df.index < post_s]
394
+ if pre_df.empty:
395
+ return "Error: Pre-intervention period is empty after filtering.", None
396
+ y_pre, X_pre = patsy.dmatrices(formula, data=pre_df, return_type="dataframe")
397
+ if X_pre.shape[0] <= X_pre.shape[1]:
398
+ return f"Error: Not enough pre-period observations ({X_pre.shape[0]}) to estimate {X_pre.shape[1]} parameters.", None
399
+
400
+ # HAC bandwidth (auto or user-provided)
401
+ if HAC_ENABLED:
402
+ if isinstance(HAC_MAXLAGS, str) and HAC_MAXLAGS.lower() == "auto":
403
+ hac_lags = _nw_auto_maxlags(len(pre_df))
404
+ else:
405
+ hac_lags = int(HAC_MAXLAGS)
406
+ sm_ols = sm.OLS(y_pre, X_pre).fit(
407
+ cov_type="HAC",
408
+ cov_kwds={"maxlags": hac_lags, "use_correction": HAC_SMALL_SAMPLE_CORR}
409
+ )
410
+ inference_note = f"HAC (Newey–West), maxlags={hac_lags}"
411
+ else:
412
+ sm_ols = sm.OLS(y_pre, X_pre).fit()
413
+ inference_note = "OLS (iid errors assumption)"
414
+
415
+ # Post design for counterfactual mean + inference
416
+ post_df = df.loc[(df.index >= post_s) & (df.index <= post_e)]
417
+ X_post = patsy.dmatrix(formula.split("~", 1)[1], data=post_df, return_type="dataframe")
418
+ pred_cf = sm_ols.predict(X_post)
419
+ observed_post = post_df['y']
420
+
421
+ post_mean = float(observed_post.mean())
422
+ cf_mean = float(np.asarray(pred_cf).mean())
423
+ effect = post_mean - cf_mean
424
+
425
+ if HAC_ENABLED:
426
+ post_se = _nw_se_of_mean(observed_post, maxlags=_nw_auto_maxlags(len(observed_post)) if (isinstance(HAC_MAXLAGS, str) and HAC_MAXLAGS.lower()=="auto") else int(HAC_MAXLAGS))
427
+ else:
428
+ post_se = float(observed_post.std(ddof=1) / np.sqrt(len(observed_post))) if len(observed_post) >= 2 else np.nan
429
+
430
+ # SE(counterfactual mean) via delta method using (robust) cov(beta)
431
+ cov_beta = sm_ols.cov_params() # MODIFIED: robust if HAC_ENABLED=True
432
+ X_bar = X_post.mean(axis=0).reindex(cov_beta.columns)
433
+ var_cf_mean = float(X_bar.T @ cov_beta @ X_bar)
434
+ cf_se = float(np.sqrt(var_cf_mean)) if var_cf_mean >= 0 else np.nan
435
+
436
+ # Combine SEs (independence approximation between observed and cf)
437
+ eff_se = float(np.sqrt(post_se**2 + cf_se**2)) if (np.isfinite(post_se) and np.isfinite(cf_se)) else np.nan
438
+
439
+ # test statistic & CI — normal (z) under HAC, t otherwise
440
+ if np.isfinite(eff_se) and eff_se > 0:
441
+ if HAC_ENABLED:
442
+ z_stat = effect / eff_se
443
+ p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))
444
+ ci_margin = 1.96 * eff_se
445
+ ci_low, ci_high = effect - ci_margin, effect + ci_margin
446
+ test_line = f" z-statistic: {z_stat:.3f}"
447
+ else:
448
+ df_t = max(int(sm_ols.df_resid), 1)
449
+ t_stat = effect / eff_se
450
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=df_t))
451
+ ci_low, ci_high = stats.t.interval(0.95, df_t, loc=effect, scale=eff_se)
452
+ test_line = f" t-statistic: {t_stat:.3f}"
453
+ else:
454
+ p_value = ci_low = ci_high = np.nan
455
+ test_line = " (Test statistic unavailable)"
456
+
457
+ # Build report
458
+ report_lines = [
459
+ "=" * 60,
460
+ "INTERRUPTED TIME SERIES ANALYSIS REPORT",
461
+ "=" * 60,
462
+ f"\nOutcome: {target_col}",
463
+ f"Pre: {pre_s.date()} to {pre_e.date()} | Post: {post_s.date()} to {post_e.date()}",
464
+ f"Frequency: {freq_input}",
465
+ f"Formula: {formula}",
466
+ f"Inference method: {inference_note}", # MODIFIED: document inference method
467
+ "\n" + "=" * 60,
468
+ "MAIN RESULTS",
469
+ "=" * 60,
470
+ f"Observed post-intervention mean: {post_mean:.3f}",
471
+ f"Estimated counterfactual mean: {cf_mean:.3f}",
472
+ f"**Estimated average effect: {effect:.3f}**",
473
+ ]
474
+ if np.isfinite(eff_se):
475
+ report_lines += [
476
+ "\nStatistical inference:",
477
+ f" Standard error: {eff_se:.3f}",
478
+ f" 95% Confidence interval: [{ci_low:.3f}, {ci_high:.3f}]",
479
+ test_line,
480
+ f" p-value: {p_value:.4f}",
481
+ ]
482
+ else:
483
+ report_lines.append("\n(Statistical inference unavailable due to insufficient data)")
484
+
485
+ # Main ITS (CausalPy) figure
486
+ its = cp.InterruptedTimeSeries(
487
+ data=df, formula=formula, treatment_time=post_s, model=LinearRegression(), freq=freq_input
488
+ )
489
+ result = its.plot(plot_predict_all=False, plot_show_params=True)
490
+ fig_main = result[0] if isinstance(result, tuple) else result
491
+ # MODIFIED: make sure the ITS composite is large, rotate ticks & add spacing
492
+ try:
493
+ fig_main.set_size_inches(14, 9, forward=True)
494
+ except Exception:
495
+ pass
496
+ _rotate_all_xticklabels(fig_main)
497
+ try:
498
+ fig_main.tight_layout(pad=1.3)
499
+ fig_main.subplots_adjust(top=0.92, bottom=0.20, left=0.08, right=0.98, hspace=0.42)
500
+ except Exception:
501
+ pass
502
+
503
+ images: List[Image.Image] = [_fig_to_pil(fig_main)]
504
+
505
+ # Diagnostics + comparison
506
+ if run_diagnostics:
507
+ diag_figs = add_diagnostic_tests(sm_ols, pre_df, formula, report_lines)
508
+ if "acf" in diag_figs:
509
+ images.append(_fig_to_pil(diag_figs["acf"]))
510
+ if "residuals" in diag_figs:
511
+ images.append(_fig_to_pil(diag_figs["residuals"]))
512
+ else:
513
+ report_lines.append("\n(Diagnostic plots disabled)")
514
+
515
+ if show_formulas:
516
+ cmp_figs = compare_model_specifications(pre_df, formula, control_columns, report_lines)
517
+ if "comparison_plot" in cmp_figs:
518
+ images.append(_fig_to_pil(cmp_figs["comparison_plot"]))
519
+ else:
520
+ report_lines.append("\n(Model specification comparison disabled)")
521
+
522
+ final_img = images[0] if len(images) == 1 else _stack_images_vertical(images, pad=22)
523
+
524
+ return "\n".join(report_lines), final_img
525
+
526
+ except Exception as e:
527
+ return f"An unexpected error occurred: {e}\n{traceback.format_exc()}", None
528
+
529
+
530
+ def run_its_analysis(file, target_col, date_col, pre_dates, post_dates, freq_input,
531
+ control_vars="", run_diagnostics=True, show_formulas=False):
532
+ """Public entrypoint used by the UI."""
533
+ return enhanced_its_analysis(file, target_col, date_col, pre_dates, post_dates,
534
+ freq_input, control_vars, run_diagnostics, show_formulas)