File size: 20,882 Bytes
c82bf42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Incoporating semantic similarity in tabular databases\n",
    "\n",
    "In this notebook we will cover how to run semantic search over a specific table column within a single SQL query, combining tabular query with RAG.\n",
    "\n",
    "\n",
    "### Overall workflow\n",
    "\n",
    "1. Generating embeddings for a specific column\n",
    "2. Storing the embeddings in a new column (if column has low cardinality, it's better to use another table containing unique values and their embeddings)\n",
    "3. Querying using standard SQL queries with [PGVector](https://github.com/pgvector/pgvector) extension which allows using L2 distance (`<->`), Cosine distance (`<=>` or cosine similarity using `1 - <=>`) and Inner product (`<#>`)\n",
    "4. Running standard SQL query\n",
    "\n",
    "### Requirements\n",
    "\n",
    "We will need a PostgreSQL database with [pgvector](https://github.com/pgvector/pgvector) extension enabled. For this example, we will use a `Chinook` database using a local PostgreSQL server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = os.environ.get(\"OPENAI_API_KEY\") or getpass.getpass(\n",
    "    \"OpenAI API Key:\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.sql_database import SQLDatabase\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "CONNECTION_STRING = \"postgresql+psycopg2://postgres:test@localhost:5432/vectordb\"  # Replace with your own\n",
    "db = SQLDatabase.from_uri(CONNECTION_STRING)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embedding the song titles"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example, we will run queries based on semantic meaning of song titles. In order to do this, let's start by adding a new column in the table for storing the embeddings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# db.run('ALTER TABLE \"Track\" ADD COLUMN \"embeddings\" vector;')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate the embedding for each *track title* and store it as a new column in our \"Track\" table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "embeddings_model = OpenAIEmbeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3503"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tracks = db.run('SELECT \"Name\" FROM \"Track\"')\n",
    "song_titles = [s[0] for s in eval(tracks)]\n",
    "title_embeddings = embeddings_model.embed_documents(song_titles)\n",
    "len(title_embeddings)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's insert the embeddings in the into the new column from our table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for i in tqdm(range(len(title_embeddings))):\n",
    "    title = song_titles[i].replace(\"'\", \"''\")\n",
    "    embedding = title_embeddings[i]\n",
    "    sql_command = (\n",
    "        f'UPDATE \"Track\" SET \"embeddings\" = ARRAY{embedding} WHERE \"Name\" ='\n",
    "        + f\"'{title}'\"\n",
    "    )\n",
    "    db.run(sql_command)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can test the semantic search running the following query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[(\"Tomorrow\\'s Dream\",), (\\'Remember Tomorrow\\',), (\\'Remember Tomorrow\\',), (\\'The Best Is Yet To Come\\',), (\"Thinking \\'Bout Tomorrow\",)]'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeded_title = embeddings_model.embed_query(\"hope about the future\")\n",
    "query = (\n",
    "    'SELECT \"Track\".\"Name\" FROM \"Track\" WHERE \"Track\".\"embeddings\" IS NOT NULL ORDER BY \"embeddings\" <-> '\n",
    "    + f\"'{embeded_title}' LIMIT 5\"\n",
    ")\n",
    "db.run(query)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the SQL Chain"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by defining useful functions to get info from database and running the query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_schema(_):\n",
    "    return db.get_table_info()\n",
    "\n",
    "\n",
    "def run_query(query):\n",
    "    return db.run(query)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's build the **prompt** we will use. This prompt is an extension from [text-to-postgres-sql](https://smith.langchain.com/hub/jacob/text-to-postgres-sql?organizationId=f9b614b8-5c3a-4e7c-afbc-6d7ad4fd8892) prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "\n",
    "template = \"\"\"You are a Postgres expert. Given an input question, first create a syntactically correct Postgres query to run, then look at the results of the query and return the answer to the input question.\n",
    "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per Postgres. You can order the results to return the most informative data in the database.\n",
    "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
    "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
    "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
    "\n",
    "You can use an extra extension which allows you to run semantic similarity using <-> operator on tables containing columns named \"embeddings\".\n",
    "<-> operator can ONLY be used on embeddings columns.\n",
    "The embeddings value for a given row typically represents the semantic meaning of that row.\n",
    "The vector represents an embedding representation of the question, given below. \n",
    "Do NOT fill in the vector values directly, but rather specify a `[search_word]` placeholder, which should contain the word that would be embedded for filtering.\n",
    "For example, if the user asks for songs about 'the feeling of loneliness' the query could be:\n",
    "'SELECT \"[whatever_table_name]\".\"SongName\" FROM \"[whatever_table_name]\" ORDER BY \"embeddings\" <-> '[loneliness]' LIMIT 5'\n",
    "\n",
    "Use the following format:\n",
    "\n",
    "Question: <Question here>\n",
    "SQLQuery: <SQL Query to run>\n",
    "SQLResult: <Result of the SQLQuery>\n",
    "Answer: <Final answer here>\n",
    "\n",
    "Only use the following tables:\n",
    "\n",
    "{schema}\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"system\", template), (\"human\", \"{question}\")]\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can create the chain using **[LangChain Expression Language](https://python.langchain.com/docs/expression_language/)**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "db = SQLDatabase.from_uri(\n",
    "    CONNECTION_STRING\n",
    ")  # We reconnect to db so the new columns are loaded as well.\n",
    "llm = ChatOpenAI(model=\"gpt-4\", temperature=0)\n",
    "\n",
    "sql_query_chain = (\n",
    "    RunnablePassthrough.assign(schema=get_schema)\n",
    "    | prompt\n",
    "    | llm.bind(stop=[\"\\nSQLResult:\"])\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'SQLQuery: SELECT \"Track\".\"Name\" FROM \"Track\" JOIN \"Genre\" ON \"Track\".\"GenreId\" = \"Genre\".\"GenreId\" WHERE \"Genre\".\"Name\" = \\'Rock\\' ORDER BY \"Track\".\"embeddings\" <-> \\'[dispair]\\' LIMIT 5'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_query_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"Which are the 5 rock songs with titles about deep feeling of dispair?\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This chain simply generates the query. Now we will create the full chain that also handles the execution and the final result for the user:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "\n",
    "\n",
    "def replace_brackets(match):\n",
    "    words_inside_brackets = match.group(1).split(\", \")\n",
    "    embedded_words = [\n",
    "        str(embeddings_model.embed_query(word)) for word in words_inside_brackets\n",
    "    ]\n",
    "    return \"', '\".join(embedded_words)\n",
    "\n",
    "\n",
    "def get_query(query):\n",
    "    sql_query = re.sub(r\"\\[([\\w\\s,]+)\\]\", replace_brackets, query)\n",
    "    return sql_query\n",
    "\n",
    "\n",
    "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
    "{schema}\n",
    "\n",
    "Question: {question}\n",
    "SQL Query: {query}\n",
    "SQL Response: {response}\"\"\"\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"system\", template), (\"human\", \"{question}\")]\n",
    ")\n",
    "\n",
    "full_chain = (\n",
    "    RunnablePassthrough.assign(query=sql_query_chain)\n",
    "    | RunnablePassthrough.assign(\n",
    "        schema=get_schema,\n",
    "        response=RunnableLambda(lambda x: db.run(get_query(x[\"query\"]))),\n",
    "    )\n",
    "    | prompt\n",
    "    | llm\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using the Chain"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 1: Filtering a column based on semantic meaning"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's say we want to retrieve songs that express `deep feeling of dispair`, but filtering based on genre:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"The 5 rock songs with titles that convey a deep feeling of despair are 'Sea Of Sorrow', 'Surrender', 'Indifference', 'Hard Luck Woman', and 'Desire'.\")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"Which are the 5 rock songs with titles about deep feeling of dispair?\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What is substantially different in implementing this method is that we have combined:\n",
    "- Semantic search (songs that have titles with some semantic meaning)\n",
    "- Traditional tabular querying (running JOIN statements to filter track based on genre)\n",
    "\n",
    "This is something we _could_ potentially achieve using metadata filtering, but it's more complex to do so (we would need to use a vector database containing the embeddings, and use metadata filtering based on genre).\n",
    "\n",
    "However, for other use cases metadata filtering **wouldn't be enough**."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 2: Combining filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"The three albums which have the most amount of songs in the top 150 saddest songs are 'International Superhits' with 5 songs, 'Ten' with 4 songs, and 'Album Of The Year' with 3 songs.\")"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"I want to know the 3 albums which have the most amount of songs in the top 150 saddest songs\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So we have result for 3 albums with most amount of songs in top 150 saddest ones. This **wouldn't** be possible using only standard metadata filtering. Without this _hybdrid query_, we would need some postprocessing to get the result.\n",
    "\n",
    "Another similar exmaple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"The 6 albums with the shortest titles that contain songs which are in the 20 saddest song list are 'Ten', 'Core', 'Big Ones', 'One By One', 'Black Album', and 'Miles Ahead'.\")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"I need the 6 albums with shortest title, as long as they contain songs which are in the 20 saddest song list.\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what the query looks like to double check:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WITH \"SadSongs\" AS (\n",
      "    SELECT \"TrackId\" FROM \"Track\" \n",
      "    ORDER BY \"embeddings\" <-> '[sad]' LIMIT 20\n",
      "),\n",
      "\"SadAlbums\" AS (\n",
      "    SELECT DISTINCT \"AlbumId\" FROM \"Track\" \n",
      "    WHERE \"TrackId\" IN (SELECT \"TrackId\" FROM \"SadSongs\")\n",
      ")\n",
      "SELECT \"Album\".\"Title\" FROM \"Album\" \n",
      "WHERE \"AlbumId\" IN (SELECT \"AlbumId\" FROM \"SadAlbums\") \n",
      "ORDER BY \"title_len\" ASC \n",
      "LIMIT 6\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    sql_query_chain.invoke(\n",
    "        {\n",
    "            \"question\": \"I need the 6 albums with shortest title, as long as they contain songs which are in the 20 saddest song list.\"\n",
    "        }\n",
    "    )\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 3: Combining two separate semantic searches\n",
    "\n",
    "One interesting aspect of this approach which is **substantially different from using standar RAG** is that we can even **combine** two semantic search filters:\n",
    "- _Get 5 saddest songs..._\n",
    "- _**...obtained from albums with \"lovely\" titles**_\n",
    "\n",
    "This could generalize to **any kind of combined RAG** (paragraphs discussing _X_ topic belonging from books about _Y_, replies to a tweet about _ABC_ topic that express _XYZ_ feeling)\n",
    "\n",
    "We will combine semantic search on songs and album titles, so we need to do the same for `Album` table:\n",
    "1. Generate the embeddings\n",
    "2. Add them to the table as a new column (which we need to add in the table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "# db.run('ALTER TABLE \"Album\" ADD COLUMN \"embeddings\" vector;')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 347/347 [00:01<00:00, 179.64it/s]\n"
     ]
    }
   ],
   "source": [
    "albums = db.run('SELECT \"Title\" FROM \"Album\"')\n",
    "album_titles = [title[0] for title in eval(albums)]\n",
    "album_title_embeddings = embeddings_model.embed_documents(album_titles)\n",
    "for i in tqdm(range(len(album_title_embeddings))):\n",
    "    album_title = album_titles[i].replace(\"'\", \"''\")\n",
    "    album_embedding = album_title_embeddings[i]\n",
    "    sql_command = (\n",
    "        f'UPDATE \"Album\" SET \"embeddings\" = ARRAY{album_embedding} WHERE \"Title\" ='\n",
    "        + f\"'{album_title}'\"\n",
    "    )\n",
    "    db.run(sql_command)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[('Realize',), ('Morning Dance',), ('Into The Light',), ('New Adventures In Hi-Fi',), ('Miles Ahead',)]\""
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeded_title = embeddings_model.embed_query(\"hope about the future\")\n",
    "query = (\n",
    "    'SELECT \"Album\".\"Title\" FROM \"Album\" WHERE \"Album\".\"embeddings\" IS NOT NULL ORDER BY \"embeddings\" <-> '\n",
    "    + f\"'{embeded_title}' LIMIT 5\"\n",
    ")\n",
    "db.run(query)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can combine both filters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SQLDatabase.from_uri(\n",
    "    CONNECTION_STRING\n",
    ")  # We reconnect to dbso the new columns are loaded as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='The songs about breakouts obtained from the top 5 albums about love are \\'Royal Orleans\\', \"Nobody\\'s Fault But Mine\", \\'Achilles Last Stand\\', \\'For Your Life\\', and \\'Hots On For Nowhere\\'.')"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"I want to know songs about breakouts obtained from top 5 albums about love\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is something **different** that **couldn't be achieved** using standard metadata filtering over a vectordb."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}