Spaces:
Running
Running
import os | |
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
import logging | |
from langchain.agents import create_sql_agent | |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
from langchain.agents.agent_types import AgentType | |
from langchain.llms import OpenAI | |
from langchain.sql_database import SQLDatabase | |
from langchain.chat_models import ChatOpenAI | |
from langchain.evaluation import load_evaluator | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize conversation history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
# OpenAI API key | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Check if the API key is set | |
if not openai_api_key: | |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") | |
st.stop() | |
# Step 1: Upload CSV data file (or use default) | |
st.title("Enhanced Natural Language to SQL Query App") | |
st.write("Upload a CSV file to get started, or use the default dataset.") | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv("default_data.csv") # Ensure this file exists | |
st.write("Using default_data.csv file.") | |
table_name = "default_table" | |
else: | |
data = pd.read_csv(csv_file) | |
table_name = csv_file.name.split('.')[0] | |
st.write(f"Data Preview ({csv_file.name}):") | |
st.dataframe(data.head()) | |
# Step 2: Load CSV data into SQLite database | |
db_file = 'my_database.db' | |
conn = sqlite3.connect(db_file) | |
data.to_sql(table_name, conn, index=False, if_exists='replace') | |
# Create SQLDatabase instance | |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name]) | |
# Initialize the LLM | |
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key) | |
# Step 3: Create the agent | |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm) | |
sql_agent = create_sql_agent( | |
llm=llm, | |
toolkit=toolkit, | |
verbose=True, | |
agent_type=AgentType.OPENAI_FUNCTIONS, | |
max_iterations=5 | |
) | |
# Step 4: Define the callback function | |
def process_input(): | |
user_prompt = st.session_state['user_input'] | |
if user_prompt: | |
try: | |
# Append user message to history | |
st.session_state.history.append({"role": "user", "content": user_prompt}) | |
# Use the agent to generate the SQL query and get the response | |
with st.spinner("Processing..."): | |
response = sql_agent.run(user_prompt) | |
# Check if the response contains a SQL query | |
if "```sql" in response: | |
# Extract the SQL query | |
start_index = response.find("```sql") + len("```sql") | |
end_index = response.find("```", start_index) | |
sql_query = response[start_index:end_index].strip() | |
else: | |
# If no SQL code is found, assume the entire response is the SQL query | |
sql_query = response.strip() | |
logging.info(f"Generated SQL Query: {sql_query}") | |
# Attempt to execute SQL query and handle exceptions | |
try: | |
result = pd.read_sql_query(sql_query, conn) | |
if result.empty: | |
assistant_response = "The query returned no results. Please try a different question." | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
# Limit the result to first 10 rows for display | |
result_display = result.head(10) | |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"}) | |
st.session_state.history.append({"role": "assistant", "content": result_display}) | |
# Generate insights based on the query result | |
insights_template = """ | |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. | |
User's Question: {question} | |
SQL Query Result: | |
{result} | |
Concise Analysis: | |
""" | |
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result']) | |
insights_chain = LLMChain(llm=llm, prompt=insights_prompt) | |
result_str = result_display.to_string(index=False) | |
insights = insights_chain.run({'question': user_prompt, 'result': result_str}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": insights}) | |
except Exception as e: | |
logging.error(f"An error occurred during SQL execution: {e}") | |
assistant_response = f"Error executing SQL query: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
except Exception as e: | |
logging.error(f"An error occurred: {e}") | |
assistant_response = f"Error: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Reset user input | |
st.session_state['user_input'] = '' | |
# Step 5: Display conversation history | |
for message in st.session_state.history: | |
if message['role'] == 'user': | |
st.markdown(f"**User:** {message['content']}") | |
elif message['role'] == 'assistant': | |
if isinstance(message['content'], pd.DataFrame): | |
st.markdown("**Assistant:** Query Results:") | |
st.dataframe(message['content']) | |
else: | |
st.markdown(f"**Assistant:** {message['content']}") | |
# Input field | |
st.text_input("Enter your message:", key='user_input', on_change=process_input) | |