# from llama_index.core.query_engine import NLSQLTableQueryEngine # from llama_index.core.indices.struct_store.sql_query import ( # SQLTableRetrieverQueryEngine, # NLSQLTableQueryEngine # ) from llama_index.core import set_global_handler from dotenv import load_dotenv load_dotenv() set_global_handler("simple") from llama_index.core.query_engine import RetryQueryEngine from llama_index.core import SQLDatabase, VectorStoreIndex, Settings, PromptTemplate, Response from llama_index.core.agent import QueryPipelineAgentWorker, AgentRunner, Task, AgentChatResponse from llama_index.llms.gemini import Gemini from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema from llama_index.core.callbacks import CallbackManager from llama_index.core.agent.types import Task from llama_index.embeddings.gemini import GeminiEmbedding from llama_index.core.query_pipeline import AgentInputComponent, AgentFnComponent, QueryPipeline as QP from custom_modules.custom_query_engine import SQLTableRetrieverQueryEngine from llama_index.core.prompts.prompt_type import PromptType import json, os, time, gradio as gr from typing import Dict, Any, Tuple from sqlalchemy import create_engine, inspect Settings.llm = Gemini(temperature=0.2) Settings.embed_model = GeminiEmbedding() llm = Settings.llm qp = QP(verbose=True) engine = create_engine("sqlite:///Llama_index_structured_sql/database/param4table.db") sql_database = SQLDatabase(engine) qa_prompt_tmpl_str = ( "Given an input question, first create a syntactically correct {dialect} " "query to run, then look at the results of the query and return the answer. " "You can order the results by a relevant column to return the most " "interesting examples in the database.\n\n" "Never query for all the columns from a specific table, only ask for a " "few relevant columns given the question.\n\n" "Pay attention to use only the column names that you can see in the schema " "description. " "Be careful to not query for columns that do not exist. " "Pay attention to which column is in which table. " "Also, qualify column names with the table name when needed. " "If sales and inwards table needs to be connected, connect them\ taking distinct str_id from inwards and then join this with the sales table." # "WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards),:\ # location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm)\ # SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls\ # ORDER BY ls.tot DESC;" "Only use the context information and not prior knowledge." # answer the query asking about citations over different topics." "Please provide your answer in the form of a structured JSON format containing \ a list of authors as the citations. Some examples are given below." "{few_shot_examples}" "\nIf there is a convo history given, you have to check the Inferred SQL query and then return a modified query.\n\n" "Convo history (failed attempts):\n{convo_history}\n" "You are required to use the following format, each taking one line:\n\n" "Question: Question here\n" "SQLQuery: SQL Query to run\n\n" # "SQLResult: Result of the SQLQuery\n" # "Answer: Final answer here\n\n" "Only use tables listed below.\n" "{schema}\n\n" "Question: {query_str}\n" "SQLQuery: " ) qa_prompt_tmpl = PromptTemplate( qa_prompt_tmpl_str, prompt_type=PromptType.TEXT_TO_SQL, ) few_shot_examples = [ # {"query":"How many states are present?","response":"SELECT COUNT(DISTINCT lctn_stat_nm) AS num_states FROM inwards;"}, # {"query":"Which colour is the most popular in terms of net sale value?","response":"SELECT f.fabric_color, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY f.fabric_color ORDER BY total_net_sale_val DESC;"}, # {"query":"How many departments?","response":"SELECT COUNT(DISTINCT department) from styles;"} {"query":"States with above average sales.","response":"WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards), location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm) SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls WHERE ls.tot > (SELECT AVG(tot) FROM location_sales) ORDER BY ls.tot DESC;"}, # {"query":"Which stores had the highest sales per quantity?","response":"SELECT s.customer, SUM(s.net_sale_val) as tot_sales, SUM(s.sales_qty) AS tot_sales_qty, SUM(s.net_sale_val) / SUM(s.sales_qty) AS sales_value_per_quantity FROM sales AS s GROUP BY s.customer ORDER BY sales_value_per_quantity DESC;"}, # {"query":"Which colored fabric was sold the most and least?","response":"SELECT fabric_color, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY fabric_color ORDER BY total_sales_qty DESC;"}, # {"query":"Which weave type was sold the most and the least?","response":"SELECT weave_type, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY weave_type ORDER BY total_sales_qty DESC;"}, # {"query":"What color and weave type together worked the best?","response":"SELECT f.fabric_color, f.weave_type, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics AS f JOIN styles AS st ON f.fabric_code = st.fabric_id JOIN sales AS s ON st.product_id = s.material GROUP BY f.fabric_color, f.weave_type ORDER BY total_net_sale_val DESC;"}, # {"query":"What colors and departments together did best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) AS sales_value, fb.fabric_color AS clr, fb.department AS department FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.department ORDER BY sales_value DESC), most AS (SELECT clr, department, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT clr, department, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"}, {"query":"How has the month on month sales changed across departments in the last year?","response":"WITH monthly_sales AS (SELECT st.department, STRFTIME('%m', s.cal_day) AS month, SUM(s.net_sale_val) AS monthly_sales FROM sales s JOIN styles st ON s.material = st.product_id WHERE s.cal_day >= DATE(CURRENT_DATE, '-1 year') GROUP BY department, STRFTIME('%m', s.cal_day)), SELECT department, month, monthly_sales, monthly_sales - LAG(monthly_sales, 1) OVER (PARTITION BY department ORDER BY month) AS month_on_month_change FROM monthly_sales ORDER BY department, month;"}, # {"query":"What was the average sell through rate for fabric across various departments?","response":""}, # {"query":"Which color that the highest sell through in each department in the last 3 months?","response":""}, # {"query":"In shirts with a certain collar type (this is style table attribute) what is the average cotton percentage?","response":""}, # {"query":"Which collar type had the most sales, most sell through?","response":""}, {"query":"Which weave type was sold the most?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.weave_type ORDER BY sales_value DESC), most AS (SELECT weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT weave_type, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"}, {"query":"What color and weave type together worked the best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.fabric_color AS clr, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.weave_type ORDER BY sales_value DESC) SELECT clr, weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables;"}, {"query":"List the combination of fabrics and brands with their total sales and total inwards quantity.", "response":"WITH sks AS (SELECT st.fabric_id AS fab_id, st.product_id AS prod_id, SUM(sa.net_sale_val) AS tot_sales FROM styles st JOIN sales sa ON st.product_id = sa.material GROUP BY st.product_id), sti AS (SELECT s.product_id AS prod_id, SUM(i.inwards_qty) AS tot_inwards_qty FROM styles s JOIN inwards i ON s.product_id = i.style_id GROUP BY s.product_id) SELECT sks.fab_id, f.brand, SUM(sks.tot_sales) AS tot_s, sti.tot_inwards_qty FROM sks JOIN sti ON sks.prod_id = sti.prod_id JOIN fabrics f ON f.fabric_code = sks.fab_id GROUP BY sks.fab_id, f.brand;"} ] text2sql_prompt = qa_prompt_tmpl.partial_format( dialect=engine.dialect.name, few_shot_examples=few_shot_examples, convo_history=None ) # print(sql_query_engine_nl) tableinfo_dir='Llama_index_structured_sql/param_enhanced_metadata' insp = inspect(engine) table_names = insp.get_table_names() table_infos=[] files = os.listdir(tableinfo_dir) for idx,file in enumerate(files): with open(f'{tableinfo_dir}/{file}', 'r+') as file: table_info = json.load(file) table_infos.append(table_info) sql_database = SQLDatabase(engine) table_node_mapping = SQLTableNodeMapping(sql_database) table_schema_objs = [ SQLTableSchema(table_name=t["table_name"], context_str=t["table_summary"]+str(t["columns"])) for t in table_infos ] # add a SQLTableSchema for each table # print(table_schema_objs) obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, VectorStoreIndex, ) response_synthesis_prompt_str = ( "Given an input question, synthesize a response from the query results." "Please stick to the Indian number system and Indian currency." "Round the sales value to 2 decimal places. If there are more than 2 columns return in a tabular format.\n" "Query: {query_str}\n" "SQL: {sql_query}\n" "SQL Response: {context_str}\n" "Response: " ) response_synthesis_prompt = PromptTemplate( response_synthesis_prompt_str, ) sql_query_engine = SQLTableRetrieverQueryEngine( sql_database, text_to_sql_prompt=text2sql_prompt, table_retriever=obj_index.as_retriever(similarity_top_k=4), llm=llm, response_synthesis_prompt=response_synthesis_prompt ) def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict: """Agent input function.""" # initialize current_reasoning if "convo_history" not in state: state["convo_history"] = [] state["count"] = 0 state["convo_history"].append(f"User: {task.input}") convo_history_str = "\n".join(state["convo_history"]) or "None" # print(state) text2sql_prompt = qa_prompt_tmpl.partial_format( dialect=engine.dialect.name, few_shot_examples=few_shot_examples, convo_history=convo_history_str ) sql_query_engine.sql_retriever._update_prompts({'text_to_sql_prompt':text2sql_prompt}) print(sql_query_engine.sql_retriever.get_prompts()) return {"input": task.input, "convo_history": convo_history_str} agent_input_component = AgentInputComponent(fn=agent_input_fn) ## This prompt will generate a transformed natural language query for further downstream tasks. retry_prompt_str = """\ You are trying to generate a proper natural language query given a user input. This query will then be interpreted by a downstream text-to-SQL agent which will convert the query to a SQL statement. If the agent triggers an error, then that will be reflected in the current conversation history (see below). If the conversation history is None, use the user input. If its not None, generate a new SQL query that avoids the problems of the previous SQL query. Input: {input} Convo history (failed attempts): {convo_history} New input: """ retry_prompt = PromptTemplate(retry_prompt_str) ## This prompt will validate whether the inferred SQL query and response from executing the query is correct and answers the query validate_prompt_str = """\ Given the user query, validate whether the inferred SQL query and response from executing the query is correct and answers the query. In SQL Response, round any decimal values to 2 decimal places. Answer with YES or NO. Query: {input} Inferred SQL query: {sql_query} SQL Response: {sql_response} Result: """ validate_prompt = PromptTemplate(validate_prompt_str) MAX_ITER = 3 extracted_sql_query = None def agent_output_fn( task: Task, state: Dict[str, Any], output: Response ) -> Tuple[AgentChatResponse, bool]: """Agent output component.""" global extracted_sql_query print(f"> Inferred SQL Query:\n{output.metadata['sql_query']}") print(f"> SQL Response:\n{str(output)}") extracted_sql_query = output.metadata['sql_query'] state["convo_history"].append( f"Assistant (inferred SQL query): {output.metadata['sql_query']}" ) state["convo_history"].append(f"Assistant (response): {str(output)}") # run a mini chain to get response validate_prompt_partial = validate_prompt.as_query_component( partial={ "sql_query": output.metadata["sql_query"], "sql_response": str(output), } ) qp = QP(chain=[validate_prompt_partial, llm]) validate_output = qp.run(input=task.input) # print(validate_output) state["count"] += 1 is_done = False if state["count"] >= MAX_ITER: is_done = True if "YES" in validate_output.message.content: is_done = True return AgentChatResponse(response=str(output)), is_done agent_output_component = AgentFnComponent(fn=agent_output_fn) qp = QP( modules={ "input": agent_input_component, "retry_prompt": retry_prompt, "llm": llm, "sql_query_engine": sql_query_engine, "output_component": agent_output_component, }, verbose=True, ) qp.add_link("input", "retry_prompt", src_key="input", dest_key="input") qp.add_link( "input", "retry_prompt", src_key="convo_history", dest_key="convo_history" ) qp.add_chain(["retry_prompt", "llm", "sql_query_engine", "output_component"]) # from pyvis.network import Network # net = Network(notebook=True, cdn_resources="in_line", directed=True) # net.from_nx(qp.clean_dag) # net.show("retry_agent_dag.html") agent_worker = QueryPipelineAgentWorker(qp) agent = AgentRunner( agent_worker=agent_worker, callback_manager=CallbackManager([]), verbose=True ) def slow_echo(query,history): # print(extracted_sql_query) response = agent.chat(query) message = str(response).split(':')[-1].strip() res = f'**Answer:**\n\n{message}\n\n\n\n**LLM Generated SQL Query:**\n\n{extracted_sql_query}' for i in range(len(res)): time.sleep(0.01) yield res[: i + 1] demo = gr.ChatInterface(slow_echo).queue() if __name__ == "__main__": demo.launch(share=True)