Spaces:
Sleeping
Sleeping
File size: 13,630 Bytes
ab66d4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
# utils/causal_chatbot.py
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from utils.preprocessor import summarize_dataframe_for_chatbot
from utils.graph_utils import get_graph_summary_for_chatbot
import pandas as pd
load_dotenv()
# Configure Groq API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
print("ERROR: GROQ_API_KEY environment variable not set.")
raise ValueError("GROQ_API_KEY is required.")
# Debug: Print API key details
print(f"Loaded GROQ_API_KEY: {GROQ_API_KEY[:5]}...{GROQ_API_KEY[-5:]}")
print(f"API Key Length: {len(GROQ_API_KEY)}")
# Initialize the Groq model with LangChain
try:
model = ChatGroq(
model_name="llama-3.3-70b-versatile",
temperature=0.7,
groq_api_key=GROQ_API_KEY
)
except Exception as e:
print(f"Error configuring Groq API: {e}")
model = None
def assess_causal_compatibility(data_json: list) -> str:
"""
Assesses the dataset's compatibility for causal inference analysis.
Args:
data_json: List of dictionaries representing the dataset.
Returns:
String describing the dataset's suitability for causal analysis.
"""
if not data_json:
return "No dataset provided for compatibility assessment."
try:
df = pd.DataFrame(data_json)
num_rows, num_cols = df.shape
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
missing_values = df.isnull().sum().sum()
assessment = [
f"Dataset has {num_rows} rows and {num_cols} columns.",
f"Numeric columns ({len(numeric_cols)}): {', '.join(numeric_cols) if len(numeric_cols) > 0 else 'None'}.",
f"Categorical columns ({len(categorical_cols)}): {', '.join(categorical_cols) if len(categorical_cols) > 0 else 'None'}.",
f"Missing values: {missing_values}."
]
# Causal compatibility insights
if num_cols < 3:
assessment.append("Warning: Dataset has fewer than 3 columns, which may limit causal analysis (e.g., no room for treatment, outcome, and confounders).")
if len(numeric_cols) == 0:
assessment.append("Warning: No numeric columns detected. Causal inference often requires numeric variables for treatment or outcome.")
if missing_values > 0:
assessment.append("Note: Missing values detected. Preprocessing (e.g., imputation) may be needed for accurate causal analysis.")
if len(numeric_cols) >= 2 and num_rows > 100:
assessment.append("Positive: Dataset has multiple numeric columns and sufficient rows, suitable for causal inference with proper preprocessing.")
else:
assessment.append("Note: Ensure at least two numeric columns (e.g., treatment and outcome) and sufficient data points for robust causal analysis.")
return "\n".join(assessment)
except Exception as e:
print(f"Error in assess_causal_compatibility: {e}")
return "Unable to assess dataset compatibility due to processing error."
# Define tools using LangChain's @tool decorator
@tool
def get_dataset_info() -> dict:
"""
Provides summary information and causal compatibility assessment for the currently loaded dataset.
The dataset is provided by the backend session context.
Returns:
Dictionary containing the dataset summary and compatibility assessment.
"""
return {"summary": "Dataset will be provided by session context"}
@tool
def get_causal_graph_info() -> dict:
"""
Provides summary information about the currently discovered causal graph.
The graph data is provided by the backend session context.
Returns:
Dictionary containing the graph summary.
"""
return {"summary": "Graph data will be provided by session context"}
# Bind tools to the model
tools = [get_dataset_info, get_causal_graph_info]
if model:
model_with_tools = model.bind_tools(tools)
def get_chatbot_response(user_message: str, session_context: dict) -> str:
"""
Gets a response from the Groq chatbot, handling tool calls.
Args:
user_message: The message from the user.
session_context: Dictionary containing current session data
(e.g., processed_data, causal_graph_adj, causal_graph_nodes).
Returns:
The chatbot's response message.
"""
if model is None:
return "Chatbot is not configured correctly. Please check Groq API key."
try:
# Create a prompt template to guide the model's behavior
prompt = ChatPromptTemplate.from_messages([
("system", """You are CausalBox Assistant, an AI that helps users analyze datasets and causal graphs.
Use the provided tools to access dataset or graph information. Do NOT generate or guess parameters for tool calls; the backend will provide all necessary data (e.g., dataset or graph details).
For dataset queries (e.g., "read the dataset", "dataset compatibility"), call `get_dataset_info` without arguments.
For graph queries (e.g., "describe the causal graph"), call `get_causal_graph_info` without arguments.
For other questions (e.g., "what is a confounder?"), respond directly with clear, accurate explanations.
When you receive tool results, provide a comprehensive analysis and explanation to help the user understand their data and causal analysis possibilities.
Examples:
- User: "Tell me about the dataset" -> Call `get_dataset_info`.
- User: "Check dataset compatibility for causal analysis" -> Call `get_dataset_info`.
- User: "Describe the causal graph" -> Call `get_causal_graph_info`.
- User: "What is a confounder?" -> Respond: "A confounder is a variable that influences both the treatment and outcome, causing a spurious association."
"""),
("human", "{user_message}")
])
# Chain the prompt with the model
chain = prompt | model_with_tools
# Log the user message and session context
print(f"Processing user message: {user_message}")
print(f"Session context keys: {list(session_context.keys())}")
# Invoke the chain with the user message
response = chain.invoke({"user_message": user_message})
print(f"Model response: {response}")
# Handle tool calls if present
if response.tool_calls:
tool_call = response.tool_calls[0]
function_name = tool_call["name"]
function_args = tool_call["args"]
print(f"Chatbot calling tool: {function_name} with args: {function_args}")
# Map session context to tool arguments
tool_output = {}
if function_name == "get_dataset_info":
data_json = session_context.get("processed_data", [])
if not isinstance(data_json, list) or not data_json:
print(f"Invalid or empty data_json: {data_json}")
return "Error: No valid dataset available."
tool_output = get_dataset_info.invoke({})
tool_output["summary"] = summarize_dataframe_for_chatbot(data_json)
tool_output["causal_compatibility"] = assess_causal_compatibility(data_json)
elif function_name == "get_causal_graph_info":
graph_adj = session_context.get("causal_graph_adj", [])
nodes = session_context.get("causal_graph_nodes", [])
if not graph_adj or not nodes:
print("No causal graph data available")
return "Error: No causal graph available."
tool_output = get_causal_graph_info.invoke({})
tool_output["summary"] = get_graph_summary_for_chatbot(graph_adj, nodes)
else:
print(f"Unknown tool: {function_name}")
return f"Error: Unknown tool {function_name}."
print(f"Tool output: {tool_output}")
# Create the tool output text
output_text = tool_output["summary"]
if tool_output.get("causal_compatibility"):
output_text += "\n\nCausal Compatibility Assessment:\n" + tool_output["causal_compatibility"]
# Create messages for the final response - FIXED VERSION
messages = [
HumanMessage(content=user_message),
AIMessage(content="", tool_calls=[tool_call]),
ToolMessage(content=output_text, tool_call_id=tool_call["id"])
]
# Create a follow-up prompt to ensure the model provides a comprehensive response
follow_up_prompt = ChatPromptTemplate.from_messages([
("system", """You are CausalBox Assistant. Based on the tool results, provide a comprehensive, helpful response to the user's question.
Explain the dataset characteristics, causal compatibility, and provide actionable insights for causal analysis.
Be specific about what the data shows and what causal analysis approaches would be suitable.
Always provide a complete response, not just acknowledgment."""),
("human", "{original_question}"),
("assistant", "I'll analyze the dataset information for you."),
("human", "Here's the dataset analysis: {tool_results}\n\nPlease provide a comprehensive explanation of this data and its suitability for causal analysis.")
])
# Get final response from the model with explicit prompting
print("Invoking model with tool response messages")
try:
final_chain = follow_up_prompt | model
final_response = final_chain.invoke({
"original_question": user_message,
"tool_results": output_text
})
print(f"Final response content: {final_response.content}")
if final_response.content and final_response.content.strip():
return final_response.content
else:
# Fallback response if model still returns empty
return create_fallback_response(output_text, user_message)
except Exception as e:
print(f"Error in final response generation: {e}")
return create_fallback_response(output_text, user_message)
else:
print("No tool calls, returning direct response")
if response.content and response.content.strip():
return response.content
else:
return "I'm ready to help you with causal analysis. Please ask me about your dataset, causal graphs, or any causal inference concepts you'd like to understand."
except Exception as e:
print(f"Error communicating with Groq: {e}")
return f"Sorry, I'm having trouble processing your request: {str(e)}"
def create_fallback_response(tool_output: str, user_message: str) -> str:
"""
Creates a fallback response when the model returns empty content.
"""
response_parts = ["Based on your dataset analysis:\n"]
if "Dataset Summary:" in tool_output:
response_parts.append("π **Dataset Overview:**")
summary_part = tool_output.split("Dataset Summary:")[1].split("Causal Compatibility Assessment:")[0]
response_parts.append(summary_part.strip())
response_parts.append("")
if "Causal Compatibility Assessment:" in tool_output:
response_parts.append("π **Causal Analysis Compatibility:**")
compatibility_part = tool_output.split("Causal Compatibility Assessment:")[1]
response_parts.append(compatibility_part.strip())
response_parts.append("")
# Add specific insights based on the data
if "FinalExamScore" in tool_output:
response_parts.append("π‘ **Key Insights for Causal Analysis:**")
response_parts.append("- Your dataset appears to be education-related with variables like FinalExamScore, StudyHours, and TuitionHours")
response_parts.append("- This is excellent for causal analysis as you can explore questions like:")
response_parts.append(" β’ Does increasing study hours causally improve exam scores?")
response_parts.append(" β’ What's the causal effect of tutoring (TuitionHours) on performance?")
response_parts.append(" β’ How does parental education influence student outcomes?")
response_parts.append("")
response_parts.append("π **Next Steps:**")
response_parts.append("- Consider identifying your treatment variable (e.g., TuitionHours)")
response_parts.append("- Define your outcome variable (likely FinalExamScore)")
response_parts.append("- Identify potential confounders (ParentalEducation, SchoolType)")
return "\n".join(response_parts) |