Spaces:
Sleeping
Sleeping
Suresh Beekhani
commited on
Update app.py
Browse files- src/app.py +97 -93
src/app.py
CHANGED
@@ -1,40 +1,41 @@
|
|
1 |
-
# Import necessary libraries and modules
|
2 |
-
from dotenv import load_dotenv
|
3 |
-
from langchain_core.messages import AIMessage, HumanMessage #
|
4 |
-
from langchain_core.prompts import ChatPromptTemplate #
|
5 |
-
from langchain_core.runnables import RunnablePassthrough # To
|
6 |
-
from langchain_community.utilities import SQLDatabase # SQL
|
7 |
-
from langchain_core.output_parsers import StrOutputParser # To parse outputs
|
8 |
-
#
|
9 |
-
|
10 |
-
import
|
11 |
-
import
|
12 |
-
|
13 |
-
# Load environment variables from the .env file
|
14 |
load_dotenv()
|
15 |
|
16 |
-
# Function to
|
17 |
def init_database() -> SQLDatabase:
|
18 |
try:
|
19 |
-
#
|
20 |
-
user = os.getenv("DB_USER", "
|
21 |
-
password = os.getenv("DB_PASSWORD", "
|
22 |
host = os.getenv("DB_HOST", "localhost")
|
23 |
-
port = os.getenv("DB_PORT", "
|
24 |
-
database = os.getenv("DB_NAME", "
|
25 |
|
26 |
-
# Construct the database URI
|
27 |
-
db_uri = f"
|
28 |
|
29 |
-
#
|
30 |
return SQLDatabase.from_uri(db_uri)
|
31 |
except Exception as e:
|
|
|
32 |
st.error(f"Failed to connect to database: {e}")
|
33 |
return None
|
34 |
|
35 |
-
# Function to create a chain that generates SQL queries
|
36 |
def get_sql_chain(db):
|
37 |
-
# SQL
|
38 |
template = """
|
39 |
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
|
40 |
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
|
@@ -47,33 +48,32 @@ def get_sql_chain(db):
|
|
47 |
SQL Query:
|
48 |
"""
|
49 |
|
50 |
-
# Create a prompt from the above
|
51 |
prompt = ChatPromptTemplate.from_template(template)
|
52 |
-
|
53 |
-
# Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
|
54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
55 |
|
56 |
-
#
|
57 |
def get_schema(_):
|
58 |
return db.get_table_info()
|
59 |
|
60 |
-
#
|
61 |
-
# 1.
|
62 |
-
# 2.
|
63 |
-
# 3.
|
64 |
return (
|
65 |
-
RunnablePassthrough.assign(schema=get_schema) #
|
66 |
-
| prompt #
|
67 |
-
| llm #
|
68 |
-
| StrOutputParser() # Parse the
|
69 |
)
|
70 |
|
71 |
-
# Function to generate a
|
72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
73 |
-
#
|
74 |
sql_chain = get_sql_chain(db)
|
75 |
-
|
76 |
-
#
|
77 |
template = """
|
78 |
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
|
79 |
<SCHEMA>{schema}</SCHEMA>
|
@@ -83,90 +83,94 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
|
83 |
SQL Response: {response}
|
84 |
"""
|
85 |
|
86 |
-
# Create a
|
87 |
prompt = ChatPromptTemplate.from_template(template)
|
88 |
-
|
89 |
-
# Initialize Groq model (alternative: OpenAI)
|
90 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
91 |
-
|
92 |
-
#
|
|
|
|
|
|
|
93 |
chain = (
|
94 |
-
RunnablePassthrough.assign(query=sql_chain)
|
95 |
-
|
96 |
-
|
|
|
97 |
)
|
98 |
-
| prompt # Use prompt
|
99 |
-
| llm #
|
100 |
-
| StrOutputParser() # Parse the
|
101 |
)
|
102 |
-
|
103 |
-
#
|
104 |
-
|
105 |
"question": user_query,
|
106 |
"chat_history": chat_history,
|
107 |
})
|
108 |
|
109 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
if "chat_history" not in st.session_state:
|
111 |
-
# Initialize chat history with a welcome message from AI
|
112 |
st.session_state.chat_history = [
|
|
|
113 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
114 |
]
|
115 |
|
116 |
-
#
|
117 |
-
st.set_page_config(page_title="Chat with
|
118 |
-
|
119 |
-
# Streamlit app title
|
120 |
-
st.title("Chat with MySQL")
|
121 |
|
122 |
-
# Sidebar for database connection settings
|
123 |
with st.sidebar:
|
124 |
-
st.subheader("Settings")
|
125 |
-
st.write("Connect to your database and start chatting.")
|
126 |
|
127 |
-
#
|
128 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
129 |
-
port = st.text_input("Port", value=os.getenv("DB_PORT", "
|
130 |
-
user = st.text_input("User", value=os.getenv("DB_USER", "
|
131 |
-
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "
|
132 |
-
database = st.text_input("Database", value=os.getenv("DB_NAME", "
|
133 |
|
134 |
-
# Button to
|
135 |
if st.button("Connect"):
|
136 |
-
with st.spinner("Connecting to database..."):
|
137 |
-
#
|
138 |
-
db = init_database()
|
139 |
if db:
|
140 |
-
st.session_state.db = db
|
141 |
-
st.success("Connected to the database!")
|
142 |
else:
|
143 |
-
st.error("Connection failed. Please check your settings.")
|
144 |
|
145 |
-
# Display chat history
|
146 |
for message in st.session_state.chat_history:
|
147 |
if isinstance(message, AIMessage):
|
148 |
-
# Display AI
|
149 |
-
with st.chat_message("AI"):
|
150 |
st.markdown(message.content)
|
151 |
elif isinstance(message, HumanMessage):
|
152 |
-
# Display human
|
153 |
-
with st.chat_message("Human"):
|
154 |
st.markdown(message.content)
|
155 |
|
156 |
-
# Input field for user
|
157 |
-
user_query = st.chat_input("Type a message...")
|
158 |
-
if user_query and user_query.strip():
|
159 |
-
#
|
160 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
161 |
|
162 |
-
# Display user's message in the chat
|
163 |
-
with st.chat_message("Human"):
|
164 |
st.markdown(user_query)
|
165 |
|
166 |
-
# Generate and display AI
|
167 |
-
|
168 |
-
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
|
169 |
st.markdown(response)
|
170 |
|
171 |
-
# Add AI's response to the chat history
|
172 |
-
st.session_state.chat_history.append(AIMessage(content=response))
|
|
|
1 |
+
# Import necessary libraries and modules for various tasks
|
2 |
+
from dotenv import load_dotenv # For loading environment variables from a .env file (such as database credentials)
|
3 |
+
from langchain_core.messages import AIMessage, HumanMessage # For handling messages from the AI and user
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate # To create templates that will guide the chatbot's responses
|
5 |
+
from langchain_core.runnables import RunnablePassthrough # To enable chaining of different operations (like inputs/outputs)
|
6 |
+
from langchain_community.utilities import SQLDatabase # A tool to help connect to SQL databases using LangChain
|
7 |
+
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
8 |
+
from langchain_groq import ChatGroq # This integrates the Groq model for generating chat responses
|
9 |
+
import streamlit as st # Streamlit is used for building the web app (user interface)
|
10 |
+
import os # To access environment variables (e.g., credentials or other settings)
|
11 |
+
import psycopg2 # A PostgreSQL database adapter to enable connections to the database
|
12 |
+
|
13 |
+
# Load environment variables (such as DB credentials) from the .env file
|
14 |
load_dotenv()
|
15 |
|
16 |
+
# Function to establish a connection to the PostgreSQL database
|
17 |
def init_database() -> SQLDatabase:
|
18 |
try:
|
19 |
+
# Retrieve database connection details from environment variables, or set default values
|
20 |
+
user = os.getenv("DB_USER", "postgres")
|
21 |
+
password = os.getenv("DB_PASSWORD", "beekhani143")
|
22 |
host = os.getenv("DB_HOST", "localhost")
|
23 |
+
port = os.getenv("DB_PORT", "5432")
|
24 |
+
database = os.getenv("DB_NAME", "")
|
25 |
|
26 |
+
# Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
|
27 |
+
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
28 |
|
29 |
+
# Connect to the database using the SQLDatabase utility and return the instance
|
30 |
return SQLDatabase.from_uri(db_uri)
|
31 |
except Exception as e:
|
32 |
+
# If connection fails, display an error message on the Streamlit UI
|
33 |
st.error(f"Failed to connect to database: {e}")
|
34 |
return None
|
35 |
|
36 |
+
# Function to create a process (chain) that generates SQL queries based on user input and previous conversation
|
37 |
def get_sql_chain(db):
|
38 |
+
# Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
|
39 |
template = """
|
40 |
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
|
41 |
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
|
|
|
48 |
SQL Query:
|
49 |
"""
|
50 |
|
51 |
+
# Create a prompt template from the above instructions
|
52 |
prompt = ChatPromptTemplate.from_template(template)
|
53 |
+
# Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
|
|
|
54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
55 |
|
56 |
+
# Function to get the schema (structure) of the tables in the database
|
57 |
def get_schema(_):
|
58 |
return db.get_table_info()
|
59 |
|
60 |
+
# Create a chain of operations:
|
61 |
+
# 1. First, get the database schema.
|
62 |
+
# 2. Then, use the prompt template to guide query creation.
|
63 |
+
# 3. Finally, parse the output as a plain text SQL query.
|
64 |
return (
|
65 |
+
RunnablePassthrough.assign(schema=get_schema) # Pass the schema into the chain
|
66 |
+
| prompt # Use the prompt template
|
67 |
+
| llm # Generate a response using the Groq model
|
68 |
+
| StrOutputParser() # Parse the response as a string (SQL query)
|
69 |
)
|
70 |
|
71 |
+
# Function to generate a natural language response based on SQL query and database result
|
72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
73 |
+
# First, get the SQL chain (responsible for generating SQL queries)
|
74 |
sql_chain = get_sql_chain(db)
|
75 |
+
|
76 |
+
# Template to guide how the AI responds to the user's query based on the SQL results
|
77 |
template = """
|
78 |
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
|
79 |
<SCHEMA>{schema}</SCHEMA>
|
|
|
83 |
SQL Response: {response}
|
84 |
"""
|
85 |
|
86 |
+
# Create a new prompt template for generating a response
|
87 |
prompt = ChatPromptTemplate.from_template(template)
|
88 |
+
# Initialize the Groq model for response generation
|
|
|
89 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
90 |
+
|
91 |
+
# Chain the following tasks:
|
92 |
+
# 1. Generate the SQL query using the earlier chain.
|
93 |
+
# 2. Get the schema and execute the query on the database.
|
94 |
+
# 3. Return the natural language response based on the query and its results.
|
95 |
chain = (
|
96 |
+
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
97 |
+
.assign(
|
98 |
+
schema=lambda _: db.get_table_info(), # Pass the schema to the next step
|
99 |
+
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the SQL query and clean up backslashes
|
100 |
)
|
101 |
+
| prompt # Use the prompt template for generating natural language response
|
102 |
+
| llm # Generate the final response using the model
|
103 |
+
| StrOutputParser() # Parse the output into plain text
|
104 |
)
|
105 |
+
|
106 |
+
# Invoke the chain to generate the final response based on the user query and history
|
107 |
+
result = chain.invoke({
|
108 |
"question": user_query,
|
109 |
"chat_history": chat_history,
|
110 |
})
|
111 |
|
112 |
+
# Debugging: Print the SQL query being executed
|
113 |
+
if isinstance(result, str):
|
114 |
+
print(f"SQL Query: {result}")
|
115 |
+
else:
|
116 |
+
sql_query = result.get('query', 'No query generated')
|
117 |
+
print(f"SQL Query: {sql_query}")
|
118 |
+
|
119 |
+
# Return the result (natural language response)
|
120 |
+
return result
|
121 |
+
|
122 |
+
# Initialize the chat session when Streamlit app starts
|
123 |
if "chat_history" not in st.session_state:
|
|
|
124 |
st.session_state.chat_history = [
|
125 |
+
# First message from AI assistant
|
126 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
127 |
]
|
128 |
|
129 |
+
# Streamlit page configuration: Set the page title and icon
|
130 |
+
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
131 |
+
st.title("Chat with PostgreSQL") # Display title on the webpage
|
|
|
|
|
132 |
|
133 |
+
# Sidebar configuration for database connection settings
|
134 |
with st.sidebar:
|
135 |
+
st.subheader("Settings") # Display a heading for the settings section
|
136 |
+
st.write("Connect to your database and start chatting.") # Instruction text for users
|
137 |
|
138 |
+
# Input fields for database connection details (host, port, user, password, and database name)
|
139 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
140 |
+
port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
|
141 |
+
user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
|
142 |
+
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
143 |
+
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
144 |
|
145 |
+
# Button to attempt database connection
|
146 |
if st.button("Connect"):
|
147 |
+
with st.spinner("Connecting to database..."): # Show a spinner while connecting
|
148 |
+
db = init_database() # Call the function to connect to the database
|
|
|
149 |
if db:
|
150 |
+
st.session_state.db = db # Save the connection in session state
|
151 |
+
st.success("Connected to the database!") # Display success message
|
152 |
else:
|
153 |
+
st.error("Connection failed. Please check your settings.") # Display error message if connection fails
|
154 |
|
155 |
+
# Display the chat history (both AI and user messages)
|
156 |
for message in st.session_state.chat_history:
|
157 |
if isinstance(message, AIMessage):
|
158 |
+
with st.chat_message("AI"): # Display AI messages in the chat
|
|
|
159 |
st.markdown(message.content)
|
160 |
elif isinstance(message, HumanMessage):
|
161 |
+
with st.chat_message("Human"): # Display human messages in the chat
|
|
|
162 |
st.markdown(message.content)
|
163 |
|
164 |
+
# Input field for the user to type their message
|
165 |
+
user_query = st.chat_input("Type a message...") # Field to capture user query
|
166 |
+
if user_query and user_query.strip(): # If the user entered a valid query
|
167 |
+
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save the user query in chat history
|
|
|
168 |
|
169 |
+
with st.chat_message("Human"): # Display the user's message in the chat
|
|
|
170 |
st.markdown(user_query)
|
171 |
|
172 |
+
with st.chat_message("AI"): # Generate and display the AI response
|
173 |
+
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get the AI's response
|
|
|
174 |
st.markdown(response)
|
175 |
|
176 |
+
st.session_state.chat_history.append(AIMessage(content=response)) # Add the AI's response to the chat history
|
|