Draichi commited on
Commit
59fc0cc
1 Parent(s): 9a1be75

feat: init advanced_text_to_SQL.ipynb

Browse files
multi-agents-analysis/advanced_text_to_SQL.ipynb ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Query Pipeline for Advanced Text-to-SQL¶\n",
8
+ "\n",
9
+ "In this guide we show you how to setup a text-to-SQL pipeline over your data with our query pipeline syntax.\n",
10
+ "\n",
11
+ "This gives you flexibility to enhance text-to-SQL with additional techniques. We show these in the below sections:\n",
12
+ "\n",
13
+ "1. Query-Time Table Retrieval: Dynamically retrieve relevant tables in the text-to-SQL prompt.\n",
14
+ "2. Query-Time Sample Row retrieval: Embed/Index each row, and dynamically retrieve example rows for each table in the text-to-SQL prompt.\n",
15
+ " Our out-of-the box pipelines include our NLSQLTableQueryEngine and SQLTableRetrieverQueryEngine. (if you want to check out our text-to-SQL guide using these modules, take a look here). This guide implements an advanced version of those modules, giving you the utmost flexibility to apply this to your own setting.\n",
16
+ "\n",
17
+ "NOTE: Any Text-to-SQL application should be aware that executing arbitrary SQL queries can be a security risk. It is recommended to take precautions as needed, such as using restricted roles, read-only databases, sandboxing, etc.\n",
18
+ "\n",
19
+ "## Load and Ingest Data\n",
20
+ "\n",
21
+ "### Load Data\n",
22
+ "\n",
23
+ "We use the [WikiTableQuestions](https://github.com/ppasupat/WikiTableQuestions/releases) dataset (Pasupat and Liang 2015) as our test dataset.\n",
24
+ "\n",
25
+ "We go through all the csv's in one folder, store each in a sqlite database (we will then build an object index over each table schema).\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 9,
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "processing file: WikiTableQuestions/csv/200-csv/0.csv\n",
38
+ "processing file: WikiTableQuestions/csv/200-csv/1.csv\n",
39
+ "processing file: WikiTableQuestions/csv/200-csv/10.csv\n",
40
+ "processing file: WikiTableQuestions/csv/200-csv/11.csv\n",
41
+ "processing file: WikiTableQuestions/csv/200-csv/12.csv\n",
42
+ "processing file: WikiTableQuestions/csv/200-csv/14.csv\n",
43
+ "processing file: WikiTableQuestions/csv/200-csv/15.csv\n",
44
+ "Error parsing WikiTableQuestions/csv/200-csv/15.csv: Error tokenizing data. C error: Expected 4 fields in line 16, saw 5\n",
45
+ "\n",
46
+ "processing file: WikiTableQuestions/csv/200-csv/17.csv\n",
47
+ "Error parsing WikiTableQuestions/csv/200-csv/17.csv: Error tokenizing data. C error: Expected 6 fields in line 5, saw 7\n",
48
+ "\n",
49
+ "processing file: WikiTableQuestions/csv/200-csv/18.csv\n",
50
+ "processing file: WikiTableQuestions/csv/200-csv/20.csv\n",
51
+ "processing file: WikiTableQuestions/csv/200-csv/22.csv\n",
52
+ "processing file: WikiTableQuestions/csv/200-csv/24.csv\n",
53
+ "processing file: WikiTableQuestions/csv/200-csv/25.csv\n",
54
+ "processing file: WikiTableQuestions/csv/200-csv/26.csv\n",
55
+ "processing file: WikiTableQuestions/csv/200-csv/28.csv\n",
56
+ "processing file: WikiTableQuestions/csv/200-csv/29.csv\n",
57
+ "processing file: WikiTableQuestions/csv/200-csv/3.csv\n",
58
+ "processing file: WikiTableQuestions/csv/200-csv/30.csv\n",
59
+ "processing file: WikiTableQuestions/csv/200-csv/31.csv\n",
60
+ "processing file: WikiTableQuestions/csv/200-csv/32.csv\n",
61
+ "processing file: WikiTableQuestions/csv/200-csv/33.csv\n",
62
+ "processing file: WikiTableQuestions/csv/200-csv/34.csv\n",
63
+ "Error parsing WikiTableQuestions/csv/200-csv/34.csv: Error tokenizing data. C error: Expected 4 fields in line 6, saw 13\n",
64
+ "\n",
65
+ "processing file: WikiTableQuestions/csv/200-csv/35.csv\n",
66
+ "processing file: WikiTableQuestions/csv/200-csv/36.csv\n",
67
+ "processing file: WikiTableQuestions/csv/200-csv/37.csv\n",
68
+ "processing file: WikiTableQuestions/csv/200-csv/38.csv\n",
69
+ "processing file: WikiTableQuestions/csv/200-csv/4.csv\n",
70
+ "processing file: WikiTableQuestions/csv/200-csv/41.csv\n",
71
+ "processing file: WikiTableQuestions/csv/200-csv/42.csv\n",
72
+ "processing file: WikiTableQuestions/csv/200-csv/44.csv\n",
73
+ "processing file: WikiTableQuestions/csv/200-csv/45.csv\n",
74
+ "processing file: WikiTableQuestions/csv/200-csv/46.csv\n",
75
+ "processing file: WikiTableQuestions/csv/200-csv/47.csv\n",
76
+ "processing file: WikiTableQuestions/csv/200-csv/48.csv\n",
77
+ "processing file: WikiTableQuestions/csv/200-csv/7.csv\n",
78
+ "processing file: WikiTableQuestions/csv/200-csv/8.csv\n",
79
+ "processing file: WikiTableQuestions/csv/200-csv/9.csv\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "import pandas as pd\n",
85
+ "from pathlib import Path\n",
86
+ "\n",
87
+ "data_dir = Path(\"./WikiTableQuestions/csv/200-csv\")\n",
88
+ "csv_files = sorted([f for f in data_dir.glob(\"*.csv\")])\n",
89
+ "dfs = []\n",
90
+ "for csv_file in csv_files:\n",
91
+ " print(f\"processing file: {csv_file}\")\n",
92
+ " try:\n",
93
+ " df = pd.read_csv(csv_file)\n",
94
+ " dfs.append(df)\n",
95
+ " except Exception as e:\n",
96
+ " print(f\"Error parsing {csv_file}: {str(e)}\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {},
102
+ "source": [
103
+ "### Extract Table Name and Summary from each Table\n",
104
+ "\n",
105
+ "Here we use gpt-3.5 to extract a table name (with underscores) and summary from each table with our Pydantic program.\n"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "from llama_index.core.program import LLMTextCompletionProgram\n",
115
+ "from llama_index.core.bridge.pydantic import BaseModel, Field\n",
116
+ "from llama_index.llms.openai import OpenAI\n",
117
+ "\n",
118
+ "\n",
119
+ "class TableInfo(BaseModel):\n",
120
+ " \"\"\"Information regarding a structured table.\"\"\"\n",
121
+ "\n",
122
+ " table_name: str = Field(\n",
123
+ " ..., description=\"table name (must be underscores and NO spaces)\"\n",
124
+ " )\n",
125
+ " table_summary: str = Field(\n",
126
+ " ..., description=\"short, concise summary/caption of the table\"\n",
127
+ " )\n",
128
+ "\n",
129
+ "\n",
130
+ "prompt_str = \"\"\"\\\n",
131
+ "Give me a summary of the table with the following JSON format.\n",
132
+ "\n",
133
+ "- The table name must be unique to the table and describe it while being concise. \n",
134
+ "- Do NOT output a generic table name (e.g. table, my_table).\n",
135
+ "\n",
136
+ "Do NOT make the table name one of the following: {exclude_table_name_list}\n",
137
+ "\n",
138
+ "Table:\n",
139
+ "{table_str}\n",
140
+ "\n",
141
+ "Summary: \"\"\"\n",
142
+ "\n",
143
+ "program = LLMTextCompletionProgram.from_defaults(\n",
144
+ " output_cls=TableInfo,\n",
145
+ " llm=OpenAI(model=\"gpt-3.5-turbo\"),\n",
146
+ " prompt_template_str=prompt_str,\n",
147
+ ")\n",
148
+ "\n",
149
+ "print(program)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "import json\n",
159
+ "\n",
160
+ "\n",
161
+ "def _get_tableinfo_with_index(idx: int):\n",
162
+ " results_gen = Path(\"WikiTableQuestions_TableInfo\").glob(f\"{idx}_*\")\n",
163
+ " results_list = list(results_gen)\n",
164
+ " if len(results_list) == 0:\n",
165
+ " return None\n",
166
+ " if len(results_list) == 1:\n",
167
+ " path = results_list[0]\n",
168
+ " return TableInfo.parse_file(path)\n",
169
+ " else:\n",
170
+ " raise ValueError(\n",
171
+ " f\"More than one file matching index: {list(results_gen)}\"\n",
172
+ " )\n",
173
+ "\n",
174
+ "\n",
175
+ "table_names = set()\n",
176
+ "table_infos = []\n",
177
+ "for idx, df in enumerate(dfs):\n",
178
+ " table_info = _get_tableinfo_with_index(idx)\n",
179
+ " if table_info:\n",
180
+ " table_infos.append(table_info)\n",
181
+ " else:\n",
182
+ " while True:\n",
183
+ " df_str = df.head(10).to_csv()\n",
184
+ " table_info = program(\n",
185
+ " table_str=df_str,\n",
186
+ " exclude_table_name_list=str(list(table_names)),\n",
187
+ " )\n",
188
+ " table_name = table_info.table_name\n",
189
+ " print(f\"Processed table: {table_name}\")\n",
190
+ " if table_name not in table_names:\n",
191
+ " table_names.add(table_name)\n",
192
+ " break\n",
193
+ " else:\n",
194
+ " # try again\n",
195
+ " print(f\"Table name {table_name} already exists, trying again.\")\n",
196
+ " pass\n",
197
+ "\n",
198
+ " out_file = f\"WikiTableQuestions_TableInfo/{idx}_{table_name}.json\"\n",
199
+ " json.dump(table_info.dict(), open(out_file, \"w\"))\n",
200
+ " table_infos.append(table_info)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "metadata": {},
206
+ "source": [
207
+ "### Put Data in SQL Database\n",
208
+ "\n",
209
+ "We use sqlalchemy, a popular SQL database toolkit, to load all the tables.\n"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "# put data into sqlite db\n",
219
+ "from sqlalchemy import (\n",
220
+ " create_engine,\n",
221
+ " MetaData,\n",
222
+ " Table,\n",
223
+ " Column,\n",
224
+ " String,\n",
225
+ " Integer,\n",
226
+ ")\n",
227
+ "import re\n",
228
+ "\n",
229
+ "\n",
230
+ "# Function to create a sanitized column name\n",
231
+ "def sanitize_column_name(col_name):\n",
232
+ " # Remove special characters and replace spaces with underscores\n",
233
+ " return re.sub(r\"\\W+\", \"_\", col_name)\n",
234
+ "\n",
235
+ "\n",
236
+ "# Function to create a table from a DataFrame using SQLAlchemy\n",
237
+ "def create_table_from_dataframe(\n",
238
+ " df: pd.DataFrame, table_name: str, engine, metadata_obj\n",
239
+ "):\n",
240
+ " # Sanitize column names\n",
241
+ " sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}\n",
242
+ " df = df.rename(columns=sanitized_columns)\n",
243
+ "\n",
244
+ " # Dynamically create columns based on DataFrame columns and data types\n",
245
+ " columns = [\n",
246
+ " Column(col, String if dtype == \"object\" else Integer)\n",
247
+ " for col, dtype in zip(df.columns, df.dtypes)\n",
248
+ " ]\n",
249
+ "\n",
250
+ " # Create a table with the defined columns\n",
251
+ " table = Table(table_name, metadata_obj, *columns)\n",
252
+ "\n",
253
+ " # Create the table in the database\n",
254
+ " metadata_obj.create_all(engine)\n",
255
+ "\n",
256
+ " # Insert data from DataFrame into the table\n",
257
+ " with engine.connect() as conn:\n",
258
+ " for _, row in df.iterrows():\n",
259
+ " insert_stmt = table.insert().values(**row.to_dict())\n",
260
+ " conn.execute(insert_stmt)\n",
261
+ " conn.commit()\n",
262
+ "\n",
263
+ "\n",
264
+ "engine = create_engine(\"sqlite:///:memory:\")\n",
265
+ "metadata_obj = MetaData()\n",
266
+ "for idx, df in enumerate(dfs):\n",
267
+ " tableinfo = _get_tableinfo_with_index(idx)\n",
268
+ " print(f\"Creating table: {tableinfo.table_name}\")\n",
269
+ " create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {},
275
+ "source": [
276
+ "Setup Arize Phoenix for observability\n"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "from openinference.instrumentation.llama_index import LlamaIndexInstrumentor\n",
286
+ "from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter\n",
287
+ "from opentelemetry.sdk import trace as trace_sdk\n",
288
+ "from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
289
+ "\n",
290
+ "endpoint = \"http://127.0.0.1:6006/v1/traces\" # Phoenix receiver address\n",
291
+ "\n",
292
+ "tracer_provider = trace_sdk.TracerProvider()\n",
293
+ "tracer_provider.add_span_processor(\n",
294
+ " SimpleSpanProcessor(OTLPSpanExporter(endpoint)))\n",
295
+ "\n",
296
+ "LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "metadata": {},
302
+ "source": [
303
+ "## Advanced Capability 1: Text-to-SQL with Query-Time Table Retrieval.\n",
304
+ "\n",
305
+ "We now show you how to setup an e2e text-to-SQL with table retrieval.\n",
306
+ "\n",
307
+ "Here we define the core modules.\n",
308
+ "\n",
309
+ "1. Object index + retriever to store table schemas\n",
310
+ "2. SQLDatabase object to connect to the above tables + SQLRetriever.\n",
311
+ "3. Text-to-SQL Prompt\n",
312
+ "4. Response synthesis Prompt\n",
313
+ "5. LLM\n",
314
+ "\n",
315
+ "### 1. Object index, retriever, SQLDatabase\n"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "from llama_index.core.objects import (\n",
325
+ " SQLTableNodeMapping,\n",
326
+ " ObjectIndex,\n",
327
+ " SQLTableSchema,\n",
328
+ ")\n",
329
+ "from llama_index.core import SQLDatabase, VectorStoreIndex\n",
330
+ "\n",
331
+ "sql_database = SQLDatabase(engine)\n",
332
+ "\n",
333
+ "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
334
+ "\n",
335
+ "table_schema_objs = [\n",
336
+ " SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)\n",
337
+ " for t in table_infos\n",
338
+ "] # add a SQLTableSchema for each table\n",
339
+ "\n",
340
+ "obj_index = ObjectIndex.from_objects(objects=table_schema_objs,\n",
341
+ " object_mapping=table_node_mapping,\n",
342
+ " index_cls=VectorStoreIndex,\n",
343
+ " )\n",
344
+ "obj_retriever = obj_index.as_retriever(similarity_top_k=3)"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "metadata": {},
350
+ "source": [
351
+ "### 2. SQLRetriever + Table Parser\n"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "from llama_index.core.retrievers import SQLRetriever\n",
361
+ "from typing import List\n",
362
+ "from llama_index.core.query_pipeline import FnComponent\n",
363
+ "\n",
364
+ "sql_retriever = SQLRetriever(sql_database)\n",
365
+ "\n",
366
+ "\n",
367
+ "def get_table_context_str(table_schema_objs: List[SQLTableSchema]):\n",
368
+ " \"\"\"Get table context string.\"\"\"\n",
369
+ " context_strs = []\n",
370
+ " for table_schema_obj in table_schema_objs:\n",
371
+ " table_info = sql_database.get_single_table_info(\n",
372
+ " table_schema_obj.table_name\n",
373
+ " )\n",
374
+ " if table_schema_obj.context_str:\n",
375
+ " table_opt_context = \" The table description is: \"\n",
376
+ " table_opt_context += table_schema_obj.context_str\n",
377
+ " table_info += table_opt_context\n",
378
+ "\n",
379
+ " context_strs.append(table_info)\n",
380
+ " return \"\\n\\n\".join(context_strs)\n",
381
+ "\n",
382
+ "\n",
383
+ "table_parser_component = FnComponent(fn=get_table_context_str)"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "metadata": {},
389
+ "source": [
390
+ "### 3. Text-to-SQL Prompt + Output Parser\n"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": null,
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT\n",
400
+ "from llama_index.core import PromptTemplate\n",
401
+ "from llama_index.core.query_pipeline import FnComponent\n",
402
+ "from llama_index.core.llms import ChatResponse\n",
403
+ "\n",
404
+ "\n",
405
+ "def extract_sql_query(content: str) -> str:\n",
406
+ " sql_query_start = content.find(\"SQLQuery:\")\n",
407
+ " if sql_query_start == -1:\n",
408
+ " raise ValueError(\"No 'SQLQuery:' marker found in the response content\")\n",
409
+ "\n",
410
+ " query_content = content[sql_query_start + len(\"SQLQuery:\"):]\n",
411
+ " sql_result_start = query_content.find(\"SQLResult:\")\n",
412
+ "\n",
413
+ " if sql_result_start != -1:\n",
414
+ " query_content = query_content[:sql_result_start]\n",
415
+ "\n",
416
+ " return query_content\n",
417
+ "\n",
418
+ "\n",
419
+ "def clean_sql_query(query: str) -> str:\n",
420
+ " return query.strip().strip(\"```\").strip()\n",
421
+ "\n",
422
+ "\n",
423
+ "def parse_response_to_sql(response: ChatResponse) -> str:\n",
424
+ " \"\"\"\n",
425
+ " Parse a ChatResponse object to extract the SQL query.\n",
426
+ "\n",
427
+ " This function takes a ChatResponse object, which is expected to contain\n",
428
+ " an SQL query within its content, and extracts the SQL query string.\n",
429
+ " The function looks for specific markers ('SQLQuery:' and 'SQLResult:')\n",
430
+ " to identify the SQL query portion of the response.\n",
431
+ "\n",
432
+ " Args:\n",
433
+ " response (ChatResponse): A ChatResponse object containing the response\n",
434
+ " from a text-to-SQL model.\n",
435
+ "\n",
436
+ " Returns:\n",
437
+ " str: The extracted SQL query as a string, with surrounding whitespace\n",
438
+ " and code block markers (```) removed.\n",
439
+ "\n",
440
+ " Raises:\n",
441
+ " AttributeError: If the input doesn't have the expected 'message.content' attribute.\n",
442
+ " ValueError: If no 'SQLQuery:' marker is found in the response content.\n",
443
+ "\n",
444
+ " Note:\n",
445
+ " - The function assumes that the SQL query is preceded by 'SQLQuery:' \n",
446
+ " and optionally followed by 'SQLResult:'.\n",
447
+ " - Any content before 'SQLQuery:' or after 'SQLResult:' is discarded.\n",
448
+ " - The function removes leading/trailing whitespace and code block markers.\n",
449
+ "\n",
450
+ " Example:\n",
451
+ " >>> response = ChatResponse(message=Message(content=\"Some text\\nSQLQuery: SELECT * FROM table\\nSQLResult: ...\"))\n",
452
+ " >>> sql_query = parse_response_to_sql(response)\n",
453
+ " >>> print(sql_query)\n",
454
+ " SELECT * FROM table\n",
455
+ " \"\"\"\n",
456
+ " try:\n",
457
+ " content = str(response.message.content)\n",
458
+ " except AttributeError:\n",
459
+ " raise ValueError(\n",
460
+ " \"Input must be a ChatResponse object with a 'message.content' attribute\")\n",
461
+ "\n",
462
+ " sql_query = extract_sql_query(content)\n",
463
+ " return clean_sql_query(sql_query)\n",
464
+ "\n",
465
+ "\n",
466
+ "sql_parser_component = FnComponent(fn=parse_response_to_sql)\n",
467
+ "\n",
468
+ "text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(\n",
469
+ " dialect=engine.dialect.name\n",
470
+ ")\n",
471
+ "print(text2sql_prompt.template)"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "metadata": {},
477
+ "source": [
478
+ "### 4. Response Synthesis Prompt\n"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "response_synthesis_prompt_str = (\n",
488
+ " \"Given an input question, synthesize a response from the query results.\\n\"\n",
489
+ " \"Query: {query_str}\\n\"\n",
490
+ " \"SQL: {sql_query}\\n\"\n",
491
+ " \"SQL Response: {context_str}\\n\"\n",
492
+ " \"Response: \"\n",
493
+ ")\n",
494
+ "response_synthesis_prompt = PromptTemplate(\n",
495
+ " response_synthesis_prompt_str,\n",
496
+ ")"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "markdown",
501
+ "metadata": {},
502
+ "source": [
503
+ "### 5. LLM\n"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "metadata": {},
510
+ "outputs": [],
511
+ "source": [
512
+ "llm = OpenAI(model=\"gpt-3.5-turbo\")"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "markdown",
517
+ "metadata": {},
518
+ "source": [
519
+ "#### Define Query Pipeline\n",
520
+ "\n",
521
+ "Now that the components are in place, let's define the query pipeline!\n"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": null,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "from llama_index.core.query_pipeline import (\n",
531
+ " QueryPipeline as QP,\n",
532
+ " Link,\n",
533
+ " InputComponent,\n",
534
+ " CustomQueryComponent,\n",
535
+ ")\n",
536
+ "\n",
537
+ "qp = QP(\n",
538
+ " modules={\n",
539
+ " \"input\": InputComponent(),\n",
540
+ " \"table_retriever\": obj_retriever,\n",
541
+ " \"table_output_parser\": table_parser_component,\n",
542
+ " \"text2sql_prompt\": text2sql_prompt,\n",
543
+ " \"text2sql_llm\": llm,\n",
544
+ " \"sql_output_parser\": sql_parser_component,\n",
545
+ " \"sql_retriever\": sql_retriever,\n",
546
+ " \"response_synthesis_prompt\": response_synthesis_prompt,\n",
547
+ " \"response_synthesis_llm\": llm,\n",
548
+ " },\n",
549
+ " verbose=True,\n",
550
+ ")\n",
551
+ "qp"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "qp.add_chain([\"input\", \"table_retriever\", \"table_output_parser\"])\n",
561
+ "qp.add_link(\"input\", \"text2sql_prompt\", dest_key=\"query_str\")\n",
562
+ "qp.add_link(\"table_output_parser\", \"text2sql_prompt\", dest_key=\"schema\")\n",
563
+ "qp.add_chain(\n",
564
+ " [\"text2sql_prompt\", \"text2sql_llm\", \"sql_output_parser\", \"sql_retriever\"]\n",
565
+ ")\n",
566
+ "qp.add_link(\n",
567
+ " \"sql_output_parser\", \"response_synthesis_prompt\", dest_key=\"sql_query\"\n",
568
+ ")\n",
569
+ "qp.add_link(\n",
570
+ " \"sql_retriever\", \"response_synthesis_prompt\", dest_key=\"context_str\"\n",
571
+ ")\n",
572
+ "qp.add_link(\"input\", \"response_synthesis_prompt\", dest_key=\"query_str\")\n",
573
+ "qp.add_link(\"response_synthesis_prompt\", \"response_synthesis_llm\")"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "markdown",
578
+ "metadata": {},
579
+ "source": [
580
+ "#### Visualize Query Pipeline\n",
581
+ "\n",
582
+ "A really nice property of the query pipeline syntax is you can easily visualize it in a graph via networkx.\n"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": null,
588
+ "metadata": {},
589
+ "outputs": [],
590
+ "source": [
591
+ "from pyvis.network import Network\n",
592
+ "\n",
593
+ "net = Network(notebook=True, cdn_resources=\"in_line\", directed=True)\n",
594
+ "net.from_nx(qp.dag)"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": null,
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "# Save the network as \"text2sql_dag.html\"\n",
604
+ "net.write_html(\"text2sql_dag.html\")"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": null,
610
+ "metadata": {},
611
+ "outputs": [],
612
+ "source": [
613
+ "from IPython.display import display, HTML\n",
614
+ "\n",
615
+ "# Read the contents of the HTML file\n",
616
+ "with open(\"text2sql_dag.html\", \"r\") as file:\n",
617
+ " html_content = file.read()\n",
618
+ "\n",
619
+ "# Display the HTML content\n",
620
+ "display(HTML(html_content))"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "markdown",
625
+ "metadata": {},
626
+ "source": [
627
+ "### Run Some Queries!\n",
628
+ "\n",
629
+ "Now we're ready to run some queries across this entire pipeline.\n"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "code",
634
+ "execution_count": null,
635
+ "metadata": {},
636
+ "outputs": [],
637
+ "source": [
638
+ "response = qp.run(\n",
639
+ " query=\"What was the year that The Notorious B.I.G was signed to Bad Boy?\"\n",
640
+ ")\n",
641
+ "print(str(response))"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "response = qp.run(query=\"Who won best director in the 1972 academy awards\")\n",
651
+ "print(str(response))"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "markdown",
656
+ "metadata": {},
657
+ "source": [
658
+ "## Advanced Capability 2: Text-to-SQL with Query-Time Row Retrieval (along with Table Retrieval)\n",
659
+ "\n",
660
+ "One problem in the previous example is that if the user asks a query that asks for \"The Notorious BIG\" but the artist is stored as \"The Notorious B.I.G\", then the generated SELECT statement will likely not return any matches.\n",
661
+ "\n",
662
+ "We can alleviate this problem by fetching a small number of example rows per table. A naive option would be to just take the first k rows. Instead, we embed, index, and retrieve k relevant rows given the user query to give the text-to-SQL LLM the most contextually relevant information for SQL generation.\n",
663
+ "\n",
664
+ "We now extend our query pipeline.\n",
665
+ "\n",
666
+ "## Index Each Table\n",
667
+ "\n",
668
+ "We embed/index the rows of each table, resulting in one index per table.\n"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": null,
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": [
677
+ "import logging\n",
678
+ "from pathlib import Path\n",
679
+ "from typing import Dict, Optional\n",
680
+ "from llama_index.core import VectorStoreIndex, load_index_from_storage\n",
681
+ "from llama_index.core.schema import TextNode\n",
682
+ "from llama_index.core import StorageContext\n",
683
+ "from sqlalchemy.exc import SQLAlchemyError\n",
684
+ "from sqlalchemy import text\n",
685
+ "\n",
686
+ "logger = logging.getLogger(__name__)\n",
687
+ "\n",
688
+ "\n",
689
+ "def get_table_rows(engine, table_name: str):\n",
690
+ " try:\n",
691
+ " with engine.connect() as conn:\n",
692
+ " cursor = conn.execute(text(f'SELECT * FROM \"{table_name}\"'))\n",
693
+ " return [tuple(row) for row in cursor.fetchall()]\n",
694
+ " except SQLAlchemyError as e:\n",
695
+ " logger.error(f\"Error fetching rows from table {table_name}: {str(e)}\")\n",
696
+ " raise\n",
697
+ "\n",
698
+ "\n",
699
+ "def create_index(rows, index_path: Path):\n",
700
+ " nodes = [TextNode(text=str(t)) for t in rows]\n",
701
+ " index = VectorStoreIndex(nodes)\n",
702
+ " index.set_index_id(\"vector_index\")\n",
703
+ " index.storage_context.persist(str(index_path))\n",
704
+ " return index\n",
705
+ "\n",
706
+ "\n",
707
+ "def load_existing_index(index_path: Path):\n",
708
+ " storage_context = StorageContext.from_defaults(persist_dir=str(index_path))\n",
709
+ " return load_index_from_storage(storage_context, index_id=\"vector_index\")\n",
710
+ "\n",
711
+ "\n",
712
+ "def index_all_tables(\n",
713
+ " sql_database,\n",
714
+ " table_index_dir: str = \"table_index_dir\",\n",
715
+ " force_refresh: bool = False,\n",
716
+ " tables_to_index: Optional[list] = None\n",
717
+ ") -> Dict[str, VectorStoreIndex]:\n",
718
+ " \"\"\"\n",
719
+ " Create or load vector store indexes for specified tables in the given SQL database.\n",
720
+ "\n",
721
+ " Args:\n",
722
+ " sql_database: An instance of SQLDatabase containing the tables to be indexed.\n",
723
+ " table_index_dir (str): The directory where the indexes will be stored.\n",
724
+ " force_refresh (bool): If True, recreate all indexes even if they already exist.\n",
725
+ " tables_to_index (Optional[list]): List of table names to index. If None, index all usable tables.\n",
726
+ "\n",
727
+ " Returns:\n",
728
+ " Dict[str, VectorStoreIndex]: A dictionary of table names to their VectorStoreIndex objects.\n",
729
+ "\n",
730
+ " Raises:\n",
731
+ " OSError: If there's an error creating or accessing the table_index_dir.\n",
732
+ " SQLAlchemyError: If there's an error connecting to the database or executing SQL queries.\n",
733
+ " \"\"\"\n",
734
+ " index_dir = Path(table_index_dir)\n",
735
+ " index_dir.mkdir(parents=True, exist_ok=True)\n",
736
+ "\n",
737
+ " vector_index_dict = {}\n",
738
+ " tables = tables_to_index or sql_database.get_usable_table_names()\n",
739
+ "\n",
740
+ " for table_name in tables:\n",
741
+ " index_path = index_dir / table_name\n",
742
+ " logger.info(f\"Processing table: {table_name}\")\n",
743
+ "\n",
744
+ " try:\n",
745
+ " if not index_path.exists() or force_refresh:\n",
746
+ " logger.info(f\"Creating new index for table: {table_name}\")\n",
747
+ " rows = get_table_rows(sql_database.engine, table_name)\n",
748
+ " index = create_index(rows, index_path)\n",
749
+ " else:\n",
750
+ " logger.info(f\"Loading existing index for table: {table_name}\")\n",
751
+ " index = load_existing_index(index_path)\n",
752
+ "\n",
753
+ " vector_index_dict[table_name] = index\n",
754
+ "\n",
755
+ " except (OSError, SQLAlchemyError) as e:\n",
756
+ " logger.error(f\"Error processing table {table_name}: {str(e)}\")\n",
757
+ " # Decide whether to continue with other tables or raise the exception\n",
758
+ "\n",
759
+ " return vector_index_dict\n",
760
+ "\n",
761
+ "\n",
762
+ "vector_index_dict = index_all_tables(sql_database)"
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "code",
767
+ "execution_count": null,
768
+ "metadata": {},
769
+ "outputs": [],
770
+ "source": [
771
+ "test_retriever = vector_index_dict[\"Bad_Boy_Artists\"].as_retriever(\n",
772
+ " similarity_top_k=1\n",
773
+ ")\n",
774
+ "nodes = test_retriever.retrieve(\"P. Diddy\")\n",
775
+ "print(nodes[0].get_content())"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "markdown",
780
+ "metadata": {},
781
+ "source": [
782
+ "### Define Expanded Table Parser Component\n",
783
+ "\n",
784
+ "We expand the capability of our table_parser_component to not only return the relevant table schemas, but also return relevant rows per table schema.\n",
785
+ "\n",
786
+ "It now takes in both table_schema_objs (output of table retriever), but also the original query_str which will then be used for vector retrieval of relevant rows.\n"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": null,
792
+ "metadata": {},
793
+ "outputs": [],
794
+ "source": [
795
+ "from llama_index.core.retrievers import SQLRetriever\n",
796
+ "from typing import List\n",
797
+ "from llama_index.core.query_pipeline import FnComponent\n",
798
+ "\n",
799
+ "sql_retriever = SQLRetriever(sql_database)\n",
800
+ "\n",
801
+ "\n",
802
+ "def get_table_context_and_rows_str(\n",
803
+ " query_str: str, table_schema_objs: List[SQLTableSchema]\n",
804
+ "):\n",
805
+ " \"\"\"Get table context string.\"\"\"\n",
806
+ " context_strs = []\n",
807
+ " for table_schema_obj in table_schema_objs:\n",
808
+ " # first append table info + additional context\n",
809
+ " table_info = sql_database.get_single_table_info(\n",
810
+ " table_schema_obj.table_name\n",
811
+ " )\n",
812
+ " if table_schema_obj.context_str:\n",
813
+ " table_opt_context = \" The table description is: \"\n",
814
+ " table_opt_context += table_schema_obj.context_str\n",
815
+ " table_info += table_opt_context\n",
816
+ "\n",
817
+ " # also lookup vector index to return relevant table rows\n",
818
+ " vector_retriever = vector_index_dict[\n",
819
+ " table_schema_obj.table_name\n",
820
+ " ].as_retriever(similarity_top_k=2)\n",
821
+ " relevant_nodes = vector_retriever.retrieve(query_str)\n",
822
+ " if len(relevant_nodes) > 0:\n",
823
+ " table_row_context = \"\\nHere are some relevant example rows (values in the same order as columns above)\\n\"\n",
824
+ " for node in relevant_nodes:\n",
825
+ " table_row_context += str(node.get_content()) + \"\\n\"\n",
826
+ " table_info += table_row_context\n",
827
+ "\n",
828
+ " context_strs.append(table_info)\n",
829
+ " return \"\\n\\n\".join(context_strs)\n",
830
+ "\n",
831
+ "\n",
832
+ "table_parser_component = FnComponent(fn=get_table_context_and_rows_str)"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "markdown",
837
+ "metadata": {},
838
+ "source": [
839
+ "### Define Expanded Query Pipeline\n",
840
+ "\n",
841
+ "This looks similar to the query pipeline in section 1, but with an upgraded table_parser_component.\n"
842
+ ]
843
+ },
844
+ {
845
+ "cell_type": "code",
846
+ "execution_count": null,
847
+ "metadata": {},
848
+ "outputs": [],
849
+ "source": [
850
+ "from llama_index.core.query_pipeline import (\n",
851
+ " QueryPipeline as QP,\n",
852
+ " Link,\n",
853
+ " InputComponent,\n",
854
+ " CustomQueryComponent,\n",
855
+ ")\n",
856
+ "\n",
857
+ "qp = QP(\n",
858
+ " modules={\n",
859
+ " \"input\": InputComponent(),\n",
860
+ " \"table_retriever\": obj_retriever,\n",
861
+ " \"table_output_parser\": table_parser_component,\n",
862
+ " \"text2sql_prompt\": text2sql_prompt,\n",
863
+ " \"text2sql_llm\": llm,\n",
864
+ " \"sql_output_parser\": sql_parser_component,\n",
865
+ " \"sql_retriever\": sql_retriever,\n",
866
+ " \"response_synthesis_prompt\": response_synthesis_prompt,\n",
867
+ " \"response_synthesis_llm\": llm,\n",
868
+ " },\n",
869
+ " verbose=True,\n",
870
+ ")\n",
871
+ "qp"
872
+ ]
873
+ },
874
+ {
875
+ "cell_type": "code",
876
+ "execution_count": null,
877
+ "metadata": {},
878
+ "outputs": [],
879
+ "source": [
880
+ "qp.add_link(\"input\", \"table_retriever\")\n",
881
+ "qp.add_link(\"input\", \"table_output_parser\", dest_key=\"query_str\")\n",
882
+ "qp.add_link(\n",
883
+ " \"table_retriever\", \"table_output_parser\", dest_key=\"table_schema_objs\"\n",
884
+ ")\n",
885
+ "qp.add_link(\"input\", \"text2sql_prompt\", dest_key=\"query_str\")\n",
886
+ "qp.add_link(\"table_output_parser\", \"text2sql_prompt\", dest_key=\"schema\")\n",
887
+ "qp.add_chain(\n",
888
+ " [\"text2sql_prompt\", \"text2sql_llm\", \"sql_output_parser\", \"sql_retriever\"]\n",
889
+ ")\n",
890
+ "qp.add_link(\n",
891
+ " \"sql_output_parser\", \"response_synthesis_prompt\", dest_key=\"sql_query\"\n",
892
+ ")\n",
893
+ "qp.add_link(\n",
894
+ " \"sql_retriever\", \"response_synthesis_prompt\", dest_key=\"context_str\"\n",
895
+ ")\n",
896
+ "qp.add_link(\"input\", \"response_synthesis_prompt\", dest_key=\"query_str\")\n",
897
+ "qp.add_link(\"response_synthesis_prompt\", \"response_synthesis_llm\")"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "markdown",
902
+ "metadata": {},
903
+ "source": [
904
+ "### Run Some Queries\n",
905
+ "\n",
906
+ "We can now ask about relevant entries even if it doesn't exactly match the entry in the database.\n"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": null,
912
+ "metadata": {},
913
+ "outputs": [],
914
+ "source": [
915
+ "response = qp.run(\n",
916
+ " query=\"What was the year that The Notorious BIG was signed to Bad Boy?\"\n",
917
+ ")\n",
918
+ "print(str(response))"
919
+ ]
920
+ }
921
+ ],
922
+ "metadata": {
923
+ "kernelspec": {
924
+ "display_name": "llama",
925
+ "language": "python",
926
+ "name": "python3"
927
+ },
928
+ "language_info": {
929
+ "codemirror_mode": {
930
+ "name": "ipython",
931
+ "version": 3
932
+ },
933
+ "file_extension": ".py",
934
+ "mimetype": "text/x-python",
935
+ "name": "python",
936
+ "nbconvert_exporter": "python",
937
+ "pygments_lexer": "ipython3",
938
+ "version": "3.11.9"
939
+ }
940
+ },
941
+ "nbformat": 4,
942
+ "nbformat_minor": 2
943
+ }