File size: 18,333 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""Diagnostic functions for Difference-in-Differences method."""

import pandas as pd
import numpy as np
from typing import Dict, Any, Optional, List
import logging
import statsmodels.formula.api as smf # Import statsmodels
from patsy import PatsyError # To catch formula errors

# Import helper function from estimator -> Change to utils
from .utils import create_post_indicator

logger = logging.getLogger(__name__)

def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str, 
                             group_indicator_col: str, treatment_period_start: Any, 
                             dataset_description: Optional[str] = None,
                             time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]:
    """Validates the parallel trends assumption using pre-treatment data.

    Regresses the outcome on group-specific time trends before the treatment period.
    Tests if the interaction terms between group and pre-treatment time periods are jointly significant.
    
    Args:
        df: DataFrame containing the data.
        time_var: Name of the time variable column.
        outcome: Name of the outcome variable column.
        group_indicator_col: Name of the binary treatment group indicator column (0/1).
        treatment_period_start: The time period value when treatment starts.
        dataset_description: Optional dictionary for additional dataset description.
        time_varying_covariates: Optional list of time-varying covariates to include.
        
    Returns:
        Dictionary with validation results.
    """
    logger.info("Validating parallel trends...")
    validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
    
    try:
        # Filter pre-treatment data
        pre_df = df[df[time_var] < treatment_period_start].copy()
        
        if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2:
            validation_result["details"] = "Insufficient pre-treatment data or variation to perform test."
            logger.warning(validation_result["details"])
            # Assume valid if cannot test? Or invalid? Let's default to True if we can't test
            validation_result["valid"] = True
            validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)."
            return validation_result
        
        # Check if group indicator is binary
        if pre_df[group_indicator_col].nunique() > 2:
            validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment."
            logger.warning(validation_result["details"])
            # Use visual assessment method instead (check if trends look roughly parallel)
            validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
            # Ensure p_value is set
            if validation_result["p_value"] is None:
                validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
            return validation_result

        # Use a robust approach first - test for pre-trend differences using a simpler model
        try:
            # Create a linear time trend
            pre_df['time_trend'] = pre_df[time_var].astype(float)
            
            # Create interaction between trend and group
            pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float)
            
            # Simple regression with linear trend interaction
            simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend"
            simple_model = smf.ols(simple_formula, data=pre_df)
            simple_results = simple_model.fit()
            
            # Check if trend interaction coefficient is significant
            group_trend_pvalue = simple_results.pvalues['group_trend']
            
            # If p > 0.05, trends are not significantly different
            validation_result["valid"] = group_trend_pvalue > 0.05
            validation_result["p_value"] = group_trend_pvalue
            validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}."
            logger.info(validation_result["details"])
            
            # If we've successfully validated with the simple approach, return
            return validation_result
            
        except Exception as e:
            logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.")
            # Continue to more complex method if simple method fails
        
        # Try more complex approach with period-specific interactions
        try:
            # Create period dummies to avoid issues with categorical variables
            time_periods = sorted(pre_df[time_var].unique())
            
            # Create dummy variables for time periods (except first)
            for period in time_periods[1:]:
                period_col = f'period_{period}'
                pre_df[period_col] = (pre_df[time_var] == period).astype(int)
                
                # Create interaction with group
                pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float)
            
            # Construct formula with manual dummies
            interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')"
            
            # Add period dummies except first (reference)
            for period in time_periods[1:]:
                period_col = f'period_{period}'
                interaction_formula += f" + {period_col}"
            
            # Add interactions
            interaction_terms = []
            for period in time_periods[1:]:
                interaction_col = f'group_x_period_{period}'
                interaction_formula += f" + {interaction_col}"
                interaction_terms.append(interaction_col)
            
            # Add covariates if provided
            if time_varying_covariates:
                for cov in time_varying_covariates:
                    interaction_formula += f" + Q('{cov}')"
            
            # Fit model
            complex_model = smf.ols(interaction_formula, data=pre_df)
            complex_results = complex_model.fit()
            
            # Test joint significance of interaction terms
            if interaction_terms:
                from statsmodels.formula.api import ols
                from statsmodels.stats.anova import anova_lm
                
                # Create models with and without interactions
                formula_with = interaction_formula
                formula_without = interaction_formula
                for term in interaction_terms:
                    formula_without = formula_without.replace(f" + {term}", "")
                
                model_with = smf.ols(formula_with, data=pre_df).fit()
                model_without = smf.ols(formula_without, data=pre_df).fit()
                
                # Compare models
                try:
                    from scipy import stats
                    df_model = len(interaction_terms)
                    df_residual = model_with.df_resid
                    f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual)
                    p_value = 1 - stats.f.cdf(f_value, df_model, df_residual)
                    
                    validation_result["valid"] = p_value > 0.05
                    validation_result["p_value"] = p_value
                    validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}."
                    logger.info(validation_result["details"])
                    
                except Exception as e:
                    logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.")
                    
                    # If F-test fails, check individual coefficients
                    significant_interactions = 0
                    for term in interaction_terms:
                        if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05:
                            significant_interactions += 1
                    
                    validation_result["valid"] = significant_interactions == 0
                    # Set a dummy p-value based on proportion of significant interactions
                    if len(interaction_terms) > 0:
                        validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms))
                    else:
                        validation_result["p_value"] = 1.0  # Default to 1.0 if no interaction terms
                    validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}."
                    logger.info(validation_result["details"])
            else:
                validation_result["valid"] = True
                validation_result["p_value"] = 1.0  # Default to 1.0 if no interaction terms
                validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends."
                logger.warning(validation_result["details"])
                
        except Exception as e:
            logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.")
            tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
            # Copy over values from visual assessment ensuring p_value is set
            validation_result.update(tmp_result)
            # Ensure p_value is set
            if validation_result["p_value"] is None:
                validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
                
    except Exception as e:
        error_msg = f"Error during parallel trends validation: {e}"
        logger.error(error_msg, exc_info=True)
        validation_result["details"] = error_msg
        validation_result["error"] = str(e)
        # Default to assuming valid if test fails completely
        validation_result["valid"] = True
        validation_result["p_value"] = 1.0  # Default to 1.0 if test fails
        validation_result["details"] += " Defaulting to assuming parallel trends (test failed)."

    return validation_result

def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str, 
                          group_indicator_col: str) -> Dict[str, Any]:
    """Simple visual assessment of parallel trends by comparing group means over time.
    
    This is a fallback method when statistical tests fail.
    """
    result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
    
    try:
        # Group by time and treatment group, calculate means
        grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index()
        
        # Pivot to get time series for each group
        if df[group_indicator_col].nunique() <= 10:  # Only if reasonable number of groups
            pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome)
            
            # Calculate slopes between consecutive periods for each group
            slopes = {}
            time_values = sorted(df[time_var].unique())
            
            if len(time_values) >= 3:  # Need at least 3 periods to compare slopes
                for group in pivot.columns:
                    group_slopes = []
                    for i in range(len(time_values) - 1):
                        t1, t2 = time_values[i], time_values[i+1]
                        if t1 in pivot.index and t2 in pivot.index:
                            slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1)
                            group_slopes.append(slope)
                    if group_slopes:
                        slopes[group] = group_slopes
                
                # Compare slopes between groups
                if len(slopes) >= 2:
                    slope_diffs = []
                    groups = list(slopes.keys())
                    for i in range(len(slopes[groups[0]])):
                        if i < len(slopes[groups[1]]):
                            slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i]))
                    
                    # If average slope difference is small relative to outcome scale
                    outcome_scale = df[outcome].std()
                    avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0
                    relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0
                    
                    result["valid"] = relative_diff < 0.2  # Threshold for "parallel enough"
                    # Set p-value based on relative difference
                    result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04
                    result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}."
                else:
                    result["valid"] = True
                    result["p_value"] = 1.0
                    result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends."
            else:
                result["valid"] = True
                result["p_value"] = 1.0
                result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends."
        else:
            result["valid"] = True
            result["p_value"] = 1.0
            result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends."
    
    except Exception as e:
        result["error"] = str(e)
        result["valid"] = True
        result["p_value"] = 1.0
        result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends."
        
    logger.info(result["details"])
    return result

def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str, 
                       treated_unit_indicator: str, covariates: List[str], 
                       treatment_period_start: Any, 
                       placebo_period_start: Any) -> Dict[str, Any]:
    """Runs a placebo test for DiD by assigning a fake earlier treatment period.

    Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant.
    
    Args:
        df: Original DataFrame.
        time_var: Name of the time variable column.
        group_var: Name of the unit/group ID column (for clustering SE).
        outcome: Name of the outcome variable column.
        treated_unit_indicator: Name of the binary treatment group indicator column (0/1).
        covariates: List of covariate names.
        treatment_period_start: The actual treatment start period.
        placebo_period_start: The fake treatment start period (must be before actual start).
        
    Returns:
        Dictionary with placebo test results.
    """
    logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...")
    placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None}

    if placebo_period_start >= treatment_period_start:
        error_msg = "Placebo period must be before the actual treatment period."
        logger.error(error_msg)
        placebo_result["error"] = error_msg
        placebo_result["details"] = error_msg
        return placebo_result
        
    try:
        df_placebo = df.copy()
        # Create placebo post and interaction terms
        post_placebo_col = 'post_placebo'
        interaction_placebo_col = 'did_interaction_placebo'
        
        df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start)
        df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col]
        
        # Construct formula for placebo regression
        formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`"
        if covariates:
             formula += f" + {' + '.join([f'`{c}`' for c in covariates])}"
        formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs
        
        logger.debug(f"Placebo test formula: {formula}")

        # Fit the placebo model with clustered SE
        ols_model = smf.ols(formula=formula, data=df_placebo)
        results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]})
        
        # Check the significance of the placebo interaction term
        placebo_effect = float(results.params[interaction_placebo_col])
        placebo_p_value = float(results.pvalues[interaction_placebo_col])
        
        # Test passes if the placebo effect is not statistically significant (e.g., p > 0.1)
        passed_test = placebo_p_value > 0.10
        
        placebo_result["passed"] = passed_test
        placebo_result["effect_estimate"] = placebo_effect
        placebo_result["p_value"] = placebo_p_value
        placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}."
        logger.info(placebo_result["details"])

    except (KeyError, PatsyError, ValueError, Exception) as e:
        error_msg = f"Error during placebo test execution: {e}"
        logger.error(error_msg, exc_info=True)
        placebo_result["details"] = error_msg
        placebo_result["error"] = str(e)

    return placebo_result

# TODO: Add function for Event Study plot (plot_event_study)
# This would involve estimating effects for leads and lags around the treatment period.

# Add other diagnostic functions as needed (e.g., plot_event_study)