Asif988 commited on
Commit
aabeba4
·
verified ·
1 Parent(s): 71ebc1b
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/qus_example.index filter=lfs diff=lfs merge=lfs -text
37
+ data/sql_examples.index filter=lfs diff=lfs merge=lfs -text
38
+ data/train_spider.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /gemini_sql_rag_1_ext
6
+
7
+ # Copy everything into the container
8
+ COPY . /gemini_sql_rag_1_ext
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Expose port for Streamlit
14
+ EXPOSE 7860
15
+
16
+ # Run the app
17
+ CMD ["streamlit", "run", "gemini_sql_rag_1_ext.py", "--server.port=7860", "--server.address=0.0.0.0"]
data/business_qus_data.json ADDED
The diff for this file is too large to render. See raw diff
 
data/example_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74f185645d851ea920b4e40b3f729214eada9bc0404752cfb089905ca37cece8
3
+ size 21504128
data/qus_example.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbf229ec13553ae0c92c812b0843c28ead8ace2cf9675ce9f5b9b3ea7452d7ef
3
+ size 4012077
data/qus_example_embedding.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0030dcdcf073bd6b17fadcca71143b852cbaf6839d70de73df8d534b1e979bfa
3
+ size 4012160
data/sql_examples.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fc0df192a89a0ca61a9e079959747abd403cad9d2f1e2b63b0aae6dd20c53e8
3
+ size 21504045
data/tables.json ADDED
The diff for this file is too large to render. See raw diff
 
