from flask import Flask, request, jsonify, render_template from flask_socketio import SocketIO, emit from langchain_google_genai import ChatGoogleGenerativeAI from langchain.agents import AgentType from langchain_community.agent_toolkits import create_sql_agent from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_core.prompts import ChatPromptTemplate, PromptTemplate import threading import os from dotenv import load_dotenv import secrets import re import traceback from werkzeug.exceptions import HTTPException from transformers import pipeline load_dotenv() # os.environ["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY") api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise EnvironmentError("GEMINI_API_KEY is not set. Please set it as an environment variable.") app = Flask(__name__) app.config['SECRET_KEY'] = secrets.token_hex(32) socketio = SocketIO(app, cors_allowed_origins="*") # Set default DB current_db_type = "mysql" db = None agent_executor = None # Load Gemini LLM llm = ChatGoogleGenerativeAI(temperature=0.2, model="gemini-2.0-flash", max_retires = 50, tool_choice="auto", # max_tokens=1024, # streaning =True, api_key=os.getenv("GEMINI_API_KEY")) def init_agent(db_type): global db, agent_executor try: if db_type == "sqlite": db = SQLDatabase.from_uri("sqlite:///ecommerce_system_2.db") elif db_type == "mysql": # db = SQLDatabase.from_uri("mysql+pymysql://root:@localhost:3306/employee_sys") db = SQLDatabase.from_uri("mysql+pymysql://root:root@mysql/employee_sys") # db = SQLDatabase.from_uri("mysql+pymysql://root:root@mysql_service_name:3306/employee_sys") else: raise ValueError("Unsupported DB type") # def remove_limit_clause(sql: str) -> str: # # Remove LIMIT clause using regex # return re.sub(r'LIMIT\s+\d+\s*$', '', sql, flags=re.IGNORECASE) # Patch SQLDatabase.run # original_run = SQLDatabase.run # def patched_run(self, command: str, fetch: str = "all", include_columns: bool = False, parameters=None): # command = remove_limit_clause(command) # return original_run(self, command, fetch=fetch, include_columns=include_columns, parameters=parameters) # SQLDatabase.run = patched_run toolkit = SQLDatabaseToolkit(db=db, llm=llm) prefix = '''You are a helpful SQL expert agent that ALWAYS returns natural language answers using the tools. Always format your responses in Markdown. For example: - Use bullet points - Use bold for headers - Wrap code in triple backticks - Tables should use Markdown table syntax You must NEVER: - Show or mention SQL syntax. - Reveal table names, column names, or database schema. - Respond with any technical details or structure of the database. - Return code or tool names. - Give wrong Answers. You must ALWAYS: - Respond in plain, friendly language. - Don't Summarize the result for the user (e.g., "There are 9 tables in the system.") - If asked to list table names or schema, politely refuse and respond with: "I'm sorry, I can't share database structure information." - ALWAYS HAVE TO SOLVE COMPLEX USER QUERIES. FOR THAT, UNDERSTAND THE PROMPT, ANALYSE PROPER AND THEN GIVE ANSWER. - Your Answers should be correct, you have to do understand process well and give accurate answers Strict Rules You MUST Follow: - NEVER display or mention SQL queries. - NEVER explain SQL syntax or logic. - NEVER return technical or code-like responses. - ONLY respond in natural, human-friendly language. - You are not allow to give the name of any COLUMNS, TABLES, DATABASE, ENTITY, SYNTAX, STRUCTURE, DESIGN, ETC... If the user asks for anything other than retrieving data (SELECT), respond using this exact message: "I'm not allowed to perform operations other than SELECT queries. Please ask something that involves reading data." Do not return SQL queries or raw technical responses to the user. For example: Wrong: SELECT * FROM ... Correct: The user assigned to the cart is Alice Smith. Use the tools provided to get the correct data from the database and summarize the response clearly. If the input is unclear or lacks sufficient data, ask for clarification using the SubmitFinalAnswer tool. Never return SQL queries as your response. If you cannot find an answer, Double-check your query and running it again. - If a query fails, revise and try again. - Else 'No data found' using SubmitFinalAnswer.No SQL, no code. ''' format_instructions = """Use the following format: Question: the user query you must solve Thought: you should always think about what to do Action: the action to take, should be one of [{tool_names}] Action Input: the input to the action Observation: the result of the action ... (you can repeat Thought/Action/Observation as needed) Thought: I now know the final answer Final Answer: the answer to the user's original question""" suffix = """Begin! Question: {input} {agent_scratchpad}""" agent_executor = create_sql_agent( llm=llm, toolkit=toolkit, verbose=False, prefix=prefix, # suffix=suffix, # format_instructions=format_instructions, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, agent_executor_kwargs={ "handle_parsing_errors": True }, ) except Exception as e: print(f"[ERROR] Agent Initialized failed : {str(e)}") traceback.print_exc() raise RuntimeError(f"Agent Initialized failed : {str(e)}") # Initial setup init_agent(current_db_type) # Check Intent of user-prompt intent_prompt = ChatPromptTemplate.from_messages([ ("system", "Classify the intent of the user prompt. If it asks for schema, table names, column names, or database structure, return YES. Otherwise return NO."), ("human", "{prompt}") ]) intent_checker = intent_prompt | llm def is_schema_leak_request(prompt): classification = intent_checker.invoke({"prompt": prompt}) return "yes" in classification.content.lower() def is_schema_request(prompt: str) -> bool: """ Checks if the user prompt is trying to access schema or structure info. Returns True if it's about table names, schema, columns, etc. """ pattern = re.compile(r'\b(schema|table names|tables|columns|structure|column names|show tables|describe table|metadata)\b', re.IGNORECASE) return bool(pattern.search(prompt)) @app.errorhandler(Exception) def handle_all_errors(e): print(f"[ERROR] Global handler caught an exception: {str(e)}") traceback.print_exc() if isinstance(e, HTTPException): return jsonify({"status": "error", "message": e.description}), e.code return jsonify({"status": "error", "message": "An unexpected error occurred"}), 500 @app.route("/") def index(): return render_template("index_test.html") # Your frontend @app.route("/set_db", methods=["POST"]) def set_db(): global current_db_type db_type = request.json.get("db_type", "sqlite") try: current_db_type = db_type init_agent(db_type) return jsonify({"status": "ok", "message": f"Switched to {db_type}"}), 200 except Exception as e: print(f"[ERROR] Failed to switch DB: {str(e)}") traceback.print_exc() return jsonify({"status": "error", "message": str(e)}), 500 @app.route("/generate", methods=["POST"]) def generate(): try: data = request.get_json() prompt = data.get("prompt", "") if not prompt: print("[WARN] Empty prompt received.") return jsonify({"status": "error", "message": "Prompt is required"}), 400 except Exception as e: print(f"[ERROR] Invalid input format: {str(e)}") traceback.print_exc() return jsonify({"status": "error", "message": "Invalid input"}), 400 if is_schema_leak_request(prompt): msg = "Sorry, I can't share schema or structure-related information." # socketio.emit("flash", {"message": msg}) socketio.emit("final", {"message": msg}) return {"status": "blocked", "message": msg}, 403 if is_schema_request(prompt): # socketio.emit("flash", {"message": "⚠️ Access to schema or database structure is restricted."}) socketio.emit("final", {"message": "I'm sorry, I can't share database structure information."}) return jsonify({"status": "blocked", "message": "Schema request blocked"}), 403 def run_agent(): try: # socketio.emit("thought", {"message": f"Thinking about: {prompt}"}) # result = agent_executor.run(prompt) result = agent_executor.invoke({"input": prompt}) final_answer = result.get("output", "") intermediate_steps = result.get("intermediate_steps", []) # Try to extract table-like observation (from SQL tool) table_result = None for step in intermediate_steps: observation = step[1] if isinstance(observation, list): table_result = observation # Expecting a list of dicts or tuples break elif isinstance(observation, str) and "│" in observation: table_result = observation break if table_result: # Emit the table separately socketio.emit("table", {"data": table_result}) # socketio.emit("final", {"message": result}) socketio.emit("final", {"message": final_answer}) except KeyError: print("[ERROR] Unexpected response format from agent.") traceback.print_exc() socketio.emit("final", {"message": "Unexpected response format. Please try again."}) except TimeoutError: print("[ERROR] Request timed out.") traceback.print_exc() socketio.emit("final", {"message": "The request took too long. Please try again."}) except Exception as e: err_msg = f"[ERROR]: {str(e)}" print(err_msg) if "429" in err_msg and "rate limit" in err_msg.lower(): user_message = "Too many requests. Please wait a few seconds and try again." elif "rate_limit_exceeded" in err_msg: user_message = "You’re sending requests too fast. Please wait and try again shortly." else: user_message = "Agent processing failed." traceback.print_exc() socketio.emit("log", {"message": err_msg}) socketio.emit("log", {"message": user_message}) socketio.emit("final", {"message": user_message}) threading.Thread(target=run_agent).start() return jsonify({"status": "ok"}), 200 if __name__ == "__main__": socketio.run(app, host="0.0.0.0", port=7860, allow_unsafe_werkzeug=True)