Spaces:
Sleeping
Sleeping
import chainlit as cl | |
import pandas as pd | |
import time | |
from typing import Dict, Any | |
from agents.table_selection import table_selection_agent | |
from agents.data_retrieval import sample_data_retrieval_agent | |
from agents.sql_generation import sql_generation_agent | |
from agents.validation import query_validation_and_optimization | |
from agents.execution import execution_agent | |
from utils.bigquery_utils import init_bigquery_connection | |
from utils.feedback_utils import save_feedback_to_bigquery | |
async def on_chat_start(): | |
"""Initialize the chat session.""" | |
# Initialize BigQuery client | |
client = init_bigquery_connection() | |
# Store the client in the user session | |
cl.user_session.set("client", client) | |
# Send a welcome message | |
await cl.Message( | |
content="👋 Welcome to the Natural Language to SQL Query Assistant! Ask me any question about your e-commerce data.", | |
author="SQL Assistant" | |
).send() | |
# Add some example questions without using actions | |
await cl.Message( | |
content="Here are some example questions you can ask:", | |
author="SQL Assistant" | |
).send() | |
examples = [ | |
"What are the top 5 products by revenue?", | |
"How many orders were placed in the last month?", | |
"Which customers spent the most in 2023?", | |
"What is the average order value by product category?" | |
] | |
# Display all examples in a single message | |
examples_text = "\n\n".join([f"• {example}" for example in examples]) | |
examples_text += "\n\n(You can copy and paste any of these examples to try them out)" | |
await cl.Message( | |
content=examples_text, | |
author="SQL Assistant" | |
).send() | |
async def on_message(message: cl.Message): | |
"""Handle user messages.""" | |
query = message.content | |
# Check if we're in "awaiting feedback" mode | |
awaiting_feedback = cl.user_session.get("awaiting_feedback", False) | |
if awaiting_feedback: | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
# Save the detailed feedback | |
feedback_details = f"negative: {query}" | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
feedback_details | |
) | |
# Reset the awaiting feedback flag | |
cl.user_session.set("awaiting_feedback", False) | |
if success: | |
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
return | |
# If not in feedback mode, process as a regular query | |
# Get the BigQuery client from the user session | |
client = cl.user_session.get("client") | |
# Store the original query in the user session for feedback | |
cl.user_session.set("original_query", query) | |
# Send a thinking message | |
thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send() | |
try: | |
# Step 1: Analyze relevant tables | |
thinking_msg.content = "🔍 Analyzing relevant tables..." | |
await thinking_msg.update() | |
# Initialize the state with the query | |
state = {"sql_query": query, "client": client} | |
tables_state = table_selection_agent(state) | |
relevant_tables = tables_state.get("relevant_tables", []) | |
# Send the tables analysis with a slight delay for better UX | |
await cl.sleep(1) | |
if relevant_tables: | |
tables_text = "I've identified these relevant tables for your query:\n\n" | |
tables_text += "\n".join([f"- `{table}`" for table in relevant_tables]) | |
await cl.Message(content=tables_text, author="SQL Assistant").send() | |
# Step 2: Retrieve sample data | |
thinking_msg.content = "📊 Retrieving sample data..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with relevant tables and get sample data | |
state.update(tables_state) | |
sample_data_state = sample_data_retrieval_agent(state) | |
# Step 3: Generate SQL | |
thinking_msg.content = "💻 Generating SQL query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with sample data and generate SQL | |
state.update(sample_data_state) | |
sql_state = sql_generation_agent(state) | |
generated_sql = sql_state.get("generated_sql", "No SQL generated") | |
# Store the generated SQL in the user session | |
cl.user_session.set("generated_sql", generated_sql) | |
# Send the generated SQL | |
await cl.Message( | |
content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```", | |
author="SQL Assistant" | |
).send() | |
# Step 4: Optimize SQL | |
thinking_msg.content = "🔧 Optimizing the query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with generated SQL and optimize | |
state.update(sql_state) | |
optimization_state = query_validation_and_optimization(state) | |
optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL") | |
# Store the optimized SQL in the user session | |
cl.user_session.set("optimized_sql", optimized_sql) | |
# Send the optimized SQL | |
await cl.Message( | |
content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```", | |
author="SQL Assistant" | |
).send() | |
# Step 5: Execute query | |
thinking_msg.content = "⚙️ Executing query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with optimized SQL and execute | |
state.update(optimization_state) | |
execution_state = execution_agent(state) | |
execution_result = execution_state.get("execution_result", {}) | |
# Format and send the results | |
if isinstance(execution_result, dict) and "error" in execution_result: | |
error_msg = execution_result.get("error", "Unknown error occurred") | |
await cl.Message( | |
content=f"❌ Error executing query: {error_msg}", | |
author="SQL Assistant" | |
).send() | |
elif not execution_result: | |
await cl.Message( | |
content="✅ Query executed successfully but returned no results.", | |
author="SQL Assistant" | |
).send() | |
else: | |
try: | |
# Convert results to DataFrame for better display | |
if isinstance(execution_result[0], tuple): | |
# Try to get column names from BigQuery schema | |
try: | |
# Get the schema from the query job | |
query_job = client.query(optimized_sql) | |
schema = query_job.result().schema | |
column_names = [field.name for field in schema] | |
# Use these column names for the DataFrame | |
df = pd.DataFrame(execution_result, columns=column_names) | |
except Exception: | |
# Fallback to generic column names | |
columns = [f"Column_{i}" for i in range(len(execution_result[0]))] | |
df = pd.DataFrame(execution_result, columns=columns) | |
else: | |
df = pd.DataFrame(execution_result) | |
# Display the DataFrame as a table | |
await cl.Message( | |
content="✅ Query executed successfully! Here are the results:", | |
author="SQL Assistant" | |
).send() | |
# Send the DataFrame as an element | |
elements = [cl.Dataframe(data=df)] | |
await cl.Message(content="", elements=elements, author="SQL Assistant").send() | |
# Also provide a summary of the results with feedback buttons | |
num_rows = len(df) | |
num_cols = len(df.columns) | |
# Ask for feedback using AskActionMessage | |
res = await cl.AskActionMessage( | |
content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?", | |
actions=[ | |
cl.Action(name="feedback", payload={"value": "positive"}, label="👍 Good results"), | |
cl.Action(name="feedback", payload={"value": "negative"}, label="👎 Not what I wanted") | |
], | |
).send() | |
if res: | |
feedback_value = res.get("payload", {}).get("value") | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
if feedback_value == "positive": | |
# Handle positive feedback | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"positive" | |
) | |
if success: | |
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
elif feedback_value == "negative": | |
# For negative feedback, just ask for text input | |
await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send() | |
# Set flag to indicate we're awaiting detailed feedback | |
cl.user_session.set("awaiting_feedback", True) | |
# Save initial negative feedback | |
save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"negative" | |
) | |
except Exception as e: | |
await cl.Message( | |
content=f"❌ Error formatting results: {str(e)}", | |
author="SQL Assistant" | |
).send() | |
except Exception as e: | |
# Handle any errors | |
thinking_msg.content = f"❌ Error: {str(e)}" | |
await thinking_msg.update() | |
await cl.Message( | |
content=f"I encountered an error while processing your query: {str(e)}", | |
author="SQL Assistant" | |
).send() | |
# Callback handlers for actions | |
async def on_feedback_action(action): | |
"""Handle feedback action.""" | |
feedback_value = action.payload.get("value") | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
if feedback_value == "positive": | |
# Handle positive feedback | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"positive" | |
) | |
if success: | |
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
async def on_feedback_bad(action): | |
"""Handle negative feedback.""" | |
# Ask for more detailed feedback | |
res = await cl.AskUserMessage( | |
content="I'm sorry the results weren't what you expected. Could you please provide more details about what was wrong?", | |
author="SQL Assistant", | |
timeout=300, | |
elements=[ | |
cl.Textarea( | |
id="feedback_details", | |
label="Your feedback", | |
initial_value="", | |
rows=3 | |
) | |
] | |
).send() | |
feedback_details = "negative" | |
if res and "feedback_details" in res: | |
feedback_details = f"negative: {res['feedback_details']}" | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
# Save the feedback to BigQuery | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
feedback_details | |
) | |
if success: | |
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
# This is needed for Chainlit to run properly | |
if __name__ == "__main__": | |
# Note: Chainlit uses its own CLI command to run the app | |
# You'll run this with: chainlit run new_app.py -w | |
pass |