File size: 15,469 Bytes
0c0cc35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# from llama_index.core.query_engine import NLSQLTableQueryEngine
# from llama_index.core.indices.struct_store.sql_query import (
#     SQLTableRetrieverQueryEngine,
#     NLSQLTableQueryEngine
# )
from llama_index.core import set_global_handler
from dotenv import load_dotenv
load_dotenv()

set_global_handler("simple")

from llama_index.core.query_engine import RetryQueryEngine
from llama_index.core import SQLDatabase, VectorStoreIndex, Settings, PromptTemplate, Response
from llama_index.core.agent import QueryPipelineAgentWorker, AgentRunner, Task, AgentChatResponse
from llama_index.llms.gemini import Gemini
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.callbacks import CallbackManager
from llama_index.core.agent.types import Task
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.core.query_pipeline import AgentInputComponent, AgentFnComponent, QueryPipeline as QP
from custom_modules.custom_query_engine import SQLTableRetrieverQueryEngine
from llama_index.core.prompts.prompt_type import PromptType

import json, os, time, gradio as gr
from typing import Dict, Any, Tuple
from sqlalchemy import create_engine, inspect

Settings.llm = Gemini(temperature=0.2)
Settings.embed_model = GeminiEmbedding()
llm = Settings.llm

qp = QP(verbose=True)

engine = create_engine("sqlite:///Llama_index_structured_sql/database/param4table.db")
sql_database = SQLDatabase(engine)

qa_prompt_tmpl_str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "

"If sales and inwards table needs to be connected, connect them\
taking distinct str_id from inwards and then join this with the sales table."
# "WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards),:\
# location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm)\
# SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls\
# ORDER BY ls.tot DESC;"

"Only use the context information and not prior knowledge."
# answer the query asking about citations over different topics."
"Please provide your answer in the form of a structured JSON format containing \
a list of authors as the citations. Some examples are given below."

"{few_shot_examples}"

"\nIf there is a convo history given, you have to check the Inferred SQL query and then return a modified query.\n\n"
"Convo history (failed attempts):\n{convo_history}\n"

"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n\n"
# "SQLResult: Result of the SQLQuery\n"
# "Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
)

qa_prompt_tmpl = PromptTemplate(
    qa_prompt_tmpl_str,
    prompt_type=PromptType.TEXT_TO_SQL,
)

