{ "cells": [ { "cell_type": "markdown", "id": "a0f21cb1-fbc8-4282-b902-f47d92974df8", "metadata": {}, "source": [ "# Pre-requisites" ] }, { "cell_type": "markdown", "id": "5f625807-0707-4e2f-a0e0-8fbcdf08c865", "metadata": {}, "source": [ "## Why TEI\n", "There are 2 **unsung** challenges with RAG at scale:\n", "1. Getting the embeddings efficiently\n", "1. Efficient ingestion into the vector DB\n", "\n", "The issue with `1.` is that there are techniques but they are not widely *applied*. TEI solves a number of aspects:\n", "- Token Based Dynamic Batching\n", "- Using latest optimizations (Flash Attention, Candle and cuBLASLt)\n", "- Fast loading with safetensors\n", "\n", "The issue with `2.` is that it takes a bit of planning. We wont go much into that side of things here though." ] }, { "cell_type": "markdown", "id": "3102abce-ea42-4da6-8c98-c6dd4edf7f0b", "metadata": {}, "source": [ "## Start TEI\n", "Run [TEI](https://github.com/huggingface/text-embeddings-inference#docker), I have this running in a nvidia-docker container, but you can install as you like. Note that I ran this in a different terminal for monitoring and seperation. \n", "\n", "Note that as its running, its always going to pull the latest. Its at a very early stage at the time of writing. \n", "\n", "I chose the smaller [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) instead of the large. Its just as good on [mteb/leaderboard](https://huggingface.co/spaces/mteb/leaderboard) but its faster and smaller. TEI is fast, but this will make our life easier for storage and retrieval.\n", "\n", "I use the `revision=refs/pr/1` because this has the pull request with [safetensors](https://github.com/huggingface/safetensors) which is required by TEI. Check out the [pull request](https://huggingface.co/BAAI/bge-base-en-v1.5/discussions/1) if you want to use a different embedding model and it doesnt have safetensors." ] }, { "cell_type": "code", "execution_count": 1, "id": "7e873652-8257-4aae-92bc-94e1bac54b73", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%bash\n", "\n", "# volume=$PWD/data\n", "# model=BAAI/bge-base-en-v1.5\n", "# revision=refs/pr/1\n", "# docker run \\\n", "# --gpus all \\\n", "# -p 8080:80 \\\n", "# -v $volume:/data \\\n", "# --pull always \\\n", "# ghcr.io/huggingface/text-embeddings-inference:latest \\\n", "# --model-id $model \\\n", "# --revision $revision \\\n", "# --pooling cls \\\n", "# --max-batch-tokens 65536" ] }, { "cell_type": "markdown", "id": "86a5ff83-1038-4880-8c90-dc3cab75cb49", "metadata": {}, "source": [ "## Test Endpoint" ] }, { "cell_type": "code", "execution_count": 2, "id": "52edfc97-5b6f-44f9-8d89-8578cf79fae9", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "passed\n" ] } ], "source": [ "%%bash\n", "\n", "response_code=$(curl -s -o /dev/null -w \"%{http_code}\" 127.0.0.1:8080/embed \\\n", " -X POST \\\n", " -d '{\"inputs\":\"What is Deep Learning?\"}' \\\n", " -H 'Content-Type: application/json')\n", "\n", "if [ \"$response_code\" -eq 200 ]; then\n", " echo \"passed\"\n", "else\n", " echo \"failed\"\n", "fi" ] }, { "cell_type": "markdown", "id": "b1b28232-b65d-41ce-88de-fd70b93a528d", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 3, "id": "88408486-566a-4791-8ef2-5ee3e6941156", "metadata": { "tags": [] }, "outputs": [], "source": [ "from IPython.core.interactiveshell import InteractiveShell\n", "InteractiveShell.ast_node_interactivity = 'all'" ] }, { "cell_type": "code", "execution_count": 4, "id": "abb5186b-ee67-4e1e-882d-3d8d5b4575d4", "metadata": { "tags": [] }, "outputs": [], "source": [ "import asyncio\n", "from pathlib import Path\n", "import pickle\n", "\n", "import aiohttp\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 5, "id": "c4b82ea2-8b30-4c2e-99f0-9a30f2f1bfb7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ec2-user/RAGDemo\n" ] } ], "source": [ "proj_dir = Path.cwd().parent\n", "print(proj_dir)" ] }, { "cell_type": "markdown", "id": "76119e74-f601-436d-a253-63c5a19d1c83", "metadata": {}, "source": [ "# Config" ] }, { "cell_type": "markdown", "id": "0d2bcda7-b245-45e3-a347-34166f217e1e", "metadata": {}, "source": [ "I'm putting the documents in pickle files. The compression is nice, though its important to note pickles are known to be a security risk." ] }, { "cell_type": "code", "execution_count": 6, "id": "f6f74545-54a7-4f41-9f02-96964e1417f0", "metadata": { "tags": [] }, "outputs": [], "source": [ "file_in = proj_dir / 'data/processed/simple_wiki_processed.pkl'\n", "file_out = proj_dir / 'data/processed/simple_wiki_embeddings.pkl'" ] }, { "cell_type": "markdown", "id": "d2dd0df0-4274-45b3-9ee5-0205494e4d75", "metadata": { "tags": [] }, "source": [ "# Setup\n", "Read in our list of documents and convert them to dictionaries for processing." ] }, { "cell_type": "code", "execution_count": 7, "id": "3c08e039-3686-4eca-9f87-7c469e3f19bc", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.24 s, sys: 928 ms, total: 7.17 s\n", "Wall time: 6.61 s\n" ] } ], "source": [ "%%time\n", "with open(file_in, 'rb') as handle:\n", " documents = pickle.load(handle)\n", "\n", "documents = [document.to_dict() for document in documents]" ] }, { "cell_type": "markdown", "id": "5e73235d-6274-4958-9e57-977afeeb5f1b", "metadata": {}, "source": [ "# Embed\n", "## Strategy\n", "TEI allows multiple concurrent requests, so its important that we dont waste the potential we have. I used the default `max-concurrent-requests` value of `512`, so I want to use that many `MAX_WORKERS`.\n", "\n", "Im using an `async` way of making requests that uses `aiohttp` as well as a nice progress bar. " ] }, { "cell_type": "code", "execution_count": 8, "id": "949d6bf8-804f-496b-a59a-834483cc7073", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Constants\n", "ENDPOINT = \"http://127.0.0.1:8080/embed\"\n", "HEADERS = {'Content-Type': 'application/json'}\n", "MAX_WORKERS = 512" ] }, { "cell_type": "markdown", "id": "cf3da8cc-1651-4704-9091-39c2a1b835be", "metadata": {}, "source": [ "Note that Im using `'truncate':True` as even with our `350` word split earlier, there are always exceptions. Its important that as this scales we have as few issues as possible when embedding. " ] }, { "cell_type": "code", "execution_count": 9, "id": "3353c849-a36c-4047-bb81-93dac6c49b68", "metadata": { "tags": [] }, "outputs": [], "source": [ "async def fetch(session, url, document):\n", " payload = {\"inputs\": [document[\"content\"]], 'truncate':True}\n", " async with session.post(url, json=payload) as response:\n", " if response.status == 200:\n", " resp_json = await response.json()\n", " # Assuming the server's response contains an 'embedding' field\n", " document[\"embedding\"] = resp_json[0]\n", " else:\n", " print(f\"Error {response.status}: {await response.text()}\")\n", " # Handle error appropriately if needed\n", "\n", "async def main(documents):\n", " async with aiohttp.ClientSession(headers=HEADERS) as session:\n", " tasks = [fetch(session, ENDPOINT, doc) for doc in documents]\n", " await asyncio.gather(*tasks)" ] }, { "cell_type": "code", "execution_count": 10, "id": "f0d17264-72dc-40be-aa46-17cde38c8189", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f0ff772e915f4432971317e2150b60f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing documents: 0%| | 0/526 [00:00