Spaces:
Running
Running
from typing import List, Optional, Union, Dict, Any, Tuple | |
from pydantic import BaseModel, Field, validator | |
import json | |
# --- Pydantic models for LLM structured output --- | |
# These models are used by query_interpreter and potentially other components | |
# to structure the output received from Language Models. | |
class LLMSelectedVariable(BaseModel): | |
"""Pydantic model for selecting a single variable.""" | |
variable_name: Optional[str] = Field(None, description="The single best column name selected.") | |
class LLMSelectedCovariates(BaseModel): | |
"""Pydantic model for selecting a list of covariates.""" | |
covariates: List[str] = Field(default_factory=list, description="The list of selected covariate column names.") | |
class LLMIVars(BaseModel): | |
"""Pydantic model for identifying IVs.""" | |
instrument_variable: Optional[str] = Field(None, description="The identified instrumental variable column name.") | |
class LLMEstimand(BaseModel): | |
"""Pydantic model for identifying estimand""" | |
estimand: Optional[str] = Field(None, description="The identified estimand") | |
class LLMRDDVars(BaseModel): | |
"""Pydantic model for identifying RDD variables.""" | |
running_variable: Optional[str] = Field(None, description="The identified running variable column name.") | |
cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.") | |
class LLMRCTCheck(BaseModel): | |
"""Pydantic model for checking if data is RCT.""" | |
is_rct: Optional[bool] = Field(None, description="True if the data is from a randomized controlled trial, False otherwise, None if unsure.") | |
reasoning: Optional[str] = Field(None, description="Brief reasoning for the RCT conclusion.") | |
class LLMTreatmentReferenceLevel(BaseModel): | |
reference_level: Optional[str] = Field(None, description="The identified reference/control level for the treatment variable, if specified in the query. Should be one of the actual values in the treatment column.") | |
reasoning: Optional[str] = Field(None, description="Brief reasoning for identifying this reference level.") | |
class LLMInteractionSuggestion(BaseModel): | |
"""Pydantic model for LLM suggestion on interaction terms.""" | |
interaction_needed: Optional[bool] = Field(None, description="True if an interaction term is strongly suggested by the query or context. LLM should provide true, false, or omit for None.") | |
interaction_variable: Optional[str] = Field(None, description="The name of the covariate that should interact with the treatment. Null if not applicable or if the interaction is complex/multiple.") | |
reasoning: Optional[str] = Field(None, description="Brief reasoning for the suggestion for or against an interaction term.") | |
# --- Pydantic models for Tool Inputs/Outputs and Data Structures --- | |
class TemporalStructure(BaseModel): | |
"""Represents detected temporal structure in the data.""" | |
has_temporal_structure: bool | |
temporal_columns: List[str] | |
is_panel_data: bool | |
id_column: Optional[str] = None | |
time_column: Optional[str] = None | |
time_periods: Optional[int] = None | |
units: Optional[int] = None | |
class DatasetInfo(BaseModel): | |
"""Basic information about the dataset file.""" | |
num_rows: int | |
num_columns: int | |
file_path: str | |
file_name: str | |
class DatasetAnalysis(BaseModel): | |
"""Results from the dataset analysis component.""" | |
dataset_info: DatasetInfo | |
columns: List[str] | |
potential_treatments: List[str] | |
potential_outcomes: List[str] | |
temporal_structure_detected: bool | |
panel_data_detected: bool | |
potential_instruments_detected: bool | |
discontinuities_detected: bool | |
temporal_structure: TemporalStructure | |
column_categories: Optional[Dict[str, str]] = None | |
column_nunique_counts: Optional[Dict[str, int]] = None | |
sample_size: int | |
num_covariates_estimate: int | |
per_group_summary_stats: Optional[Dict[str, Dict[str, Any]]] = None | |
potential_instruments: Optional[List[str]] = None | |
overlap_assessment: Optional[Dict[str, Any]] = None | |
# --- Model for Dataset Analyzer Tool Output --- | |
class DatasetAnalyzerOutput(BaseModel): | |
"""Structured output for the dataset analyzer tool.""" | |
analysis_results: DatasetAnalysis | |
dataset_description: Optional[str] = None | |
workflow_state: Dict[str, Any] | |
#TODO make query info consistent with the Data analysis out put | |
class QueryInfo(BaseModel): | |
"""Information extracted from the user's initial query.""" | |
query_text: str | |
potential_treatments: Optional[List[str]] = None | |
potential_outcomes: Optional[List[str]] = None | |
covariates_hints: Optional[List[str]] = None | |
instrument_hints: Optional[List[str]] = None | |
running_variable_hints: Optional[List[str]] = None | |
cutoff_value_hint: Optional[Union[float, int]] = None | |
class QueryInterpreterInput(BaseModel): | |
"""Input structure for the query interpreter tool.""" | |
query_info: QueryInfo | |
dataset_analysis: DatasetAnalysis | |
dataset_description: str | |
# Add original_query if it should be part of the standard input | |
original_query: Optional[str] = None | |
class Variables(BaseModel): | |
"""Structured variables identified by the query interpreter component.""" | |
treatment_variable: Optional[str] = None | |
treatment_variable_type: Optional[str] = Field(None, description="Type of the treatment variable (e.g., 'binary', 'continuous', 'categorical_multi_value')") | |
outcome_variable: Optional[str] = None | |
instrument_variable: Optional[str] = None | |
covariates: Optional[List[str]] = Field(default_factory=list) | |
time_variable: Optional[str] = None | |
group_variable: Optional[str] = None # Often the unit ID | |
running_variable: Optional[str] = None | |
cutoff_value: Optional[Union[float, int]] = None | |
is_rct: Optional[bool] = Field(False, description="Flag indicating if the dataset is from an RCT.") | |
treatment_reference_level: Optional[str] = Field(None, description="The specified reference/control level for a multi-valued treatment variable.") | |
interaction_term_suggested: Optional[bool] = Field(False, description="Whether the query or context suggests an interaction term with the treatment might be relevant.") | |
interaction_variable_candidate: Optional[str] = Field(None, description="The covariate identified as a candidate for interaction with the treatment.") | |
class QueryInterpreterOutput(BaseModel): | |
"""Structured output for the query interpreter tool.""" | |
variables: Variables | |
dataset_analysis: DatasetAnalysis | |
dataset_description: Optional[str] | |
workflow_state: Dict[str, Any] | |
original_query: Optional[str] = None | |
# Input model for Method Selector Tool | |
class MethodSelectorInput(BaseModel): | |
"""Input structure for the method selector tool.""" | |
variables: Variables# Uses the Variables model identified by QueryInterpreter | |
dataset_analysis: DatasetAnalysis # Uses the DatasetAnalysis model | |
dataset_description: Optional[str] = None | |
original_query: Optional[str] = None | |
# Note: is_rct is expected inside inputs.variables | |
# --- Models for Method Validator Tool --- | |
class MethodInfo(BaseModel): | |
"""Information about the selected causal inference method.""" | |
selected_method: Optional[str] = None | |
method_name: Optional[str] = None # Often a title-cased version for display | |
method_justification: Optional[str] = None | |
method_assumptions: Optional[List[str]] = Field(default_factory=list) | |
# Add alternative methods if it should be part of the standard info passed around | |
alternative_methods: Optional[List[str]] = Field(default_factory=list) | |
class MethodValidatorInput(BaseModel): | |
"""Input structure for the method validator tool.""" | |
method_info: MethodInfo | |
variables: Variables | |
dataset_analysis: DatasetAnalysis | |
dataset_description: Optional[str] = None | |
original_query: Optional[str] = None | |
# --- Model for Method Executor Tool --- | |
class MethodExecutorInput(BaseModel): | |
"""Input structure for the method executor tool.""" | |
method: str = Field(..., description="The causal method name (use recommended method if validation failed).") | |
variables: Variables # Contains T, O, C, etc. | |
dataset_path: str | |
dataset_analysis: DatasetAnalysis | |
dataset_description: Optional[str] = None | |
# Include validation_info from validator output if needed by estimator or LLM assist later? | |
validation_info: Optional[Any] = None | |
original_query: Optional[str] = None | |
# --- Model for Explanation Generator Tool --- | |
class ExplainerInput(BaseModel): | |
"""Input structure for the explanation generator tool.""" | |
# Based on expected output from method_executor_tool and validator | |
method_info: MethodInfo | |
validation_info: Optional[Dict[str, Any]] = None # From validator tool | |
variables: Variables | |
results: Dict[str, Any] # Numerical results from executor | |
dataset_analysis: DatasetAnalysis | |
dataset_description: Optional[str] = None | |
# Add original query if needed for explanation context | |
original_query: Optional[str] = None | |
# Add other shared models/schemas below as needed. | |
class FormattedOutput(BaseModel): | |
""" | |
Structured output containing the final formatted results and explanations | |
from a causal analysis run. | |
""" | |
query: str = Field(description="The original user query.") | |
method_used: str = Field(description="The user-friendly name of the causal inference method used.") | |
causal_effect: Optional[float] = Field(None, description="The point estimate of the causal effect.") | |
standard_error: Optional[float] = Field(None, description="The standard error of the causal effect estimate.") | |
confidence_interval: Optional[Tuple[Optional[float], Optional[float]]] = Field(None, description="The confidence interval for the causal effect (e.g., 95% CI).") | |
p_value: Optional[float] = Field(None, description="The p-value associated with the causal effect estimate.") | |
summary: str = Field(description="A concise summary paragraph interpreting the main findings.") | |
method_explanation: Optional[str] = Field("", description="Explanation of the causal inference method used.") | |
interpretation_guide: Optional[str] = Field("", description="Guidance on how to interpret the results.") | |
limitations: Optional[List[str]] = Field(default_factory=list, description="List of limitations or potential issues with the analysis.") | |
assumptions: Optional[str] = Field("", description="Discussion of the key assumptions underlying the method and their validity.") | |
practical_implications: Optional[str] = Field("", description="Discussion of the practical implications or significance of the findings.") | |
# Optionally add dataset_analysis and dataset_description if they should be part of the final structure | |
# dataset_analysis: Optional[DatasetAnalysis] = None # Example if using DatasetAnalysis model | |
# dataset_description: Optional[str] = None | |
# This model itself doesn't include workflow_state, as it represents the *content* | |
# The tool using this component will add the workflow_state separately. | |
class LLMParameterDetails(BaseModel): | |
parameter_name: str = Field(description="The full parameter name as found in the model results.") | |
estimate: float | |
p_value: float | |
conf_int_low: float | |
conf_int_high: float | |
std_err: float | |
reasoning: Optional[str] = Field(None, description="Brief reasoning for selecting this parameter and its values.") | |
class LLMTreatmentEffectResults(BaseModel): | |
effects: Optional[Dict[str, LLMParameterDetails]] = Field(description="Dictionary where keys are treatment level names (e.g., 'LevelA', 'LevelB' if multi-level) or a generic key like 'treatment_effect' for binary/continuous treatments. Values are the statistical details for that effect.") | |
all_parameters_successfully_identified: Optional[bool] = Field(description="True if all expected treatment effect parameters were identified and their values extracted, False otherwise.") | |
overall_reasoning: Optional[str] = Field(None, description="Overall reasoning for the extraction process or if issues were encountered.") | |
class RelevantParamInfo(BaseModel): | |
param_name: str = Field(description="The exact parameter name as it appears in the statsmodels results.") | |
param_index: int = Field(description="The index of this parameter in the original list of parameter names.") | |
class LLMIdentifiedRelevantParams(BaseModel): | |
identified_params: List[RelevantParamInfo] = Field(description="A list of parameters identified as relevant to the query or representing all treatment effects for a general query.") | |
all_parameters_successfully_identified: bool = Field(description="True if LLM is confident it identified all necessary params based on query type (e.g., all levels for a general query).") | |