Spaces:
Running
Running
File size: 8,250 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
"""
Method Selector Tool for selecting causal inference methods.
This module provides a LangChain tool for selecting appropriate
causal inference methods based on dataset characteristics and query details.
"""
import logging # Add logging
from typing import Dict, List, Any, Optional, Union
from langchain_core.tools import tool # Use langchain_core
# Import component function and central LLM factory
from auto_causal.components.decision_tree import rule_based_select_method # Rule-based
from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine # LLM-based
from auto_causal.config import get_llm_client # Updated import path
from auto_causal.components.state_manager import create_workflow_state_update
# Import shared models from central location
from auto_causal.models import (
Variables,
DatasetAnalysis,
MethodSelectorInput # Still needed for args_schema
)
logger = logging.getLogger(__name__)
@tool(args_schema=MethodSelectorInput)
# Option 1: Modify signature to match args_schema fields
def method_selector_tool(
variables: Variables,
dataset_analysis: DatasetAnalysis,
dataset_description: Optional[str] = None,
original_query: Optional[str] = None,
excluded_methods: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Select the most appropriate causal inference method based on structured input.
Applies decision logic based on dataset analysis and identified variables (including is_rct).
Args:
variables: Pydantic model containing identified variables (T, O, C, IV, RDD, is_rct, etc.).
dataset_analysis: Pydantic model containing results of dataset analysis.
dataset_description: Optional textual description of the dataset.
original_query: Optional original user query string.
excluded_methods: Optional list of method names to exclude from selection.
Returns:
Dictionary with method selection details, context for next step, and workflow state.
"""
logger.info("Running method_selector_tool with individual args...")
# Access data directly from arguments (they are already Pydantic models)
variables_model = variables
dataset_analysis_model = dataset_analysis
dataset_description_str = dataset_description
is_rct_flag = variables_model.is_rct # Get is_rct directly from variables argument
# Convert Pydantic models to dicts for the component call (select_method expects dicts)
variables_dict = variables_model.model_dump()
dataset_analysis_dict = dataset_analysis_model.model_dump()
# Basic validation
treatment = variables_dict.get("treatment_variable")
outcome = variables_dict.get("outcome_variable")
if not all([treatment, outcome]):
logger.error("Missing treatment or outcome variable in input.")
# Construct error output, including passed-along context
workflow_update = create_workflow_state_update(
current_step="method_selection",
step_completed_flag=False,
next_tool="method_selector_tool",
next_step_reason="Missing treatment/outcome variable in input",
error="Missing treatment/outcome variable in input"
)
# Use model_dump() for analysis dict
return { "error": "Missing treatment/outcome",
"variables": variables_dict,
"dataset_analysis": dataset_analysis_dict,
"dataset_description": dataset_description_str,
**workflow_update.get('workflow_state', {})}
# Get LLM instance (optional for component)
try:
llm_instance = get_llm_client()
except Exception as e:
logger.warning(f"Failed to initialize LLM for method_selector_tool: {e}. Proceeding without LLM features.")
llm_instance = None
# --- Configuration for switching ---
USE_LLM_DECISION_TREE = False # Set to False to use the original rule-based tree
# Call the component function
try:
if USE_LLM_DECISION_TREE:
logger.info("Using LLM-based Decision Tree Engine for method selection.")
if not llm_instance:
logger.warning("LLM instance is required for DecisionTreeLLMEngine but not available. Falling back to rule-based or error.")
# Potentially raise an error or explicitly call rule-based here if LLM is mandatory for this path
# For now, it will proceed and DecisionTreeLLMEngine will handle the missing llm
llm_engine = DecisionTreeLLMEngine(verbose=True) # You can set verbosity as needed
method_selection_dict = llm_engine.select_method_llm(
dataset_analysis=dataset_analysis_dict,
variables=variables_dict,
is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False,
llm=llm_instance,
excluded_methods=excluded_methods
)
else:
logger.info("Using Rule-based Decision Tree Engine for method selection.")
# Pass dicts and the is_rct flag
method_selection_dict = rule_based_select_method(
dataset_analysis=dataset_analysis_dict,
variables=variables_dict,
is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False, # Handle None case
llm=llm_instance,
dataset_description = dataset_description,
original_query = original_query,
excluded_methods = excluded_methods
)
except Exception as e:
logger.error(f"Error during method selection execution: {e}", exc_info=True)
# Construct error output
workflow_update = create_workflow_state_update(
current_step="method_selection",
step_completed_flag=False,
next_tool="error_handler_tool",
next_step_reason=f"Component failed: {e}",
error=f"Component failed: {e}"
)
return { "error": f"Method selection logic failed: {e}",
"variables": variables_dict,
"dataset_analysis": dataset_analysis_dict,
"dataset_description": dataset_description_str,
**workflow_update.get('workflow_state', {})}
# --- Prepare Output Dictionary ---
method_selected_flag = bool(method_selection_dict.get("selected_method") and method_selection_dict["selected_method"] != "Error")
# Create the 'method_info' sub-dictionary required by the validator
# Include alternative_methods if present in the selection output
method_info = {
"selected_method": method_selection_dict.get("selected_method"),
"method_name": method_selection_dict.get("selected_method", "").replace("_", " ").title() if method_selected_flag else None,
"method_justification": method_selection_dict.get("method_justification"),
"method_assumptions": method_selection_dict.get("method_assumptions", []),
"alternative_methods": method_selection_dict.get("alternatives", []) # Include alternatives
}
# Create the final output dictionary for the agent
result = {
"method_info": method_info,
"variables": variables_dict,
"dataset_analysis": dataset_analysis_dict,
"dataset_description": dataset_description_str,
"original_query": original_query # Pass original query argument
}
# Determine workflow state for the next step
next_tool_name = "method_validator_tool" if method_selected_flag else "error_handler_tool"
next_reason = "Now we need to validate the assumptions of the selected method" if method_selected_flag else "Method selection failed or returned an error."
workflow_update = create_workflow_state_update(
current_step="method_selection",
step_completed_flag=method_selected_flag,
next_tool=next_tool_name,
next_step_reason=next_reason
)
result.update(workflow_update.get('workflow_state', {})) # Add workflow state dict
logger.info(f"method_selector_tool finished. Selected: {method_info.get('selected_method')}")
return result |