Suresh Beekhani commited on
Commit
0dd285c
·
unverified ·
1 Parent(s): 394d592

Update app.py

Browse files
Files changed (1) hide show
  1. src/app.py +97 -93
src/app.py CHANGED
@@ -1,40 +1,41 @@
1
- # Import necessary libraries and modules
2
- from dotenv import load_dotenv # For loading environment variables from .env
3
- from langchain_core.messages import AIMessage, HumanMessage # Message handling
4
- from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses
5
- from langchain_core.runnables import RunnablePassthrough # To chain operations
6
- from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain
7
- from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings
8
- # OpenAI model for chat (if used)
9
- from langchain_groq import ChatGroq # Groq model for chat (currently used)
10
- import streamlit as st # Streamlit for building the web app
11
- import os # To access environment variables
12
-
13
- # Load environment variables from the .env file (like API keys, database credentials)
14
  load_dotenv()
15
 
16
- # Function to initialize a connection to a MySQL database
17
  def init_database() -> SQLDatabase:
18
  try:
19
- # Load credentials from environment variables for better security
20
- user = os.getenv("DB_USER", "root")
21
- password = os.getenv("DB_PASSWORD", "admin")
22
  host = os.getenv("DB_HOST", "localhost")
23
- port = os.getenv("DB_PORT", "3306")
24
- database = os.getenv("DB_NAME", "Chinook")
25
 
26
- # Construct the database URI
27
- db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
28
 
29
- # Initialize and return the SQLDatabase instance
30
  return SQLDatabase.from_uri(db_uri)
31
  except Exception as e:
 
32
  st.error(f"Failed to connect to database: {e}")
33
  return None
34
 
35
- # Function to create a chain that generates SQL queries from user input and conversation history
36
  def get_sql_chain(db):
37
- # SQL prompt template
38
  template = """
39
  You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
40
  Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
@@ -47,33 +48,32 @@ def get_sql_chain(db):
47
  SQL Query:
48
  """
49
 
50
- # Create a prompt from the above template
51
  prompt = ChatPromptTemplate.from_template(template)
52
-
53
- # Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
54
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
55
 
56
- # Helper function to get schema info from the database
57
  def get_schema(_):
58
  return db.get_table_info()
59
 
60
- # Chain of operations:
61
- # 1. Assign schema information from the database
62
- # 2. Use the AI model to generate a SQL query
63
- # 3. Parse the result into a string
64
  return (
65
- RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database
66
- | prompt # Generate SQL query from the prompt template
67
- | llm # Use Groq model to process the prompt and return a SQL query
68
- | StrOutputParser() # Parse the result as a string
69
  )
70
 
71
- # Function to generate a response in natural language based on the SQL query result
72
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
73
- # Generate the SQL query using the chain
74
  sql_chain = get_sql_chain(db)
75
-
76
- # Prompt template for natural language response based on SQL query and result
77
  template = """
78
  You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
79
  <SCHEMA>{schema}</SCHEMA>
@@ -83,90 +83,94 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
83
  SQL Response: {response}
84
  """
85
 
86
- # Create a natural language response prompt
87
  prompt = ChatPromptTemplate.from_template(template)
88
-
89
- # Initialize Groq model (alternative: OpenAI)
90
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
91
-
92
- # Build a chain: generate SQL query, run it on the database, generate a natural language response
 
 
 
93
  chain = (
94
- RunnablePassthrough.assign(query=sql_chain).assign(
95
- schema=lambda _: db.get_table_info(), # Get schema info
96
- response=lambda vars: db.run(vars["query"]), # Run SQL query on the database
 
97
  )
98
- | prompt # Use prompt to generate a natural language response
99
- | llm # Process prompt with Groq model
100
- | StrOutputParser() # Parse the final result as a string
101
  )
102
-
103
- # Execute the chain and return the response
104
- return chain.invoke({
105
  "question": user_query,
106
  "chat_history": chat_history,
107
  })
108
 
109
- # Initialize the Streamlit session
 
 
 
 
 
 
 
 
 
 
110
  if "chat_history" not in st.session_state:
111
- # Initialize chat history with a welcome message from AI
112
  st.session_state.chat_history = [
 
113
  AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
114
  ]
115
 
116
- # Set up the Streamlit web page configuration
117
- st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
118
-
119
- # Streamlit app title
120
- st.title("Chat with MySQL")
121
 
122
- # Sidebar for database connection settings
123
  with st.sidebar:
124
- st.subheader("Settings")
125
- st.write("Connect to your database and start chatting.")
126
 
127
- # Database connection input fields
128
  host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
