Spaces:
Running
Running
import os | |
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
import logging | |
from langchain.agents import create_sql_agent | |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
from langchain.llms import OpenAI | |
from langchain.sql_database import SQLDatabase | |
from langchain.prompts import ( | |
ChatPromptTemplate, | |
FewShotPromptTemplate, | |
PromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder | |
) | |
from langchain.schema import HumanMessage | |
from langchain.chat_models import ChatOpenAI | |
from langchain.evaluation import load_evaluator | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize conversation history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
# OpenAI API key (ensure it is securely stored) | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Check if the API key is set | |
if not openai_api_key: | |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") | |
st.stop() | |
# Step 1: Upload CSV data file (or use default) | |
st.title("Enhanced Natural Language to SQL Query App") | |
st.write("Upload a CSV file to get started, or use the default dataset.") | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory | |
st.write("Using default_data.csv file.") | |
table_name = "default_table" | |
else: | |
data = pd.read_csv(csv_file) | |
table_name = csv_file.name.split('.')[0] | |
st.write(f"Data Preview ({csv_file.name}):") | |
st.dataframe(data.head()) | |
# Step 2: Load CSV data into a persistent SQLite database | |
db_file = 'my_database.db' | |
conn = sqlite3.connect(db_file) | |
data.to_sql(table_name, conn, index=False, if_exists='replace') | |
# SQL table metadata (for validation and schema) | |
valid_columns = list(data.columns) | |
st.write(f"Valid columns: {valid_columns}") | |
# Create SQLDatabase instance with custom table info | |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name]) | |
# Step 3: Define the few-shot examples for the prompt | |
few_shot_examples = [ | |
{ | |
"input": "What is the total revenue for each category?", | |
"query": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;" | |
}, | |
{ | |
"input": "Show the top 5 products by sales.", | |
"query": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;" | |
}, | |
{ | |
"input": "How many orders were placed in the last month?", | |
"query": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');" | |
} | |
] | |
# Step 4: Define the prompt templates | |
system_prefix = """ | |
You are an expert data analyst who can convert natural language questions into SQL queries. | |
Follow these guidelines: | |
1. Only use the columns and tables provided. | |
2. Use appropriate SQL syntax for SQLite. | |
3. Ensure string comparisons are case-insensitive. | |
4. Do not execute queries that could be harmful or unethical. | |
5. Provide clear and concise SQL queries. | |
""" | |
few_shot_prompt = FewShotPromptTemplate( | |
example_prompt=PromptTemplate.from_template("Question: {input}\nSQL Query: {query}"), | |
examples=few_shot_examples, | |
prefix=system_prefix, | |
suffix="Question: {input}\nSQL Query:", | |
input_variables=["input"] | |
) | |
# Step 5: Initialize the LLM and toolkit | |
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key) | |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm) | |
# Step 6: Create the agent | |
agent_prompt = ChatPromptTemplate.from_messages([ | |
SystemMessagePromptTemplate(prompt=few_shot_prompt), | |
HumanMessagePromptTemplate.from_template("{input}") | |
]) | |
sql_agent = create_sql_agent( | |
llm=llm, | |
toolkit=toolkit, | |
prompt=agent_prompt, | |
verbose=True, | |
agent_type="openai-functions", | |
max_iterations=5 | |
) | |
# Step 7: Define the callback function | |
def process_input(): | |
user_prompt = st.session_state['user_input'] | |
if user_prompt: | |
try: | |
# Append user message to history | |
st.session_state.history.append({"role": "user", "content": user_prompt}) | |
# Use the agent to generate the SQL query | |
with st.spinner("Generating SQL query..."): | |
response = sql_agent.run(user_prompt) | |
# Check if the response contains SQL code | |
if "SELECT" in response.upper(): | |
sql_query = response.strip() | |
logging.info(f"Generated SQL Query: {sql_query}") | |
# Attempt to execute SQL query and handle exceptions | |
try: | |
result = pd.read_sql_query(sql_query, conn) | |
if result.empty: | |
assistant_response = "The query returned no results. Please try a different question." | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
# Limit the result to first 10 rows for display | |
result_display = result.head(10) | |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"}) | |
st.session_state.history.append({"role": "assistant", "content": result_display}) | |
# Generate insights based on the query result | |
insights_template = """ | |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. | |
User's Question: {question} | |
SQL Query Result: | |
{result} | |
Concise Analysis: | |
""" | |
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result']) | |
insights_chain = LLMChain(llm=llm, prompt=insights_prompt) | |
result_str = result_display.to_string(index=False) | |
insights = insights_chain.run({'question': user_prompt, 'result': result_str}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": insights}) | |
except Exception as e: | |
logging.error(f"An error occurred during SQL execution: {e}") | |
assistant_response = f"Error executing SQL query: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
# Handle responses that do not contain SQL queries | |
assistant_response = response | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Evaluate the response for harmful content | |
try: | |
evaluator = load_evaluator("harmful_content", llm=llm) | |
eval_result = evaluator.evaluate_strings( | |
input=user_prompt, | |
prediction=response | |
) | |
if eval_result['flagged']: | |
st.warning("The assistant's response may not be appropriate.") | |
else: | |
logging.info("Response evaluated as appropriate.") | |
except Exception as e: | |
logging.error(f"An error occurred during evaluation: {e}") | |
except Exception as e: | |
logging.error(f"An error occurred: {e}") | |
assistant_response = f"Error: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Reset the user_input in session state | |
st.session_state['user_input'] = '' | |
# Step 8: Display the conversation history | |
for message in st.session_state.history: | |
if message['role'] == 'user': | |
st.markdown(f"**User:** {message['content']}") | |
elif message['role'] == 'assistant': | |
if isinstance(message['content'], pd.DataFrame): | |
st.markdown("**Assistant:** Query Results:") | |
st.dataframe(message['content']) | |
else: | |
st.markdown(f"**Assistant:** {message['content']}") | |
# Place the input field at the bottom with the callback | |
st.text_input("Enter your message:", key='user_input', on_change=process_input) | |