Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_community.utilities.sql_database import SQLDatabase | |
| from langchain_community.agent_toolkits import create_sql_agent | |
| from langchain_openai import AzureChatOpenAI | |
| ccms_db_loc = 'ccms.db' | |
| ccms_db = SQLDatabase.from_uri(f"sqlite:///{ccms_db_loc}") | |
| gpt4o_azure = AzureChatOpenAI( | |
| model_name='gpt-4o-mini', | |
| api_key=os.environ["AZURE_OPENAI_KEY"], | |
| azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], | |
| api_version="2024-02-01", | |
| temperature=0 | |
| ) | |
| context = ccms_db.get_context() | |
| database_schema = context['table_info'] | |
| system_message = f"""You are a SQLite expert agent designed to interact with a SQLite database. | |
| Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. | |
| Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.. | |
| You can order the results by a relevant column to return the most interesting examples in the database. | |
| Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
| You have access to tools for interacting with the database. | |
| Only use the given tools. Only use the information returned by the tools to construct your final answer. | |
| You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | |
| DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | |
| If the question does not seem related to the database, just return "I don't know" as the answer. | |
| Only use the following tables: | |
| {database_schema} | |
| """ | |
| full_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_message), | |
| ("human", '{input}'), | |
| MessagesPlaceholder("agent_scratchpad") | |
| ] | |
| ) | |
| sqlite_agent = create_sql_agent( | |
| llm=gpt4o_azure, | |
| db=ccms_db, | |
| prompt=full_prompt, | |
| agent_type="openai-tools", | |
| agent_executor_kwargs={'handle_parsing_errors':True}, | |
| max_iterations=10, | |
| verbose=True | |
| ) | |
| def predict(user_input): | |
| try: | |
| response = sqlite_agent.invoke(user_input) | |
| prediction = response['output'] | |
| except Exception as e: | |
| prediction = e | |
| return prediction | |
| # UI | |
| textbox = gr.Textbox(placeholder="Enter your query here", lines=6) | |
| schema = 'The schema for the database is presented below: \n <img src="https://cdn-uploads.huggingface.co/production/uploads/64118e60756b9e455c7eddd6/S1alVt_D88qatd-N4Dkjd.png" > \n<img src="https://cdn-uploads.huggingface.co/production/uploads/64118e60756b9e455c7eddd6/81ggHEjrt6wFrMyXJtHVS.png" > (Source: https://github.com/shrivastavasatyam/Credit-Card-Management-System)' | |
| demo = gr.Interface( | |
| inputs=textbox, fn=predict, outputs="text", | |
| title="Query a Credit Card Database", | |
| description="This web API presents an interface to ask questions on information stored in a credit card database.", | |
| article=schema, | |
| examples=[ | |
| ["Who are the top 5 merchants by total transactions?", ""], | |
| ["Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?", ""], | |
| ["Which is the highest spend month and amount for each card type?", ""], | |
| ["Which was the city with the lowest percentage spend for the Gold card type?", ""], | |
| ["What was the percentage contribution of spends by females for each card type?", ""], | |
| ["Which city has the highest spend to transaction ratio on weekends?", ""], | |
| ["Which was the city to reach 500 transactions the fastest?", ""] | |
| ], | |
| cache_examples=False, | |
| theme=gr.themes.Base(), | |
| concurrency_limit=8 | |
| ) | |
| demo.queue() | |
| demo.launch(auth=("demouser", os.getenv('PASSWD'))) |