129
- port = st.text_input("Port", value=os.getenv("DB_PORT", "3306"))
130
- user = st.text_input("User", value=os.getenv("DB_USER", "root"))
131
- password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin"))
132
- database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook"))
133
 
134
- # Button to connect to the database
135
  if st.button("Connect"):
136
- with st.spinner("Connecting to database..."):
137
- # Initialize the database connection and store in session state
138
- db = init_database()
139
  if db:
140
- st.session_state.db = db
141
- st.success("Connected to the database!")
142
  else:
143
- st.error("Connection failed. Please check your settings.")
144
 
145
- # Display chat history
146
  for message in st.session_state.chat_history:
147
  if isinstance(message, AIMessage):
148
- # Display AI message
149
- with st.chat_message("AI"):
150
  st.markdown(message.content)
151
  elif isinstance(message, HumanMessage):
152
- # Display human message
153
- with st.chat_message("Human"):
154
  st.markdown(message.content)
155
 
156
- # Input field for user's message
157
- user_query = st.chat_input("Type a message...")
158
- if user_query and user_query.strip():
159
- # Add user's query to the chat history
160
- st.session_state.chat_history.append(HumanMessage(content=user_query))
161
 
162
- # Display user's message in the chat
163
- with st.chat_message("Human"):
164
  st.markdown(user_query)
165
 
166
- # Generate and display AI's response based on the query
167
- with st.chat_message("AI"):
168
- response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
169
  st.markdown(response)
170
 
171
- # Add AI's response to the chat history
172
- st.session_state.chat_history.append(AIMessage(content=response))
 
1
+ # Import necessary libraries and modules for various tasks
2
+ from dotenv import load_dotenv # For loading environment variables from a .env file (such as database credentials)
3
+ from langchain_core.messages import AIMessage, HumanMessage # For handling messages from the AI and user
4
+ from langchain_core.prompts import ChatPromptTemplate # To create templates that will guide the chatbot's responses
5
+ from langchain_core.runnables import RunnablePassthrough # To enable chaining of different operations (like inputs/outputs)
6
+ from langchain_community.utilities import SQLDatabase # A tool to help connect to SQL databases using LangChain
7
+ from langchain_core.output_parsers import StrOutputParser # To parse outputs into plain text
8
+ from langchain_groq import ChatGroq # This integrates the Groq model for generating chat responses
9
+ import streamlit as st # Streamlit is used for building the web app (user interface)
10
+ import os # To access environment variables (e.g., credentials or other settings)
11
+ import psycopg2 # A PostgreSQL database adapter to enable connections to the database
12
+
13
+ # Load environment variables (such as DB credentials) from the .env file
14
  load_dotenv()
15
 
16
+ # Function to establish a connection to the PostgreSQL database
17
  def init_database() -> SQLDatabase:
18
  try:
19
+ # Retrieve database connection details from environment variables, or set default values
20
+ user = os.getenv("DB_USER", "postgres")
21
+ password = os.getenv("DB_PASSWORD", "beekhani143")
22
  host = os.getenv("DB_HOST", "localhost")
23
+ port = os.getenv("DB_PORT", "5432")
24
+ database = os.getenv("DB_NAME", "")
25
 
26
+ # Construct the database URI (a URL-like string) with the necessary credentials for PostgreSQL
27
+ db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
28
 
29
+ # Connect to the database using the SQLDatabase utility and return the instance
30
  return SQLDatabase.from_uri(db_uri)
31
  except Exception as e:
32
+ # If connection fails, display an error message on the Streamlit UI
33
  st.error(f"Failed to connect to database: {e}")
34
  return None
35
 
36
+ # Function to create a process (chain) that generates SQL queries based on user input and previous conversation
37
  def get_sql_chain(db):
38
+ # Template to guide how SQL queries are generated. The bot receives table schema and conversation history.
39
  template = """
40
  You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
41
  Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
 
48
  SQL Query:
49
  """
50
 
51
+ # Create a prompt template from the above instructions
52
  prompt = ChatPromptTemplate.from_template(template)
53
+ # Initialize the Groq model for generating responses with low randomness (temperature=0 for more deterministic outputs)
 
54
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
55
 
56
+ # Function to get the schema (structure) of the tables in the database
57
  def get_schema(_):
58
  return db.get_table_info()
59
 
60
+ # Create a chain of operations:
61
+ # 1. First, get the database schema.
62
+ # 2. Then, use the prompt template to guide query creation.
63
+ # 3. Finally, parse the output as a plain text SQL query.
64
  return (
65
+ RunnablePassthrough.assign(schema=get_schema) # Pass the schema into the chain
66
+ | prompt # Use the prompt template
67
+ | llm # Generate a response using the Groq model
68
+ | StrOutputParser() # Parse the response as a string (SQL query)
69
  )