few_shot_examples = [
    # {"query":"How many states are present?","response":"SELECT COUNT(DISTINCT lctn_stat_nm) AS num_states FROM inwards;"},
    # {"query":"Which colour is the most popular in terms of net sale value?","response":"SELECT f.fabric_color, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY f.fabric_color ORDER BY total_net_sale_val DESC;"},
    # {"query":"How many departments?","response":"SELECT COUNT(DISTINCT department) from styles;"}
    {"query":"States with above average sales.","response":"WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards), location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm) SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls WHERE ls.tot > (SELECT AVG(tot) FROM location_sales) ORDER BY ls.tot DESC;"},
    # {"query":"Which stores had the highest sales per quantity?","response":"SELECT s.customer, SUM(s.net_sale_val) as tot_sales, SUM(s.sales_qty) AS tot_sales_qty, SUM(s.net_sale_val) / SUM(s.sales_qty) AS sales_value_per_quantity FROM sales AS s GROUP BY s.customer ORDER BY sales_value_per_quantity DESC;"},
    # {"query":"Which colored fabric was sold the most and least?","response":"SELECT fabric_color, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY fabric_color ORDER BY total_sales_qty DESC;"},
    # {"query":"Which weave type was sold the most and the least?","response":"SELECT weave_type, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY weave_type ORDER BY total_sales_qty DESC;"},
    # {"query":"What color and weave type together worked the best?","response":"SELECT f.fabric_color, f.weave_type, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics AS f JOIN styles AS st ON f.fabric_code = st.fabric_id JOIN sales AS s ON st.product_id = s.material GROUP BY f.fabric_color, f.weave_type ORDER BY total_net_sale_val DESC;"}, 
    # {"query":"What colors and departments together did best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) AS sales_value, fb.fabric_color AS clr, fb.department AS department FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.department ORDER BY sales_value DESC), most AS (SELECT clr, department, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT clr, department, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
    {"query":"How has the month on month sales changed across departments in the last year?","response":"WITH monthly_sales AS (SELECT st.department, STRFTIME('%m', s.cal_day) AS month, SUM(s.net_sale_val) AS monthly_sales FROM sales s JOIN styles st ON s.material = st.product_id WHERE s.cal_day >= DATE(CURRENT_DATE, '-1 year') GROUP BY department, STRFTIME('%m', s.cal_day)), SELECT department, month, monthly_sales, monthly_sales - LAG(monthly_sales, 1) OVER (PARTITION BY department ORDER BY month) AS month_on_month_change FROM monthly_sales ORDER BY department, month;"},
    # {"query":"What was the average sell through rate for fabric across various departments?","response":""},
    # {"query":"Which color that the highest sell through in each department in the last 3 months?","response":""},
    # {"query":"In shirts with a certain collar type (this is style table attribute) what is the average cotton percentage?","response":""},
    # {"query":"Which collar type had the most sales, most sell through?","response":""},

    {"query":"Which weave type was sold the most?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.weave_type ORDER BY sales_value DESC), most AS (SELECT weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT weave_type, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
    {"query":"What color and weave type together worked the best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.fabric_color AS clr, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.weave_type ORDER BY sales_value DESC) SELECT clr, weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables;"},
    {"query":"List the combination of fabrics and brands with their total sales and total inwards quantity.", "response":"WITH sks AS (SELECT st.fabric_id AS fab_id, st.product_id AS prod_id, SUM(sa.net_sale_val) AS tot_sales FROM styles st JOIN sales sa ON st.product_id = sa.material GROUP BY st.product_id), sti AS (SELECT s.product_id AS prod_id, SUM(i.inwards_qty) AS tot_inwards_qty FROM styles s JOIN inwards i ON s.product_id = i.style_id GROUP BY s.product_id) SELECT sks.fab_id, f.brand, SUM(sks.tot_sales) AS tot_s, sti.tot_inwards_qty FROM sks JOIN sti ON sks.prod_id = sti.prod_id JOIN fabrics f ON f.fabric_code = sks.fab_id GROUP BY sks.fab_id, f.brand;"}
                     ]

text2sql_prompt = qa_prompt_tmpl.partial_format(
    dialect=engine.dialect.name,
    few_shot_examples=few_shot_examples,
    convo_history=None
)

# print(sql_query_engine_nl)

tableinfo_dir='Llama_index_structured_sql/param_enhanced_metadata'
    
insp = inspect(engine)
table_names = insp.get_table_names()

table_infos=[]

files = os.listdir(tableinfo_dir)

for idx,file in enumerate(files):
    with open(f'{tableinfo_dir}/{file}', 'r+') as file:
        table_info = json.load(file)
    table_infos.append(table_info)

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t["table_name"], context_str=t["table_summary"]+str(t["columns"]))
    for t in table_infos
]  # add a SQLTableSchema for each table

# print(table_schema_objs)


obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)


response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results."
    "Please stick to the Indian number system and Indian currency."
    "Round the sales value to 2 decimal places. If there are more than 2 columns return in a tabular format.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)

response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

sql_query_engine = SQLTableRetrieverQueryEngine(
    
    sql_database, 
    text_to_sql_prompt=text2sql_prompt,
    table_retriever=obj_index.as_retriever(similarity_top_k=4), 
    llm=llm,
    response_synthesis_prompt=response_synthesis_prompt
)

def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict:
    """Agent input function."""
    # initialize current_reasoning
    if "convo_history" not in state:
        state["convo_history"] = []
        state["count"] = 0
    state["convo_history"].append(f"User: {task.input}")
    convo_history_str = "\n".join(state["convo_history"]) or "None"
    # print(state)
    text2sql_prompt = qa_prompt_tmpl.partial_format(
        dialect=engine.dialect.name,
        few_shot_examples=few_shot_examples,
        convo_history=convo_history_str
    )
    sql_query_engine.sql_retriever._update_prompts({'text_to_sql_prompt':text2sql_prompt})
    print(sql_query_engine.sql_retriever.get_prompts())
    return {"input": task.input, "convo_history": convo_history_str}