data/train_spider.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c43d0d72e59e1a9e1a60837da9bf70d5a6277226bdb7f634d544f380646f527a
3
+ size 24928884
gemini_sql_rag_1_ext.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ import re
3
+ import streamlit as st
4
+ import numpy as np # linear algebra
5
+ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
6
+ from sentence_transformers import SentenceTransformer
7
+ import faiss
8
+ import re
9
+ import json
10
+ import psycopg2
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ def load_schema():
15
+ with open("data/tables.json", "r") as f:
16
+ tables = json.load(f)
17
+
18
+ # Create schema dictionary
19
+ schema_dict = {}
20
+ for db in tables:
21
+ db_id = db["db_id"]
22
+ schema_parts = []
23
+ for table_idx, table_name in enumerate(db["table_names_original"]):
24
+ cols = []
25
+ for col in db["column_names_original"]:
26
+ if col[0] == table_idx:
27
+ col_name = col[1]
28
+ col_type = db["column_types"][db["column_names_original"].index(col)]
29
+ cols.append(f"{col_name} {col_type.upper()}")
30
+ if cols:
31
+ schema_parts.append(f"{table_name}({', '.join(cols)})")
32
+
33
+ # Add foreign keys if available
34
+ fk_str = ""
35
+ if "foreign_keys" in db:
36
+ for fk in db["foreign_keys"]:
37
+ from_col = db["column_names_original"][fk[0]][1]
38
+ to_col = db["column_names_original"][fk[1]][1]
39
+ from_table = db["table_names_original"][db["column_names_original"][fk[0]][0]]
40
+ to_table = db["table_names_original"][db["column_names_original"][fk[1]][0]]
41
+ fk_str += f"FOREIGN KEY {from_table}({from_col}) REFERENCES {to_table}({to_col}); "
42
+
43
+ schema_str = f"{db_id}({', '.join(schema_parts)})"
44
+ if fk_str:
45
+ schema_str += f" {fk_str.strip()}"
46
+
47
+ schema_dict[db_id] = schema_str
48
+ return schema_dict
49
+
50
+ def data_examples(schema_dict):
51
+ # Load train_spider.json for examples
52
+ with open("data/train_spider.json", "r") as f:
53
+ train_data = json.load(f)
54
+
55
+ sql_examples = []
56
+ for ex in train_data:
57
+ db_id = ex["db_id"]
58
+ schema_str = schema_dict.get(db_id, f"{db_id} (Schema details not found)")
59
+ request = ex["question"]
60
+ query = ex["query"]
61
+ example_str = f"Schema: {schema_str}\nRequest: {request}\nQuery: {query};"
62
+ sql_examples.append(example_str)
63
+ return sql_examples
64
+
65
+ @st.cache_resource
66
+ def prepare_sql_examples():
67
+ schema_dict = load_schema()
68
+ sql_examples = data_examples(schema_dict)
69
+
70
+ return sql_examples
71
+
72
+ @st.cache_resource
73
+ def qus_data_examples():
74
+
75
+ with open("data/business_qus_data.json", 'r') as f:
76
+ qus_answer = json.load(f)
77
+
78
+ unique = []
79
+ for record in qus_answer:
80
+ unique.append(record)
81
+
82
+ #create formatted qus_ans dataset
83
+ qus_examples = []
84
+ for ex in unique:
85
+ data = ex["Data"]
86
+ qus = ex["Question"]
87
+ ans = ex["Answer"]
88
+ qus_example = f"Data: {data}\nQuestion: {qus}\nAnswer: {ans};"
89
+ qus_examples.append(qus_example)
90
+
91
+ return qus_examples
92
+
93
+ @st.cache_resource
94
+ def load_embedder():
95
+ return SentenceTransformer("BAAI/bge-base-en-v1.5")
96
+
97
+ @st.cache_resource
98
+ def load_indexes():
99
+ index = faiss.read_index("data/sql_examples.index")
100
+ index_qus = faiss.read_index("data/qus_example.index")
101
+ return index,index_qus
102
+
103
+ # def embed_store_tokens(sql_examples,qus_examples):
104
+ #
105
+ # example_embeddings = np.load("data/example_embeddings.npy")
106
+ #
107
+ # dimension = example_embeddings.shape[1] # Embedding size (e.g., 768 for bge-base)
108
+ # index = faiss.IndexFlatL2(dimension) # Simple L2 distance index; for large datasets, consider IndexIVFFlat for faster search
109
+ # index.add(example_embeddings) # Add vectors to index
110
+ #
111
+ # qus_example_embedding = np.load("data/qus_example_embedding.npy")
112
+ #
113
+ # dimension_qus = qus_example_embedding.shape[1] # Embedding size (e.g., 768 for bge-base)
114
+ # index_qus = faiss.IndexFlatL2(dimension_qus) # Simple L2 distance index; for large datasets, consider IndexIVFFlat for faster search
115
+ # index_qus.add(qus_example_embedding) # Add vectors to index
116
+ #
117
+ # return index,index_qus
118
+
119
+
120
+ def generate_sql_with_custom_rag(gemini_model,schema, embedder,request,conversation_history,sql_examples,index, max_length=1024, temperature=0.4, top_p=0.9, k=3):
121
+ try:
122
+ # Step 1: Create a query string for retrieval
123
+ query_text = f"Schema: {schema}\nRequest: {request}\nContext: {conversation_history}"
124
+ query_embedding = embedder.encode([query_text], convert_to_tensor=False)
125
+ query_embedding = np.array(query_embedding).astype('float32')
126
+
127
+ # Step 2: Retrieve top-k similar examples using FAISS
128
+ distances, indices = index.search(query_embedding, k)
129
+ retrieved_examples = [sql_examples[idx] for idx in indices[0] if idx != -1]
130
+
131
+ # Step 3: Format retrieved examples for prompt
132
+ examples_str = "\n\n".join(retrieved_examples) if retrieved_examples else "No similar examples found."
133
+
134
+ # Step 4: Build prompt
135
+ prompt = f"""
136
+ You are a SQL expert.
137
+ Use the following examples, schema, and conversation context to generate a single, correct SQL query.
138
+ Assume a standard SQL database (PostgreSQL/MySQL).
139
+ Return only the SQL query — no explanations.
140
+
141
+ Examples:
142
+ {examples_str}
143
+
144
+ Database Schema:
145
+ {schema}
146
+
147
+ Conversation Context:
148
+ {conversation_history}
149
+
150
+ Request:
151
+ {request}
152
+ """
153
+
154
+ # Step 5: Generate SQL using Gemini
155
+ response = gemini_model.generate_content(
156
+ prompt,
157
+ generation_config={
158
+ "temperature": temperature,
159
+ "top_p": top_p,
160
+ "max_output_tokens": 300,
161
+ }
162
+ )
163
+
164
+ text = response.text.strip()
165
+
166
+ sql_match = re.search(r"(SELECT.*?\n)", text, re.DOTALL | re.IGNORECASE)
167
+ if sql_match:
168
+ text = sql_match.group(1).strip()
169
+
170
+ return text
171
+ except Exception as e:
172
+ return f"Error: {str(e)}"
173
+
174
+
175
+ def fetch_data_from_database(sql_query: str):
176
+ conn = psycopg2.connect(
177
+ host="ep-long-tooth-a1zzotwg-pooler.ap-southeast-1.aws.neon.tech", # e.g., ep-silent-sunset-123456.neon.tech
178
+ dbname="neondb",
179
+ user="neondb_owner",
180
+ password="npg_Bd06StQryYlV",
181
+ sslmode="require")
182
+
183
+ conn.cursor()
184
+ df = pd.read_sql(sql_query, conn)
185
+ conn.close()
186
+ records = df.to_dict(orient="records")
187
+ json_data = json.dumps(records, indent=2)
188
+
189
+ return json_data
190
+
191
+
192
+ def generate_answer_from_json_data(gemini_model,json_data,embedder, request,conversation_history,qus_examples,index_qus, max_length=1024, temperature=0.5, top_p=0.9, k=3):
193
+ try:
194
+ # Step 1: Create query for retrieval
195
+ query_text = f"Data: {json_data}\nQuestion: {request}\nContext: {conversation_history}"
196
+ query_embedding = embedder.encode([query_text], convert_to_tensor=False)
197
+ query_embedding = np.array(query_embedding).astype('float32')
198
+
199
+ # Step 2: Retrieve similar examples using FAISS
200
+ distances, indices = index_qus.search(query_embedding, k)
201
+ retrieved_examples = [qus_examples[idx] for idx in indices[0] if idx != -1]
202
+ examples_str = "\n\n".join(retrieved_examples) if retrieved_examples else "No similar examples found."
203
+
204
+ # Step 3: Build prompt for Gemini
205
+ prompt = f"""
206
+ You are a helpful AI assistant.
207
+ Use the provided data and conversation context to answer the question.
208
+ Be concise and human-readable.
209
+ Do not include extra commentary or repeat data.
210
+
211
+ Examples:
212
+ {examples_str}
213
+
214
+ Data:
215
+ {json_data}
216
+
217
+ Conversation Context:
218
+ {conversation_history}
219
+
220
+ Question:
221
+ {request}
222
+ """
223
+
224
+ # Step 4: Generate answer using Gemini
225
+ response = gemini_model.generate_content(
226
+ prompt,
227
+ generation_config={
228
+ "temperature": temperature,
229
+ "top_p": top_p,
230
+ "max_output_tokens": 300,
231
+ }
232
+ )
233
+
234
+ text = response.text.strip()
235
+
236
+ # Optional cleanup for safety
237
+ answer_match = re.search(r'(?i)(answer:)?\s*(.*)', text, re.DOTALL)
238
+ if answer_match:
239
+ text = answer_match.group(2).strip()
240
+
241
+ return text
242
+ except Exception as e:
243
+ return f"Error: {str(e)}"
244
+
245
+ @st.cache_resource
246
+ def load_llm_model():
247
+ # Configure Gemini
248
+ genai.configure(api_key="AIzaSyCiGgeMMHrELnvKg-1ydHCVWlFm9LFLYpU")
249
+ # Choose model
250
+ return genai.GenerativeModel("gemini-2.0-flash")
251
+
252
+ def generate_text(gemini_model,schema,embedder, request, conversation_history,sql_examples, index,qus_examples, index_qus):
253
+ # Step 1: Generate SQL
254
+ sql_query = generate_sql_with_custom_rag(gemini_model,schema, embedder,request,conversation_history,sql_examples,index)
255
+
256
+ # Step 2: Fetch data from DB using your existing function
257
+ result_data = fetch_data_from_database(sql_query)
258
+
259
+ # Step 3: Generate final natural-language answer
260
+ answer = generate_answer_from_json_data(gemini_model,result_data,embedder, request,conversation_history,qus_examples,index_qus)
261
+
262
+ return answer
263
+
264
+ def format_conversation_history(conversation_history):
265
+ """Format the dictionary into readable text for passing to the model."""
266
+ formatted = ""
267
+ for msg in conversation_history["messages"]:
268
+ formatted += f"{msg['role'].capitalize()}: {msg['content']}\n"
269
+ return formatted.strip()
270
+
271
+ if __name__=="__main__":
272
+ st.set_page_config(page_title="SQL Chatbot", page_icon="🤖", layout="centered")
273
+ st.title("🤖 SQL Chatbot (Gemini + RAG Ready)")
274
+ st.caption("Ask me anything about your database. Type below to start chatting!")
275
+
276
+ schema = """ecommerce(customers(customer_id INT, first_name TEXT, last_name TEXT, email TEXT, phone TEXT, address TEXT, city TEXT, country TEXT, created_at TIMESTAMP)
277
+ ,orders(order_id INT, customer_id INT, order_date TIMESTAMP, status TEXT, amount DECIMAL))"""
278
+
279
+ sql_examples = prepare_sql_examples()
280
+
281
+ qus_examples = qus_data_examples()
282
+
283
+ embedder = load_embedder()
284
+
285
+ index, index_qus = load_indexes()
286
+
287
+ # Load model
288
+ gemini_model = load_llm_model()
289
+
290
+ # Ensure proper structure in session_state
291
+ if "conversation_history" not in st.session_state or not isinstance(st.session_state.conversation_history, dict):
292
+ st.session_state.conversation_history = {"messages": []}
293
+ elif "messages" not in st.session_state.conversation_history:
294
+ st.session_state.conversation_history["messages"] = []
295
+
296
+ # Display previous messages
297
+ for msg in st.session_state.conversation_history["messages"]:
298
+ if msg["role"] == "user":
299
+ st.chat_message("user").write(msg["content"])
300
+ elif msg["role"] == "assistant":
301
+ st.chat_message("assistant").write(msg["content"])
302
+
303
+ # --- Chat input box ---
304
+ if user_input := st.chat_input("Type your question or SQL request..."):
305
+ # Add user message
306
+ st.session_state.conversation_history["messages"].append({"role": "user", "content": user_input})
307
+ st.chat_message("user").write(user_input)
308
+
309
+ # Format history for prompt (if your generate_text uses it)
310
+ history_text = format_conversation_history(st.session_state.conversation_history)
311
+
312
+ # --- Generate model answer (your function here) ---
313
+ response = generate_text(gemini_model, schema, embedder, user_input, history_text, sql_examples, index,
314
+ qus_examples, index_qus)
315
+
316
+ # Add assistant response
317
+ st.session_state.conversation_history["messages"].append({"role": "assistant", "content": response})
318
+ st.chat_message("assistant").write(response)
319
+
320
+ # Sidebar options
321
+ st.sidebar.header("⚙️ Settings")
322
+ if st.sidebar.button("🧹 Clear Conversation"):
323
+ st.session_state.conversation_history = []
324
+ st.rerun()
325
+
326
+ st.sidebar.markdown("---")
327
+ st.sidebar.info("Built with ❤️ using Streamlit + Python\n\nModel backend: Gemini + Custom RAG")
328
+
329
+
330
+
331
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy~=2.3.3
2
+ protobuf~=5.29.5
3
+ streamlit~=1.50.0
4
+ pandas~=2.3.3
5
+ faiss-cpu~=1.12.0
6
+ psycopg2~=2.9.11
7
+ sentence-transformers~=5.1.1
8
+ torch
9
+ google-generativeai