prthm11 commited on
Commit
315b480
·
verified ·
1 Parent(s): e8f79fb

Update dynamicDB_gemini_sql_agent.py

Browse files
Files changed (1) hide show
  1. dynamicDB_gemini_sql_agent.py +236 -236
dynamicDB_gemini_sql_agent.py CHANGED
@@ -1,236 +1,236 @@
1
- from flask import Flask, request, jsonify, render_template
2
- from flask_socketio import SocketIO, emit
3
- from langchain_google_genai import ChatGoogleGenerativeAI
4
- from langchain.agents import AgentType
5
- from langchain_community.agent_toolkits import create_sql_agent
6
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
7
- from langchain_community.utilities import SQLDatabase
8
- from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
9
- import threading
10
- import os
11
- from dotenv import load_dotenv
12
- import secrets
13
- import re
14
- import traceback
15
- from werkzeug.exceptions import HTTPException
16
- from werkzeug.utils import secure_filename
17
-
18
- load_dotenv()
19
- os.environ["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY")
20
-
21
- app = Flask(__name__)
22
- app.config['SECRET_KEY'] = secrets.token_hex(32)
23
- app.config['UPLOAD_FOLDER'] = 'uploads'
24
- app.config['ALLOWED_EXTENSIONS'] = {'db'}
25
- socketio = SocketIO(app, cors_allowed_origins="*")
26
-
27
- # Ensure upload folder exists
28
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
29
-
30
- llm = ChatGoogleGenerativeAI(temperature=0.2,
31
- model="gemini-2.0-flash",
32
- max_retires = 50,
33
- tool_choice="auto",
34
- # max_tokens=1024,
35
- # streaning =True,
36
- api_key=os.getenv("GEMINI_API_KEY"))
37
-
38
- db = None
39
- agent_executor = None
40
-
41
- def allowed_file(filename):
42
- return filename.lower().endswith('.db')
43
-
44
- def init_agent(db_uri):
45
- global db, agent_executor
46
- db = SQLDatabase.from_uri(db_uri)
47
- toolkit = SQLDatabaseToolkit(db=db, llm=llm)
48
-
49
- prefix = '''You are a helpful SQL expert agent that ALWAYS returns natural language answers using the tools.
50
- Always format your responses in Markdown. For example:
51
- - Use bullet points
52
- - Use bold for headers
53
- - Wrap code in triple backticks
54
- - Tables should use Markdown table syntax
55
-
56
- You must NEVER:
57
- - Show or mention SQL syntax.
58
- - Reveal table names, column names, or database schema.
59
- - Respond with any technical details or structure of the database.
60
- - Return code or tool names.
61
- - Give wrong Answers.
62
-
63
- You must ALWAYS:
64
- - Respond in plain, friendly language.
65
- - Don't Summarize the result for the user (e.g., "There are 9 tables in the system.")
66
- - If asked to list table names or schema, politely refuse and respond with:
67
- "I'm sorry, I can't share database structure information."
68
- - ALWAYS HAVE TO SOLVE COMPLEX USER QUERIES. FOR THAT, UNDERSTAND THE PROMPT, ANALYSE PROPER AND THEN GIVE ANSWER.
69
- - Your Answers should be correct, you have to do understand process well and give accurate answers
70
-
71
- Strict Rules You MUST Follow:
72
- - NEVER display or mention SQL queries.
73
- - NEVER explain SQL syntax or logic.
74
- - NEVER return technical or code-like responses.
75
- - ONLY respond in natural, human-friendly language.
76
- - You are not allow to give the name of any COLUMNS, TABLES, DATABASE, ENTITY, SYNTAX, STRUCTURE, DESIGN, ETC...
77
-
78
- If the user asks for anything other than retrieving data (SELECT), respond using this exact message:
79
- "I'm not allowed to perform operations other than SELECT queries. Please ask something that involves reading data."
80
-
81
- Do not return SQL queries or raw technical responses to the user.
82
-
83
- For example:
84
- Wrong: SELECT * FROM ...
85
- Correct: The user assigned to the cart is Alice Smith.
86
-
87
- Use the tools provided to get the correct data from the database and summarize the response clearly.
88
- If the input is unclear or lacks sufficient data, ask for clarification using the SubmitFinalAnswer tool.
89
- Never return SQL queries as your response.
90
-
91
- If you cannot find an answer,
92
- Double-check your query and running it again.
93
- - If a query fails, revise and try again.
94
- - Else 'No data found' using SubmitFinalAnswer.No SQL, no code. '''
95
-
96
- agent_executor = create_sql_agent(
97
- llm=llm,
98
- toolkit=toolkit,
99
- verbose=False,
100
- prefix=prefix,
101
- agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
102
- agent_executor_kwargs={"handle_parsing_errors": True},
103
- )
104
-
105
- # Simple schema‐leak check
106
- intent_prompt = ChatPromptTemplate.from_messages([
107
- ("system", "Classify if user is asking schema/structure info: YES or NO."),
108
- ("human", "{prompt}")
109
- ])
110
- intent_checker = intent_prompt | llm
111
-
112
- def is_schema_leak_request(prompt):
113
- classification = intent_checker.invoke({"prompt": prompt})
114
- return "yes" in classification.content.lower()
115
-
116
- def is_schema_request(prompt: str) -> bool:
117
- """
118
- Checks if the user prompt is trying to access schema or structure info.
119
- Returns True if it's about table names, schema, columns, etc.
120
- """
121
- pattern = re.compile(r'\b(schema|table names|tables|columns|structure|column names|show tables|describe table|metadata)\b', re.IGNORECASE)
122
- return bool(pattern.search(prompt))
123
-
124
- @app.errorhandler(Exception)
125
- def handle_all_errors(e):
126
- print(f"[ERROR] Global handler caught an exception: {str(e)}")
127
- traceback.print_exc()
128
-
129
- if isinstance(e, HTTPException):
130
- return jsonify({"status": "error", "message": e.description}), e.code
131
-
132
- return jsonify({"status": "error", "message": "An unexpected error occurred"}), 500
133
-
134
- @app.route("/")
135
- def index():
136
- return render_template("dynamicDB_index_test2.html")
137
-
138
- @app.route("/upload_db", methods=["POST"])
139
- def upload_db():
140
- file = request.files.get('file')
141
- if not file or file.filename == '':
142
- return jsonify(success=False, message="No file provided"), 400
143
- if not allowed_file(file.filename):
144
- return jsonify(success=False, message="Only .db files supported"), 400
145
-
146
- filename = secure_filename(file.filename)
147
- path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
148
- file.save(path)
149
-
150
- try:
151
- init_agent(f"sqlite:///{path}")
152
- return jsonify(success=True, message="Database uploaded and initialized"), 200
153
- except Exception as e:
154
- return jsonify(success=False, message=f"Init failed: {e}"), 500
155
-
156
- @app.route("/generate", methods=["POST"])
157
- def generate():
158
- try:
159
- data = request.get_json(force=True)
160
- prompt = data.get("prompt", "").strip()
161
- if not prompt:
162
- print("[WARN] Empty prompt received.")
163
- return jsonify({"status": "error", "message": "Prompt is required"}), 400
164
- except Exception as e:
165
- print(f"[ERROR] Invalid input format: {str(e)}")
166
- traceback.print_exc()
167
- return jsonify({"status": "error", "message": "Invalid input"}), 400
168
-
169
- if is_schema_leak_request(prompt):
170
- msg = "Sorry, I can't share schema or structure-related information."
171
- # socketio.emit("flash", {"message": msg})
172
- socketio.emit("final", {"message": msg})
173
- return {"status": "blocked", "message": msg}, 403
174
-
175
- if is_schema_request(prompt):
176
- # socketio.emit("flash", {"message": "⚠️ Access to schema or database structure is restricted."})
177
- socketio.emit("final", {"message": "I'm sorry, I can't share database structure information."})
178
- return jsonify({"status": "blocked", "message": "Schema request blocked"}), 403
179
-
180
- def run_agent():
181
- try:
182
- # socketio.emit("thought", {"message": f"Thinking about: {prompt}"})
183
- # result = agent_executor.run(prompt)
184
- result = agent_executor.invoke({"input": prompt})
185
-
186
- final_answer = result.get("output", "")
187
- intermediate_steps = result.get("intermediate_steps", [])
188
-
189
- # Try to extract table-like observation (from SQL tool)
190
- table_result = None
191
- for step in intermediate_steps:
192
- observation = step[1]
193
- if isinstance(observation, list):
194
- table_result = observation # Expecting a list of dicts or tuples
195
- break
196
- elif isinstance(observation, str) and "│" in observation:
197
- table_result = observation
198
- break
199
-
200
- if table_result:
201
- # Emit the table separately
202
- socketio.emit("table", {"data": table_result})
203
-
204
- # socketio.emit("final", {"message": result})
205
- socketio.emit("final", {"message": final_answer})
206
- except KeyError:
207
- print("[ERROR] Unexpected response format from agent.")
208
- traceback.print_exc()
209
- socketio.emit("final", {"message": "Unexpected response format. Please try again."})
210
- except TimeoutError:
211
- print("[ERROR] Request timed out.")
212
- traceback.print_exc()
213
- socketio.emit("final", {"message": "The request took too long. Please try again."})
214
-
215
- except Exception as e:
216
- err_msg = f"[ERROR]: {str(e)}"
217
- print(err_msg)
218
- if "429" in err_msg and "rate limit" in err_msg.lower():
219
- user_message = "Too many requests. Please wait a few seconds and try again."
220
- elif "ResourceExhausted" in err_msg:
221
- user_message = "Try again after some time."
222
- elif "rate_limit_exceeded" in err_msg:
223
- user_message = "You’re sending requests too fast. Please wait and try again shortly."
224
- else:
225
- user_message = "Agent processing failed."
226
-
227
- traceback.print_exc()
228
- socketio.emit("log", {"message": err_msg})
229
- socketio.emit("log", {"message": user_message})
230
- socketio.emit("final", {"message": user_message})
231
-
232
- threading.Thread(target=run_agent).start()
233
- return jsonify({"status": "ok"}), 200
234
-
235
- if __name__ == "__main__":
236
- socketio.run(app, debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from flask_socketio import SocketIO, emit
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain.agents import AgentType
5
+ from langchain_community.agent_toolkits import create_sql_agent
6
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
7
+ from langchain_community.utilities import SQLDatabase
8
+ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
9
+ import threading
10
+ import os
11
+ from dotenv import load_dotenv
12
+ import secrets
13
+ import re
14
+ import traceback
15
+ from werkzeug.exceptions import HTTPException
16
+ from werkzeug.utils import secure_filename
17
+
18
+ load_dotenv()
19
+ os.environ["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY")
20
+
21
+ app = Flask(__name__)
22
+ app.config['SECRET_KEY'] = secrets.token_hex(32)
23
+ app.config['UPLOAD_FOLDER'] = 'uploads'
24
+ app.config['ALLOWED_EXTENSIONS'] = {'db'}
25
+ socketio = SocketIO(app, cors_allowed_origins="*")
26
+
27
+ # Ensure upload folder exists
28
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
29
+
30
+ llm = ChatGoogleGenerativeAI(temperature=0.2,
31
+ model="gemini-2.0-flash",
32
+ max_retires = 50,
33
+ tool_choice="auto",
34
+ # max_tokens=1024,
35
+ # streaning =True,
36
+ api_key=os.getenv("GEMINI_API_KEY"))
37
+
38
+ db = None
39
+ agent_executor = None
40
+
41
+ def allowed_file(filename):
42
+ return filename.lower().endswith('.db')
43
+
44
+ def init_agent(db_uri):
45
+ global db, agent_executor
46
+ db = SQLDatabase.from_uri(db_uri)
47
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
48
+
49
+ prefix = '''You are a helpful SQL expert agent that ALWAYS returns natural language answers using the tools.
50
+ Always format your responses in Markdown. For example:
51
+ - Use bullet points
52
+ - Use bold for headers
53
+ - Wrap code in triple backticks
54
+ - Tables should use Markdown table syntax
55
+
56
+ You must NEVER:
57
+ - Show or mention SQL syntax.
58
+ - Reveal table names, column names, or database schema.
59
+ - Respond with any technical details or structure of the database.
60
+ - Return code or tool names.
61
+ - Give wrong Answers.
62
+
63
+ You must ALWAYS:
64
+ - Respond in plain, friendly language.
65
+ - Don't Summarize the result for the user (e.g., "There are 9 tables in the system.")
66
+ - If asked to list table names or schema, politely refuse and respond with:
67
+ "I'm sorry, I can't share database structure information."
68
+ - ALWAYS HAVE TO SOLVE COMPLEX USER QUERIES. FOR THAT, UNDERSTAND THE PROMPT, ANALYSE PROPER AND THEN GIVE ANSWER.
69
+ - Your Answers should be correct, you have to do understand process well and give accurate answers
70
+
71
+ Strict Rules You MUST Follow:
72
+ - NEVER display or mention SQL queries.
73
+ - NEVER explain SQL syntax or logic.
74
+ - NEVER return technical or code-like responses.
75
+ - ONLY respond in natural, human-friendly language.
76
+ - You are not allow to give the name of any COLUMNS, TABLES, DATABASE, ENTITY, SYNTAX, STRUCTURE, DESIGN, ETC...
77
+
78
+ If the user asks for anything other than retrieving data (SELECT), respond using this exact message:
79
+ "I'm not allowed to perform operations other than SELECT queries. Please ask something that involves reading data."
80
+
81
+ Do not return SQL queries or raw technical responses to the user.
82
+
83
+ For example:
84
+ Wrong: SELECT * FROM ...
85
+ Correct: The user assigned to the cart is Alice Smith.
86
+
87
+ Use the tools provided to get the correct data from the database and summarize the response clearly.
88
+ If the input is unclear or lacks sufficient data, ask for clarification using the SubmitFinalAnswer tool.
89
+ Never return SQL queries as your response.
90
+
91
+ If you cannot find an answer,
92
+ Double-check your query and running it again.
93
+ - If a query fails, revise and try again.
94
+ - Else 'No data found' using SubmitFinalAnswer.No SQL, no code. '''
95
+
96
+ agent_executor = create_sql_agent(
97
+ llm=llm,
98
+ toolkit=toolkit,
99
+ verbose=False,
100
+ prefix=prefix,
101
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
102
+ agent_executor_kwargs={"handle_parsing_errors": True},
103
+ )
104
+
105
+ # Simple schema‐leak check
106
+ intent_prompt = ChatPromptTemplate.from_messages([
107
+ ("system", "Classify if user is asking schema/structure info: YES or NO."),
108
+ ("human", "{prompt}")
109
+ ])
110
+ intent_checker = intent_prompt | llm
111
+
112
+ def is_schema_leak_request(prompt):
113
+ classification = intent_checker.invoke({"prompt": prompt})
114
+ return "yes" in classification.content.lower()
115
+
116
+ def is_schema_request(prompt: str) -> bool:
117
+ """
118
+ Checks if the user prompt is trying to access schema or structure info.
119
+ Returns True if it's about table names, schema, columns, etc.
120
+ """
121
+ pattern = re.compile(r'\b(schema|table names|tables|columns|structure|column names|show tables|describe table|metadata)\b', re.IGNORECASE)
122
+ return bool(pattern.search(prompt))
123
+
124
+ @app.errorhandler(Exception)
125
+ def handle_all_errors(e):
126
+ print(f"[ERROR] Global handler caught an exception: {str(e)}")
127
+ traceback.print_exc()
128
+
129
+ if isinstance(e, HTTPException):
130
+ return jsonify({"status": "error", "message": e.description}), e.code
131
+
132
+ return jsonify({"status": "error", "message": "An unexpected error occurred"}), 500
133
+
134
+ @app.route("/")
135
+ def index():
136
+ return render_template("dynamicDB_index_test2.html")
137
+
138
+ @app.route("/upload_db", methods=["POST"])
139
+ def upload_db():
140
+ file = request.files.get('file')
141
+ if not file or file.filename == '':
142
+ return jsonify(success=False, message="No file provided"), 400
143
+ if not allowed_file(file.filename):
144
+ return jsonify(success=False, message="Only .db files supported"), 400
145
+
146
+ filename = secure_filename(file.filename)
147
+ path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
148
+ file.save(path)
149
+
150
+ try:
151
+ init_agent(f"sqlite:///{path}")
152
+ return jsonify(success=True, message="Database uploaded and initialized"), 200
153
+ except Exception as e:
154
+ return jsonify(success=False, message=f"Init failed: {e}"), 500
155
+
156
+ @app.route("/generate", methods=["POST"])
157
+ def generate():
158
+ try:
159
+ data = request.get_json(force=True)
160
+ prompt = data.get("prompt", "").strip()
161
+ if not prompt:
162
+ print("[WARN] Empty prompt received.")
163
+ return jsonify({"status": "error", "message": "Prompt is required"}), 400
164
+ except Exception as e:
165
+ print(f"[ERROR] Invalid input format: {str(e)}")
166
+ traceback.print_exc()
167
+ return jsonify({"status": "error", "message": "Invalid input"}), 400
168
+
169
+ if is_schema_leak_request(prompt):
170
+ msg = "Sorry, I can't share schema or structure-related information."
171
+ # socketio.emit("flash", {"message": msg})
172
+ socketio.emit("final", {"message": msg})
173
+ return {"status": "blocked", "message": msg}, 403
174
+
175
+ if is_schema_request(prompt):
176
+ # socketio.emit("flash", {"message": "⚠️ Access to schema or database structure is restricted."})
177
+ socketio.emit("final", {"message": "I'm sorry, I can't share database structure information."})
178
+ return jsonify({"status": "blocked", "message": "Schema request blocked"}), 403
179
+
180
+ def run_agent():
181
+ try:
182
+ # socketio.emit("thought", {"message": f"Thinking about: {prompt}"})
183
+ # result = agent_executor.run(prompt)
184
+ result = agent_executor.invoke({"input": prompt})
185
+
186
+ final_answer = result.get("output", "")
187
+ intermediate_steps = result.get("intermediate_steps", [])
188
+
189
+ # Try to extract table-like observation (from SQL tool)
190
+ table_result = None
191
+ for step in intermediate_steps:
192
+ observation = step[1]
193
+ if isinstance(observation, list):
194
+ table_result = observation # Expecting a list of dicts or tuples
195
+ break
196
+ elif isinstance(observation, str) and "│" in observation:
197
+ table_result = observation
198
+ break
199
+
200
+ if table_result:
201
+ # Emit the table separately
202
+ socketio.emit("table", {"data": table_result})
203
+
204
+ # socketio.emit("final", {"message": result})
205
+ socketio.emit("final", {"message": final_answer})
206
+ except KeyError:
207
+ print("[ERROR] Unexpected response format from agent.")
208
+ traceback.print_exc()
209
+ socketio.emit("final", {"message": "Unexpected response format. Please try again."})
210
+ except TimeoutError:
211
+ print("[ERROR] Request timed out.")
212
+ traceback.print_exc()
213
+ socketio.emit("final", {"message": "The request took too long. Please try again."})
214
+
215
+ except Exception as e:
216
+ err_msg = f"[ERROR]: {str(e)}"
217
+ print(err_msg)
218
+ if "429" in err_msg and "rate limit" in err_msg.lower():
219
+ user_message = "Too many requests. Please wait a few seconds and try again."
220
+ elif "ResourceExhausted" in err_msg:
221
+ user_message = "Try again after some time."
222
+ elif "rate_limit_exceeded" in err_msg:
223
+ user_message = "You’re sending requests too fast. Please wait and try again shortly."
224
+ else:
225
+ user_message = "Agent processing failed."
226
+
227
+ traceback.print_exc()
228
+ socketio.emit("log", {"message": err_msg})
229
+ socketio.emit("log", {"message": user_message})
230
+ socketio.emit("final", {"message": user_message})
231
+
232
+ threading.Thread(target=run_agent).start()
233
+ return jsonify({"status": "ok"}), 200
234
+
235
+ if __name__ == "__main__":
236
+ socketio.run(app, host="0.0.0.0", port=7860, allow_unsafe_werkzeug=True)