agent_input_component = AgentInputComponent(fn=agent_input_fn)

## This prompt will generate a transformed natural language query for further downstream tasks.
retry_prompt_str = """\
You are trying to generate a proper natural language query given a user input.

This query will then be interpreted by a downstream text-to-SQL agent which
will convert the query to a SQL statement. If the agent triggers an error,
then that will be reflected in the current conversation history (see below).

If the conversation history is None, use the user input. If its not None,
generate a new SQL query that avoids the problems of the previous SQL query.

Input: {input}
Convo history (failed attempts): 
{convo_history}

New input: """

retry_prompt = PromptTemplate(retry_prompt_str)

## This prompt will validate whether the inferred SQL query and response from executing the query is correct and answers the query
validate_prompt_str = """\
Given the user query, validate whether the inferred SQL query and response from executing the query is correct and answers the query. In SQL Response, round any decimal values to 2 decimal places.

Answer with YES or NO.

Query: {input}
Inferred SQL query: {sql_query}
SQL Response: {sql_response}

Result: """

validate_prompt = PromptTemplate(validate_prompt_str)

MAX_ITER = 3

extracted_sql_query = None 

def agent_output_fn(
    task: Task, state: Dict[str, Any], output: Response
) -> Tuple[AgentChatResponse, bool]:
    """Agent output component."""

    global extracted_sql_query
    print(f"> Inferred SQL Query:\n{output.metadata['sql_query']}")
    print(f"> SQL Response:\n{str(output)}")
    extracted_sql_query = output.metadata['sql_query']
    state["convo_history"].append(
        f"Assistant (inferred SQL query): {output.metadata['sql_query']}"
    )
    state["convo_history"].append(f"Assistant (response): {str(output)}")

    # run a mini chain to get response
    validate_prompt_partial = validate_prompt.as_query_component(
        partial={
            "sql_query": output.metadata["sql_query"],
            "sql_response": str(output),
        }
    )
    qp = QP(chain=[validate_prompt_partial, llm])
    validate_output = qp.run(input=task.input)
    # print(validate_output)

    state["count"] += 1
    is_done = False
    if state["count"] >= MAX_ITER:
        is_done = True
    if "YES" in validate_output.message.content:
        is_done = True

    return AgentChatResponse(response=str(output)), is_done


agent_output_component = AgentFnComponent(fn=agent_output_fn)

qp = QP(
    modules={
        "input": agent_input_component,
        "retry_prompt": retry_prompt,
        "llm": llm,
        "sql_query_engine": sql_query_engine,
        "output_component": agent_output_component,
    },
    verbose=True,
)
qp.add_link("input", "retry_prompt", src_key="input", dest_key="input")
qp.add_link(
    "input", "retry_prompt", src_key="convo_history", dest_key="convo_history"
)
qp.add_chain(["retry_prompt", "llm", "sql_query_engine", "output_component"])


# from pyvis.network import Network

# net = Network(notebook=True, cdn_resources="in_line", directed=True)
# net.from_nx(qp.clean_dag)
# net.show("retry_agent_dag.html")

agent_worker = QueryPipelineAgentWorker(qp)

agent = AgentRunner(
    agent_worker=agent_worker,
    callback_manager=CallbackManager([]), verbose=True
)

def slow_echo(query,history):
    # print(extracted_sql_query)
    response = agent.chat(query)
    message = str(response).split(':')[-1].strip()
    res = f'**Answer:**\n\n{message}\n\n\n\n**LLM Generated SQL Query:**\n\n{extracted_sql_query}'

    for i in range(len(res)):
        time.sleep(0.01)
        yield res[: i + 1]

demo = gr.ChatInterface(slow_echo).queue()

if __name__ == "__main__":
    demo.launch(share=True)