Suresh Beekhani
Your commit message
5b26f53
raw
history blame
7.24 kB
# Import necessary libraries and modules
from dotenv import load_dotenv # For loading environment variables from .env
from langchain_core.messages import AIMessage, HumanMessage # Message handling
from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses
from langchain_core.runnables import RunnablePassthrough # To chain operations
from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain
from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings
# OpenAI model for chat (if used)
from langchain_groq import ChatGroq # Groq model for chat (currently used)
import streamlit as st # Streamlit for building the web app
import os # To access environment variables
# Load environment variables from the .env file (like API keys, database credentials)
load_dotenv()
# Function to initialize a connection to a MySQL database
def init_database() -> SQLDatabase:
try:
# Load credentials from environment variables for better security
user = os.getenv("DB_USER", "root")
password = os.getenv("DB_PASSWORD", "admin")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "3306")
database = os.getenv("DB_NAME", "Chinook")
# Construct the database URI
db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
# Initialize and return the SQLDatabase instance
return SQLDatabase.from_uri(db_uri)
except Exception as e:
st.error(f"Failed to connect to database: {e}")
return None
# Function to create a chain that generates SQL queries from user input and conversation history
def get_sql_chain(db):
# SQL prompt template
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
Write only the SQL query and nothing else.
Question: {question}
SQL Query:
"""
# Create a prompt from the above template
prompt = ChatPromptTemplate.from_template(template)
# Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Helper function to get schema info from the database
def get_schema(_):
return db.get_table_info()
# Chain of operations:
# 1. Assign schema information from the database
# 2. Use the AI model to generate a SQL query
# 3. Parse the result into a string
return (
RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database
| prompt # Generate SQL query from the prompt template
| llm # Use Groq model to process the prompt and return a SQL query
| StrOutputParser() # Parse the result as a string
)
# Function to generate a response in natural language based on the SQL query result
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
# Generate the SQL query using the chain
sql_chain = get_sql_chain(db)
# Prompt template for natural language response based on SQL query and result
template = """
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
SQL Query: <SQL>{query}</SQL>
User question: {question}
SQL Response: {response}
"""
# Create a natural language response prompt
prompt = ChatPromptTemplate.from_template(template)
# Initialize Groq model (alternative: OpenAI)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Build a chain: generate SQL query, run it on the database, generate a natural language response
chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
schema=lambda _: db.get_table_info(), # Get schema info
response=lambda vars: db.run(vars["query"]), # Run SQL query on the database
)
| prompt # Use prompt to generate a natural language response
| llm # Process prompt with Groq model
| StrOutputParser() # Parse the final result as a string
)
# Execute the chain and return the response
return chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
# Initialize the Streamlit session
if "chat_history" not in st.session_state:
# Initialize chat history with a welcome message from AI
st.session_state.chat_history = [
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]
# Set up the Streamlit web page configuration
st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
# Streamlit app title
st.title("Chat with MySQL")
# Sidebar for database connection settings
with st.sidebar:
st.subheader("Settings")
st.write("Connect to your database and start chatting.")
# Database connection input fields
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
port = st.text_input("Port", value=os.getenv("DB_PORT", "3306"))
user = st.text_input("User", value=os.getenv("DB_USER", "root"))
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin"))
database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook"))
# Button to connect to the database
if st.button("Connect"):
with st.spinner("Connecting to database..."):
# Initialize the database connection and store in session state
db = init_database()
if db:
st.session_state.db = db
st.success("Connected to the database!")
else:
st.error("Connection failed. Please check your settings.")
# Display chat history
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
# Display AI message
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
# Display human message
with st.chat_message("Human"):
st.markdown(message.content)
# Input field for user's message
user_query = st.chat_input("Type a message...")
if user_query and user_query.strip():
# Add user's query to the chat history
st.session_state.chat_history.append(HumanMessage(content=user_query))
# Display user's message in the chat
with st.chat_message("Human"):
st.markdown(user_query)
# Generate and display AI's response based on the query
with st.chat_message("AI"):
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)
# Add AI's response to the chat history
st.session_state.chat_history.append(AIMessage(content=response))