Suresh Beekhani commited on
Commit
d39a66d
·
unverified ·
1 Parent(s): 55fd932

Update app.py

Browse files
Files changed (1) hide show
  1. src/app.py +64 -70
src/app.py CHANGED
@@ -1,44 +1,44 @@
1
  # Import necessary libraries and modules for various tasks
2
- from dotenv import load_dotenv # For loading environment variables from a .env file (such as database credentials)
3
- from langchain_core.messages import AIMessage, HumanMessage # For handling messages from the AI and user
4
- from langchain_core.prompts import ChatPromptTemplate # To create templates that will guide the chatbot's responses
5
- from langchain_core.runnables import RunnablePassthrough # To enable chaining of different operations (like inputs/outputs)
6
- from langchain_community.utilities import SQLDatabase # A tool to help connect to SQL databases using LangChain
7
  from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
8
- from langchain_groq import ChatGroq # This integrates the Groq model for generating chat responses
9
- import streamlit as st # Streamlit is used for building the web app (user interface)
10
- import os # To access environment variables (e.g., credentials or other settings)
11
- import psycopg2 # A PostgreSQL database adapter to enable connections to the database
12
 
13
- # Load environment variables (such as DB credentials) from the .env file
14
  load_dotenv()
15
 
16
- # Function to establish a connection to the PostgreSQL database
17
  def init_database() -> SQLDatabase:
18
  try:
19
- # Retrieve database connection details from environment variables, or set default values
20
  user = os.getenv("DB_USER", "postgres")
21
  password = os.getenv("DB_PASSWORD", "beekhani143")
22
  host = os.getenv("DB_HOST", "localhost")
23
  port = os.getenv("DB_PORT", "5432")
24
  database = os.getenv("DB_NAME", "")
25
 
26
- # Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
27
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
28
 
29
- # Connect to the database using the SQLDatabase utility and return the instance
30
  return SQLDatabase.from_uri(db_uri)
31
  except Exception as e:
32
- # If connection fails, display an error message on the Streamlit UI
33
- st.error(f"Failed to connect to database: {e}")
34
  return None
35
 
36
- # Function to create a process (chain) that generates SQL queries based on user input and previous conversation
37
  def get_sql_chain(db):
38
- # Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
39
  template = """
40
- You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
41
- Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
42
 
43
  <SCHEMA>{schema}</SCHEMA>
44
  Conversation History: {chat_history}
@@ -48,62 +48,56 @@ def get_sql_chain(db):
48
  SQL Query:
49
  """
50
 
51
- # Create a prompt template from the above instructions
52
  prompt = ChatPromptTemplate.from_template(template)
53
- # Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
54
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
55
 
56
- # Function to get the schema (structure) of the tables in the database
57
  def get_schema(_):
58
  return db.get_table_info()
59
 
60
- # Create a chain of operations:
61
- # 1. First, get the database schema.
62
- # 2. Then, use the prompt template to guide query creation.
63
- # 3. Finally, parse the output as a plain text SQL query.
64
  return (
65
- RunnablePassthrough.assign(schema=get_schema) # Pass the schema into the chain
66
- | prompt # Use the prompt template
67
- | llm # Generate a response using the Groq model
68
- | StrOutputParser() # Parse the response as a string (SQL query)
69
  )
70
 
71
- # Function to generate a natural language response based on SQL query and database result
72
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
73
- # First, get the SQL chain (responsible for generating SQL queries)
74
  sql_chain = get_sql_chain(db)
75
 
76
- # Template to guide how the AI responds to the user's query based on the SQL results
77
  template = """
78
- You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
79
  <SCHEMA>{schema}</SCHEMA>
80
  Conversation History: {chat_history}
81
  SQL Query: <SQL>{query}</SQL>
82
- User question: {question}
83
  SQL Response: {response}
84
  """
85
 
86
- # Create a new prompt template for generating a response
87
  prompt = ChatPromptTemplate.from_template(template)
88
  # Initialize the Groq model for response generation
89
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
90
 
91
- # Chain the following tasks:
92
- # 1. Generate the SQL query using the earlier chain.
93
- # 2. Get the schema and execute the query on the database.
94
- # 3. Return the natural language response based on the query and its results.
95
  chain = (
96
  RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
97
  .assign(
98
- schema=lambda _: db.get_table_info(), # Pass the schema to the next step
99
- response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the SQL query and clean up backslashes
100
  )
101
- | prompt # Use the prompt template for generating natural language response
102
- | llm # Generate the final response using the model
103
  | StrOutputParser() # Parse the output into plain text
104
  )
105
 
