Spaces:
Sleeping
Sleeping
File size: 15,469 Bytes
0c0cc35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
# from llama_index.core.query_engine import NLSQLTableQueryEngine
# from llama_index.core.indices.struct_store.sql_query import (
# SQLTableRetrieverQueryEngine,
# NLSQLTableQueryEngine
# )
from llama_index.core import set_global_handler
from dotenv import load_dotenv
load_dotenv()
set_global_handler("simple")
from llama_index.core.query_engine import RetryQueryEngine
from llama_index.core import SQLDatabase, VectorStoreIndex, Settings, PromptTemplate, Response
from llama_index.core.agent import QueryPipelineAgentWorker, AgentRunner, Task, AgentChatResponse
from llama_index.llms.gemini import Gemini
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.callbacks import CallbackManager
from llama_index.core.agent.types import Task
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.core.query_pipeline import AgentInputComponent, AgentFnComponent, QueryPipeline as QP
from custom_modules.custom_query_engine import SQLTableRetrieverQueryEngine
from llama_index.core.prompts.prompt_type import PromptType
import json, os, time, gradio as gr
from typing import Dict, Any, Tuple
from sqlalchemy import create_engine, inspect
Settings.llm = Gemini(temperature=0.2)
Settings.embed_model = GeminiEmbedding()
llm = Settings.llm
qp = QP(verbose=True)
engine = create_engine("sqlite:///Llama_index_structured_sql/database/param4table.db")
sql_database = SQLDatabase(engine)
qa_prompt_tmpl_str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"If sales and inwards table needs to be connected, connect them\
taking distinct str_id from inwards and then join this with the sales table."
# "WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards),:\
# location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm)\
# SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls\
# ORDER BY ls.tot DESC;"
"Only use the context information and not prior knowledge."
# answer the query asking about citations over different topics."
"Please provide your answer in the form of a structured JSON format containing \
a list of authors as the citations. Some examples are given below."
"{few_shot_examples}"
"\nIf there is a convo history given, you have to check the Inferred SQL query and then return a modified query.\n\n"
"Convo history (failed attempts):\n{convo_history}\n"
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n\n"
# "SQLResult: Result of the SQLQuery\n"
# "Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
)
qa_prompt_tmpl = PromptTemplate(
qa_prompt_tmpl_str,
prompt_type=PromptType.TEXT_TO_SQL,
)
few_shot_examples = [
# {"query":"How many states are present?","response":"SELECT COUNT(DISTINCT lctn_stat_nm) AS num_states FROM inwards;"},
# {"query":"Which colour is the most popular in terms of net sale value?","response":"SELECT f.fabric_color, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY f.fabric_color ORDER BY total_net_sale_val DESC;"},
# {"query":"How many departments?","response":"SELECT COUNT(DISTINCT department) from styles;"}
{"query":"States with above average sales.","response":"WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards), location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm) SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls WHERE ls.tot > (SELECT AVG(tot) FROM location_sales) ORDER BY ls.tot DESC;"},
# {"query":"Which stores had the highest sales per quantity?","response":"SELECT s.customer, SUM(s.net_sale_val) as tot_sales, SUM(s.sales_qty) AS tot_sales_qty, SUM(s.net_sale_val) / SUM(s.sales_qty) AS sales_value_per_quantity FROM sales AS s GROUP BY s.customer ORDER BY sales_value_per_quantity DESC;"},
# {"query":"Which colored fabric was sold the most and least?","response":"SELECT fabric_color, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY fabric_color ORDER BY total_sales_qty DESC;"},
# {"query":"Which weave type was sold the most and the least?","response":"SELECT weave_type, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY weave_type ORDER BY total_sales_qty DESC;"},
# {"query":"What color and weave type together worked the best?","response":"SELECT f.fabric_color, f.weave_type, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics AS f JOIN styles AS st ON f.fabric_code = st.fabric_id JOIN sales AS s ON st.product_id = s.material GROUP BY f.fabric_color, f.weave_type ORDER BY total_net_sale_val DESC;"},
# {"query":"What colors and departments together did best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) AS sales_value, fb.fabric_color AS clr, fb.department AS department FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.department ORDER BY sales_value DESC), most AS (SELECT clr, department, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT clr, department, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
{"query":"How has the month on month sales changed across departments in the last year?","response":"WITH monthly_sales AS (SELECT st.department, STRFTIME('%m', s.cal_day) AS month, SUM(s.net_sale_val) AS monthly_sales FROM sales s JOIN styles st ON s.material = st.product_id WHERE s.cal_day >= DATE(CURRENT_DATE, '-1 year') GROUP BY department, STRFTIME('%m', s.cal_day)), SELECT department, month, monthly_sales, monthly_sales - LAG(monthly_sales, 1) OVER (PARTITION BY department ORDER BY month) AS month_on_month_change FROM monthly_sales ORDER BY department, month;"},
# {"query":"What was the average sell through rate for fabric across various departments?","response":""},
# {"query":"Which color that the highest sell through in each department in the last 3 months?","response":""},
# {"query":"In shirts with a certain collar type (this is style table attribute) what is the average cotton percentage?","response":""},
# {"query":"Which collar type had the most sales, most sell through?","response":""},
{"query":"Which weave type was sold the most?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.weave_type ORDER BY sales_value DESC), most AS (SELECT weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT weave_type, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
{"query":"What color and weave type together worked the best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.fabric_color AS clr, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.weave_type ORDER BY sales_value DESC) SELECT clr, weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables;"},
{"query":"List the combination of fabrics and brands with their total sales and total inwards quantity.", "response":"WITH sks AS (SELECT st.fabric_id AS fab_id, st.product_id AS prod_id, SUM(sa.net_sale_val) AS tot_sales FROM styles st JOIN sales sa ON st.product_id = sa.material GROUP BY st.product_id), sti AS (SELECT s.product_id AS prod_id, SUM(i.inwards_qty) AS tot_inwards_qty FROM styles s JOIN inwards i ON s.product_id = i.style_id GROUP BY s.product_id) SELECT sks.fab_id, f.brand, SUM(sks.tot_sales) AS tot_s, sti.tot_inwards_qty FROM sks JOIN sti ON sks.prod_id = sti.prod_id JOIN fabrics f ON f.fabric_code = sks.fab_id GROUP BY sks.fab_id, f.brand;"}
]
text2sql_prompt = qa_prompt_tmpl.partial_format(
dialect=engine.dialect.name,
few_shot_examples=few_shot_examples,
convo_history=None
)
# print(sql_query_engine_nl)
tableinfo_dir='Llama_index_structured_sql/param_enhanced_metadata'
insp = inspect(engine)
table_names = insp.get_table_names()
table_infos=[]
files = os.listdir(tableinfo_dir)
for idx,file in enumerate(files):
with open(f'{tableinfo_dir}/{file}', 'r+') as file:
table_info = json.load(file)
table_infos.append(table_info)
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t["table_name"], context_str=t["table_summary"]+str(t["columns"]))
for t in table_infos
] # add a SQLTableSchema for each table
# print(table_schema_objs)
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results."
"Please stick to the Indian number system and Indian currency."
"Round the sales value to 2 decimal places. If there are more than 2 columns return in a tabular format.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
sql_query_engine = SQLTableRetrieverQueryEngine(
sql_database,
text_to_sql_prompt=text2sql_prompt,
table_retriever=obj_index.as_retriever(similarity_top_k=4),
llm=llm,
response_synthesis_prompt=response_synthesis_prompt
)
def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict:
"""Agent input function."""
# initialize current_reasoning
if "convo_history" not in state:
state["convo_history"] = []
state["count"] = 0
state["convo_history"].append(f"User: {task.input}")
convo_history_str = "\n".join(state["convo_history"]) or "None"
# print(state)
text2sql_prompt = qa_prompt_tmpl.partial_format(
dialect=engine.dialect.name,
few_shot_examples=few_shot_examples,
convo_history=convo_history_str
)
sql_query_engine.sql_retriever._update_prompts({'text_to_sql_prompt':text2sql_prompt})
print(sql_query_engine.sql_retriever.get_prompts())
return {"input": task.input, "convo_history": convo_history_str}
agent_input_component = AgentInputComponent(fn=agent_input_fn)
## This prompt will generate a transformed natural language query for further downstream tasks.
retry_prompt_str = """\
You are trying to generate a proper natural language query given a user input.
This query will then be interpreted by a downstream text-to-SQL agent which
will convert the query to a SQL statement. If the agent triggers an error,
then that will be reflected in the current conversation history (see below).
If the conversation history is None, use the user input. If its not None,
generate a new SQL query that avoids the problems of the previous SQL query.
Input: {input}
Convo history (failed attempts):
{convo_history}
New input: """
retry_prompt = PromptTemplate(retry_prompt_str)
## This prompt will validate whether the inferred SQL query and response from executing the query is correct and answers the query
validate_prompt_str = """\
Given the user query, validate whether the inferred SQL query and response from executing the query is correct and answers the query. In SQL Response, round any decimal values to 2 decimal places.
Answer with YES or NO.
Query: {input}
Inferred SQL query: {sql_query}
SQL Response: {sql_response}
Result: """
validate_prompt = PromptTemplate(validate_prompt_str)
MAX_ITER = 3
extracted_sql_query = None
def agent_output_fn(
task: Task, state: Dict[str, Any], output: Response
) -> Tuple[AgentChatResponse, bool]:
"""Agent output component."""
global extracted_sql_query
print(f"> Inferred SQL Query:\n{output.metadata['sql_query']}")
print(f"> SQL Response:\n{str(output)}")
extracted_sql_query = output.metadata['sql_query']
state["convo_history"].append(
f"Assistant (inferred SQL query): {output.metadata['sql_query']}"
)
state["convo_history"].append(f"Assistant (response): {str(output)}")
# run a mini chain to get response
validate_prompt_partial = validate_prompt.as_query_component(
partial={
"sql_query": output.metadata["sql_query"],
"sql_response": str(output),
}
)
qp = QP(chain=[validate_prompt_partial, llm])
validate_output = qp.run(input=task.input)
# print(validate_output)
state["count"] += 1
is_done = False
if state["count"] >= MAX_ITER:
is_done = True
if "YES" in validate_output.message.content:
is_done = True
return AgentChatResponse(response=str(output)), is_done
agent_output_component = AgentFnComponent(fn=agent_output_fn)
qp = QP(
modules={
"input": agent_input_component,
"retry_prompt": retry_prompt,
"llm": llm,
"sql_query_engine": sql_query_engine,
"output_component": agent_output_component,
},
verbose=True,
)
qp.add_link("input", "retry_prompt", src_key="input", dest_key="input")
qp.add_link(
"input", "retry_prompt", src_key="convo_history", dest_key="convo_history"
)
qp.add_chain(["retry_prompt", "llm", "sql_query_engine", "output_component"])
# from pyvis.network import Network
# net = Network(notebook=True, cdn_resources="in_line", directed=True)
# net.from_nx(qp.clean_dag)
# net.show("retry_agent_dag.html")
agent_worker = QueryPipelineAgentWorker(qp)
agent = AgentRunner(
agent_worker=agent_worker,
callback_manager=CallbackManager([]), verbose=True
)
def slow_echo(query,history):
# print(extracted_sql_query)
response = agent.chat(query)
message = str(response).split(':')[-1].strip()
res = f'**Answer:**\n\n{message}\n\n\n\n**LLM Generated SQL Query:**\n\n{extracted_sql_query}'
for i in range(len(res)):
time.sleep(0.01)
yield res[: i + 1]
demo = gr.ChatInterface(slow_echo).queue()
if __name__ == "__main__":
demo.launch(share=True) |