Spaces:
Sleeping
Sleeping
# from llama_index.core.query_engine import NLSQLTableQueryEngine | |
# from llama_index.core.indices.struct_store.sql_query import ( | |
# SQLTableRetrieverQueryEngine, | |
# NLSQLTableQueryEngine | |
# ) | |
from llama_index.core import set_global_handler | |
from dotenv import load_dotenv | |
load_dotenv() | |
set_global_handler("simple") | |
from llama_index.core.query_engine import RetryQueryEngine | |
from llama_index.core import SQLDatabase, VectorStoreIndex, Settings, PromptTemplate, Response | |
from llama_index.core.agent import QueryPipelineAgentWorker, AgentRunner, Task, AgentChatResponse | |
from llama_index.llms.gemini import Gemini | |
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema | |
from llama_index.core.callbacks import CallbackManager | |
from llama_index.core.agent.types import Task | |
from llama_index.embeddings.gemini import GeminiEmbedding | |
from llama_index.core.query_pipeline import AgentInputComponent, AgentFnComponent, QueryPipeline as QP | |
from custom_modules.custom_query_engine import SQLTableRetrieverQueryEngine | |
from llama_index.core.prompts.prompt_type import PromptType | |
import json, os, time, gradio as gr | |
from typing import Dict, Any, Tuple | |
from sqlalchemy import create_engine, inspect | |
Settings.llm = Gemini(temperature=0.2) | |
Settings.embed_model = GeminiEmbedding() | |
llm = Settings.llm | |
qp = QP(verbose=True) | |
engine = create_engine("sqlite:///Llama_index_structured_sql/database/param4table.db") | |
sql_database = SQLDatabase(engine) | |
qa_prompt_tmpl_str = ( | |
"Given an input question, first create a syntactically correct {dialect} " | |
"query to run, then look at the results of the query and return the answer. " | |
"You can order the results by a relevant column to return the most " | |
"interesting examples in the database.\n\n" | |
"Never query for all the columns from a specific table, only ask for a " | |
"few relevant columns given the question.\n\n" | |
"Pay attention to use only the column names that you can see in the schema " | |
"description. " | |
"Be careful to not query for columns that do not exist. " | |
"Pay attention to which column is in which table. " | |
"Also, qualify column names with the table name when needed. " | |
"If sales and inwards table needs to be connected, connect them\ | |
taking distinct str_id from inwards and then join this with the sales table." | |
# "WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards),:\ | |
# location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm)\ | |
# SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls\ | |
# ORDER BY ls.tot DESC;" | |
"Only use the context information and not prior knowledge." | |
# answer the query asking about citations over different topics." | |
"Please provide your answer in the form of a structured JSON format containing \ | |
a list of authors as the citations. Some examples are given below." | |
"{few_shot_examples}" | |
"\nIf there is a convo history given, you have to check the Inferred SQL query and then return a modified query.\n\n" | |
"Convo history (failed attempts):\n{convo_history}\n" | |
"You are required to use the following format, each taking one line:\n\n" | |
"Question: Question here\n" | |
"SQLQuery: SQL Query to run\n\n" | |
# "SQLResult: Result of the SQLQuery\n" | |
# "Answer: Final answer here\n\n" | |
"Only use tables listed below.\n" | |
"{schema}\n\n" | |
"Question: {query_str}\n" | |
"SQLQuery: " | |
) | |
qa_prompt_tmpl = PromptTemplate( | |
qa_prompt_tmpl_str, | |
prompt_type=PromptType.TEXT_TO_SQL, | |
) | |
few_shot_examples = [ | |
# {"query":"How many states are present?","response":"SELECT COUNT(DISTINCT lctn_stat_nm) AS num_states FROM inwards;"}, | |
# {"query":"Which colour is the most popular in terms of net sale value?","response":"SELECT f.fabric_color, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY f.fabric_color ORDER BY total_net_sale_val DESC;"}, | |
# {"query":"How many departments?","response":"SELECT COUNT(DISTINCT department) from styles;"} | |
{"query":"States with above average sales.","response":"WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards), location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm) SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls WHERE ls.tot > (SELECT AVG(tot) FROM location_sales) ORDER BY ls.tot DESC;"}, | |
# {"query":"Which stores had the highest sales per quantity?","response":"SELECT s.customer, SUM(s.net_sale_val) as tot_sales, SUM(s.sales_qty) AS tot_sales_qty, SUM(s.net_sale_val) / SUM(s.sales_qty) AS sales_value_per_quantity FROM sales AS s GROUP BY s.customer ORDER BY sales_value_per_quantity DESC;"}, | |
# {"query":"Which colored fabric was sold the most and least?","response":"SELECT fabric_color, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY fabric_color ORDER BY total_sales_qty DESC;"}, | |
# {"query":"Which weave type was sold the most and the least?","response":"SELECT weave_type, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY weave_type ORDER BY total_sales_qty DESC;"}, | |
# {"query":"What color and weave type together worked the best?","response":"SELECT f.fabric_color, f.weave_type, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics AS f JOIN styles AS st ON f.fabric_code = st.fabric_id JOIN sales AS s ON st.product_id = s.material GROUP BY f.fabric_color, f.weave_type ORDER BY total_net_sale_val DESC;"}, | |
# {"query":"What colors and departments together did best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) AS sales_value, fb.fabric_color AS clr, fb.department AS department FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.department ORDER BY sales_value DESC), most AS (SELECT clr, department, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT clr, department, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"}, | |
{"query":"How has the month on month sales changed across departments in the last year?","response":"WITH monthly_sales AS (SELECT st.department, STRFTIME('%m', s.cal_day) AS month, SUM(s.net_sale_val) AS monthly_sales FROM sales s JOIN styles st ON s.material = st.product_id WHERE s.cal_day >= DATE(CURRENT_DATE, '-1 year') GROUP BY department, STRFTIME('%m', s.cal_day)), SELECT department, month, monthly_sales, monthly_sales - LAG(monthly_sales, 1) OVER (PARTITION BY department ORDER BY month) AS month_on_month_change FROM monthly_sales ORDER BY department, month;"}, | |
# {"query":"What was the average sell through rate for fabric across various departments?","response":""}, | |
# {"query":"Which color that the highest sell through in each department in the last 3 months?","response":""}, | |
# {"query":"In shirts with a certain collar type (this is style table attribute) what is the average cotton percentage?","response":""}, | |
# {"query":"Which collar type had the most sales, most sell through?","response":""}, | |
{"query":"Which weave type was sold the most?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.weave_type ORDER BY sales_value DESC), most AS (SELECT weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT weave_type, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"}, | |
{"query":"What color and weave type together worked the best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.fabric_color AS clr, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.weave_type ORDER BY sales_value DESC) SELECT clr, weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables;"}, | |
{"query":"List the combination of fabrics and brands with their total sales and total inwards quantity.", "response":"WITH sks AS (SELECT st.fabric_id AS fab_id, st.product_id AS prod_id, SUM(sa.net_sale_val) AS tot_sales FROM styles st JOIN sales sa ON st.product_id = sa.material GROUP BY st.product_id), sti AS (SELECT s.product_id AS prod_id, SUM(i.inwards_qty) AS tot_inwards_qty FROM styles s JOIN inwards i ON s.product_id = i.style_id GROUP BY s.product_id) SELECT sks.fab_id, f.brand, SUM(sks.tot_sales) AS tot_s, sti.tot_inwards_qty FROM sks JOIN sti ON sks.prod_id = sti.prod_id JOIN fabrics f ON f.fabric_code = sks.fab_id GROUP BY sks.fab_id, f.brand;"} | |
] | |
text2sql_prompt = qa_prompt_tmpl.partial_format( | |
dialect=engine.dialect.name, | |
few_shot_examples=few_shot_examples, | |
convo_history=None | |
) | |
# print(sql_query_engine_nl) | |
tableinfo_dir='Llama_index_structured_sql/param_enhanced_metadata' | |
insp = inspect(engine) | |
table_names = insp.get_table_names() | |
table_infos=[] | |
files = os.listdir(tableinfo_dir) | |
for idx,file in enumerate(files): | |
with open(f'{tableinfo_dir}/{file}', 'r+') as file: | |
table_info = json.load(file) | |
table_infos.append(table_info) | |
sql_database = SQLDatabase(engine) | |
table_node_mapping = SQLTableNodeMapping(sql_database) | |
table_schema_objs = [ | |
SQLTableSchema(table_name=t["table_name"], context_str=t["table_summary"]+str(t["columns"])) | |
for t in table_infos | |
] # add a SQLTableSchema for each table | |
# print(table_schema_objs) | |
obj_index = ObjectIndex.from_objects( | |
table_schema_objs, | |
table_node_mapping, | |
VectorStoreIndex, | |
) | |
response_synthesis_prompt_str = ( | |
"Given an input question, synthesize a response from the query results." | |
"Please stick to the Indian number system and Indian currency." | |
"Round the sales value to 2 decimal places. If there are more than 2 columns return in a tabular format.\n" | |
"Query: {query_str}\n" | |
"SQL: {sql_query}\n" | |
"SQL Response: {context_str}\n" | |
"Response: " | |
) | |
response_synthesis_prompt = PromptTemplate( | |
response_synthesis_prompt_str, | |
) | |
sql_query_engine = SQLTableRetrieverQueryEngine( | |
sql_database, | |
text_to_sql_prompt=text2sql_prompt, | |
table_retriever=obj_index.as_retriever(similarity_top_k=4), | |
llm=llm, | |
response_synthesis_prompt=response_synthesis_prompt | |
) | |
def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict: | |
"""Agent input function.""" | |
# initialize current_reasoning | |
if "convo_history" not in state: | |
state["convo_history"] = [] | |
state["count"] = 0 | |
state["convo_history"].append(f"User: {task.input}") | |
convo_history_str = "\n".join(state["convo_history"]) or "None" | |
# print(state) | |
text2sql_prompt = qa_prompt_tmpl.partial_format( | |
dialect=engine.dialect.name, | |
few_shot_examples=few_shot_examples, | |
convo_history=convo_history_str | |
) | |
sql_query_engine.sql_retriever._update_prompts({'text_to_sql_prompt':text2sql_prompt}) | |
print(sql_query_engine.sql_retriever.get_prompts()) | |
return {"input": task.input, "convo_history": convo_history_str} | |
agent_input_component = AgentInputComponent(fn=agent_input_fn) | |
## This prompt will generate a transformed natural language query for further downstream tasks. | |
retry_prompt_str = """\ | |
You are trying to generate a proper natural language query given a user input. | |
This query will then be interpreted by a downstream text-to-SQL agent which | |
will convert the query to a SQL statement. If the agent triggers an error, | |
then that will be reflected in the current conversation history (see below). | |
If the conversation history is None, use the user input. If its not None, | |
generate a new SQL query that avoids the problems of the previous SQL query. | |
Input: {input} | |
Convo history (failed attempts): | |
{convo_history} | |
New input: """ | |
retry_prompt = PromptTemplate(retry_prompt_str) | |
## This prompt will validate whether the inferred SQL query and response from executing the query is correct and answers the query | |
validate_prompt_str = """\ | |
Given the user query, validate whether the inferred SQL query and response from executing the query is correct and answers the query. In SQL Response, round any decimal values to 2 decimal places. | |
Answer with YES or NO. | |
Query: {input} | |
Inferred SQL query: {sql_query} | |
SQL Response: {sql_response} | |
Result: """ | |
validate_prompt = PromptTemplate(validate_prompt_str) | |
MAX_ITER = 3 | |
extracted_sql_query = None | |
def agent_output_fn( | |
task: Task, state: Dict[str, Any], output: Response | |
) -> Tuple[AgentChatResponse, bool]: | |
"""Agent output component.""" | |
global extracted_sql_query | |
print(f"> Inferred SQL Query:\n{output.metadata['sql_query']}") | |
print(f"> SQL Response:\n{str(output)}") | |
extracted_sql_query = output.metadata['sql_query'] | |
state["convo_history"].append( | |
f"Assistant (inferred SQL query): {output.metadata['sql_query']}" | |
) | |
state["convo_history"].append(f"Assistant (response): {str(output)}") | |
# run a mini chain to get response | |
validate_prompt_partial = validate_prompt.as_query_component( | |
partial={ | |
"sql_query": output.metadata["sql_query"], | |
"sql_response": str(output), | |
} | |
) | |
qp = QP(chain=[validate_prompt_partial, llm]) | |
validate_output = qp.run(input=task.input) | |
# print(validate_output) | |
state["count"] += 1 | |
is_done = False | |
if state["count"] >= MAX_ITER: | |
is_done = True | |
if "YES" in validate_output.message.content: | |
is_done = True | |
return AgentChatResponse(response=str(output)), is_done | |
agent_output_component = AgentFnComponent(fn=agent_output_fn) | |
qp = QP( | |
modules={ | |
"input": agent_input_component, | |
"retry_prompt": retry_prompt, | |
"llm": llm, | |
"sql_query_engine": sql_query_engine, | |
"output_component": agent_output_component, | |
}, | |
verbose=True, | |
) | |
qp.add_link("input", "retry_prompt", src_key="input", dest_key="input") | |
qp.add_link( | |
"input", "retry_prompt", src_key="convo_history", dest_key="convo_history" | |
) | |
qp.add_chain(["retry_prompt", "llm", "sql_query_engine", "output_component"]) | |
# from pyvis.network import Network | |
# net = Network(notebook=True, cdn_resources="in_line", directed=True) | |
# net.from_nx(qp.clean_dag) | |
# net.show("retry_agent_dag.html") | |
agent_worker = QueryPipelineAgentWorker(qp) | |
agent = AgentRunner( | |
agent_worker=agent_worker, | |
callback_manager=CallbackManager([]), verbose=True | |
) | |
def slow_echo(query,history): | |
# print(extracted_sql_query) | |
response = agent.chat(query) | |
message = str(response).split(':')[-1].strip() | |
res = f'**Answer:**\n\n{message}\n\n\n\n**LLM Generated SQL Query:**\n\n{extracted_sql_query}' | |
for i in range(len(res)): | |
time.sleep(0.01) | |
yield res[: i + 1] | |
demo = gr.ChatInterface(slow_echo).queue() | |
if __name__ == "__main__": | |
demo.launch(share=True) |