Aayush4396 commited on
Commit
0c0cc35
1 Parent(s): 3198886

uploaded retry agent

Browse files
Files changed (1) hide show
  1. demo_gradio_retryAgent.py +305 -0
demo_gradio_retryAgent.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from llama_index.core.query_engine import NLSQLTableQueryEngine
2
+ # from llama_index.core.indices.struct_store.sql_query import (
3
+ # SQLTableRetrieverQueryEngine,
4
+ # NLSQLTableQueryEngine
5
+ # )
6
+ from llama_index.core import set_global_handler
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+
10
+ set_global_handler("simple")
11
+
12
+ from llama_index.core.query_engine import RetryQueryEngine
13
+ from llama_index.core import SQLDatabase, VectorStoreIndex, Settings, PromptTemplate, Response
14
+ from llama_index.core.agent import QueryPipelineAgentWorker, AgentRunner, Task, AgentChatResponse
15
+ from llama_index.llms.gemini import Gemini
16
+ from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
17
+ from llama_index.core.callbacks import CallbackManager
18
+ from llama_index.core.agent.types import Task
19
+ from llama_index.embeddings.gemini import GeminiEmbedding
20
+ from llama_index.core.query_pipeline import AgentInputComponent, AgentFnComponent, QueryPipeline as QP
21
+ from custom_modules.custom_query_engine import SQLTableRetrieverQueryEngine
22
+ from llama_index.core.prompts.prompt_type import PromptType
23
+
24
+ import json, os, time, gradio as gr
25
+ from typing import Dict, Any, Tuple
26
+ from sqlalchemy import create_engine, inspect
27
+
28
+ Settings.llm = Gemini(temperature=0.2)
29
+ Settings.embed_model = GeminiEmbedding()
30
+ llm = Settings.llm
31
+
32
+ qp = QP(verbose=True)
33
+
34
+ engine = create_engine("sqlite:///Llama_index_structured_sql/database/param4table.db")
35
+ sql_database = SQLDatabase(engine)
36
+
37
+ qa_prompt_tmpl_str = (
38
+ "Given an input question, first create a syntactically correct {dialect} "
39
+ "query to run, then look at the results of the query and return the answer. "
40
+ "You can order the results by a relevant column to return the most "
41
+ "interesting examples in the database.\n\n"
42
+ "Never query for all the columns from a specific table, only ask for a "
43
+ "few relevant columns given the question.\n\n"
44
+ "Pay attention to use only the column names that you can see in the schema "
45
+ "description. "
46
+ "Be careful to not query for columns that do not exist. "
47
+ "Pay attention to which column is in which table. "
48
+ "Also, qualify column names with the table name when needed. "
49
+
50
+ "If sales and inwards table needs to be connected, connect them\
51
+ taking distinct str_id from inwards and then join this with the sales table."
52
+ # "WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards),:\
53
+ # location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm)\
54
+ # SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls\
55
+ # ORDER BY ls.tot DESC;"
56
+
57
+ "Only use the context information and not prior knowledge."
58
+ # answer the query asking about citations over different topics."
59
+ "Please provide your answer in the form of a structured JSON format containing \
60
+ a list of authors as the citations. Some examples are given below."
61
+
62
+ "{few_shot_examples}"
63
+
64
+ "\nIf there is a convo history given, you have to check the Inferred SQL query and then return a modified query.\n\n"
65
+ "Convo history (failed attempts):\n{convo_history}\n"
66
+
67
+ "You are required to use the following format, each taking one line:\n\n"
68
+ "Question: Question here\n"
69
+ "SQLQuery: SQL Query to run\n\n"
70
+ # "SQLResult: Result of the SQLQuery\n"
71
+ # "Answer: Final answer here\n\n"
72
+ "Only use tables listed below.\n"
73
+ "{schema}\n\n"
74
+ "Question: {query_str}\n"
75
+ "SQLQuery: "
76
+ )
77
+
78
+ qa_prompt_tmpl = PromptTemplate(
79
+ qa_prompt_tmpl_str,
80
+ prompt_type=PromptType.TEXT_TO_SQL,
81
+ )
82
+
83
+ few_shot_examples = [
84
+ # {"query":"How many states are present?","response":"SELECT COUNT(DISTINCT lctn_stat_nm) AS num_states FROM inwards;"},
85
+ # {"query":"Which colour is the most popular in terms of net sale value?","response":"SELECT f.fabric_color, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY f.fabric_color ORDER BY total_net_sale_val DESC;"},
86
+ # {"query":"How many departments?","response":"SELECT COUNT(DISTINCT department) from styles;"}
87
+ {"query":"States with above average sales.","response":"WITH distinct_locations AS (SELECT DISTINCT str_id, lctn_stat_nm FROM inwards), location_sales AS (SELECT dl.lctn_stat_nm, SUM(s.net_sale_val) AS tot FROM distinct_locations dl JOIN sales s ON s.customer = dl.str_id GROUP BY dl.lctn_stat_nm) SELECT ls.lctn_stat_nm, ls.tot FROM location_sales ls WHERE ls.tot > (SELECT AVG(tot) FROM location_sales) ORDER BY ls.tot DESC;"},
88
+ # {"query":"Which stores had the highest sales per quantity?","response":"SELECT s.customer, SUM(s.net_sale_val) as tot_sales, SUM(s.sales_qty) AS tot_sales_qty, SUM(s.net_sale_val) / SUM(s.sales_qty) AS sales_value_per_quantity FROM sales AS s GROUP BY s.customer ORDER BY sales_value_per_quantity DESC;"},
89
+ # {"query":"Which colored fabric was sold the most and least?","response":"SELECT fabric_color, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY fabric_color ORDER BY total_sales_qty DESC;"},
90
+ # {"query":"Which weave type was sold the most and the least?","response":"SELECT weave_type, SUM(sales_qty) AS total_sales_qty FROM fabrics f JOIN styles st ON f.fabric_code = st.fabric_id JOIN sales s ON st.product_id = s.material GROUP BY weave_type ORDER BY total_sales_qty DESC;"},
91
+ # {"query":"What color and weave type together worked the best?","response":"SELECT f.fabric_color, f.weave_type, SUM(s.net_sale_val) AS total_net_sale_val FROM fabrics AS f JOIN styles AS st ON f.fabric_code = st.fabric_id JOIN sales AS s ON st.product_id = s.material GROUP BY f.fabric_color, f.weave_type ORDER BY total_net_sale_val DESC;"},
92
+ # {"query":"What colors and departments together did best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) AS sales_value, fb.fabric_color AS clr, fb.department AS department FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.department ORDER BY sales_value DESC), most AS (SELECT clr, department, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT clr, department, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
93
+ {"query":"How has the month on month sales changed across departments in the last year?","response":"WITH monthly_sales AS (SELECT st.department, STRFTIME('%m', s.cal_day) AS month, SUM(s.net_sale_val) AS monthly_sales FROM sales s JOIN styles st ON s.material = st.product_id WHERE s.cal_day >= DATE(CURRENT_DATE, '-1 year') GROUP BY department, STRFTIME('%m', s.cal_day)), SELECT department, month, monthly_sales, monthly_sales - LAG(monthly_sales, 1) OVER (PARTITION BY department ORDER BY month) AS month_on_month_change FROM monthly_sales ORDER BY department, month;"},
94
+ # {"query":"What was the average sell through rate for fabric across various departments?","response":""},
95
+ # {"query":"Which color that the highest sell through in each department in the last 3 months?","response":""},
96
+ # {"query":"In shirts with a certain collar type (this is style table attribute) what is the average cotton percentage?","response":""},
97
+ # {"query":"Which collar type had the most sales, most sell through?","response":""},
98
+
99
+ {"query":"Which weave type was sold the most?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.weave_type ORDER BY sales_value DESC), most AS (SELECT weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables) SELECT weave_type, ROUND(sales_value, 3) FROM most WHERE rnk = 1;"},
100
+ {"query":"What color and weave type together worked the best?","response":"WITH sales_val AS (SELECT s.net_sale_val, st.fabric_id FROM sales s JOIN styles st ON s.material = st.product_id), combine_tables AS (SELECT SUM(sv.net_sale_val) sales_value, fb.fabric_color AS clr, fb.weave_type AS weave_type FROM sales_val sv JOIN fabrics fb ON sv.fabric_id = fb.fabric_code GROUP BY fb.fabric_color, fb.weave_type ORDER BY sales_value DESC) SELECT clr, weave_type, sales_value, DENSE_RANK() OVER (ORDER BY sales_value DESC) AS rnk FROM combine_tables;"},
101
+ {"query":"List the combination of fabrics and brands with their total sales and total inwards quantity.", "response":"WITH sks AS (SELECT st.fabric_id AS fab_id, st.product_id AS prod_id, SUM(sa.net_sale_val) AS tot_sales FROM styles st JOIN sales sa ON st.product_id = sa.material GROUP BY st.product_id), sti AS (SELECT s.product_id AS prod_id, SUM(i.inwards_qty) AS tot_inwards_qty FROM styles s JOIN inwards i ON s.product_id = i.style_id GROUP BY s.product_id) SELECT sks.fab_id, f.brand, SUM(sks.tot_sales) AS tot_s, sti.tot_inwards_qty FROM sks JOIN sti ON sks.prod_id = sti.prod_id JOIN fabrics f ON f.fabric_code = sks.fab_id GROUP BY sks.fab_id, f.brand;"}
102
+ ]
103
+
104
+ text2sql_prompt = qa_prompt_tmpl.partial_format(
105
+ dialect=engine.dialect.name,
106
+ few_shot_examples=few_shot_examples,
107
+ convo_history=None
108
+ )
109
+
110
+ # print(sql_query_engine_nl)
111
+
112
+ tableinfo_dir='Llama_index_structured_sql/param_enhanced_metadata'
113
+
114
+ insp = inspect(engine)
115
+ table_names = insp.get_table_names()
116
+
117
+ table_infos=[]
118
+
119
+ files = os.listdir(tableinfo_dir)
120
+
121
+ for idx,file in enumerate(files):
122
+ with open(f'{tableinfo_dir}/{file}', 'r+') as file:
123
+ table_info = json.load(file)
124
+ table_infos.append(table_info)
125
+
126
+ sql_database = SQLDatabase(engine)
127
+
128
+ table_node_mapping = SQLTableNodeMapping(sql_database)
129
+ table_schema_objs = [
130
+ SQLTableSchema(table_name=t["table_name"], context_str=t["table_summary"]+str(t["columns"]))
131
+ for t in table_infos
132
+ ] # add a SQLTableSchema for each table
133
+
134
+ # print(table_schema_objs)
135
+
136
+
137
+ obj_index = ObjectIndex.from_objects(
138
+ table_schema_objs,
139
+ table_node_mapping,
140
+ VectorStoreIndex,
141
+ )
142
+
143
+
144
+ response_synthesis_prompt_str = (
145
+ "Given an input question, synthesize a response from the query results."
146
+ "Please stick to the Indian number system and Indian currency."
147
+ "Round the sales value to 2 decimal places. If there are more than 2 columns return in a tabular format.\n"
148
+ "Query: {query_str}\n"
149
+ "SQL: {sql_query}\n"
150
+ "SQL Response: {context_str}\n"
151
+ "Response: "
152
+ )
153
+
154
+ response_synthesis_prompt = PromptTemplate(
155
+ response_synthesis_prompt_str,
156
+ )
157
+
158
+ sql_query_engine = SQLTableRetrieverQueryEngine(
159
+
160
+ sql_database,
161
+ text_to_sql_prompt=text2sql_prompt,
162
+ table_retriever=obj_index.as_retriever(similarity_top_k=4),
163
+ llm=llm,
164
+ response_synthesis_prompt=response_synthesis_prompt
165
+ )
166
+
167
+ def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict:
168
+ """Agent input function."""
169
+ # initialize current_reasoning
170
+ if "convo_history" not in state:
171
+ state["convo_history"] = []
172
+ state["count"] = 0
173
+ state["convo_history"].append(f"User: {task.input}")
174
+ convo_history_str = "\n".join(state["convo_history"]) or "None"
175
+ # print(state)
176
+ text2sql_prompt = qa_prompt_tmpl.partial_format(
177
+ dialect=engine.dialect.name,
178
+ few_shot_examples=few_shot_examples,
179
+ convo_history=convo_history_str
180
+ )
181
+ sql_query_engine.sql_retriever._update_prompts({'text_to_sql_prompt':text2sql_prompt})
182
+ print(sql_query_engine.sql_retriever.get_prompts())
183
+ return {"input": task.input, "convo_history": convo_history_str}
184
+
185
+
186
+ agent_input_component = AgentInputComponent(fn=agent_input_fn)
187
+
188
+ ## This prompt will generate a transformed natural language query for further downstream tasks.
189
+ retry_prompt_str = """\
190
+ You are trying to generate a proper natural language query given a user input.
191
+
192
+ This query will then be interpreted by a downstream text-to-SQL agent which
193
+ will convert the query to a SQL statement. If the agent triggers an error,
194
+ then that will be reflected in the current conversation history (see below).
195
+
196
+ If the conversation history is None, use the user input. If its not None,
197
+ generate a new SQL query that avoids the problems of the previous SQL query.
198
+
199
+ Input: {input}
200
+ Convo history (failed attempts):
201
+ {convo_history}
202
+
203
+ New input: """
204
+
205
+ retry_prompt = PromptTemplate(retry_prompt_str)
206
+
207
+ ## This prompt will validate whether the inferred SQL query and response from executing the query is correct and answers the query
208
+ validate_prompt_str = """\
209
+ Given the user query, validate whether the inferred SQL query and response from executing the query is correct and answers the query. In SQL Response, round any decimal values to 2 decimal places.
210
+
211
+ Answer with YES or NO.
212
+
213
+ Query: {input}
214
+ Inferred SQL query: {sql_query}
215
+ SQL Response: {sql_response}
216
+
217
+ Result: """
218
+
219
+ validate_prompt = PromptTemplate(validate_prompt_str)
220
+
221
+ MAX_ITER = 3
222
+
223
+ extracted_sql_query = None
224
+
225
+ def agent_output_fn(
226
+ task: Task, state: Dict[str, Any], output: Response
227
+ ) -> Tuple[AgentChatResponse, bool]:
228
+ """Agent output component."""
229
+
230
+ global extracted_sql_query
231
+ print(f"> Inferred SQL Query:\n{output.metadata['sql_query']}")
232
+ print(f"> SQL Response:\n{str(output)}")
233
+ extracted_sql_query = output.metadata['sql_query']
234
+ state["convo_history"].append(
235
+ f"Assistant (inferred SQL query): {output.metadata['sql_query']}"
236
+ )
237
+ state["convo_history"].append(f"Assistant (response): {str(output)}")
238
+
239
+ # run a mini chain to get response
240
+ validate_prompt_partial = validate_prompt.as_query_component(
241
+ partial={
242
+ "sql_query": output.metadata["sql_query"],
243
+ "sql_response": str(output),
244
+ }
245
+ )
246
+ qp = QP(chain=[validate_prompt_partial, llm])
247
+ validate_output = qp.run(input=task.input)
248
+ # print(validate_output)
249
+
250
+ state["count"] += 1
251
+ is_done = False
252
+ if state["count"] >= MAX_ITER:
253
+ is_done = True
254
+ if "YES" in validate_output.message.content:
255
+ is_done = True
256
+
257
+ return AgentChatResponse(response=str(output)), is_done
258
+
259
+
260
+ agent_output_component = AgentFnComponent(fn=agent_output_fn)
261
+
262
+ qp = QP(
263
+ modules={
264
+ "input": agent_input_component,
265
+ "retry_prompt": retry_prompt,
266
+ "llm": llm,
267
+ "sql_query_engine": sql_query_engine,
268
+ "output_component": agent_output_component,
269
+ },
270
+ verbose=True,
271
+ )
272
+ qp.add_link("input", "retry_prompt", src_key="input", dest_key="input")
273
+ qp.add_link(
274
+ "input", "retry_prompt", src_key="convo_history", dest_key="convo_history"
275
+ )
276
+ qp.add_chain(["retry_prompt", "llm", "sql_query_engine", "output_component"])
277
+
278
+
279
+ # from pyvis.network import Network
280
+
281
+ # net = Network(notebook=True, cdn_resources="in_line", directed=True)
282
+ # net.from_nx(qp.clean_dag)
283
+ # net.show("retry_agent_dag.html")
284
+
285
+ agent_worker = QueryPipelineAgentWorker(qp)
286
+
287
+ agent = AgentRunner(
288
+ agent_worker=agent_worker,
289
+ callback_manager=CallbackManager([]), verbose=True
290
+ )
291
+
292
+ def slow_echo(query,history):
293
+ # print(extracted_sql_query)
294
+ response = agent.chat(query)
295
+ message = str(response).split(':')[-1].strip()
296
+ res = f'**Answer:**\n\n{message}\n\n\n\n**LLM Generated SQL Query:**\n\n{extracted_sql_query}'
297
+
298
+ for i in range(len(res)):
299
+ time.sleep(0.01)
300
+ yield res[: i + 1]
301
+
302
+ demo = gr.ChatInterface(slow_echo).queue()
303
+
304
+ if __name__ == "__main__":
305
+ demo.launch(share=True)