Spaces:
Running
Running
""" | |
decision tree component for selecting causal inference methods | |
this module implements the decision tree logic to select the most appropriate | |
causal inference method based on dataset characteristics and available variables | |
""" | |
import logging | |
from typing import Dict, List, Any, Optional | |
import pandas as pd | |
# define method names | |
BACKDOOR_ADJUSTMENT = "backdoor_adjustment" | |
LINEAR_REGRESSION = "linear_regression" | |
DIFF_IN_MEANS = "diff_in_means" | |
DIFF_IN_DIFF = "difference_in_differences" | |
REGRESSION_DISCONTINUITY = "regression_discontinuity_design" | |
PROPENSITY_SCORE_MATCHING = "propensity_score_matching" | |
INSTRUMENTAL_VARIABLE = "instrumental_variable" | |
CORRELATION_ANALYSIS = "correlation_analysis" | |
PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting" | |
GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score" | |
FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment" | |
logger = logging.getLogger(__name__) | |
# method assumptions mapping | |
METHOD_ASSUMPTIONS = { | |
BACKDOOR_ADJUSTMENT: [ | |
"no unmeasured confounders (conditional ignorability given covariates)", | |
"correct model specification for outcome conditional on treatment and covariates", | |
"positivity/overlap (for all covariate values, units could potentially receive either treatment level)" | |
], | |
LINEAR_REGRESSION: [ | |
"linear relationship between treatment, covariates, and outcome", | |
"no unmeasured confounders (if observational)", | |
"correct model specification", | |
"homoscedasticity of errors", | |
"normally distributed errors (for inference)" | |
], | |
DIFF_IN_MEANS: [ | |
"treatment is randomly assigned (or as-if random)", | |
"no spillover effects", | |
"stable unit treatment value assumption (SUTVA)" | |
], | |
DIFF_IN_DIFF: [ | |
"parallel trends between treatment and control groups before treatment", | |
"no spillover effects between groups", | |
"no anticipation effects before treatment", | |
"stable composition of treatment and control groups", | |
"treatment timing is exogenous" | |
], | |
REGRESSION_DISCONTINUITY: [ | |
"units cannot precisely manipulate the running variable around the cutoff", | |
"continuity of conditional expectation functions of potential outcomes at the cutoff", | |
"no other changes occurring precisely at the cutoff" | |
], | |
PROPENSITY_SCORE_MATCHING: [ | |
"no unmeasured confounders (conditional ignorability)", | |
"sufficient overlap (common support) between treatment and control groups", | |
"correct propensity score model specification" | |
], | |
INSTRUMENTAL_VARIABLE: [ | |
"instrument is correlated with treatment (relevance)", | |
"instrument affects outcome only through treatment (exclusion restriction)", | |
"instrument is independent of unmeasured confounders (exogeneity/independence)" | |
], | |
CORRELATION_ANALYSIS: [ | |
"data represents a sample from the population of interest", | |
"variables are measured appropriately" | |
], | |
PROPENSITY_SCORE_WEIGHTING: [ | |
"no unmeasured confounders (conditional ignorability)", | |
"sufficient overlap (common support) between treatment and control groups", | |
"correct propensity score model specification", | |
"weights correctly specified (e.g., ATE, ATT)" | |
], | |
GENERALIZED_PROPENSITY_SCORE: [ | |
"conditional mean independence", | |
"positivity/common support for GPS", | |
"correct specification of the GPS model", | |
"correct specification of the outcome model", | |
"no unmeasured confounders affecting both treatment and outcome, given X", | |
"treatment variable is continuous" | |
], | |
FRONTDOOR_ADJUSTMENT: [ | |
"mediator is affected by treatment and affects outcome", | |
"mediator is not affected by any confounders of the treatment-outcome relationship" | |
] | |
} | |
def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]: | |
excluded_methods = set(excluded_methods or []) | |
logger.info(f"Excluded methods: {sorted(excluded_methods)}") | |
treatment = dataset_properties.get("treatment_variable") | |
outcome = dataset_properties.get("outcome_variable") | |
if not treatment or not outcome: | |
raise ValueError("Both treatment and outcome variables must be specified") | |
instrument_var = dataset_properties.get("instrument_variable") | |
running_var = dataset_properties.get("running_variable") | |
cutoff_val = dataset_properties.get("cutoff_value") | |
time_var = dataset_properties.get("time_variable") | |
is_rct = dataset_properties.get("is_rct", False) | |
has_temporal = dataset_properties.get("has_temporal_structure", False) | |
frontdoor = dataset_properties.get("frontdoor_criterion", False) | |
covariate_overlap_result = dataset_properties.get("covariate_overlap_score") | |
covariates = dataset_properties.get("covariates", []) | |
treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary") | |
# Helpers to collect candidates | |
candidates = [] # list of (method, priority_index) | |
justifications: Dict[str, str] = {} | |
assumptions: Dict[str, List[str]] = {} | |
def add(method: str, justification: str, prio_order: List[str]): | |
if method in justifications: # already added | |
return | |
justifications[method] = justification | |
assumptions[method] = METHOD_ASSUMPTIONS[method] | |
# priority index from provided order (fallback large if not present) | |
try: | |
idx = prio_order.index(method) | |
except ValueError: | |
idx = 10**6 | |
candidates.append((method, idx)) | |
# ----- Build candidate set (no returns here) ----- | |
# RCT branch | |
if is_rct: | |
logger.info("Dataset is from a randomized controlled trial (RCT)") | |
rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS] | |
if instrument_var and instrument_var != treatment: | |
add(INSTRUMENTAL_VARIABLE, | |
f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.", | |
rct_priority) | |
if covariates: | |
add(LINEAR_REGRESSION, | |
"RCT with covariates—use OLS for precision.", | |
rct_priority) | |
else: | |
add(DIFF_IN_MEANS, | |
"Pure RCT without covariates—difference-in-means.", | |
rct_priority) | |
# Observational branch | |
obs_priority_binary = [ | |
INSTRUMENTAL_VARIABLE, | |
PROPENSITY_SCORE_MATCHING, | |
PROPENSITY_SCORE_WEIGHTING, | |
FRONTDOOR_ADJUSTMENT, | |
LINEAR_REGRESSION, | |
] | |
obs_priority_nonbinary = [ | |
INSTRUMENTAL_VARIABLE, | |
FRONTDOOR_ADJUSTMENT, | |
LINEAR_REGRESSION, | |
] | |
# Common early structural signals first (still only add as candidates) | |
if has_temporal and time_var: | |
add(DIFF_IN_DIFF, | |
f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).", | |
[DIFF_IN_DIFF]) # highest among itself | |
if running_var and cutoff_val is not None: | |
add(REGRESSION_DISCONTINUITY, | |
f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.", | |
[REGRESSION_DISCONTINUITY]) | |
# Binary vs non-binary pathways | |
if treatment_variable_type == "binary": | |
if instrument_var: | |
add(INSTRUMENTAL_VARIABLE, | |
f"Instrumental variable '{instrument_var}' available.", | |
obs_priority_binary) | |
# Propensity score methods only if covariates exist | |
if covariates: | |
if covariate_overlap_result is not None: | |
ps_method = (PROPENSITY_SCORE_WEIGHTING | |
if covariate_overlap_result < 0.1 | |
else PROPENSITY_SCORE_MATCHING) | |
else: | |
ps_method = PROPENSITY_SCORE_MATCHING | |
add(ps_method, | |
"Covariates observed; PS method chosen based on overlap.", | |
obs_priority_binary) | |
if frontdoor: | |
add(FRONTDOOR_ADJUSTMENT, | |
"Front-door criterion satisfied.", | |
obs_priority_binary) | |
add(LINEAR_REGRESSION, | |
"OLS as a fallback specification.", | |
obs_priority_binary) | |
else: | |
logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}") | |
if instrument_var: | |
add(INSTRUMENTAL_VARIABLE, | |
f"Instrument '{instrument_var}' candidate for non-binary treatment.", | |
obs_priority_nonbinary) | |
if frontdoor: | |
add(FRONTDOOR_ADJUSTMENT, | |
"Front-door criterion satisfied.", | |
obs_priority_nonbinary) | |
add(LINEAR_REGRESSION, | |
"Fallback for non-binary treatment without stronger identification.", | |
obs_priority_nonbinary) | |
# ----- Centralized exclusion handling ----- | |
# Remove excluded | |
filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods] | |
# If nothing survives, attempt a safe fallback not excluded | |
if not filtered: | |
logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}") | |
fallback_order = [ | |
LINEAR_REGRESSION, | |
DIFF_IN_MEANS, | |
PROPENSITY_SCORE_MATCHING, | |
PROPENSITY_SCORE_WEIGHTING, | |
DIFF_IN_DIFF, | |
REGRESSION_DISCONTINUITY, | |
INSTRUMENTAL_VARIABLE, | |
FRONTDOOR_ADJUSTMENT, | |
] | |
fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None) | |
if not fallback: | |
# truly nothing left; raise with context | |
raise RuntimeError("No viable method remains after exclusions.") | |
selected_method = fallback | |
alternatives = [] | |
justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.") | |
else: | |
# Pick by smallest priority index, then stable by insertion | |
filtered.sort(key=lambda x: x[1]) | |
selected_method = filtered[0][0] | |
alternatives = [m for (m, _) in filtered[1:] if m != selected_method] | |
logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}") | |
return { | |
"selected_method": selected_method, | |
"method_justification": justifications[selected_method], | |
"method_assumptions": assumptions[selected_method], | |
"alternatives": alternatives, | |
"excluded_methods": sorted(excluded_methods), | |
} | |
def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None): | |
""" | |
Wrapped function to select causal method based on dataset properties and query | |
Args: | |
dataset_analysis (Dict): results of dataset analysis | |
variables (Dict): dictionary of variable names and types | |
is_rct (bool): whether the dataset is from a randomized controlled trial | |
llm (BaseChatModel): language model instance for generating prompts | |
dataset_description (str): description of the dataset | |
original_query (str): the original user query | |
excluded_methods (List[str], optional): list of methods to exclude from selection | |
""" | |
logger.info("Running rule-based method selection") | |
properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"), | |
"covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"), | |
"time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"), | |
"treatment_variable_type": variables.get("treatment_variable_type", "binary"), | |
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), | |
"frontdoor_criterion": variables.get("frontdoor_criterion", False), | |
"cutoff_value": variables.get("cutoff_value"), | |
"covariate_overlap_score": variables.get("covariate_overlap_result", 0)} | |
properties["is_rct"] = is_rct | |
logger.info(f"Dataset properties for method selection: {properties}") | |
return select_method(properties, excluded_methods) | |
class DecisionTreeEngine: | |
""" | |
Engine for applying decision trees to select appropriate causal methods. | |
This class wraps the functional decision tree implementation to provide | |
an object-oriented interface for method selection. | |
""" | |
def __init__(self, verbose=False): | |
self.verbose = verbose | |
def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str], | |
dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Apply decision tree to select appropriate causal method. | |
""" | |
if self.verbose: | |
print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}") | |
print(f"Available covariates: {covariates}") | |
treatment_variable_type = query_details.get("treatment_variable_type") | |
covariate_overlap_result = query_details.get("covariate_overlap_result") | |
info = {"treatment_variable": treatment, "outcome_variable": outcome, | |
"covariates": covariates, "time_variable": query_details.get("time_variable"), | |
"group_variable": query_details.get("group_variable"), | |
"instrument_variable": query_details.get("instrument_variable"), | |
"running_variable": query_details.get("running_variable"), | |
"cutoff_value": query_details.get("cutoff_value"), | |
"is_rct": query_details.get("is_rct", False), | |
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), | |
"frontdoor_criterion": query_details.get("frontdoor_criterion", False), | |
"covariate_overlap_score": covariate_overlap_result, | |
"treatment_variable_type": treatment_variable_type} | |
result = select_method(info) | |
if self.verbose: | |
print(f"Selected method: {result['selected_method']}") | |
print(f"Justification: {result['method_justification']}") | |
result["decision_path"] = self._get_decision_path(result["selected_method"]) | |
return result | |
def _get_decision_path(self, method): | |
if method == "linear_regression": | |
return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"] | |
elif method == "propensity_score_matching": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for sufficient covariate overlap", "Sufficient overlap exists"] | |
elif method == "propensity_score_weighting": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for sufficient covariate overlap", "Low overlap—weighting preferred"] | |
elif method == "backdoor_adjustment": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for sufficient covariate overlap", "Adjusting for covariates"] | |
elif method == "instrumental_variable": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for instrumental variables", "Instrument is available"] | |
elif method == "regression_discontinuity_design": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for discontinuity", "Discontinuity exists"] | |
elif method == "difference_in_differences": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check for temporal structure", "Panel data structure exists"] | |
elif method == "frontdoor_adjustment": | |
return ["Check if randomized experiment", "Data is observational", | |
"Check front-door criterion", "Front-door path identified"] | |
elif method == "diff_in_means": | |
return ["Check if randomized experiment", "Pure RCT without covariates"] | |
else: | |
return ["Default method selection"] |