import os import streamlit as st import pandas as pd import sqlite3 import logging from langchain.agents import create_sql_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit from langchain.llms import OpenAI from langchain.sql_database import SQLDatabase from langchain.prompts import ( ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder ) from langchain.schema import HumanMessage from langchain.chat_models import ChatOpenAI from langchain.evaluation import load_evaluator # Initialize logging logging.basicConfig(level=logging.INFO) # Initialize conversation history if 'history' not in st.session_state: st.session_state.history = [] # OpenAI API key (ensure it is securely stored) openai_api_key = os.getenv("OPENAI_API_KEY") # Check if the API key is set if not openai_api_key: st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") st.stop() # Step 1: Upload CSV data file (or use default) st.title("Enhanced Natural Language to SQL Query App") st.write("Upload a CSV file to get started, or use the default dataset.") csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) if csv_file is None: data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory st.write("Using default_data.csv file.") table_name = "default_table" else: data = pd.read_csv(csv_file) table_name = csv_file.name.split('.')[0] st.write(f"Data Preview ({csv_file.name}):") st.dataframe(data.head()) # Step 2: Load CSV data into a persistent SQLite database db_file = 'my_database.db' conn = sqlite3.connect(db_file) data.to_sql(table_name, conn, index=False, if_exists='replace') # SQL table metadata (for validation and schema) valid_columns = list(data.columns) st.write(f"Valid columns: {valid_columns}") # Create SQLDatabase instance with custom table info engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name]) # Step 3: Define the few-shot examples for the prompt few_shot_examples = [ { "input": "What is the total revenue for each category?", "query": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;" }, { "input": "Show the top 5 products by sales.", "query": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;" }, { "input": "How many orders were placed in the last month?", "query": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');" } ] # Step 4: Define the prompt templates system_prefix = """ You are an expert data analyst who can convert natural language questions into SQL queries. Follow these guidelines: 1. Only use the columns and tables provided. 2. Use appropriate SQL syntax for SQLite. 3. Ensure string comparisons are case-insensitive. 4. Do not execute queries that could be harmful or unethical. 5. Provide clear and concise SQL queries. """ few_shot_prompt = FewShotPromptTemplate( example_prompt=PromptTemplate.from_template("Question: {input}\nSQL Query: {query}"), examples=few_shot_examples, prefix=system_prefix, suffix="Question: {input}\nSQL Query:", input_variables=["input"] ) # Step 5: Initialize the LLM and toolkit llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key) toolkit = SQLDatabaseToolkit(db=engine, llm=llm) # Step 6: Create the agent agent_prompt = ChatPromptTemplate.from_messages([ SystemMessagePromptTemplate(prompt=few_shot_prompt), HumanMessagePromptTemplate.from_template("{input}") ]) sql_agent = create_sql_agent( llm=llm, toolkit=toolkit, prompt=agent_prompt, verbose=True, agent_type="openai-functions", max_iterations=5 ) # Step 7: Define the callback function def process_input(): user_prompt = st.session_state['user_input'] if user_prompt: try: # Append user message to history st.session_state.history.append({"role": "user", "content": user_prompt}) # Use the agent to generate the SQL query with st.spinner("Generating SQL query..."): response = sql_agent.run(user_prompt) # Check if the response contains SQL code if "SELECT" in response.upper(): sql_query = response.strip() logging.info(f"Generated SQL Query: {sql_query}") # Attempt to execute SQL query and handle exceptions try: result = pd.read_sql_query(sql_query, conn) if result.empty: assistant_response = "The query returned no results. Please try a different question." st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # Limit the result to first 10 rows for display result_display = result.head(10) st.session_state.history.append({"role": "assistant", "content": "Here are the results:"}) st.session_state.history.append({"role": "assistant", "content": result_display}) # Generate insights based on the query result insights_template = """ You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. User's Question: {question} SQL Query Result: {result} Concise Analysis: """ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result']) insights_chain = LLMChain(llm=llm, prompt=insights_prompt) result_str = result_display.to_string(index=False) insights = insights_chain.run({'question': user_prompt, 'result': result_str}) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": insights}) except Exception as e: logging.error(f"An error occurred during SQL execution: {e}") assistant_response = f"Error executing SQL query: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # Handle responses that do not contain SQL queries assistant_response = response st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Evaluate the response for harmful content try: evaluator = load_evaluator("harmful_content", llm=llm) eval_result = evaluator.evaluate_strings( input=user_prompt, prediction=response ) if eval_result['flagged']: st.warning("The assistant's response may not be appropriate.") else: logging.info("Response evaluated as appropriate.") except Exception as e: logging.error(f"An error occurred during evaluation: {e}") except Exception as e: logging.error(f"An error occurred: {e}") assistant_response = f"Error: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Reset the user_input in session state st.session_state['user_input'] = '' # Step 8: Display the conversation history for message in st.session_state.history: if message['role'] == 'user': st.markdown(f"**User:** {message['content']}") elif message['role'] == 'assistant': if isinstance(message['content'], pd.DataFrame): st.markdown("**Assistant:** Query Results:") st.dataframe(message['content']) else: st.markdown(f"**Assistant:** {message['content']}") # Place the input field at the bottom with the callback st.text_input("Enter your message:", key='user_input', on_change=process_input)