{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c9493399", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#Install Packages\n", "!pip install faiss-cpu\n", "!pip install sentence-transformers" ] }, { "cell_type": "code", "execution_count": 2, "id": "c49be142", "metadata": {}, "outputs": [], "source": [ "# import necessary libraries\n", "import pandas as pd\n", "pd.set_option('display.max_colwidth', 100)" ] }, { "cell_type": "code", "execution_count": 3, "id": "f5a30989", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(8, 2)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(\"sample_text.csv\")\n", "df.shape" ] }, { "cell_type": "code", "execution_count": 49, "id": "b72e2ecb", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textcategory
0Meditation and yoga can improve mental healthHealth
1Fruits, whole grains and vegetables helps control blood pressureHealth
2These are the latest fashion trends for this weekFashion
3Vibrant color jeans for male are becoming a trendFashion
4The concert starts at 7 PM tonightEvent
5Navaratri dandiya program at Expo center in Mumbai this octoberEvent
6Exciting vacation destinations for your next tripTravel
7Maldives and Srilanka are gaining popularity in terms of low budget vacation placesTravel
\n", "
" ], "text/plain": [ " text \\\n", "0 Meditation and yoga can improve mental health \n", "1 Fruits, whole grains and vegetables helps control blood pressure \n", "2 These are the latest fashion trends for this week \n", "3 Vibrant color jeans for male are becoming a trend \n", "4 The concert starts at 7 PM tonight \n", "5 Navaratri dandiya program at Expo center in Mumbai this october \n", "6 Exciting vacation destinations for your next trip \n", "7 Maldives and Srilanka are gaining popularity in terms of low budget vacation places \n", "\n", " category \n", "0 Health \n", "1 Health \n", "2 Fashion \n", "3 Fashion \n", "4 Event \n", "5 Event \n", "6 Travel \n", "7 Travel " ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "id": "2d935944", "metadata": {}, "source": [ "### Step 1 : Create source embeddings for the text column" ] }, { "cell_type": "code", "execution_count": 5, "id": "cd04834b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\dhava\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from sentence_transformers import SentenceTransformer" ] }, { "cell_type": "code", "execution_count": 6, "id": "03ed4874", "metadata": {}, "outputs": [], "source": [ "encoder = SentenceTransformer(\"all-mpnet-base-v2\")\n", "vectors = encoder.encode(df.text)" ] }, { "cell_type": "code", "execution_count": 7, "id": "b8b8c1ce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(8, 768)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vectors.shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "8e5c7da8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "768" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dim = vectors.shape[1]\n", "dim" ] }, { "cell_type": "markdown", "id": "149e6b32", "metadata": {}, "source": [ "### Step 2 : Build a FAISS Index for vectors" ] }, { "cell_type": "code", "execution_count": 9, "id": "1033b6bd", "metadata": {}, "outputs": [], "source": [ "import faiss\n", "\n", "index = faiss.IndexFlatL2(dim)" ] }, { "cell_type": "markdown", "id": "76ad509d", "metadata": {}, "source": [ "### Step 3 : Normalize the source vectors (as we are using L2 distance to measure similarity) and add to the index" ] }, { "cell_type": "code", "execution_count": 10, "id": "90b527fc", "metadata": {}, "outputs": [], "source": [ "index.add(vectors)" ] }, { "cell_type": "code", "execution_count": 11, "id": "7ac0b8ef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ " >" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "index" ] }, { "cell_type": "markdown", "id": "6c42234c", "metadata": {}, "source": [ "### Step 4 : Encode search text using same encorder and normalize the output vector" ] }, { "cell_type": "code", "execution_count": 64, "id": "018faf33", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(768,)" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search_query = \"I want to buy a polo t-shirt\"\n", "# search_query = \"looking for places to visit during the holidays\"\n", "# search_query = \"An apple a day keeps the doctor away\"\n", "vec = encoder.encode(search_query)\n", "vec.shape" ] }, { "cell_type": "code", "execution_count": 66, "id": "af05bce3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 768)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "svec = np.array(vec).reshape(1,-1)\n", "svec.shape" ] }, { "cell_type": "code", "execution_count": 67, "id": "84275adf", "metadata": {}, "outputs": [], "source": [ "# faiss.normalize_L2(svec)" ] }, { "cell_type": "markdown", "id": "90c0cdd8", "metadata": {}, "source": [ "### Step 5: Search for similar vector in the FAISS index created" ] }, { "cell_type": "code", "execution_count": 68, "id": "3d5a0e69", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1.3844836, 1.4039096]], dtype=float32)" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "distances, I = index.search(new_vec, k=2)\n", "distances" ] }, { "cell_type": "code", "execution_count": 69, "id": "7ef978ca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[3, 2]], dtype=int64)" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "I" ] }, { "cell_type": "code", "execution_count": 70, "id": "e2fceefd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[3, 2]]" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "I.tolist()" ] }, { "cell_type": "code", "execution_count": 71, "id": "68f88083", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[3, 2]" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row_indices = I.tolist()[0]\n", "row_indices" ] }, { "cell_type": "code", "execution_count": 72, "id": "d856895d", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textcategory
3Vibrant color jeans for male are becoming a trendFashion
2These are the latest fashion trends for this weekFashion
\n", "
" ], "text/plain": [ " text category\n", "3 Vibrant color jeans for male are becoming a trend Fashion\n", "2 These are the latest fashion trends for this week Fashion" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.loc[row_indices]" ] }, { "cell_type": "code", "execution_count": 73, "id": "b65050a9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I want to buy a polo t-shirt'" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search_query" ] }, { "cell_type": "markdown", "id": "e066c78d", "metadata": {}, "source": [ "You can see that the two results from the dataframe are similar to a search_query" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }