GenBIChatbot / app.py
arithescientist's picture
Update app.py
f0e4f1b verified
raw
history blame
8.54 kB
import os
import streamlit as st
import pandas as pd
import sqlite3
import logging
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.llms import OpenAI
from langchain.sql_database import SQLDatabase
from langchain.prompts import (
ChatPromptTemplate,
FewShotPromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder
)
from langchain.schema import HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.evaluation import load_evaluator
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# OpenAI API key (ensure it is securely stored)
openai_api_key = os.getenv("OPENAI_API_KEY")
# Check if the API key is set
if not openai_api_key:
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
st.stop()
# Step 1: Upload CSV data file (or use default)
st.title("Enhanced Natural Language to SQL Query App")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Create SQLDatabase instance with custom table info
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
# Step 3: Define the few-shot examples for the prompt
few_shot_examples = [
{
"input": "What is the total revenue for each category?",
"query": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;"
},
{
"input": "Show the top 5 products by sales.",
"query": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;"
},
{
"input": "How many orders were placed in the last month?",
"query": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');"
}
]
# Step 4: Define the prompt templates
system_prefix = """
You are an expert data analyst who can convert natural language questions into SQL queries.
Follow these guidelines:
1. Only use the columns and tables provided.
2. Use appropriate SQL syntax for SQLite.
3. Ensure string comparisons are case-insensitive.
4. Do not execute queries that could be harmful or unethical.
5. Provide clear and concise SQL queries.
"""
few_shot_prompt = FewShotPromptTemplate(
example_prompt=PromptTemplate.from_template("Question: {input}\nSQL Query: {query}"),
examples=few_shot_examples,
prefix=system_prefix,
suffix="Question: {input}\nSQL Query:",
input_variables=["input"]
)
# Step 5: Initialize the LLM and toolkit
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
# Step 6: Create the agent
agent_prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate(prompt=few_shot_prompt),
HumanMessagePromptTemplate.from_template("{input}")
])
sql_agent = create_sql_agent(
llm=llm,
toolkit=toolkit,
prompt=agent_prompt,
verbose=True,
agent_type="openai-functions",
max_iterations=5
)
# Step 7: Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Use the agent to generate the SQL query
with st.spinner("Generating SQL query..."):
response = sql_agent.run(user_prompt)
# Check if the response contains SQL code
if "SELECT" in response.upper():
sql_query = response.strip()
logging.info(f"Generated SQL Query: {sql_query}")
# Attempt to execute SQL query and handle exceptions
try:
result = pd.read_sql_query(sql_query, conn)
if result.empty:
assistant_response = "The query returned no results. Please try a different question."
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# Limit the result to first 10 rows for display
result_display = result.head(10)
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
st.session_state.history.append({"role": "assistant", "content": result_display})
# Generate insights based on the query result
insights_template = """
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
User's Question: {question}
SQL Query Result:
{result}
Concise Analysis:
"""
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
result_str = result_display.to_string(index=False)
insights = insights_chain.run({'question': user_prompt, 'result': result_str})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
except Exception as e:
logging.error(f"An error occurred during SQL execution: {e}")
assistant_response = f"Error executing SQL query: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# Handle responses that do not contain SQL queries
assistant_response = response
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Evaluate the response for harmful content
try:
evaluator = load_evaluator("harmful_content", llm=llm)
eval_result = evaluator.evaluate_strings(
input=user_prompt,
prediction=response
)
if eval_result['flagged']:
st.warning("The assistant's response may not be appropriate.")
else:
logging.info("Response evaluated as appropriate.")
except Exception as e:
logging.error(f"An error occurred during evaluation: {e}")
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Step 8: Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
if isinstance(message['content'], pd.DataFrame):
st.markdown("**Assistant:** Query Results:")
st.dataframe(message['content'])
else:
st.markdown(f"**Assistant:** {message['content']}")
# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)