Parthiban97 commited on
Commit
85397cc
·
verified ·
1 Parent(s): 3821cdd

Upload 3 files

Browse files
Files changed (3) hide show
  1. .streamlit/secrets.toml +5 -0
  2. app.py +246 -0
  3. requirements.txt +20 -0
.streamlit/secrets.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Host = "34.70.107.70"
2
+ Port = "3306"
3
+ User = "root"
4
+ Password = "mysql"
5
+ Databases = "atliq_tshirts"
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain_core.messages import AIMessage, HumanMessage
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_community.utilities import SQLDatabase
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_groq import ChatGroq
10
+ import toml
11
+
12
+ # Function to update secrets.toml file
13
+ def update_secrets_file(data):
14
+ secrets_file_path = ".streamlit/secrets.toml"
15
+ if os.path.exists(secrets_file_path):
16
+ with open(secrets_file_path, "r") as file:
17
+ secrets_data = toml.load(file)
18
+ else:
19
+ secrets_data = {}
20
+
21
+ secrets_data.update(data)
22
+
23
+ with open(secrets_file_path, "w") as file:
24
+ toml.dump(secrets_data, file)
25
+
26
+ # Initialize database connections
27
+ def init_databases():
28
+ secrets_file_path = ".streamlit/secrets.toml"
29
+ with open(secrets_file_path, "r") as file:
30
+ secrets_data = toml.load(file)
31
+
32
+ db_connections = {}
33
+ for database in secrets_data["Databases"].split(','):
34
+ database = database.strip()
35
+ db_uri = f"mysql+mysqlconnector://{secrets_data['User']}:{secrets_data['Password']}@{secrets_data['Host']}:{secrets_data['Port']}/{database}"
36
+ db_connections[database] = SQLDatabase.from_uri(db_uri)
37
+ return db_connections
38
+
39
+ # Function to get SQL chain
40
+ def get_sql_chain(dbs, llm):
41
+ template = """
42
+ You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
43
+ You are interacting with a user who is asking you questions about the company's databases.
44
+ Based on the table schemas below, write SQL queries that would answer the user's question. Take the conversation history into account.
45
+
46
+ <SCHEMAS>{schemas}</SCHEMAS>
47
+
48
+ Conversation History: {chat_history}
49
+
50
+ Write the SQL queries for each relevant database, prefixed by the database name (e.g., DB1: SELECT * FROM ...; DB2: SELECT * FROM ...).
51
+ Do not wrap the SQL queries in any other text, not even backticks.
52
+
53
+ For example:
54
+ Question: which 3 artists have the most tracks?
55
+ SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
56
+ Question: Name 10 artists
57
+ SQL Query: SELECT Name FROM Artist LIMIT 10;
58
+ Question: How much is the price of the inventory for all small size t-shirts?
59
+ SQL Query: SELECT SUM(price * stock_quantity) FROM t_shirts WHERE size = 'S';
60
+ Question: If we have to sell all the Levi's T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?
61
+ SQL Query: SELECT SUM(a.total_amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) AS total_revenue
62
+ FROM (SELECT SUM(price * stock_quantity) AS total_amount, t_shirt_id
63
+ FROM t_shirts
64
+ WHERE brand = 'Levi' GROUP BY t_shirt_id) a
65
+ LEFT JOIN discounts ON a.t_shirt_id = discounts.t_shirt_id;
66
+ Question: For each brand, find the total revenue generated from t-shirts with a discount applied, grouped by the discount percentage.
67
+ SQL Query: SELECT brand, COALESCE(discounts.pct_discount, 0) AS discount_pct, SUM(t.price * t.stock_quantity * (1 - COALESCE(discounts.pct_discount, 0) / 100)) AS total_revenue
68
+ FROM t_shirts t
69
+ LEFT JOIN discounts ON t.t_shirt_id = discounts.t_shirt_id
70
+ GROUP BY brand, COALESCE(discounts.pct_discount, 0);
71
+ Question: Find the top 3 most popular colors for each brand, based on the total stock quantity.
72
+ SQL Query: SELECT brand, color, SUM(stock_quantity) AS total_stock
73
+ FROM t_shirts
74
+ GROUP BY brand, color
75
+ ORDER BY brand, total_stock DESC;
76
+
77
+ Question: Calculate the average price per size for each brand, excluding sizes with less than 10 t-shirts in stock.
78
+ SQL Query: SELECT brand, size, AVG(price) AS avg_price
79
+ FROM t_shirts
80
+ WHERE stock_quantity >= 10
81
+ GROUP BY brand, size
82
+ HAVING COUNT(*) >= 10;
83
+
84
+ Question: Find the brand and color combination with the highest total revenue, considering discounts.
85
+ SQL Query: SELECT brand, color, SUM(t.price * t.stock_quantity * (1 - COALESCE(d.pct_discount, 0) / 100)) AS total_revenue
86
+ FROM t_shirts t
87
+ LEFT JOIN discounts d ON t.t_shirt_id = d.t_shirt_id
88
+ GROUP BY brand, color
89
+ ORDER BY total_revenue DESC
90
+ LIMIT 1;
91
+
92
+ Question: Create a view that shows the total stock quantity and revenue for each brand, size, and color combination.
93
+ SQL Query: CREATE VIEW brand_size_color_stats AS
94
+ SELECT brand, size, color, SUM(stock_quantity) AS total_stock, SUM(price * stock_quantity) AS total_revenue
95
+ FROM t_shirts
96
+ GROUP BY brand, size, color;
97
+
98
+ Question: How much is the price of the inventory for all varients t-shirts and group them y brands?
99
+ SQL Query: SELECT brand, SUM(price * stock_quantity) FROM t_shirts GROUP BY brand;
100
+
101
+ Question: List the total revenue of t-shirts of L size for all brands
102
+ SQL Query: SELECT brand, SUM(price * stock_quantity) AS total_revenue FROM t_shirts WHERE size = 'L' GROUP BY brand;
103
+
104
+ Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
105
+ SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
106
+
107
+ Your turn:
108
+
109
+ Question: {question}
110
+ SQL Queries:
111
+ """
112
+
113
+ prompt = ChatPromptTemplate.from_template(template)
114
+ llm = llm
115
+
116
+ def get_schema(_):
117
+ schemas = {db_name: db.get_table_info() for db_name, db in dbs.items()}
118
+ return schemas
119
+
120
+ return (
121
+ RunnablePassthrough.assign(schemas=get_schema)
122
+ | prompt
123
+ | llm
124
+ | StrOutputParser()
125
+ | (lambda result: {line.split(":")[0]: line.split(":")[1].strip() for line in result.strip().split("\n") if ":" in line and line.strip()})
126
+ )
127
+
128
+ # Function to get response
129
+ def get_response(user_query, dbs, chat_history, llm):
130
+ sql_chain = get_sql_chain(dbs, llm)
131
+
132
+ template = """
133
+ You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
134
+ You are interacting with a user who is asking you questions about the company's databases.
135
+ Based on the table schemas below, question, sql queries, and sql responses, write an
136
+ accurate natural language response so that the end user can understand things
137
+ and make sure do not include words like "Based on the SQL queries I ran".
138
+ Just provide only the answer with some text that the user expects.
139
+ <SCHEMAS>{schemas}</SCHEMAS>
140
+ Conversation History: {chat_history}
141
+ SQL Queries: <SQL>{queries}</SQL>
142
+ User question: {question}
143
+ SQL Responses: {responses}"""
144
+
145
+ prompt = ChatPromptTemplate.from_template(template)
146
+ llm = llm
147
+
148
+ def run_queries(var):
149
+ responses = {}
150
+ for db_name, query in var["queries"].items():
151
+ responses[db_name] = dbs[db_name].run(query)
152
+ return responses
153
+
154
+ chain = (
155
+ RunnablePassthrough.assign(queries=sql_chain).assign(
156
+ schemas=lambda _: {db_name: db.get_table_info() for db_name, db in dbs.items()},
157
+ responses=run_queries) # The comma at the end of the assign() method call is used to indicate that there may be more keyword arguments or method calls following it
158
+ | prompt
159
+ | llm
160
+ | StrOutputParser()
161
+ )
162
+
163
+ return chain.invoke({
164
+ "question": user_query,
165
+ "chat_history": chat_history,
166
+ })
167
+
168
+ # Streamlit app configuration
169
+ if "chat_history" not in st.session_state:
170
+ st.session_state.chat_history = [
171
+ AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
172
+ ]
173
+
174
+ st.set_page_config(page_title="Chat with MySQL", page_icon="🛢️")
175
+ st.title("Chat with MySQL")
176
+
177
+ with st.sidebar:
178
+ st.subheader("Settings")
179
+ st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
180
+
181
+ if "db" not in st.session_state:
182
+ st.session_state.Host = st.text_input("Host", value=st.secrets.get("Host", ""))
183
+ st.session_state.Port = st.text_input("Port", value=st.secrets.get("Port", ""))
184
+ st.session_state.User = st.text_input("User", value=st.secrets.get("User", ""))
185
+ st.session_state.Password = st.text_input("Password", type="password", value=st.secrets.get("Password", ""))
186
+ st.session_state.Databases = st.text_input("Databases", placeholder="Enter DB's separated by (,)", value=st.secrets.get("Databases", ""))
187
+ st.session_state.openai_api_key = st.text_input("OpenAI API Key", type="password", help="Get your API key from [OpenAI Website](https://platform.openai.com/api-keys)", value=st.secrets.get("openai_api_key", ""))
188
+ st.session_state.groq_api_key = st.text_input("Groq API Key", type="password", help="Get your API key from [GROQ Console](https://console.groq.com/keys)", value=st.secrets.get("groq_api_key", ""))
189
+
190
+ st.info("Note: For interacting multiple databases, GPT-4 Model is recommended for accurate results else proceed with Groq Model")
191
+
192
+ os.environ["OPENAI_API_KEY"] = str(st.session_state.openai_api_key)
193
+
194
+ if st.button("Connect"):
195
+ with st.spinner("Connecting to databases..."):
196
+
197
+ # Update secrets.toml with connection details
198
+ update_secrets_file({
199
+ "Host": st.session_state.Host,
200
+ "Port": st.session_state.Port,
201
+ "User": st.session_state.User,
202
+ "Password": st.session_state.Password,
203
+ "Databases": st.session_state.Databases
204
+ })
205
+
206
+ dbs = init_databases()
207
+ st.session_state.dbs = dbs
208
+
209
+ if len(dbs) > 1:
210
+ st.success(f"Connected to {len(dbs)} databases")
211
+ else:
212
+ st.success("Connected to database")
213
+
214
+
215
+
216
+ if st.session_state.openai_api_key == "" and st.session_state.groq_api_key == "":
217
+ st.error("Enter one API Key At least")
218
+ elif st.session_state.openai_api_key:
219
+ st.session_state.llm = ChatOpenAI(model="gpt-4-turbo", api_key=st.session_state.openai_api_key)
220
+ elif st.session_state.groq_api_key:
221
+ st.session_state.llm = ChatGroq(model="llama3-70b-8192", temperature=0.4, api_key=st.session_state.groq_api_key)
222
+ else:
223
+ pass
224
+
225
+ # Display chat messages
226
+ for message in st.session_state.chat_history:
227
+ if isinstance(message, AIMessage):
228
+ with st.chat_message("AI"):
229
+ st.markdown(message.content)
230
+ elif isinstance(message, HumanMessage):
231
+ with st.chat_message("Human"):
232
+ st.markdown(message.content)
233
+
234
+ # Handle user input
235
+ user_query = st.chat_input("Type a message...")
236
+ if user_query is not None and user_query.strip() != "":
237
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
238
+
239
+ with st.chat_message("Human"):
240
+ st.markdown(user_query)
241
+
242
+ with st.chat_message("AI"):
243
+ response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
244
+ st.markdown(response)
245
+
246
+ st.session_state.chat_history.append(AIMessage(content=response))
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ streamlit==1.31.1
3
+ langchain==0.1.8
4
+ langchain-community==0.0.21
5
+ langchain-core==0.1.24
6
+ langchain-openai==0.0.6
7
+ mysql-connector-python==8.3.0
8
+ groq==0.4.2
9
+ langchain-groq==0.0.1
10
+ =======
11
+ streamlit==1.31.1
12
+ langchain==0.1.8
13
+ langchain-community==0.0.21
14
+ langchain-core==0.1.24
15
+ langchain-openai==0.0.6
16
+ mysql-connector-python==8.3.0
17
+ groq==0.4.2
18
+ langchain-groq==0.0.1
19
+ toml
20
+ >>>>>>> ac80cc2d6f7f9213dc646047cf721e1a35cc0808