Spaces:
Sleeping
Sleeping
# routers/treatment_routes.py | |
from flask import Blueprint, request, jsonify | |
import pandas as pd | |
from utils.treatment_effects import TreatmentEffectAlgorithms | |
import logging | |
treatment_bp = Blueprint('treatment', __name__) | |
logger = logging.getLogger(__name__) | |
treatment_effect_algorithms = TreatmentEffectAlgorithms() | |
def estimate_ate(): | |
""" | |
Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE). | |
Expects 'data' (list of dicts), 'treatment_col', 'outcome_col', 'covariates' (list of column names), | |
and 'method' (string for estimation method). | |
Returns ATE/CATE as float or dictionary. | |
""" | |
try: | |
payload = request.json | |
if not payload or 'data' not in payload or 'treatment_col' not in payload or 'outcome_col' not in payload or 'covariates' not in payload: | |
return jsonify({"detail": "Missing required ATE estimation parameters."}), 400 | |
df = pd.DataFrame(payload["data"]) | |
treatment_col = payload["treatment_col"] | |
outcome_col = payload["outcome_col"] | |
covariates = payload["covariates"] | |
method = payload.get("method", "linear_regression").lower() # Default to linear regression | |
logger.info(f"ATE/CATE request: treatment={treatment_col}, outcome={outcome_col}, method={method}, data shape: {df.shape}") | |
if not all(col in df.columns for col in [treatment_col, outcome_col] + covariates): | |
return jsonify({"detail": "Invalid column names provided for ATE estimation."}), 400 | |
if method == "linear_regression": | |
result = treatment_effect_algorithms.linear_regression_ate(df, treatment_col, outcome_col, covariates) | |
elif method == "propensity_score_matching": | |
result = treatment_effect_algorithms.propensity_score_matching(df, treatment_col, outcome_col, covariates) # Placeholder | |
elif method == "inverse_propensity_weighting": | |
result = treatment_effect_algorithms.inverse_propensity_weighting(df, treatment_col, outcome_col, covariates) # Placeholder | |
elif method == "t_learner": | |
result = treatment_effect_algorithms.t_learner(df, treatment_col, outcome_col, covariates) # Placeholder | |
elif method == "s_learner": | |
result = treatment_effect_algorithms.s_learner(df, treatment_col, outcome_col, covariates) # Placeholder | |
else: | |
return jsonify({"detail": f"Unsupported treatment effect estimation method: {method}"}), 400 | |
logger.info(f"Estimated ATE/CATE using {method}: {result}") | |
return jsonify({"result": result}) | |
except Exception as e: | |
logger.exception(f"Error in ATE/CATE estimation: {str(e)}") | |
return jsonify({"detail": f"ATE/CATE estimation failed: {str(e)}"}), 500 |