106
- # Invoke the chain to generate the final response based on the user query and history
107
  result = chain.invoke({
108
  "question": user_query,
109
  "chat_history": chat_history,
@@ -116,24 +110,24 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
116
  sql_query = result.get('query', 'No query generated')
117
  print(f"SQL Query: {sql_query}")
118
 
119
- # Return the result (natural language response)
120
  return result
121
 
122
- # Initialize the chat session when Streamlit app starts
123
  if "chat_history" not in st.session_state:
124
  st.session_state.chat_history = [
125
- # First message from AI assistant
126
  AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
127
  ]
128
 
129
- # Streamlit page configuration: Set the page title and icon
130
  st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
131
- st.title("Chat with PostgreSQL") # Display title on the webpage
132
 
133
- # Sidebar configuration for database connection settings
134
  with st.sidebar:
135
- st.subheader("Settings") # Display a heading for the settings section
136
- st.write("Connect to your database and start chatting.") # Instruction text for users
137
 
138
  # Input fields for database connection details (host, port, user, password, and database name)
139
  host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
@@ -142,35 +136,35 @@ with st.sidebar:
142
  password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
143
  database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
144
 
145
- # Button to attempt database connection
146
  if st.button("Connect"):
147
- with st.spinner("Connecting to database..."): # Show a spinner while connecting
148
- db = init_database() # Call the function to connect to the database
149
  if db:
150
- st.session_state.db = db # Save the connection in session state
151
  st.success("Connected to the database!") # Display success message
152
  else:
153
- st.error("Connection failed. Please check your settings.") # Display error message if connection fails
154
 
155
  # Display the chat history (both AI and user messages)
156
  for message in st.session_state.chat_history:
157
  if isinstance(message, AIMessage):
158
- with st.chat_message("AI"): # Display AI messages in the chat
159
  st.markdown(message.content)
160
  elif isinstance(message, HumanMessage):
161
- with st.chat_message("Human"): # Display human messages in the chat
162
  st.markdown(message.content)
163
 
164
- # Input field for the user to type their message
165
- user_query = st.chat_input("Type a message...") # Field to capture user query
166
- if user_query and user_query.strip(): # If the user entered a valid query
167
- st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save the user query in chat history
168
-
169
- with st.chat_message("Human"): # Display the user's message in the chat
170
  st.markdown(user_query)
171
 
172
- with st.chat_message("AI"): # Generate and display the AI response
173
- response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get the AI's response
174
  st.markdown(response)
175
 
176
- st.session_state.chat_history.append(AIMessage(content=response)) # Add the AI's response to the chat history
 
1
  # Import necessary libraries and modules for various tasks
2
+ from dotenv import load_dotenv # For loading environment variables from a .env file (e.g., database credentials)
3
+ from langchain_core.messages import AIMessage, HumanMessage # For handling AI and user messages
4
+ from langchain_core.prompts import ChatPromptTemplate # To create templates for chatbot responses
5
+ from langchain_core.runnables import RunnablePassthrough # To chain different operations (e.g., inputs/outputs)
6
+ from langchain_community.utilities import SQLDatabase # Utility to connect to SQL databases using LangChain
7
  from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
8
+ from langchain_groq import ChatGroq # Integrates the Groq model for generating chat responses
9
+ import streamlit as st # Streamlit for building the web interface
10
+ import os # For accessing environment variables (e.g., credentials or other settings)
11
+ import psycopg2 # PostgreSQL database adapter for database connections
12
 
13
+ # Load environment variables (e.g., database credentials) from the .env file
14
  load_dotenv()
15
 
16
+ # Function to initialize the database connection
17
  def init_database() -> SQLDatabase:
18
  try:
19
+ # Retrieve connection details from environment variables or set defaults
20
  user = os.getenv("DB_USER", "postgres")
21
  password = os.getenv("DB_PASSWORD", "beekhani143")
22
  host = os.getenv("DB_HOST", "localhost")
23
  port = os.getenv("DB_PORT", "5432")
24
  database = os.getenv("DB_NAME", "")
25
 
26
+ # Construct the database URI for PostgreSQL connection
27
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
28
 
29
+ # Connect to the database using SQLDatabase utility and return the connection
30
  return SQLDatabase.from_uri(db_uri)
31
  except Exception as e:
32
+ # If connection fails, display error in the Streamlit UI
33
+ st.error(f"Failed to connect to the database: {e}")
34
  return None
35
 
36
+ # Function to generate the SQL query chain based on user input and conversation history
37
  def get_sql_chain(db):
38
+ # Define the prompt template for generating SQL queries based on schema and chat history
39
  template = """
40
+ You are a data analyst. You are interacting with a user who is asking questions about the company's database.
41
+ Based on the table schema below, write a SQL query that answers the user's question. Consider the conversation history.
42
 
43
  <SCHEMA>{schema}</SCHEMA>
44
  Conversation History: {chat_history}
 
48
  SQL Query:
49
  """
50
 
51
+ # Create the prompt template from the instructions above
52
  prompt = ChatPromptTemplate.from_template(template)
53
+ # Initialize the Groq model for generating SQL responses (deterministic output with temperature=0)
54
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
55
 
56
+ # Function to retrieve the database schema (table structure)
57
  def get_schema(_):
58
  return db.get_table_info()
59
 
60
+ # Build a chain of operations: get the schema, generate SQL query, and parse as plain text
 
 
 
61
  return (
62
+ RunnablePassthrough.assign(schema=get_schema) # Pass schema to the chain
63
+ | prompt # Use the prompt template to guide query creation
64
+ | llm # Generate SQL query using the Groq model
65
+ | StrOutputParser() # Parse the output as a plain text SQL query
66
  )
67
 
68
+ # Function to generate a natural language response based on SQL query and database results
69
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
70
+ # Generate the SQL query using the SQL chain
71
  sql_chain = get_sql_chain(db)
72
 
73
+ # Define the template for generating natural language responses
74
  template = """
75
+ You are a data analyst. Based on the schema, SQL query, and response, write a natural language response.
76
  <SCHEMA>{schema}</SCHEMA>
77
  Conversation History: {chat_history}
78
  SQL Query: <SQL>{query}</SQL>
79
+ User Question: {question}
80
  SQL Response: {response}
81
  """
82
 
83
+ # Create the prompt template for generating a natural language response
84
  prompt = ChatPromptTemplate.from_template(template)
85
  # Initialize the Groq model for response generation
86
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
87
 
88
+ # Chain the steps: Generate SQL, execute it, and generate the natural language response
 
 
 
89
  chain = (
90
  RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
91
  .assign(
92
+ schema=lambda _: db.get_table_info(), # Pass the schema for query execution
93
+ response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the query and clean the output
94
  )
95
+ | prompt # Generate a response from the AI using the prompt
96
+ | llm # Get the final response from the model
97
  | StrOutputParser() # Parse the output into plain text
98
  )
99
 
100
+ # Invoke the chain to generate the response
101
  result = chain.invoke({
102
  "question": user_query,
103
  "chat_history": chat_history,
 
110
  sql_query = result.get('query', 'No query generated')
111
  print(f"SQL Query: {sql_query}")
112
 
113
+ # Return the final natural language response
114
  return result
115
 
116
+ # Initialize the chat session in Streamlit
117
  if "chat_history" not in st.session_state:
118
  st.session_state.chat_history = [
119
+ # Initial greeting from the AI assistant
120
  AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
121
  ]
122
 
123
+ # Configure the Streamlit page title and icon
124
  st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
125
+ st.title("Chat with PostgreSQL") # Display the main title on the web page
126
 
127
+ # Sidebar for database connection settings
128
  with st.sidebar:
129
+ st.subheader("Settings") # Display settings section header
130
+ st.write("Connect to your database and start chatting.") # Instructions for users
131
 
132
  # Input fields for database connection details (host, port, user, password, and database name)
133
  host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
 
136
  password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
137
  database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
138
 
139
+ # Button to initiate database connection
140
  if st.button("Connect"):
141
+ with st.spinner("Connecting to the database..."):
142
+ db = init_database() # Attempt to connect to the database
143
  if db:
144
+ st.session_state.db = db # Save the connection to session state
145
  st.success("Connected to the database!") # Display success message
146
  else:
147
+ st.error("Connection failed. Please check your settings.") # Error message if connection fails
148
 
149
  # Display the chat history (both AI and user messages)
150
  for message in st.session_state.chat_history:
151
  if isinstance(message, AIMessage):
152
+ with st.chat_message("AI"):
153
  st.markdown(message.content)
154
  elif isinstance(message, HumanMessage):
155
+ with st.chat_message("Human"):
156
  st.markdown(message.content)
157
 
158
+ # Input field for the user to type a message
159
+ user_query = st.chat_input("Type a message...")
160
+ if user_query and user_query.strip():
161
+ st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save user query
162
+
163
+ with st.chat_message("Human"): # Display user's message in chat
164
  st.markdown(user_query)
165
 
166
+ with st.chat_message("AI"): # Generate and display AI's response
167
+ response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get AI response
168
  st.markdown(response)
169
 
170
+ st.session_state.chat_history.append(AIMessage(content=response)) # Save AI's response to chat history