FireShadow's picture
Initial clean commit
1721aea
"""
decision tree component for selecting causal inference methods
this module implements the decision tree logic to select the most appropriate
causal inference method based on dataset characteristics and available variables
"""
import logging
from typing import Dict, List, Any, Optional
import pandas as pd
# define method names
BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
LINEAR_REGRESSION = "linear_regression"
DIFF_IN_MEANS = "diff_in_means"
DIFF_IN_DIFF = "difference_in_differences"
REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
INSTRUMENTAL_VARIABLE = "instrumental_variable"
CORRELATION_ANALYSIS = "correlation_analysis"
PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"
logger = logging.getLogger(__name__)
# method assumptions mapping
METHOD_ASSUMPTIONS = {
BACKDOOR_ADJUSTMENT: [
"no unmeasured confounders (conditional ignorability given covariates)",
"correct model specification for outcome conditional on treatment and covariates",
"positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
],
LINEAR_REGRESSION: [
"linear relationship between treatment, covariates, and outcome",
"no unmeasured confounders (if observational)",
"correct model specification",
"homoscedasticity of errors",
"normally distributed errors (for inference)"
],
DIFF_IN_MEANS: [
"treatment is randomly assigned (or as-if random)",
"no spillover effects",
"stable unit treatment value assumption (SUTVA)"
],
DIFF_IN_DIFF: [
"parallel trends between treatment and control groups before treatment",
"no spillover effects between groups",
"no anticipation effects before treatment",
"stable composition of treatment and control groups",
"treatment timing is exogenous"
],
REGRESSION_DISCONTINUITY: [
"units cannot precisely manipulate the running variable around the cutoff",
"continuity of conditional expectation functions of potential outcomes at the cutoff",
"no other changes occurring precisely at the cutoff"
],
PROPENSITY_SCORE_MATCHING: [
"no unmeasured confounders (conditional ignorability)",
"sufficient overlap (common support) between treatment and control groups",
"correct propensity score model specification"
],
INSTRUMENTAL_VARIABLE: [
"instrument is correlated with treatment (relevance)",
"instrument affects outcome only through treatment (exclusion restriction)",
"instrument is independent of unmeasured confounders (exogeneity/independence)"
],
CORRELATION_ANALYSIS: [
"data represents a sample from the population of interest",
"variables are measured appropriately"
],
PROPENSITY_SCORE_WEIGHTING: [
"no unmeasured confounders (conditional ignorability)",
"sufficient overlap (common support) between treatment and control groups",
"correct propensity score model specification",
"weights correctly specified (e.g., ATE, ATT)"
],
GENERALIZED_PROPENSITY_SCORE: [
"conditional mean independence",
"positivity/common support for GPS",
"correct specification of the GPS model",
"correct specification of the outcome model",
"no unmeasured confounders affecting both treatment and outcome, given X",
"treatment variable is continuous"
],
FRONTDOOR_ADJUSTMENT: [
"mediator is affected by treatment and affects outcome",
"mediator is not affected by any confounders of the treatment-outcome relationship"
]
}
def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
excluded_methods = set(excluded_methods or [])
logger.info(f"Excluded methods: {sorted(excluded_methods)}")
treatment = dataset_properties.get("treatment_variable")
outcome = dataset_properties.get("outcome_variable")
if not treatment or not outcome:
raise ValueError("Both treatment and outcome variables must be specified")
instrument_var = dataset_properties.get("instrument_variable")
running_var = dataset_properties.get("running_variable")
cutoff_val = dataset_properties.get("cutoff_value")
time_var = dataset_properties.get("time_variable")
is_rct = dataset_properties.get("is_rct", False)
has_temporal = dataset_properties.get("has_temporal_structure", False)
frontdoor = dataset_properties.get("frontdoor_criterion", False)
covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
covariates = dataset_properties.get("covariates", [])
treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")
# Helpers to collect candidates
candidates = [] # list of (method, priority_index)
justifications: Dict[str, str] = {}
assumptions: Dict[str, List[str]] = {}
def add(method: str, justification: str, prio_order: List[str]):
if method in justifications: # already added
return
justifications[method] = justification
assumptions[method] = METHOD_ASSUMPTIONS[method]
# priority index from provided order (fallback large if not present)
try:
idx = prio_order.index(method)
except ValueError:
idx = 10**6
candidates.append((method, idx))
# ----- Build candidate set (no returns here) -----
# RCT branch
if is_rct:
logger.info("Dataset is from a randomized controlled trial (RCT)")
rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]
if instrument_var and instrument_var != treatment:
add(INSTRUMENTAL_VARIABLE,
f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
rct_priority)
if covariates:
add(LINEAR_REGRESSION,
"RCT with covariates—use OLS for precision.",
rct_priority)
else:
add(DIFF_IN_MEANS,
"Pure RCT without covariates—difference-in-means.",
rct_priority)
# Observational branch
obs_priority_binary = [
INSTRUMENTAL_VARIABLE,
PROPENSITY_SCORE_MATCHING,
PROPENSITY_SCORE_WEIGHTING,
FRONTDOOR_ADJUSTMENT,
LINEAR_REGRESSION,
]
obs_priority_nonbinary = [
INSTRUMENTAL_VARIABLE,
FRONTDOOR_ADJUSTMENT,
LINEAR_REGRESSION,
]
# Common early structural signals first (still only add as candidates)
if has_temporal and time_var:
add(DIFF_IN_DIFF,
f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
[DIFF_IN_DIFF]) # highest among itself
if running_var and cutoff_val is not None:
add(REGRESSION_DISCONTINUITY,
f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
[REGRESSION_DISCONTINUITY])
# Binary vs non-binary pathways
if treatment_variable_type == "binary":
if instrument_var:
add(INSTRUMENTAL_VARIABLE,
f"Instrumental variable '{instrument_var}' available.",
obs_priority_binary)
# Propensity score methods only if covariates exist
if covariates:
if covariate_overlap_result is not None:
ps_method = (PROPENSITY_SCORE_WEIGHTING
if covariate_overlap_result < 0.1
else PROPENSITY_SCORE_MATCHING)
else:
ps_method = PROPENSITY_SCORE_MATCHING
add(ps_method,
"Covariates observed; PS method chosen based on overlap.",
obs_priority_binary)
if frontdoor:
add(FRONTDOOR_ADJUSTMENT,
"Front-door criterion satisfied.",
obs_priority_binary)
add(LINEAR_REGRESSION,
"OLS as a fallback specification.",
obs_priority_binary)
else:
logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
if instrument_var:
add(INSTRUMENTAL_VARIABLE,
f"Instrument '{instrument_var}' candidate for non-binary treatment.",
obs_priority_nonbinary)
if frontdoor:
add(FRONTDOOR_ADJUSTMENT,
"Front-door criterion satisfied.",
obs_priority_nonbinary)
add(LINEAR_REGRESSION,
"Fallback for non-binary treatment without stronger identification.",
obs_priority_nonbinary)
# ----- Centralized exclusion handling -----
# Remove excluded
filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]
# If nothing survives, attempt a safe fallback not excluded
if not filtered:
logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
fallback_order = [
LINEAR_REGRESSION,
DIFF_IN_MEANS,
PROPENSITY_SCORE_MATCHING,
PROPENSITY_SCORE_WEIGHTING,
DIFF_IN_DIFF,
REGRESSION_DISCONTINUITY,
INSTRUMENTAL_VARIABLE,
FRONTDOOR_ADJUSTMENT,
]
fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
if not fallback:
# truly nothing left; raise with context
raise RuntimeError("No viable method remains after exclusions.")
selected_method = fallback
alternatives = []
justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
else:
# Pick by smallest priority index, then stable by insertion
filtered.sort(key=lambda x: x[1])
selected_method = filtered[0][0]
alternatives = [m for (m, _) in filtered[1:] if m != selected_method]
logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")
return {
"selected_method": selected_method,
"method_justification": justifications[selected_method],
"method_assumptions": assumptions[selected_method],
"alternatives": alternatives,
"excluded_methods": sorted(excluded_methods),
}
def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
"""
Wrapped function to select causal method based on dataset properties and query
Args:
dataset_analysis (Dict): results of dataset analysis
variables (Dict): dictionary of variable names and types
is_rct (bool): whether the dataset is from a randomized controlled trial
llm (BaseChatModel): language model instance for generating prompts
dataset_description (str): description of the dataset
original_query (str): the original user query
excluded_methods (List[str], optional): list of methods to exclude from selection
"""
logger.info("Running rule-based method selection")
properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
"covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
"time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
"treatment_variable_type": variables.get("treatment_variable_type", "binary"),
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": variables.get("frontdoor_criterion", False),
"cutoff_value": variables.get("cutoff_value"),
"covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
properties["is_rct"] = is_rct
logger.info(f"Dataset properties for method selection: {properties}")
return select_method(properties, excluded_methods)
class DecisionTreeEngine:
"""
Engine for applying decision trees to select appropriate causal methods.
This class wraps the functional decision tree implementation to provide
an object-oriented interface for method selection.
"""
def __init__(self, verbose=False):
self.verbose = verbose
def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply decision tree to select appropriate causal method.
"""
if self.verbose:
print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
print(f"Available covariates: {covariates}")
treatment_variable_type = query_details.get("treatment_variable_type")
covariate_overlap_result = query_details.get("covariate_overlap_result")
info = {"treatment_variable": treatment, "outcome_variable": outcome,
"covariates": covariates, "time_variable": query_details.get("time_variable"),
"group_variable": query_details.get("group_variable"),
"instrument_variable": query_details.get("instrument_variable"),
"running_variable": query_details.get("running_variable"),
"cutoff_value": query_details.get("cutoff_value"),
"is_rct": query_details.get("is_rct", False),
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": query_details.get("frontdoor_criterion", False),
"covariate_overlap_score": covariate_overlap_result,
"treatment_variable_type": treatment_variable_type}
result = select_method(info)
if self.verbose:
print(f"Selected method: {result['selected_method']}")
print(f"Justification: {result['method_justification']}")
result["decision_path"] = self._get_decision_path(result["selected_method"])
return result
def _get_decision_path(self, method):
if method == "linear_regression":
return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
elif method == "propensity_score_matching":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Sufficient overlap exists"]
elif method == "propensity_score_weighting":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
elif method == "backdoor_adjustment":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Adjusting for covariates"]
elif method == "instrumental_variable":
return ["Check if randomized experiment", "Data is observational",
"Check for instrumental variables", "Instrument is available"]
elif method == "regression_discontinuity_design":
return ["Check if randomized experiment", "Data is observational",
"Check for discontinuity", "Discontinuity exists"]
elif method == "difference_in_differences":
return ["Check if randomized experiment", "Data is observational",
"Check for temporal structure", "Panel data structure exists"]
elif method == "frontdoor_adjustment":
return ["Check if randomized experiment", "Data is observational",
"Check front-door criterion", "Front-door path identified"]
elif method == "diff_in_means":
return ["Check if randomized experiment", "Pure RCT without covariates"]
else:
return ["Default method selection"]