diff --git "a/docs/langchain/cookbook/Multi_modal_RAG_google.ipynb" "b/docs/langchain/cookbook/Multi_modal_RAG_google.ipynb" new file mode 100644--- /dev/null +++ "b/docs/langchain/cookbook/Multi_modal_RAG_google.ipynb" @@ -0,0 +1,697 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "SzvBjdID1V3m", + "metadata": { + "id": "SzvBjdID1V3m" + }, + "source": [ + "# Multi-modal RAG with Google Cloud" + ] + }, + { + "cell_type": "markdown", + "id": "4tfidrmE1Zlo", + "metadata": { + "id": "4tfidrmE1Zlo" + }, + "source": [ + "This tutorial demonstrates how to implement the Option 2 described [here](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb) with Generative API on Google Cloud." + ] + }, + { + "cell_type": "markdown", + "id": "84fcd59f-2eaf-4a76-ad1a-96d6db70bf42", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Install the required dependencies, and create an API key for your Google service." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b1e10dd-25de-4c0a-9577-f36e72518f89", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U --quiet langchain langchain_community openai chromadb langchain-experimental\n", + "%pip install --quiet \"unstructured[all-docs]\" pypdf pillow pydantic lxml pillow matplotlib chromadb tiktoken" + ] + }, + { + "cell_type": "markdown", + "id": "pSInKtCZ32mt", + "metadata": { + "id": "pSInKtCZ32mt" + }, + "source": [ + "## Data loading" + ] + }, + { + "cell_type": "markdown", + "id": "Iv2R8-lJ37dG", + "metadata": { + "id": "Iv2R8-lJ37dG" + }, + "source": [ + "We use a zip file with a sub-set of the extracted images and pdf from [this](https://cloudedjudgement.substack.com/p/clouded-judgement-111023) blog post. If you want to follow the full flow, please, use the original [example](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d999f3fe-c165-4772-b63e-ffe4dd5b03cf", + "metadata": {}, + "outputs": [], + "source": [ + "# First download\n", + "import logging\n", + "import zipfile\n", + "\n", + "import requests\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "\n", + "data_url = \"https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip\"\n", + "result = requests.get(data_url)\n", + "filename = \"cj.zip\"\n", + "with open(filename, \"wb\") as file:\n", + " file.write(result.content)\n", + "\n", + "with zipfile.ZipFile(filename, \"r\") as zip_ref:\n", + " zip_ref.extractall()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "eGUfuevMUA6R", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.document_loaders import PyPDFLoader\n", + "\n", + "loader = PyPDFLoader(\"./cj/cj.pdf\")\n", + "docs = loader.load()\n", + "tables = []\n", + "texts = [d.page_content for d in docs]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "Fst17fNHWYcq", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(texts)" + ] + }, + { + "cell_type": "markdown", + "id": "vjfcg_Vn3_1C", + "metadata": { + "id": "vjfcg_Vn3_1C" + }, + "source": [ + "## Multi-vector retriever" + ] + }, + { + "cell_type": "markdown", + "id": "1ynRqJn04BFG", + "metadata": { + "id": "1ynRqJn04BFG" + }, + "source": [ + "Let's generate text and image summaries and save them to a ChromaDB vectorstore." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "kWDWfSDBMPl8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:numexpr.utils:Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n", + "INFO:numexpr.utils:NumExpr defaulting to 8 threads.\n" + ] + } + ], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain_community.chat_models import ChatVertexAI\n", + "from langchain_community.llms import VertexAI\n", + "from langchain_core.messages import AIMessage\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.runnables import RunnableLambda\n", + "\n", + "\n", + "# Generate summaries of text elements\n", + "def generate_text_summaries(texts, tables, summarize_texts=False):\n", + " \"\"\"\n", + " Summarize text elements\n", + " texts: List of str\n", + " tables: List of str\n", + " summarize_texts: Bool to summarize texts\n", + " \"\"\"\n", + "\n", + " # Prompt\n", + " prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw text or table elements. \\\n", + " Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n", + " prompt = PromptTemplate.from_template(prompt_text)\n", + " empty_response = RunnableLambda(\n", + " lambda x: AIMessage(content=\"Error processing document\")\n", + " )\n", + " # Text summary chain\n", + " model = VertexAI(\n", + " temperature=0, model_name=\"gemini-pro\", max_tokens=1024\n", + " ).with_fallbacks([empty_response])\n", + " summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n", + "\n", + " # Initialize empty summaries\n", + " text_summaries = []\n", + " table_summaries = []\n", + "\n", + " # Apply to text if texts are provided and summarization is requested\n", + " if texts and summarize_texts:\n", + " text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n", + " elif texts:\n", + " text_summaries = texts\n", + "\n", + " # Apply to tables if tables are provided\n", + " if tables:\n", + " table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n", + "\n", + " return text_summaries, table_summaries\n", + "\n", + "\n", + "# Get text, table summaries\n", + "text_summaries, table_summaries = generate_text_summaries(\n", + " texts, tables, summarize_texts=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "F0NnyUl48yYb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(text_summaries)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "PeK9bzXv3olF", + "metadata": {}, + "outputs": [], + "source": [ + "import base64\n", + "import os\n", + "\n", + "from langchain_core.messages import HumanMessage\n", + "\n", + "\n", + "def encode_image(image_path):\n", + " \"\"\"Getting the base64 string\"\"\"\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode(\"utf-8\")\n", + "\n", + "\n", + "def image_summarize(img_base64, prompt):\n", + " \"\"\"Make image summary\"\"\"\n", + " model = ChatVertexAI(model=\"gemini-pro-vision\", max_tokens=1024)\n", + "\n", + " msg = model.invoke(\n", + " [\n", + " HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": prompt},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n", + " },\n", + " ]\n", + " )\n", + " ]\n", + " )\n", + " return msg.content\n", + "\n", + "\n", + "def generate_img_summaries(path):\n", + " \"\"\"\n", + " Generate summaries and base64 encoded strings for images\n", + " path: Path to list of .jpg files extracted by Unstructured\n", + " \"\"\"\n", + "\n", + " # Store base64 encoded images\n", + " img_base64_list = []\n", + "\n", + " # Store image summaries\n", + " image_summaries = []\n", + "\n", + " # Prompt\n", + " prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw image. \\\n", + " Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n", + "\n", + " # Apply to images\n", + " for img_file in sorted(os.listdir(path)):\n", + " if img_file.endswith(\".jpg\"):\n", + " img_path = os.path.join(path, img_file)\n", + " base64_image = encode_image(img_path)\n", + " img_base64_list.append(base64_image)\n", + " image_summaries.append(image_summarize(base64_image, prompt))\n", + "\n", + " return img_base64_list, image_summaries\n", + "\n", + "\n", + "# Image summaries\n", + "img_base64_list, image_summaries = generate_img_summaries(\"./cj\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6WDYpDFzjocl", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(image_summaries)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cWyWfZ-XB6cS", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n" + ] + } + ], + "source": [ + "import uuid\n", + "\n", + "from langchain.retrievers.multi_vector import MultiVectorRetriever\n", + "from langchain.storage import InMemoryStore\n", + "from langchain_community.embeddings import VertexAIEmbeddings\n", + "from langchain_community.vectorstores import Chroma\n", + "from langchain_core.documents import Document\n", + "\n", + "\n", + "def create_multi_vector_retriever(\n", + " vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images\n", + "):\n", + " \"\"\"\n", + " Create retriever that indexes summaries, but returns raw images or texts\n", + " \"\"\"\n", + "\n", + " # Initialize the storage layer\n", + " store = InMemoryStore()\n", + " id_key = \"doc_id\"\n", + "\n", + " # Create the multi-vector retriever\n", + " retriever = MultiVectorRetriever(\n", + " vectorstore=vectorstore,\n", + " docstore=store,\n", + " id_key=id_key,\n", + " )\n", + "\n", + " # Helper function to add documents to the vectorstore and docstore\n", + " def add_documents(retriever, doc_summaries, doc_contents):\n", + " doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n", + " summary_docs = [\n", + " Document(page_content=s, metadata={id_key: doc_ids[i]})\n", + " for i, s in enumerate(doc_summaries)\n", + " ]\n", + " retriever.vectorstore.add_documents(summary_docs)\n", + " retriever.docstore.mset(list(zip(doc_ids, doc_contents)))\n", + "\n", + " # Add texts, tables, and images\n", + " # Check that text_summaries is not empty before adding\n", + " if text_summaries:\n", + " add_documents(retriever, text_summaries, texts)\n", + " # Check that table_summaries is not empty before adding\n", + " if table_summaries:\n", + " add_documents(retriever, table_summaries, tables)\n", + " # Check that image_summaries is not empty before adding\n", + " if image_summaries:\n", + " add_documents(retriever, image_summaries, images)\n", + "\n", + " return retriever\n", + "\n", + "\n", + "# The vectorstore to use to index the summaries\n", + "vectorstore = Chroma(\n", + " collection_name=\"mm_rag_cj_blog\",\n", + " embedding_function=VertexAIEmbeddings(model_name=\"textembedding-gecko@latest\"),\n", + ")\n", + "\n", + "# Create retriever\n", + "retriever_multi_vector_img = create_multi_vector_retriever(\n", + " vectorstore,\n", + " text_summaries,\n", + " texts,\n", + " table_summaries,\n", + " tables,\n", + " image_summaries,\n", + " img_base64_list,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "NGDkkMFfCg4j", + "metadata": { + "id": "NGDkkMFfCg4j" + }, + "source": [ + "## Building a RAG" + ] + }, + { + "cell_type": "markdown", + "id": "8TzOcHVsCmBc", + "metadata": { + "id": "8TzOcHVsCmBc" + }, + "source": [ + "Let's build a retriever:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "GlwCErBaCKQW", + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import re\n", + "\n", + "from IPython.display import HTML, display\n", + "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", + "from PIL import Image\n", + "\n", + "\n", + "def plt_img_base64(img_base64):\n", + " \"\"\"Disply base64 encoded string as image\"\"\"\n", + " # Create an HTML img tag with the base64 string as the source\n", + " image_html = f''\n", + " # Display the image by rendering the HTML\n", + " display(HTML(image_html))\n", + "\n", + "\n", + "def looks_like_base64(sb):\n", + " \"\"\"Check if the string looks like base64\"\"\"\n", + " return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n", + "\n", + "\n", + "def is_image_data(b64data):\n", + " \"\"\"\n", + " Check if the base64 data is an image by looking at the start of the data\n", + " \"\"\"\n", + " image_signatures = {\n", + " b\"\\xff\\xd8\\xff\": \"jpg\",\n", + " b\"\\x89\\x50\\x4e\\x47\\x0d\\x0a\\x1a\\x0a\": \"png\",\n", + " b\"\\x47\\x49\\x46\\x38\": \"gif\",\n", + " b\"\\x52\\x49\\x46\\x46\": \"webp\",\n", + " }\n", + " try:\n", + " header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n", + " for sig, format in image_signatures.items():\n", + " if header.startswith(sig):\n", + " return True\n", + " return False\n", + " except Exception:\n", + " return False\n", + "\n", + "\n", + "def resize_base64_image(base64_string, size=(128, 128)):\n", + " \"\"\"\n", + " Resize an image encoded as a Base64 string\n", + " \"\"\"\n", + " # Decode the Base64 string\n", + " img_data = base64.b64decode(base64_string)\n", + " img = Image.open(io.BytesIO(img_data))\n", + "\n", + " # Resize the image\n", + " resized_img = img.resize(size, Image.LANCZOS)\n", + "\n", + " # Save the resized image to a bytes buffer\n", + " buffered = io.BytesIO()\n", + " resized_img.save(buffered, format=img.format)\n", + "\n", + " # Encode the resized image to Base64\n", + " return base64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n", + "\n", + "\n", + "def split_image_text_types(docs):\n", + " \"\"\"\n", + " Split base64-encoded images and texts\n", + " \"\"\"\n", + " b64_images = []\n", + " texts = []\n", + " for doc in docs:\n", + " # Check if the document is of type Document and extract page_content if so\n", + " if isinstance(doc, Document):\n", + " doc = doc.page_content\n", + " if looks_like_base64(doc) and is_image_data(doc):\n", + " doc = resize_base64_image(doc, size=(1300, 600))\n", + " b64_images.append(doc)\n", + " else:\n", + " texts.append(doc)\n", + " if len(b64_images) > 0:\n", + " return {\"images\": b64_images[:1], \"texts\": []}\n", + " return {\"images\": b64_images, \"texts\": texts}\n", + "\n", + "\n", + "def img_prompt_func(data_dict):\n", + " \"\"\"\n", + " Join the context into a single string\n", + " \"\"\"\n", + " formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n", + " messages = []\n", + "\n", + " # Adding the text for analysis\n", + " text_message = {\n", + " \"type\": \"text\",\n", + " \"text\": (\n", + " \"You are financial analyst tasking with providing investment advice.\\n\"\n", + " \"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\\n\"\n", + " \"Use this information to provide investment advice related to the user question. \\n\"\n", + " f\"User-provided question: {data_dict['question']}\\n\\n\"\n", + " \"Text and / or tables:\\n\"\n", + " f\"{formatted_texts}\"\n", + " ),\n", + " }\n", + " messages.append(text_message)\n", + " # Adding image(s) to the messages if present\n", + " if data_dict[\"context\"][\"images\"]:\n", + " for image in data_dict[\"context\"][\"images\"]:\n", + " image_message = {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n", + " }\n", + " messages.append(image_message)\n", + " return [HumanMessage(content=messages)]\n", + "\n", + "\n", + "def multi_modal_rag_chain(retriever):\n", + " \"\"\"\n", + " Multi-modal RAG chain\n", + " \"\"\"\n", + "\n", + " # Multi-modal LLM\n", + " model = ChatVertexAI(temperature=0, model_name=\"gemini-pro-vision\", max_tokens=1024)\n", + "\n", + " # RAG pipeline\n", + " chain = (\n", + " {\n", + " \"context\": retriever | RunnableLambda(split_image_text_types),\n", + " \"question\": RunnablePassthrough(),\n", + " }\n", + " | RunnableLambda(img_prompt_func)\n", + " | model\n", + " | StrOutputParser()\n", + " )\n", + "\n", + " return chain\n", + "\n", + "\n", + "# Create RAG chain\n", + "chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)" + ] + }, + { + "cell_type": "markdown", + "id": "BS4hNKqCCp8u", + "metadata": { + "id": "BS4hNKqCCp8u" + }, + "source": [ + "Let's check that we get images as documents:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "Q7GrwFC_FGwr", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"What are the EV / NTM and NTM rev growth for MongoDB, Cloudflare, and Datadog?\"\n", + "docs = retriever_multi_vector_img.invoke(query, limit=1)\n", + "\n", + "# We get 2 docs\n", + "len(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "unnxB5M_FLCD", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt_img_base64(docs[0])" + ] + }, + { + "cell_type": "markdown", + "id": "YUkGZXqsCtF6", + "metadata": { + "id": "YUkGZXqsCtF6" + }, + "source": [ + "And let's run our RAG on the same query:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "LsPTehdK-T-_", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' | Company | EV / NTM Rev | NTM Rev Growth |\\n|---|---|---|\\n| MongoDB | 14.6x | 17% |\\n| Cloudflare | 13.4x | 28% |\\n| Datadog | 13.1x | 19% |'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain_multimodal_rag.invoke(query)" + ] + }, + { + "cell_type": "markdown", + "id": "XpLQB6dEfQX-", + "metadata": { + "id": "XpLQB6dEfQX-" + }, + "source": [ + "As we can see, the model was able to figure out the the right values that are relevant to answer the question." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}