Spaces:
Sleeping
Sleeping
Suresh Beekhani
commited on
Update app.py
Browse files- src/app.py +64 -70
src/app.py
CHANGED
@@ -1,44 +1,44 @@
|
|
1 |
# Import necessary libraries and modules for various tasks
|
2 |
-
from dotenv import load_dotenv # For loading environment variables from a .env file (
|
3 |
-
from langchain_core.messages import AIMessage, HumanMessage # For handling
|
4 |
-
from langchain_core.prompts import ChatPromptTemplate # To create templates
|
5 |
-
from langchain_core.runnables import RunnablePassthrough # To
|
6 |
-
from langchain_community.utilities import SQLDatabase #
|
7 |
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
8 |
-
from langchain_groq import ChatGroq #
|
9 |
-
import streamlit as st # Streamlit
|
10 |
-
import os #
|
11 |
-
import psycopg2 #
|
12 |
|
13 |
-
# Load environment variables (
|
14 |
load_dotenv()
|
15 |
|
16 |
-
# Function to
|
17 |
def init_database() -> SQLDatabase:
|
18 |
try:
|
19 |
-
# Retrieve
|
20 |
user = os.getenv("DB_USER", "postgres")
|
21 |
password = os.getenv("DB_PASSWORD", "beekhani143")
|
22 |
host = os.getenv("DB_HOST", "localhost")
|
23 |
port = os.getenv("DB_PORT", "5432")
|
24 |
database = os.getenv("DB_NAME", "")
|
25 |
|
26 |
-
# Construct the database URI
|
27 |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
28 |
|
29 |
-
# Connect to the database using
|
30 |
return SQLDatabase.from_uri(db_uri)
|
31 |
except Exception as e:
|
32 |
-
# If connection fails, display
|
33 |
-
st.error(f"Failed to connect to database: {e}")
|
34 |
return None
|
35 |
|
36 |
-
# Function to
|
37 |
def get_sql_chain(db):
|
38 |
-
#
|
39 |
template = """
|
40 |
-
You are a data analyst
|
41 |
-
Based on the table schema below, write a SQL query that
|
42 |
|
43 |
<SCHEMA>{schema}</SCHEMA>
|
44 |
Conversation History: {chat_history}
|
@@ -48,62 +48,56 @@ def get_sql_chain(db):
|
|
48 |
SQL Query:
|
49 |
"""
|
50 |
|
51 |
-
# Create
|
52 |
prompt = ChatPromptTemplate.from_template(template)
|
53 |
-
# Initialize the Groq model for generating responses
|
54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
55 |
|
56 |
-
# Function to
|
57 |
def get_schema(_):
|
58 |
return db.get_table_info()
|
59 |
|
60 |
-
#
|
61 |
-
# 1. First, get the database schema.
|
62 |
-
# 2. Then, use the prompt template to guide query creation.
|
63 |
-
# 3. Finally, parse the output as a plain text SQL query.
|
64 |
return (
|
65 |
-
RunnablePassthrough.assign(schema=get_schema) # Pass
|
66 |
-
| prompt # Use the prompt template
|
67 |
-
| llm # Generate
|
68 |
-
| StrOutputParser() # Parse the
|
69 |
)
|
70 |
|
71 |
-
# Function to generate a natural language response based on SQL query and database
|
72 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
73 |
-
#
|
74 |
sql_chain = get_sql_chain(db)
|
75 |
|
76 |
-
#
|
77 |
template = """
|
78 |
-
You are a data analyst
|
79 |
<SCHEMA>{schema}</SCHEMA>
|
80 |
Conversation History: {chat_history}
|
81 |
SQL Query: <SQL>{query}</SQL>
|
82 |
-
User
|
83 |
SQL Response: {response}
|
84 |
"""
|
85 |
|
86 |
-
# Create
|
87 |
prompt = ChatPromptTemplate.from_template(template)
|
88 |
# Initialize the Groq model for response generation
|
89 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
90 |
|
91 |
-
# Chain the
|
92 |
-
# 1. Generate the SQL query using the earlier chain.
|
93 |
-
# 2. Get the schema and execute the query on the database.
|
94 |
-
# 3. Return the natural language response based on the query and its results.
|
95 |
chain = (
|
96 |
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
97 |
.assign(
|
98 |
-
schema=lambda _: db.get_table_info(), # Pass the schema
|
99 |
-
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the
|
100 |
)
|
101 |
-
| prompt #
|
102 |
-
| llm #
|
103 |
| StrOutputParser() # Parse the output into plain text
|
104 |
)
|
105 |
|
106 |
-
# Invoke the chain to generate the
|
107 |
result = chain.invoke({
|
108 |
"question": user_query,
|
109 |
"chat_history": chat_history,
|
@@ -116,24 +110,24 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
|
116 |
sql_query = result.get('query', 'No query generated')
|
117 |
print(f"SQL Query: {sql_query}")
|
118 |
|
119 |
-
# Return the
|
120 |
return result
|
121 |
|
122 |
-
# Initialize the chat session
|
123 |
if "chat_history" not in st.session_state:
|
124 |
st.session_state.chat_history = [
|
125 |
-
#
|
126 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
127 |
]
|
128 |
|
129 |
-
#
|
130 |
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
131 |
-
st.title("Chat with PostgreSQL") # Display title on the
|
132 |
|
133 |
-
# Sidebar
|
134 |
with st.sidebar:
|
135 |
-
st.subheader("Settings") # Display
|
136 |
-
st.write("Connect to your database and start chatting.") #
|
137 |
|
138 |
# Input fields for database connection details (host, port, user, password, and database name)
|
139 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
@@ -142,35 +136,35 @@ with st.sidebar:
|
|
142 |
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
143 |
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
144 |
|
145 |
-
# Button to
|
146 |
if st.button("Connect"):
|
147 |
-
with st.spinner("Connecting to database..."):
|
148 |
-
db = init_database() #
|
149 |
if db:
|
150 |
-
st.session_state.db = db # Save the connection
|
151 |
st.success("Connected to the database!") # Display success message
|
152 |
else:
|
153 |
-
st.error("Connection failed. Please check your settings.") #
|
154 |
|
155 |
# Display the chat history (both AI and user messages)
|
156 |
for message in st.session_state.chat_history:
|
157 |
if isinstance(message, AIMessage):
|
158 |
-
with st.chat_message("AI"):
|
159 |
st.markdown(message.content)
|
160 |
elif isinstance(message, HumanMessage):
|
161 |
-
with st.chat_message("Human"):
|
162 |
st.markdown(message.content)
|
163 |
|
164 |
-
# Input field for the user to type
|
165 |
-
user_query = st.chat_input("Type a message...")
|
166 |
-
if user_query and user_query.strip():
|
167 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save
|
168 |
-
|
169 |
-
with st.chat_message("Human"): # Display
|
170 |
st.markdown(user_query)
|
171 |
|
172 |
-
with st.chat_message("AI"): # Generate and display
|
173 |
-
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get
|
174 |
st.markdown(response)
|
175 |
|
176 |
-
st.session_state.chat_history.append(AIMessage(content=response)) #
|
|
|
1 |
# Import necessary libraries and modules for various tasks
|
2 |
+
from dotenv import load_dotenv # For loading environment variables from a .env file (e.g., database credentials)
|
3 |
+
from langchain_core.messages import AIMessage, HumanMessage # For handling AI and user messages
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate # To create templates for chatbot responses
|
5 |
+
from langchain_core.runnables import RunnablePassthrough # To chain different operations (e.g., inputs/outputs)
|
6 |
+
from langchain_community.utilities import SQLDatabase # Utility to connect to SQL databases using LangChain
|
7 |
from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
|
8 |
+
from langchain_groq import ChatGroq # Integrates the Groq model for generating chat responses
|
9 |
+
import streamlit as st # Streamlit for building the web interface
|
10 |
+
import os # For accessing environment variables (e.g., credentials or other settings)
|
11 |
+
import psycopg2 # PostgreSQL database adapter for database connections
|
12 |
|
13 |
+
# Load environment variables (e.g., database credentials) from the .env file
|
14 |
load_dotenv()
|
15 |
|
16 |
+
# Function to initialize the database connection
|
17 |
def init_database() -> SQLDatabase:
|
18 |
try:
|
19 |
+
# Retrieve connection details from environment variables or set defaults
|
20 |
user = os.getenv("DB_USER", "postgres")
|
21 |
password = os.getenv("DB_PASSWORD", "beekhani143")
|
22 |
host = os.getenv("DB_HOST", "localhost")
|
23 |
port = os.getenv("DB_PORT", "5432")
|
24 |
database = os.getenv("DB_NAME", "")
|
25 |
|
26 |
+
# Construct the database URI for PostgreSQL connection
|
27 |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
28 |
|
29 |
+
# Connect to the database using SQLDatabase utility and return the connection
|
30 |
return SQLDatabase.from_uri(db_uri)
|
31 |
except Exception as e:
|
32 |
+
# If connection fails, display error in the Streamlit UI
|
33 |
+
st.error(f"Failed to connect to the database: {e}")
|
34 |
return None
|
35 |
|
36 |
+
# Function to generate the SQL query chain based on user input and conversation history
|
37 |
def get_sql_chain(db):
|
38 |
+
# Define the prompt template for generating SQL queries based on schema and chat history
|
39 |
template = """
|
40 |
+
You are a data analyst. You are interacting with a user who is asking questions about the company's database.
|
41 |
+
Based on the table schema below, write a SQL query that answers the user's question. Consider the conversation history.
|
42 |
|
43 |
<SCHEMA>{schema}</SCHEMA>
|
44 |
Conversation History: {chat_history}
|
|
|
48 |
SQL Query:
|
49 |
"""
|
50 |
|
51 |
+
# Create the prompt template from the instructions above
|
52 |
prompt = ChatPromptTemplate.from_template(template)
|
53 |
+
# Initialize the Groq model for generating SQL responses (deterministic output with temperature=0)
|
54 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
55 |
|
56 |
+
# Function to retrieve the database schema (table structure)
|
57 |
def get_schema(_):
|
58 |
return db.get_table_info()
|
59 |
|
60 |
+
# Build a chain of operations: get the schema, generate SQL query, and parse as plain text
|
|
|
|
|
|
|
61 |
return (
|
62 |
+
RunnablePassthrough.assign(schema=get_schema) # Pass schema to the chain
|
63 |
+
| prompt # Use the prompt template to guide query creation
|
64 |
+
| llm # Generate SQL query using the Groq model
|
65 |
+
| StrOutputParser() # Parse the output as a plain text SQL query
|
66 |
)
|
67 |
|
68 |
+
# Function to generate a natural language response based on SQL query and database results
|
69 |
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
|
70 |
+
# Generate the SQL query using the SQL chain
|
71 |
sql_chain = get_sql_chain(db)
|
72 |
|
73 |
+
# Define the template for generating natural language responses
|
74 |
template = """
|
75 |
+
You are a data analyst. Based on the schema, SQL query, and response, write a natural language response.
|
76 |
<SCHEMA>{schema}</SCHEMA>
|
77 |
Conversation History: {chat_history}
|
78 |
SQL Query: <SQL>{query}</SQL>
|
79 |
+
User Question: {question}
|
80 |
SQL Response: {response}
|
81 |
"""
|
82 |
|
83 |
+
# Create the prompt template for generating a natural language response
|
84 |
prompt = ChatPromptTemplate.from_template(template)
|
85 |
# Initialize the Groq model for response generation
|
86 |
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
87 |
|
88 |
+
# Chain the steps: Generate SQL, execute it, and generate the natural language response
|
|
|
|
|
|
|
89 |
chain = (
|
90 |
RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
|
91 |
.assign(
|
92 |
+
schema=lambda _: db.get_table_info(), # Pass the schema for query execution
|
93 |
+
response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the query and clean the output
|
94 |
)
|
95 |
+
| prompt # Generate a response from the AI using the prompt
|
96 |
+
| llm # Get the final response from the model
|
97 |
| StrOutputParser() # Parse the output into plain text
|
98 |
)
|
99 |
|
100 |
+
# Invoke the chain to generate the response
|
101 |
result = chain.invoke({
|
102 |
"question": user_query,
|
103 |
"chat_history": chat_history,
|
|
|
110 |
sql_query = result.get('query', 'No query generated')
|
111 |
print(f"SQL Query: {sql_query}")
|
112 |
|
113 |
+
# Return the final natural language response
|
114 |
return result
|
115 |
|
116 |
+
# Initialize the chat session in Streamlit
|
117 |
if "chat_history" not in st.session_state:
|
118 |
st.session_state.chat_history = [
|
119 |
+
# Initial greeting from the AI assistant
|
120 |
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
|
121 |
]
|
122 |
|
123 |
+
# Configure the Streamlit page title and icon
|
124 |
st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
|
125 |
+
st.title("Chat with PostgreSQL") # Display the main title on the web page
|
126 |
|
127 |
+
# Sidebar for database connection settings
|
128 |
with st.sidebar:
|
129 |
+
st.subheader("Settings") # Display settings section header
|
130 |
+
st.write("Connect to your database and start chatting.") # Instructions for users
|
131 |
|
132 |
# Input fields for database connection details (host, port, user, password, and database name)
|
133 |
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
|
|
|
136 |
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
|
137 |
database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
|
138 |
|
139 |
+
# Button to initiate database connection
|
140 |
if st.button("Connect"):
|
141 |
+
with st.spinner("Connecting to the database..."):
|
142 |
+
db = init_database() # Attempt to connect to the database
|
143 |
if db:
|
144 |
+
st.session_state.db = db # Save the connection to session state
|
145 |
st.success("Connected to the database!") # Display success message
|
146 |
else:
|
147 |
+
st.error("Connection failed. Please check your settings.") # Error message if connection fails
|
148 |
|
149 |
# Display the chat history (both AI and user messages)
|
150 |
for message in st.session_state.chat_history:
|
151 |
if isinstance(message, AIMessage):
|
152 |
+
with st.chat_message("AI"):
|
153 |
st.markdown(message.content)
|
154 |
elif isinstance(message, HumanMessage):
|
155 |
+
with st.chat_message("Human"):
|
156 |
st.markdown(message.content)
|
157 |
|
158 |
+
# Input field for the user to type a message
|
159 |
+
user_query = st.chat_input("Type a message...")
|
160 |
+
if user_query and user_query.strip():
|
161 |
+
st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save user query
|
162 |
+
|
163 |
+
with st.chat_message("Human"): # Display user's message in chat
|
164 |
st.markdown(user_query)
|
165 |
|
166 |
+
with st.chat_message("AI"): # Generate and display AI's response
|
167 |
+
response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get AI response
|
168 |
st.markdown(response)
|
169 |
|
170 |
+
st.session_state.chat_history.append(AIMessage(content=response)) # Save AI's response to chat history
|