Spaces:
Build error
Build error
Suresh Beekhani
commited on
Update app.py
Browse files- src/app.py +64 -70
src/app.py
CHANGED
|
@@ -1,44 +1,44 @@
|
|
| 1 |
# Import necessary libraries and modules for various tasks
|
| 2 |
-
from dotenv import load_dotenv # For loading environment variables from a .env file (
|
| 3 |
-
from langchain_core.messages import AIMessage, HumanMessage # For handling
|
| 4 |
-
from langchain_core.prompts import ChatPromptTemplate # To create templates
|
| 5 |
-
from langchain_core.runnables import RunnablePassthrough # To
|
| 6 |
-
from langchain_community.utilities import SQLDatabase #
|
| 7 |
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
| 8 |
-
from langchain_groq import ChatGroq #
|
| 9 |
-
import streamlit as st # Streamlit
|
| 10 |
-
import os #
|
| 11 |
-
import psycopg2 #
|
| 12 |
|
| 13 |
-
# Load environment variables (
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
-
# Function to
|
| 17 |
def init_database() -> SQLDatabase:
|
| 18 |
try:
|
| 19 |
-
# Retrieve
|
| 20 |
user = os.getenv("DB_USER", "postgres")
|
| 21 |
password = os.getenv("DB_PASSWORD", "beekhani143")
|
| 22 |
host = os.getenv("DB_HOST", "localhost")
|
| 23 |
port = os.getenv("DB_PORT", "5432")
|
| 24 |
database = os.getenv("DB_NAME", "")
|
| 25 |
|
| 26 |
-
# Construct the database URI
|
| 27 |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
| 28 |
|
| 29 |
-
# Connect to the database using
|
| 30 |
return SQLDatabase.from_uri(db_uri)
|
| 31 |
except Exception as e:
|
| 32 |
-
# If connection fails, display
|
| 33 |
-
st.error(f"Failed to connect to database: {e}")
|
| 34 |
return None
|
| 35 |
|
| 36 |
-
# Function to
|
| 37 |
def get_sql_chain(db):
|
| 38 |
-
#
|
| 39 |
template = """
|
| 40 |
-
You are a data analyst
|
| 41 |
-
Based on the table schema below, write a SQL query that
|
| 42 |
|
| 43 |
<SCHEMA>{schema}</SCHEMA>
|
| 44 |
Conversation History: {chat_history}
|
|
@@ -48,62 +48,56 @@ def get_sql_chain(db):
|
|
| 48 |
SQL Query:
|
| 49 |
"""
|
| 50 |
|
| 51 |
-
# Create
|
| 52 |
prompt = ChatPromptTemplate.from_template(template)
|
| 53 |
-
# Initialize the Groq model for generating responses
|
| 54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 55 |
|
| 56 |
-
# Function to
|
| 57 |
def get_schema(_):
|
| 58 |
return db.get_table_info()
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
# 1. First, get the database schema.
|
| 62 |
-
# 2. Then, use the prompt template to guide query creation.
|
| 63 |
-
# 3. Finally, parse the output as a plain text SQL query.
|
| 64 |
return (
|
| 65 |
-
RunnablePassthrough.assign(schema=get_schema) # Pass
|
| 66 |
-
| prompt # Use the prompt template
|
| 67 |
-
| llm # Generate
|
| 68 |
-
| StrOutputParser() # Parse the
|
| 69 |
)
|
| 70 |
|
| 71 |
-
# Function to generate a natural language response based on SQL query and database
|
| 72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
| 73 |
-
#
|
| 74 |
sql_chain = get_sql_chain(db)
|
| 75 |
|
| 76 |
-
#
|
| 77 |
template = """
|
| 78 |
-
You are a data analyst
|
| 79 |
<SCHEMA>{schema}</SCHEMA>
|
| 80 |
Conversation History: {chat_history}
|
| 81 |
SQL Query: <SQL>{query}</SQL>
|
| 82 |
-
User
|
| 83 |
SQL Response: {response}
|
| 84 |
"""
|
| 85 |
|
| 86 |
-
# Create
|
| 87 |
prompt = ChatPromptTemplate.from_template(template)
|
| 88 |
# Initialize the Groq model for response generation
|
| 89 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 90 |
|
| 91 |
-
# Chain the
|
| 92 |
-
# 1. Generate the SQL query using the earlier chain.
|
| 93 |
-
# 2. Get the schema and execute the query on the database.
|
| 94 |
-
# 3. Return the natural language response based on the query and its results.
|
| 95 |
chain = (
|
| 96 |
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
| 97 |
.assign(
|
| 98 |
-
schema=lambda _: db.get_table_info(), # Pass the schema
|
| 99 |
-
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the
|
| 100 |
)
|
| 101 |
-
| prompt #
|
| 102 |
-
| llm #
|
| 103 |
| StrOutputParser() # Parse the output into plain text
|
| 104 |
)
|
| 105 |
|
| 106 |
-
# Invoke the chain to generate the
|
| 107 |
result = chain.invoke({
|
| 108 |
"question": user_query,
|
| 109 |
"chat_history": chat_history,
|
|
@@ -116,24 +110,24 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
|
| 116 |
sql_query = result.get('query', 'No query generated')
|
| 117 |
print(f"SQL Query: {sql_query}")
|
| 118 |
|
| 119 |
-
# Return the
|
| 120 |
return result
|
| 121 |
|
| 122 |
-
# Initialize the chat session
|
| 123 |
if "chat_history" not in st.session_state:
|
| 124 |
st.session_state.chat_history = [
|
| 125 |
-
#
|
| 126 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
| 127 |
]
|
| 128 |
|
| 129 |
-
#
|
| 130 |
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
| 131 |
-
st.title("Chat with PostgreSQL") # Display title on the
|
| 132 |
|
| 133 |
-
# Sidebar
|
| 134 |
with st.sidebar:
|
| 135 |
-
st.subheader("Settings") # Display
|
| 136 |
-
st.write("Connect to your database and start chatting.") #
|
| 137 |
|
| 138 |
# Input fields for database connection details (host, port, user, password, and database name)
|
| 139 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
|
@@ -142,35 +136,35 @@ with st.sidebar:
|
|
| 142 |
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
| 143 |
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
| 144 |
|
| 145 |
-
# Button to
|
| 146 |
if st.button("Connect"):
|
| 147 |
-
with st.spinner("Connecting to database..."):
|
| 148 |
-
db = init_database() #
|
| 149 |
if db:
|
| 150 |
-
st.session_state.db = db # Save the connection
|
| 151 |
st.success("Connected to the database!") # Display success message
|
| 152 |
else:
|
| 153 |
-
st.error("Connection failed. Please check your settings.") #
|
| 154 |
|
| 155 |
# Display the chat history (both AI and user messages)
|
| 156 |
for message in st.session_state.chat_history:
|
| 157 |
if isinstance(message, AIMessage):
|
| 158 |
-
with st.chat_message("AI"):
|
| 159 |
st.markdown(message.content)
|
| 160 |
elif isinstance(message, HumanMessage):
|
| 161 |
-
with st.chat_message("Human"):
|
| 162 |
st.markdown(message.content)
|
| 163 |
|
| 164 |
-
# Input field for the user to type
|
| 165 |
-
user_query = st.chat_input("Type a message...")
|
| 166 |
-
if user_query and user_query.strip():
|
| 167 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save
|
| 168 |
-
|
| 169 |
-
with st.chat_message("Human"): # Display
|
| 170 |
st.markdown(user_query)
|
| 171 |
|
| 172 |
-
with st.chat_message("AI"): # Generate and display
|
| 173 |
-
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get
|
| 174 |
st.markdown(response)
|
| 175 |
|
| 176 |
-
st.session_state.chat_history.append(AIMessage(content=response)) #
|
|
|
|
| 1 |
# Import necessary libraries and modules for various tasks
|
| 2 |
+
from dotenv import load_dotenv # For loading environment variables from a .env file (e.g., database credentials)
|
| 3 |
+
from langchain_core.messages import AIMessage, HumanMessage # For handling AI and user messages
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate # To create templates for chatbot responses
|
| 5 |
+
from langchain_core.runnables import RunnablePassthrough # To chain different operations (e.g., inputs/outputs)
|
| 6 |
+
from langchain_community.utilities import SQLDatabase # Utility to connect to SQL databases using LangChain
|
| 7 |
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
| 8 |
+
from langchain_groq import ChatGroq # Integrates the Groq model for generating chat responses
|
| 9 |
+
import streamlit as st # Streamlit for building the web interface
|
| 10 |
+
import os # For accessing environment variables (e.g., credentials or other settings)
|
| 11 |
+
import psycopg2 # PostgreSQL database adapter for database connections
|
| 12 |
|
| 13 |
+
# Load environment variables (e.g., database credentials) from the .env file
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
+
# Function to initialize the database connection
|
| 17 |
def init_database() -> SQLDatabase:
|
| 18 |
try:
|
| 19 |
+
# Retrieve connection details from environment variables or set defaults
|
| 20 |
user = os.getenv("DB_USER", "postgres")
|
| 21 |
password = os.getenv("DB_PASSWORD", "beekhani143")
|
| 22 |
host = os.getenv("DB_HOST", "localhost")
|
| 23 |
port = os.getenv("DB_PORT", "5432")
|
| 24 |
database = os.getenv("DB_NAME", "")
|
| 25 |
|
| 26 |
+
# Construct the database URI for PostgreSQL connection
|
| 27 |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
| 28 |
|
| 29 |
+
# Connect to the database using SQLDatabase utility and return the connection
|
| 30 |
return SQLDatabase.from_uri(db_uri)
|
| 31 |
except Exception as e:
|
| 32 |
+
# If connection fails, display error in the Streamlit UI
|
| 33 |
+
st.error(f"Failed to connect to the database: {e}")
|
| 34 |
return None
|
| 35 |
|
| 36 |
+
# Function to generate the SQL query chain based on user input and conversation history
|
| 37 |
def get_sql_chain(db):
|
| 38 |
+
# Define the prompt template for generating SQL queries based on schema and chat history
|
| 39 |
template = """
|
| 40 |
+
You are a data analyst. You are interacting with a user who is asking questions about the company's database.
|
| 41 |
+
Based on the table schema below, write a SQL query that answers the user's question. Consider the conversation history.
|
| 42 |
|
| 43 |
<SCHEMA>{schema}</SCHEMA>
|
| 44 |
Conversation History: {chat_history}
|
|
|
|
| 48 |
SQL Query:
|
| 49 |
"""
|
| 50 |
|
| 51 |
+
# Create the prompt template from the instructions above
|
| 52 |
prompt = ChatPromptTemplate.from_template(template)
|
| 53 |
+
# Initialize the Groq model for generating SQL responses (deterministic output with temperature=0)
|
| 54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 55 |
|
| 56 |
+
# Function to retrieve the database schema (table structure)
|
| 57 |
def get_schema(_):
|
| 58 |
return db.get_table_info()
|
| 59 |
|
| 60 |
+
# Build a chain of operations: get the schema, generate SQL query, and parse as plain text
|
|
|
|
|
|
|
|
|
|
| 61 |
return (
|
| 62 |
+
RunnablePassthrough.assign(schema=get_schema) # Pass schema to the chain
|
| 63 |
+
| prompt # Use the prompt template to guide query creation
|
| 64 |
+
| llm # Generate SQL query using the Groq model
|
| 65 |
+
| StrOutputParser() # Parse the output as a plain text SQL query
|
| 66 |
)
|
| 67 |
|
| 68 |
+
# Function to generate a natural language response based on SQL query and database results
|
| 69 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
| 70 |
+
# Generate the SQL query using the SQL chain
|
| 71 |
sql_chain = get_sql_chain(db)
|
| 72 |
|
| 73 |
+
# Define the template for generating natural language responses
|
| 74 |
template = """
|
| 75 |
+
You are a data analyst. Based on the schema, SQL query, and response, write a natural language response.
|
| 76 |
<SCHEMA>{schema}</SCHEMA>
|
| 77 |
Conversation History: {chat_history}
|
| 78 |
SQL Query: <SQL>{query}</SQL>
|
| 79 |
+
User Question: {question}
|
| 80 |
SQL Response: {response}
|
| 81 |
"""
|
| 82 |
|
| 83 |
+
# Create the prompt template for generating a natural language response
|
| 84 |
prompt = ChatPromptTemplate.from_template(template)
|
| 85 |
# Initialize the Groq model for response generation
|
| 86 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
| 87 |
|
| 88 |
+
# Chain the steps: Generate SQL, execute it, and generate the natural language response
|
|
|
|
|
|
|
|
|
|
| 89 |
chain = (
|
| 90 |
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
| 91 |
.assign(
|
| 92 |
+
schema=lambda _: db.get_table_info(), # Pass the schema for query execution
|
| 93 |
+
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the query and clean the output
|
| 94 |
)
|
| 95 |
+
| prompt # Generate a response from the AI using the prompt
|
| 96 |
+
| llm # Get the final response from the model
|
| 97 |
| StrOutputParser() # Parse the output into plain text
|
| 98 |
)
|
| 99 |
|
| 100 |
+
# Invoke the chain to generate the response
|
| 101 |
result = chain.invoke({
|
| 102 |
"question": user_query,
|
| 103 |
"chat_history": chat_history,
|
|
|
|
| 110 |
sql_query = result.get('query', 'No query generated')
|
| 111 |
print(f"SQL Query: {sql_query}")
|
| 112 |
|
| 113 |
+
# Return the final natural language response
|
| 114 |
return result
|
| 115 |
|
| 116 |
+
# Initialize the chat session in Streamlit
|
| 117 |
if "chat_history" not in st.session_state:
|
| 118 |
st.session_state.chat_history = [
|
| 119 |
+
# Initial greeting from the AI assistant
|
| 120 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
| 121 |
]
|
| 122 |
|
| 123 |
+
# Configure the Streamlit page title and icon
|
| 124 |
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
| 125 |
+
st.title("Chat with PostgreSQL") # Display the main title on the web page
|
| 126 |
|
| 127 |
+
# Sidebar for database connection settings
|
| 128 |
with st.sidebar:
|
| 129 |
+
st.subheader("Settings") # Display settings section header
|
| 130 |
+
st.write("Connect to your database and start chatting.") # Instructions for users
|
| 131 |
|
| 132 |
# Input fields for database connection details (host, port, user, password, and database name)
|
| 133 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
|
|
|
| 136 |
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
| 137 |
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
| 138 |
|
| 139 |
+
# Button to initiate database connection
|
| 140 |
if st.button("Connect"):
|
| 141 |
+
with st.spinner("Connecting to the database..."):
|
| 142 |
+
db = init_database() # Attempt to connect to the database
|
| 143 |
if db:
|
| 144 |
+
st.session_state.db = db # Save the connection to session state
|
| 145 |
st.success("Connected to the database!") # Display success message
|
| 146 |
else:
|
| 147 |
+
st.error("Connection failed. Please check your settings.") # Error message if connection fails
|
| 148 |
|
| 149 |
# Display the chat history (both AI and user messages)
|
| 150 |
for message in st.session_state.chat_history:
|
| 151 |
if isinstance(message, AIMessage):
|
| 152 |
+
with st.chat_message("AI"):
|
| 153 |
st.markdown(message.content)
|
| 154 |
elif isinstance(message, HumanMessage):
|
| 155 |
+
with st.chat_message("Human"):
|
| 156 |
st.markdown(message.content)
|
| 157 |
|
| 158 |
+
# Input field for the user to type a message
|
| 159 |
+
user_query = st.chat_input("Type a message...")
|
| 160 |
+
if user_query and user_query.strip():
|
| 161 |
+
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save user query
|
| 162 |
+
|
| 163 |
+
with st.chat_message("Human"): # Display user's message in chat
|
| 164 |
st.markdown(user_query)
|
| 165 |
|
| 166 |
+
with st.chat_message("AI"): # Generate and display AI's response
|
| 167 |
+
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get AI response
|
| 168 |
st.markdown(response)
|
| 169 |
|
| 170 |
+
st.session_state.chat_history.append(AIMessage(content=response)) # Save AI's response to chat history
|