70
 
71
+ # Function to generate a natural language response based on SQL query and database result
72
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
73
+ # First, get the SQL chain (responsible for generating SQL queries)
74
  sql_chain = get_sql_chain(db)
75
+
76
+ # Template to guide how the AI responds to the user's query based on the SQL results
77
  template = """
78
  You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
79
  <SCHEMA>{schema}</SCHEMA>
 
83
  SQL Response: {response}
84
  """
85
 
86
+ # Create a new prompt template for generating a response
87
  prompt = ChatPromptTemplate.from_template(template)
88
+ # Initialize the Groq model for response generation
 
89
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
90
+
91
+ # Chain the following tasks:
92
+ # 1. Generate the SQL query using the earlier chain.
93
+ # 2. Get the schema and execute the query on the database.
94
+ # 3. Return the natural language response based on the query and its results.
95
  chain = (
96
+ RunnablePassthrough.assign(query=sql_chain) # Generate SQL query
97
+ .assign(
98
+ schema=lambda _: db.get_table_info(), # Pass the schema to the next step
99
+ response=lambda vars: db.run(vars["query"].replace("\\", "")), # Execute the SQL query and clean up backslashes
100
  )
101
+ | prompt # Use the prompt template for generating natural language response
102
+ | llm # Generate the final response using the model
103
+ | StrOutputParser() # Parse the output into plain text
104
  )
105
+
106
+ # Invoke the chain to generate the final response based on the user query and history
107
+ result = chain.invoke({
108
  "question": user_query,
109
  "chat_history": chat_history,
110
  })
111
 
112
+ # Debugging: Print the SQL query being executed
113
+ if isinstance(result, str):
114
+ print(f"SQL Query: {result}")
115
+ else:
116
+ sql_query = result.get('query', 'No query generated')
117
+ print(f"SQL Query: {sql_query}")
118
+
119
+ # Return the result (natural language response)
120
+ return result
121
+
122
+ # Initialize the chat session when Streamlit app starts
123
  if "chat_history" not in st.session_state:
 
124
  st.session_state.chat_history = [
125
+ # First message from AI assistant
126
  AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
127
  ]
128
 
129
+ # Streamlit page configuration: Set the page title and icon
130
+ st.set_page_config(page_title="Chat with PostgreSQL", page_icon=":speech_balloon:")
131
+ st.title("Chat with PostgreSQL") # Display title on the webpage
 
 
132
 
133
+ # Sidebar configuration for database connection settings
134
  with st.sidebar:
135
+ st.subheader("Settings") # Display a heading for the settings section
136
+ st.write("Connect to your database and start chatting.") # Instruction text for users
137
 
138
+ # Input fields for database connection details (host, port, user, password, and database name)
139
  host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
140
+ port = st.text_input("Port", value=os.getenv("DB_PORT", "5432"))
141
+ user = st.text_input("User", value=os.getenv("DB_USER", "postgres"))
142
+ password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "beekhani143"))
143
+ database = st.text_input("Database", value=os.getenv("DB_NAME", "db"))
144
 
145
+ # Button to attempt database connection
146
  if st.button("Connect"):
147
+ with st.spinner("Connecting to database..."): # Show a spinner while connecting
148
+ db = init_database() # Call the function to connect to the database
 
149
  if db:
150
+ st.session_state.db = db # Save the connection in session state
151
+ st.success("Connected to the database!") # Display success message
152
  else:
153
+ st.error("Connection failed. Please check your settings.") # Display error message if connection fails
154
 
155
+ # Display the chat history (both AI and user messages)
156
  for message in st.session_state.chat_history:
157
  if isinstance(message, AIMessage):
158
+ with st.chat_message("AI"): # Display AI messages in the chat
 
159
  st.markdown(message.content)
160
  elif isinstance(message, HumanMessage):
161
+ with st.chat_message("Human"): # Display human messages in the chat
 
162
  st.markdown(message.content)
163
 
164
+ # Input field for the user to type their message
165
+ user_query = st.chat_input("Type a message...") # Field to capture user query
166
+ if user_query and user_query.strip(): # If the user entered a valid query
167
+ st.session_state.chat_history.append(HumanMessage(content=user_query)) # Save the user query in chat history
 
168
 
169
+ with st.chat_message("Human"): # Display the user's message in the chat
 
170
  st.markdown(user_query)
171
 
172
+ with st.chat_message("AI"): # Generate and display the AI response
173
+ response = get_response(user_query, st.session_state.db, st.session_state.chat_history) # Get the AI's response
 
174
  st.markdown(response)
175
 
176
+ st.session_state.chat_history.append(AIMessage(content=response)) # Add the AI's response to the chat history