Suresh Beekhani
Update app.py
d39a66d unverified
# Import necessary libraries and modules for various tasks
from dotenv import load_dotenv # For loading environment variables from a .env file (e.g., database credentials)
from langchain_core.messages import AIMessage, HumanMessage # For handling AI and user messages
from langchain_core.prompts import ChatPromptTemplate # To create templates for chatbot responses
from langchain_core.runnables import RunnablePassthrough # To chain different operations (e.g., inputs/outputs)
from langchain_community.utilities import SQLDatabase # Utility to connect to SQL databases using LangChain
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
from langchain_groq import ChatGroq # Integrates the Groq model for generating chat responses
import streamlit as st # Streamlit for building the web interface
import os # For accessing environment variables (e.g., credentials or other settings)
import psycopg2 # PostgreSQL database adapter for database connections
# Load environment variables (e.g., database credentials) from the .env file
load_dotenv()
# Function to initialize the database connection
def init_database() -> SQLDatabase:
try:
# Retrieve connection details from environment variables or set defaults
user = os.getenv("DB_USER", "postgres")
password = os.getenv("DB_PASSWORD", "beekhani143")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "5432")
database = os.getenv("DB_NAME", "")
# Construct the database URI for PostgreSQL connection
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
# Connect to the database using SQLDatabase utility and return the connection
return SQLDatabase.from_uri(db_uri)
except Exception as e:
# If connection fails, display error in the Streamlit UI
st.error(f"Failed to connect to the database: {e}")
return None
# Function to generate the SQL query chain based on user input and conversation history
def get_sql_chain(db):
# Define the prompt template for generating SQL queries based on schema and chat history
template = """
You are a data analyst. You are interacting with a user who is asking questions about the company's database.
Based on the table schema below, write a SQL query that answers the user's question. Consider the conversation history.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
Write only the SQL query and nothing else.
Question: {question}
SQL Query:
"""
# Create the prompt template from the instructions above
prompt = ChatPromptTemplate.from_template(template)
# Initialize the Groq model for generating SQL responses (deterministic output with temperature=0)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Function to retrieve the database schema (table structure)
def get_schema(_):
return db.get_table_info()
# Build a chain of operations: get the schema, generate SQL query, and parse as plain text
return (
RunnablePassthrough.assign(schema=get_schema) # Pass schema to the chain
| prompt # Use the prompt template to guide query creation
| llm # Generate SQL query using the Groq model
| StrOutputParser() # Parse the output as a plain text SQL query
)
# Function to generate a natural language response based on SQL query and database results
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
# Generate the SQL query using the SQL chain
sql_chain = get_sql_chain(db)
# Define the template for generating natural language responses
template = """
You are a data analyst. Based on the schema, SQL query, and response, write a natural language response.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
SQL Query: <SQL>{query}</SQL>
User Question: {question}
SQL Response: {response}
"""
# Create the prompt template for generating a natural language response
prompt = ChatPromptTemplate.from_template(template)
# Initialize the Groq model for response generation
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Chain the steps: Generate SQL, execute it, and generate the natural language response
chain = (
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
.assign(
schema=lambda _: db.get_table_info(), # Pass the schema for query execution
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the query and clean the output
)
| prompt # Generate a response from the AI using the prompt
| llm # Get the final response from the model
| StrOutputParser() # Parse the output into plain text
)
# Invoke the chain to generate the response
result = chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
# Debugging: Print the SQL query being executed
if isinstance(result, str):
print(f"SQL Query: {result}")
else:
sql_query = result.get('query', 'No query generated')
print(f"SQL Query: {sql_query}")
# Return the final natural language response
return result
# Initialize the chat session in Streamlit
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
# Initial greeting from the AI assistant
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]
# Configure the Streamlit page title and icon
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
st.title("Chat with PostgreSQL") # Display the main title on the web page
# Sidebar for database connection settings
with st.sidebar:
st.subheader("Settings") # Display settings section header
st.write("Connect to your database and start chatting.") # Instructions for users
# Input fields for database connection details (host, port, user, password, and database name)
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
# Button to initiate database connection
if st.button("Connect"):
with st.spinner("Connecting to the database..."):
db = init_database() # Attempt to connect to the database
if db:
st.session_state.db = db # Save the connection to session state
st.success("Connected to the database!") # Display success message
else:
st.error("Connection failed. Please check your settings.") # Error message if connection fails
# Display the chat history (both AI and user messages)
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
# Input field for the user to type a message
user_query = st.chat_input("Type a message...")
if user_query and user_query.strip():
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save user query
with st.chat_message("Human"): # Display user's message in chat
st.markdown(user_query)
with st.chat_message("AI"): # Generate and display AI's response
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get AI response
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response)) # Save AI's response to chat history