diff --git a/app.py b/app.py index 2394ec0ce3b1269fd19e7ab49ee231d5ce14d40e..fbb47e6591b88c69e3efd1994cf1fd99b9e1005a 100644 --- a/app.py +++ b/app.py @@ -1,26 +1,20 @@ -import os -import wandb - -wandb.login(relogin=True, key=os.getenv("WANDB_API_KEY")) - - import streamlit as st -import weave -from medrag_multi_modal.assistant import ( - FigureAnnotatorFromPageImage, - LLMClient, - MedQAAssistant, -) -from medrag_multi_modal.assistant.llm_client import ( - GOOGLE_MODELS, - MISTRAL_MODELS, - OPENAI_MODELS, +from medrag_multi_modal.assistant import LLMClient, MedQAAssistant +from medrag_multi_modal.retrieval.text_retrieval import ( + BM25sRetriever, + ContrieverRetriever, + MedCPTRetriever, + NVEmbed2Retriever, ) -from medrag_multi_modal.retrieval import MedCPTRetriever # Define constants -ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS +ALL_AVAILABLE_MODELS = [ + "gemini-1.5-flash-latest", + "gemini-1.5-pro-latest", + "gpt-4o", + "gpt-4o-mini", +] # Sidebar for configuration settings st.sidebar.title("Configuration Settings") @@ -30,68 +24,91 @@ project_name = st.sidebar.text_input( placeholder="wandb project name", help="format: wandb_username/wandb_project_name", ) -chunk_dataset_name = st.sidebar.text_input( - label="Text Chunk WandB Dataset Name", - value="grays-anatomy-chunks:v0", - placeholder="wandb dataset name", - help="format: wandb_dataset_name:version", +chunk_dataset_id = st.sidebar.selectbox( + label="Chunk Dataset ID", + options=["ashwiniai/medrag-text-corpus-chunks"], ) -index_artifact_address = st.sidebar.text_input( - label="WandB Index Artifact Address", - value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", - placeholder="wandb artifact address", - help="format: wandb_username/wandb_project_name/wandb_artifact_name:version", +llm_model = st.sidebar.selectbox( + label="LLM Model", + options=ALL_AVAILABLE_MODELS, ) -image_artifact_address = st.sidebar.text_input( - label="WandB Image Artifact Address", - value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", - placeholder="wandb artifact address", - help="format: wandb_username/wandb_project_name/wandb_artifact_name:version", +top_k_chunks_for_query = st.sidebar.slider( + label="Top K Chunks for Query", + min_value=1, + max_value=20, + value=5, ) -llm_client_model_name = st.sidebar.selectbox( - label="LLM Client Model Name", - options=ALL_AVAILABLE_MODELS, - index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"), - help="select a model from the list", +top_k_chunks_for_options = st.sidebar.slider( + label="Top K Chunks for Options", + min_value=1, + max_value=20, + value=3, ) -figure_extraction_model_name = st.sidebar.selectbox( - label="Figure Extraction Model Name", - options=ALL_AVAILABLE_MODELS, - index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"), - help="select a model from the list", +rely_only_on_context = st.sidebar.checkbox( + label="Rely Only on Context", + value=False, ) -structured_output_model_name = st.sidebar.selectbox( - label="Structured Output Model Name", - options=ALL_AVAILABLE_MODELS, - index=ALL_AVAILABLE_MODELS.index("gpt-4o"), - help="select a model from the list", +retriever_type = st.sidebar.selectbox( + label="Retriever Type", + options=[ + "", + "BM25S", + "Contriever", + "MedCPT", + "NV-Embed-v2", + ], ) -# Streamlit app layout -st.title("MedQA Assistant App") +if retriever_type != "": -# Initialize Weave -weave.init(project_name=project_name) + llm_model = LLMClient(model_name=llm_model) -# Initialize clients and assistants -llm_client = LLMClient(model_name=llm_client_model_name) -retriever = MedCPTRetriever.from_wandb_artifact( - chunk_dataset_name=chunk_dataset_name, - index_artifact_address=index_artifact_address, -) -figure_annotator = FigureAnnotatorFromPageImage( - figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name), - structured_output_llm_client=LLMClient(model_name=structured_output_model_name), - image_artifact_address=image_artifact_address, -) -medqa_assistant = MedQAAssistant( - llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator -) + retriever = None + + if retriever_type == "BM25S": + retriever = BM25sRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" + ) + elif retriever_type == "Contriever": + retriever = ContrieverRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", + chunk_dataset_id=chunk_dataset_id, + ) + elif retriever_type == "MedCPT": + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset_id=chunk_dataset_id, + ) + elif retriever_type == "NV-Embed-v2": + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset_id=chunk_dataset_id, + ) + + medqa_assistant = MedQAAssistant( + llm_client=llm_model, + retriever=retriever, + top_k_chunks_for_query=top_k_chunks_for_query, + top_k_chunks_for_options=top_k_chunks_for_options, + ) -query = st.chat_input("Enter your question here") -if query: - with st.chat_message("user"): - st.markdown(query) - response = medqa_assistant.predict(query=query) with st.chat_message("assistant"): - st.markdown(response) + st.markdown( + """ +Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences. +I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge. + +**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions. +Please consult a medical professional for any medical advice. + +In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal). + """, + unsafe_allow_html=True, + ) + query = st.chat_input("Enter your question here") + if query: + with st.chat_message("user"): + st.markdown(query) + response = medqa_assistant.predict(query=query) + with st.chat_message("assistant"): + st.markdown(response.response) diff --git a/docs/app.md b/docs/app.md deleted file mode 100644 index 9efa67f6908da956fb3ca69e74c5a6df1041dcdf..0000000000000000000000000000000000000000 --- a/docs/app.md +++ /dev/null @@ -1,61 +0,0 @@ -# MedQA Assistant App - -The MedQA Assistant App is a Streamlit-based application designed to provide a chat interface for medical question answering. It leverages advanced language models (LLMs) and retrieval augmented generation (RAG) techniques to deliver accurate and informative responses to medical queries. - -## Features - -- **Interactive Chat Interface**: Engage with the app through a user-friendly chat interface. -- **Configurable Settings**: Customize model selection and data sources via the sidebar. -- **Retrieval-Augmented Generation**: Ensures precise and contextually relevant responses. -- **Figure Annotation Capabilities**: Extracts and annotates figures from medical texts. - -## Usage - -1. Install the package using: - ```bash - uv pip install . - ``` -1. **Launch the App**: Start the application using Streamlit: - ```bash - medrag run - ``` -2. **Configure Settings**: Adjust configuration settings in the sidebar to suit your needs. -3. **Ask a Question**: Enter your medical question in the chat input field. -4. **Receive a Response**: Get a detailed answer from the MedQA Assistant. - -## Configuration - -The app allows users to customize various settings through the sidebar: - -- **Project Name**: Specify the WandB project name. -- **Text Chunk WandB Dataset Name**: Define the dataset containing text chunks. -- **WandB Index Artifact Address**: Provide the address of the index artifact. -- **WandB Image Artifact Address**: Provide the address of the image artifact. -- **LLM Client Model Name**: Choose a language model for generating responses. -- **Figure Extraction Model Name**: Select a model for extracting figures from images. -- **Structured Output Model Name**: Choose a model for generating structured outputs. - -## Technical Details - -The app is built using the following components: - -- **Streamlit**: For the user interface. -- **Weave**: For project initialization and artifact management. -- **MedQAAssistant**: For processing queries and generating responses. -- **LLMClient**: For interacting with language models. -- **MedCPTRetriever**: For retrieving relevant text chunks. -- **FigureAnnotatorFromPageImage**: For annotating figures in medical texts. - -## Development and Deployment - -- **Environment Setup**: Ensure all dependencies are installed as per the `pyproject.toml`. -- **Running the App**: Use Streamlit to run the app locally. -- **Deployment**: coming soon... - -## Additional Resources - -For more detailed information on the components and their usage, refer to the following documentation sections: - -- [MedQA Assistant](/assistant/medqa_assistant) -- [LLM Client](/assistant/llm_client) -- [Figure Annotation](/assistant/figure_annotation) diff --git a/docs/assistant/figure_annotation.md b/docs/assistant/figure_annotation.md deleted file mode 100644 index 629198317146d6c028957ed2c75fd612a82dc712..0000000000000000000000000000000000000000 --- a/docs/assistant/figure_annotation.md +++ /dev/null @@ -1,3 +0,0 @@ -# Figure Annotation - -::: medrag_multi_modal.assistant.figure_annotation \ No newline at end of file diff --git a/docs/assistant/llm_client.md b/docs/assistant/llm_client.md deleted file mode 100644 index e7720b19172b6cefa6599ec3c6f384c2fcc93442..0000000000000000000000000000000000000000 --- a/docs/assistant/llm_client.md +++ /dev/null @@ -1,3 +0,0 @@ -# LLM Client - -::: medrag_multi_modal.assistant.llm_client \ No newline at end of file diff --git a/docs/assistant/medqa_assistant.md b/docs/assistant/medqa_assistant.md deleted file mode 100644 index e9fd06fa4065d4fcb0f69e2e819da042fd4b7b9d..0000000000000000000000000000000000000000 --- a/docs/assistant/medqa_assistant.md +++ /dev/null @@ -1,3 +0,0 @@ -# MedQA Assistant - -::: medrag_multi_modal.assistant.medqa_assistant diff --git a/docs/chunking.md b/docs/chunking.md deleted file mode 100644 index b63900feca9b2bae6761b67317f5849c560049e0..0000000000000000000000000000000000000000 --- a/docs/chunking.md +++ /dev/null @@ -1,3 +0,0 @@ -# Chunking - -::: medrag_multi_modal.semantic_chunking \ No newline at end of file diff --git a/docs/document_loader/image_loader/base_img_loader.md b/docs/document_loader/image_loader/base_img_loader.md deleted file mode 100644 index 6cff185f81bf2c6594e5f75e89b3f74f833860dc..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/base_img_loader.md +++ /dev/null @@ -1,3 +0,0 @@ -## Load images from PDF files - -::: medrag_multi_modal.document_loader.image_loader.base_img_loader diff --git a/docs/document_loader/image_loader/fitzpil_img_loader.md b/docs/document_loader/image_loader/fitzpil_img_loader.md deleted file mode 100644 index 7167b490cae8b84ad66806d00912868360a22c46..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/fitzpil_img_loader.md +++ /dev/null @@ -1,22 +0,0 @@ -# Load images from PDF files (using Fitz & PIL) - -??? note "Note" - **Underlying Library:** `fitz` & `pillow` - - Extract images from PDF files using `fitz` and `pillow`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader - ``` - - For more details, please refer to the sources below. - - **Sources:** - - - [Docs](https://pymupdf.readthedocs.io/en/latest/intro.html) - - [GitHub](https://github.com/kastman/fitz) - - [PyPI](https://pypi.org/project/fitz/) - - [PyPI](https://pypi.org/project/pillow/) - -::: medrag_multi_modal.document_loader.image_loader.fitzpil_img_loader diff --git a/docs/document_loader/image_loader/marker_img_loader.md b/docs/document_loader/image_loader/marker_img_loader.md deleted file mode 100644 index 392d89b6a333340bed5a6e152b6d2fadfb12bb87..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/marker_img_loader.md +++ /dev/null @@ -1,21 +0,0 @@ -# Load images from PDF files (using Marker) - -??? note "Note" - **Underlying Library:** `marker-pdf` - - Extract images from PDF files using `marker-pdf`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader - ``` - - For details, please refer to the sources below. - - **Sources:** - - - [DataLab](https://www.datalab.to) - - [GitHub](https://github.com/VikParuchuri/marker) - - [PyPI](https://pypi.org/project/marker-pdf/) - -::: medrag_multi_modal.document_loader.image_loader.marker_img_loader diff --git a/docs/document_loader/image_loader/pdf2image_img_loader.md b/docs/document_loader/image_loader/pdf2image_img_loader.md deleted file mode 100644 index fb721fa725ffe1a3a4aa83e38746dedb5ac05e8f..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/pdf2image_img_loader.md +++ /dev/null @@ -1,26 +0,0 @@ -# Load images from PDF files (using PDF2Image) - -!!! danger "Warning" - Unlike other image extraction methods in `document_loader.image_loader`, this loader does not extract embedded images from the PDF. - Instead, it creates a snapshot image version of each selected page from the PDF. - -??? note "Note" - **Underlying Library:** `pdf2image` - - Extract images from PDF files using `pdf2image`. - - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader - ``` - - For details and available `**kwargs`, please refer to the sources below. - - **Sources:** - - - [DataLab](https://www.datalab.to) - - [GitHub](https://github.com/VikParuchuri/marker) - - [PyPI](https://pypi.org/project/marker-pdf/) - -::: medrag_multi_modal.document_loader.image_loader.pdf2image_img_loader diff --git a/docs/document_loader/image_loader/pdfplumber_img_loader.md b/docs/document_loader/image_loader/pdfplumber_img_loader.md deleted file mode 100644 index 1f892e967eaf71d9bbb211f841d9a2ddd12ba75c..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/pdfplumber_img_loader.md +++ /dev/null @@ -1,22 +0,0 @@ -# Load images from PDF files (using PDFPlumber) - -??? note "Note" - **Underlying Library:** `pdfplumber` - - Extract images from PDF files using `pdfplumber`. - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader - ``` - - For details, please refer to the sources below. - - **Sources:** - - - [GitHub](https://github.com/jsvine/pdfplumber) - - [PyPI](https://pypi.org/project/pdfplumber/) - -::: medrag_multi_modal.document_loader.image_loader.pdfplumber_img_loader diff --git a/docs/document_loader/image_loader/pymupdf_img_loader.md b/docs/document_loader/image_loader/pymupdf_img_loader.md deleted file mode 100644 index 968e11cc4629d8385ebadbefc050672337228c0c..0000000000000000000000000000000000000000 --- a/docs/document_loader/image_loader/pymupdf_img_loader.md +++ /dev/null @@ -1,23 +0,0 @@ -# Load images from PDF files (using PyMuPDF) - -??? note "Note" - **Underlying Library:** `pymupdf` - - PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents. - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader - ``` - - For details, please refer to the sources below. - - **Sources:** - - - [Docs](https://pymupdf.readthedocs.io/en/latest/) - - [GitHub](https://github.com/pymupdf/PyMuPDF) - - [PyPI](https://pypi.org/project/PyMuPDF/) - -::: medrag_multi_modal.document_loader.image_loader.pymupdf_img_loader diff --git a/docs/document_loader/text_loader/base_text_loader.md b/docs/document_loader/text_loader/base_text_loader.md deleted file mode 100644 index 221524ff37b234566a490ecc8fcc5a148daaa69c..0000000000000000000000000000000000000000 --- a/docs/document_loader/text_loader/base_text_loader.md +++ /dev/null @@ -1,3 +0,0 @@ -## Load text from PDF files - -::: medrag_multi_modal.document_loader.text_loader.base_text_loader diff --git a/docs/document_loader/text_loader/marker_text_loader.md b/docs/document_loader/text_loader/marker_text_loader.md deleted file mode 100644 index ec1afab992c30b9a00d04ce5c25e6118152d1dbc..0000000000000000000000000000000000000000 --- a/docs/document_loader/text_loader/marker_text_loader.md +++ /dev/null @@ -1,23 +0,0 @@ -## Load text from PDF files (using Marker) - -??? note "Note" - **Underlying Library:** `marker-pdf` - - Convert PDF to markdown quickly and accurately using a pipeline of deep learning models. - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.text_loader import MarkerTextLoader - ``` - - For details and available `**kwargs`, please refer to the sources below. - - **Sources:** - - - [DataLab](https://www.datalab.to) - - [GitHub](https://github.com/VikParuchuri/marker) - - [PyPI](https://pypi.org/project/marker-pdf/) - -::: medrag_multi_modal.document_loader.text_loader.marker_text_loader diff --git a/docs/document_loader/text_loader/pdfplumber_text_loader.md b/docs/document_loader/text_loader/pdfplumber_text_loader.md deleted file mode 100644 index 53c137cca9738fc6925cfc35e1492f579e3f2974..0000000000000000000000000000000000000000 --- a/docs/document_loader/text_loader/pdfplumber_text_loader.md +++ /dev/null @@ -1,22 +0,0 @@ -## Load text from PDF files (using PDFPlumber) - -??? note "Note" - **Underlying Library:** `pdfplumber` - - Plumb a PDF for detailed information about each char, rectangle, line, et cetera — and easily extract text and tables. - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.text_loader import PDFPlumberTextLoader - ``` - - For details and available `**kwargs`, please refer to the sources below. - - **Sources:** - - - [GitHub](https://github.com/jsvine/pdfplumber) - - [PyPI](https://pypi.org/project/pdfplumber/) - -::: medrag_multi_modal.document_loader.text_loader.pdfplumber_text_loader diff --git a/docs/document_loader/text_loader/pymupdf4llm_text_loader.md b/docs/document_loader/text_loader/pymupdf4llm_text_loader.md deleted file mode 100644 index a54e2c95a05f1f31b128f01c12ed1c9a2aa95509..0000000000000000000000000000000000000000 --- a/docs/document_loader/text_loader/pymupdf4llm_text_loader.md +++ /dev/null @@ -1,23 +0,0 @@ -## Load text from PDF files (using PyMuPDF4LLM) - -??? note "Note" - **Underlying Library:** `pymupdf4llm` - - PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents. - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.text_loader import PyMuPDF4LLMTextLoader - ``` - - For details and available `**kwargs`, please refer to the sources below. - - **Sources:** - - - [Docs](https://pymupdf.readthedocs.io/en/latest/pymupdf4llm/) - - [GitHub](https://github.com/pymupdf/PyMuPDF) - - [PyPI](https://pypi.org/project/pymupdf4llm/) - -::: medrag_multi_modal.document_loader.text_loader.pymupdf4llm_text_loader diff --git a/docs/document_loader/text_loader/pypdf2_text_loader.md b/docs/document_loader/text_loader/pypdf2_text_loader.md deleted file mode 100644 index b5db21b90009832b5d049ad9efef27d6da6f3e43..0000000000000000000000000000000000000000 --- a/docs/document_loader/text_loader/pypdf2_text_loader.md +++ /dev/null @@ -1,23 +0,0 @@ -## Load text from PDF files (using PyPDF2) - -??? note "Note" - **Underlying Library:** `pypdf2` - - A pure-python PDF library capable of splitting, merging, cropping, and transforming the pages of PDF files - - You can interact with the underlying library and fine-tune the outputs via `**kwargs`. - - Use it in our library with: - ```python - from medrag_multi_modal.document_loader.text_loader import PyPDF2TextLoader - ``` - - For details and available `**kwargs`, please refer to the sources below. - - **Sources:** - - - [Docs](https://pypdf2.readthedocs.io/en/3.x/) - - [GitHub](https://github.com/py-pdf/pypdf) - - [PyPI](https://pypi.org/project/PyPDF2/) - -::: medrag_multi_modal.document_loader.text_loader.pypdf2_text_loader diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index daa05d2ac273b70725708fffa9029437bb2f704f..0000000000000000000000000000000000000000 --- a/docs/index.md +++ /dev/null @@ -1,40 +0,0 @@ -# MedRAG Multi-Modal - -Multi-modal RAG for medical docmain. - -## Installation - -### For Development - -For MacOS, you need to run - -```bash -brew install poppler -``` - -For Debian/Ubuntu, you need to run - -```bash -sudo apt-get install -y poppler-utils -``` - -Then, you can install the dependencies using uv in the virtual environment `.venv` using - -```bash -git clone https://github.com/soumik12345/medrag-multi-modal -cd medrag-multi-modal -pip install -U pip uv -uv sync -``` - -After this, you need to activate the virtual environment using - -```bash -source .venv/bin/activate -``` - -In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using - -```bash -uv pip install flash-attn --no-build-isolation -``` diff --git a/docs/installation/development.md b/docs/installation/development.md deleted file mode 100644 index 1c40ea4970ca771c0ad1e7d6e17a7ee45dd7b051..0000000000000000000000000000000000000000 --- a/docs/installation/development.md +++ /dev/null @@ -1,40 +0,0 @@ -# Setting up the development environment - -## Install Poppler - -For MacOS, you need to run - -```bash -brew install poppler -``` - -For Debian/Ubuntu, you need to run - -```bash -sudo apt-get install -y poppler-utils -``` - -## Install the dependencies - -Then, you can install the dependencies using uv in the virtual environment `.venv` using - -```bash -git clone https://github.com/soumik12345/medrag-multi-modal -cd medrag-multi-modal -pip install -U pip uv -uv sync -``` - -After this, you need to activate the virtual environment using - -```bash -source .venv/bin/activate -``` - -## [Optional] Install Flash Attention - -In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using - -```bash -uv pip install flash-attn --no-build-isolation -``` \ No newline at end of file diff --git a/docs/installation/install.md b/docs/installation/install.md deleted file mode 100644 index 250e0ede9709f476cc966e6fd7e31b6a65bb8f1e..0000000000000000000000000000000000000000 --- a/docs/installation/install.md +++ /dev/null @@ -1,9 +0,0 @@ -# Installation - -You just need to clone the repository and run the install.sh script - -```bash -git clone https://github.com/soumik12345/medrag-multi-modal -cd medrag-multi-modal -sh install.sh -``` diff --git a/docs/retreival/bm25s.md b/docs/retreival/bm25s.md deleted file mode 100644 index f382558babf5d3b0c0a42df4e45e1e8420506b90..0000000000000000000000000000000000000000 --- a/docs/retreival/bm25s.md +++ /dev/null @@ -1,3 +0,0 @@ -# BM25-Sparse Retrieval - -::: medrag_multi_modal.retrieval.bm25s_retrieval \ No newline at end of file diff --git a/docs/retreival/colpali.md b/docs/retreival/colpali.md deleted file mode 100644 index a602f9aeece1159efa12c127dd7ea3ba37a2afa6..0000000000000000000000000000000000000000 --- a/docs/retreival/colpali.md +++ /dev/null @@ -1,3 +0,0 @@ -# ColPali Retrieval - -::: medrag_multi_modal.retrieval.colpali_retrieval \ No newline at end of file diff --git a/docs/retreival/contriever.md b/docs/retreival/contriever.md deleted file mode 100644 index 986fdc2866c9566d4dd0c66bc6f49b693132b0e5..0000000000000000000000000000000000000000 --- a/docs/retreival/contriever.md +++ /dev/null @@ -1,3 +0,0 @@ -# Contriever Retrieval - -::: medrag_multi_modal.retrieval.contriever_retrieval \ No newline at end of file diff --git a/docs/retreival/medcpt.md b/docs/retreival/medcpt.md deleted file mode 100644 index ea157fa842e57395c68c85385f5b58c11abaff1e..0000000000000000000000000000000000000000 --- a/docs/retreival/medcpt.md +++ /dev/null @@ -1,3 +0,0 @@ -# MedCPT Retrieval - -::: medrag_multi_modal.retrieval.medcpt_retrieval \ No newline at end of file diff --git a/docs/retreival/nv_embed_2.md b/docs/retreival/nv_embed_2.md deleted file mode 100644 index cb446ceaa025d9faffba24875d8deeec84d39c53..0000000000000000000000000000000000000000 --- a/docs/retreival/nv_embed_2.md +++ /dev/null @@ -1,3 +0,0 @@ -# NV-Embed-v2 Retrieval - -::: medrag_multi_modal.retrieval.nv_embed_2 \ No newline at end of file diff --git a/install.sh b/install.sh deleted file mode 100644 index 5b3c1c9e531413050632ce113d68739cd84aa703..0000000000000000000000000000000000000000 --- a/install.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -OS_TYPE=$(uname -s) - -if [ "$OS_TYPE" = "Darwin" ]; then - echo "Detected macOS." - brew install poppler -elif [ "$OS_TYPE" = "Linux" ]; then - if [ -f /etc/os-release ]; then - . /etc/os-release - if [ "$ID" = "ubuntu" ] || [ "$ID" = "debian" ]; then - echo "Detected Ubuntu/Debian." - sudo apt-get update - sudo apt-get install -y poppler-utils - else - echo "Unsupported Linux distribution: $ID" - exit 1 - fi - else - echo "Cannot detect Linux distribution." - exit 1 - fi -else - echo "Unsupported OS: $OS_TYPE" - exit 1 -fi - -git clone https://github.com/soumik12345/medrag-multi-modal -cd medrag-multi-modal -pip install -U .[core] diff --git a/medrag_multi_modal/assistant/figure_annotation.py b/medrag_multi_modal/assistant/figure_annotation.py index 614623049ee2b79d244d7046baca22459a2b5f94..fb3838004688355f117e922720b2c8558917e0e0 100644 --- a/medrag_multi_modal/assistant/figure_annotation.py +++ b/medrag_multi_modal/assistant/figure_annotation.py @@ -5,19 +5,10 @@ from typing import Optional, Union import cv2 import weave from PIL import Image -from pydantic import BaseModel -from ..utils import get_wandb_artifact, read_jsonl_file -from .llm_client import LLMClient - - -class FigureAnnotation(BaseModel): - figure_id: str - figure_description: str - - -class FigureAnnotations(BaseModel): - annotations: list[FigureAnnotation] +from medrag_multi_modal.assistant.llm_client import LLMClient +from medrag_multi_modal.assistant.schema import FigureAnnotations +from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file class FigureAnnotatorFromPageImage(weave.Model): @@ -108,7 +99,7 @@ Here are some clues you need to follow: ) @weave.op() - def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]: + def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]: """ Predicts figure annotations for a specific page in a document. diff --git a/medrag_multi_modal/assistant/llm_client.py b/medrag_multi_modal/assistant/llm_client.py index e4c0c7e4282b87a0a26842019acde00235ab692a..ee8ff5a637a5b342ac1689dc85fee903dff64b27 100644 --- a/medrag_multi_modal/assistant/llm_client.py +++ b/medrag_multi_modal/assistant/llm_client.py @@ -1,3 +1,4 @@ +import json import os from enum import Enum from typing import Any, Optional, Union @@ -93,6 +94,7 @@ class LLMClient(weave.Model): schema: Optional[Any] = None, ) -> Union[str, Any]: import google.generativeai as genai + from google.generativeai.types import HarmBlockThreshold, HarmCategory system_prompt = ( [system_prompt] if isinstance(system_prompt, str) else system_prompt @@ -100,18 +102,25 @@ class LLMClient(weave.Model): user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt genai.configure(api_key=os.environ.get("GOOGLE_API_KEY")) - model = genai.GenerativeModel(self.model_name) + model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt) generation_config = ( None if schema is None else genai.GenerationConfig( - response_mime_type="application/json", response_schema=list[schema] + response_mime_type="application/json", response_schema=schema ) ) response = model.generate_content( - system_prompt + user_prompt, generation_config=generation_config + user_prompt, + generation_config=generation_config, + # This is necessary in order to answer questions about anatomy, sexual diseases, + # medical devices, medicines, etc. + safety_settings={ + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + }, ) - return response.text if schema is None else response + return response.text if schema is None else json.loads(response.text) @weave.op() def execute_mistral_sdk( @@ -146,14 +155,13 @@ class LLMClient(weave.Model): client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY")) client = instructor.from_mistral(client) if schema is not None else client - response = ( - client.chat.complete(model=self.model_name, messages=messages) - if schema is None - else client.messages.create( - response_model=schema, messages=messages, temperature=0 + if schema is None: + raise NotImplementedError( + "Mistral does not support structured output using a schema" ) - ) - return response.choices[0].message.content + else: + response = client.chat.complete(model=self.model_name, messages=messages) + return response.choices[0].message.content @weave.op() def execute_openai_sdk( diff --git a/medrag_multi_modal/assistant/medqa_assistant.py b/medrag_multi_modal/assistant/medqa_assistant.py index be336f05aa8bb59c3aa67521d2f78d175de00859..95cc5e539958e9046aee6f69045f85bb86f1cf37 100644 --- a/medrag_multi_modal/assistant/medqa_assistant.py +++ b/medrag_multi_modal/assistant/medqa_assistant.py @@ -1,8 +1,16 @@ +from typing import Optional + import weave -from ..retrieval import SimilarityMetric -from .figure_annotation import FigureAnnotatorFromPageImage -from .llm_client import LLMClient +from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage +from medrag_multi_modal.assistant.llm_client import LLMClient +from medrag_multi_modal.assistant.schema import ( + MedQACitation, + MedQAMCQResponse, + MedQAResponse, +) +from medrag_multi_modal.retrieval.common import SimilarityMetric +from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever class MedQAAssistant(weave.Model): @@ -47,39 +55,68 @@ class MedQAAssistant(weave.Model): llm_client (LLMClient): The language model client used to generate responses. retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. - top_k_chunks (int): The number of top chunks to retrieve based on similarity metric. + top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query. + top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options. retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. """ llm_client: LLMClient retriever: weave.Model - figure_annotator: FigureAnnotatorFromPageImage - top_k_chunks: int = 2 + figure_annotator: Optional[FigureAnnotatorFromPageImage] = None + top_k_chunks_for_query: int = 2 + top_k_chunks_for_options: int = 2 + rely_only_on_context: bool = True retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE @weave.op() - def predict(self, query: str) -> str: + def retrieve_chunks_for_query(self, query: str) -> list[dict]: + retriever_kwargs = {"top_k": self.top_k_chunks_for_query} + if not isinstance(self.retriever, BM25sRetriever): + retriever_kwargs["metric"] = self.retrieval_similarity_metric + return self.retriever.predict(query, **retriever_kwargs) + + @weave.op() + def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]: + retriever_kwargs = {"top_k": self.top_k_chunks_for_options} + if not isinstance(self.retriever, BM25sRetriever): + retriever_kwargs["metric"] = self.retrieval_similarity_metric + retrieved_chunks = [] + for option in options: + retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs) + return retrieved_chunks + + @weave.op() + def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse: """ Generates a response to a medical query by retrieving relevant text chunks and figure descriptions from a medical document and using a language model to generate the final response. This function performs the following steps: - 1. Retrieves relevant text chunks from the medical document based on the query using the retriever model. + 1. Retrieves relevant text chunks from the medical document based on the query and any provided options + using the retriever model. 2. Extracts the text and page indices from the retrieved chunks. 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. - 4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions. - 5. Uses the language model client to generate a response based on the constructed prompts. - 6. Appends the source information (page numbers) to the generated response. + 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks, + and figure descriptions. + 5. Uses the language model client to generate a response based on the constructed prompts, either choosing + from provided options or generating a free-form response. + 6. Returns the generated response, which includes the answer and explanation if options were provided. + + The function can operate in two modes: + - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice + - Free response: When no options are provided, it generates a comprehensive response based on the context Args: query (str): The medical query to be answered. + options (Optional[list[str]]): The list of options to choose from. + rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation. Returns: - str: The generated response to the query, including source information. + MedQAResponse: The generated response to the query, including source information. """ - retrieved_chunks = self.retriever.predict( - query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric - ) + retrieved_chunks = self.retrieve_chunks_for_query(query) + options = options or [] + retrieved_chunks += self.retrieve_chunks_for_options(options) retrieved_chunk_texts = [] page_indices = set() @@ -88,21 +125,50 @@ class MedQAAssistant(weave.Model): page_indices.add(int(chunk["page_idx"])) figure_descriptions = [] - for page_idx in page_indices: - figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ - page_idx - ] - figure_descriptions += [ - item["figure_description"] for item in figure_annotations - ] - - system_prompt = """ - You are an expert in medical science. You are given a query and a list of chunks from a medical document. + if self.figure_annotator is not None: + for page_idx in page_indices: + figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ + page_idx + ] + figure_descriptions += [ + item["figure_description"] for item in figure_annotations + ] + + system_prompt = """You are an expert in medical science. You are given a question +and a list of excerpts from various medical documents. + """ + query = f"""# Question +{query} """ + + if len(options) > 0: + system_prompt += """\nYou are also given a list of options to choose your answer from. +You are supposed to choose the best possible option based on the context provided. You should also +explain your answer to justify why you chose that option. +""" + query += "## Options\n" + for option in options: + query += f"- {option}\n" + else: + system_prompt += "\nYou are supposed to answer the question based on the context provided." + + if self.rely_only_on_context: + system_prompt += """\n\nYou are only allowed to use the context provided to answer the question. +You are not allowed to use any external knowledge to answer the question. +""" + response = self.llm_client.predict( system_prompt=system_prompt, user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], + schema=MedQAMCQResponse if len(options) > 0 else None, ) - page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices]) - response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy" - return response + + # TODO: Add figure citations + # TODO: Add source document name from retrieved chunks as citations + citations = [] + for page_idx in page_indices: + citations.append( + MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy") + ) + + return MedQAResponse(response=response, citations=citations) diff --git a/medrag_multi_modal/assistant/schema.py b/medrag_multi_modal/assistant/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4ca6cfb280436a70aa04a50e7d48bd6dbc08f1 --- /dev/null +++ b/medrag_multi_modal/assistant/schema.py @@ -0,0 +1,27 @@ +from typing import Union + +from pydantic import BaseModel + + +class FigureAnnotation(BaseModel): + figure_id: str + figure_description: str + + +class FigureAnnotations(BaseModel): + annotations: list[FigureAnnotation] + + +class MedQAMCQResponse(BaseModel): + answer: str + explanation: str + + +class MedQACitation(BaseModel): + page_number: int + document_name: str + + +class MedQAResponse(BaseModel): + response: Union[str, MedQAMCQResponse] + citations: list[MedQACitation] diff --git a/medrag_multi_modal/cli.py b/medrag_multi_modal/cli.py index 62b1d7e76e9a68d68923c1e8a8a28de0daa98173..419362447c14f4a3994fd7563818603bd271e24a 100644 --- a/medrag_multi_modal/cli.py +++ b/medrag_multi_modal/cli.py @@ -1,16 +1,67 @@ import argparse +import os import subprocess import sys def main(): parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI") - parser.add_argument("command", choices=["run"], help="Command to execute") + subparsers = parser.add_subparsers(dest="command", required=True) + + # Run subcommand + run_parser = subparsers.add_parser("run", help="Run the Streamlit application") + run_parser.add_argument( + "--port", type=int, default=8501, help="Port to run Streamlit on" + ) + + # Evaluate subcommand + eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests") + eval_parser.add_argument( + "--test-file", + default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"), + help="Path to test file", + ) + eval_parser.add_argument( + "--test-case", + type=str, + help="Only run tests which match the given substring expression", + ) + eval_parser.add_argument( + "--model-name", + type=str, + default="gemini-1.5-flash", + help="Model name to use for evaluation", + ) + args = parser.parse_args() if args.command == "run": - # Assuming your Streamlit app is in app.py - subprocess.run([sys.executable, "-m", "streamlit", "run", "app.py"]) + subprocess.run( + [ + sys.executable, + "-m", + "streamlit", + "run", + "app.py", + "--server.port", + str(args.port), + ] + ) + + elif args.command == "evaluate": + test_file = ( + args.test_file + "::" + args.test_case if args.test_case else args.test_file + ) + cmd = [ + sys.executable, + "-m", + "pytest", + "-s", + test_file, + "-v", + f"--model-name={args.model_name}", + ] + subprocess.run(cmd) if __name__ == "__main__": diff --git a/medrag_multi_modal/document_loader/image_loader/base_img_loader.py b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py index 30c9fb68a804d1bcd39a45ff920182122a96224a..bdc99e99fb6102812b4106e3a5ae6eae73a1836e 100644 --- a/medrag_multi_modal/document_loader/image_loader/base_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py @@ -1,11 +1,21 @@ import asyncio import os from abc import abstractmethod +from glob import glob from typing import Dict, List, Optional +import huggingface_hub import jsonlines import rich -import wandb +from datasets import ( + Dataset, + Features, + Image, + Sequence, + Value, + concatenate_datasets, + load_dataset, +) from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( BaseTextLoader, @@ -36,14 +46,72 @@ class BaseImageLoader(BaseTextLoader): """ pass + def save_as_dataset( + self, + start_page: int, + end_page: int, + image_save_dir: str, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + ): + features = Features( + { + "page_image": Image(decode=True), + "page_figure_images": Sequence(Image(decode=True)), + "document_name": Value(dtype="string"), + "page_idx": Value(dtype="int32"), + } + ) + + all_examples = [] + for page_idx in range(start_page, end_page): + page_image_file_paths = glob( + os.path.join(image_save_dir, f"page{page_idx}*.png") + ) + if len(page_image_file_paths) > 0: + page_image_path = page_image_file_paths[0] + figure_image_paths = [ + image_file_path + for image_file_path in glob( + os.path.join(image_save_dir, f"page{page_idx}*_fig*.png") + ) + ] + + example = { + "page_image": page_image_path, + "page_figure_images": figure_image_paths, + "document_name": self.document_name, + "page_idx": page_idx, + } + all_examples.append(example) + + dataset = Dataset.from_list(all_examples, features=features) + + if dataset_repo_id: + if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"): + if not overwrite_dataset: + dataset = concatenate_datasets( + [dataset, load_dataset(dataset_repo_id)["corpus"]] + ) + + dataset.push_to_hub(dataset_repo_id, split="corpus") + + return dataset + + def cleanup_image_dir(self, image_save_dir: str = "./images"): + for file in os.listdir(image_save_dir): + file_path = os.path.join(image_save_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + async def load_data( self, start_page: Optional[int] = None, end_page: Optional[int] = None, - wandb_artifact_name: Optional[str] = None, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, image_save_dir: str = "./images", exclude_file_extensions: list[str] = [], - cleanup: bool = False, **kwargs, ) -> List[Dict[str, str]]: """ @@ -65,21 +133,15 @@ class BaseImageLoader(BaseTextLoader): Args: start_page (Optional[int]): The starting page index (0-based) to process. end_page (Optional[int]): The ending page index (0-based) to process. - wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided. + dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. image_save_dir (str): The directory to save the extracted images. exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir. - cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact. **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library. Returns: - List[Dict[str, Any]]: A list of dictionaries, each containing the image and metadata for a processed page. - Each dictionary will have the following keys and values: - - - "page_idx": (int) the index of the page. - - "document_name": (str) the name of the document. - - "file_path": (str) the local file path where the PDF is stored. - - "file_url": (str) the URL of the PDF file. - - "image_file_path" or "image_file_paths": (str) the local file path where the image/images are stored. + Dataset: A HuggingFace dataset containing the processed pages. + Raises: ValueError: If the specified start_page or end_page is out of bounds of the document's page count. """ @@ -111,19 +173,8 @@ class BaseImageLoader(BaseTextLoader): if file.endswith(tuple(exclude_file_extensions)): os.remove(os.path.join(image_save_dir, file)) - if wandb_artifact_name: - artifact = wandb.Artifact( - name=wandb_artifact_name, - type="dataset", - metadata={"loader_name": self.__class__.__name__}, - ) - artifact.add_dir(local_path=image_save_dir) - artifact.save() - rich.print("Artifact saved and uploaded to wandb!") - - if cleanup: - for file in os.listdir(image_save_dir): - file_path = os.path.join(image_save_dir, file) - if os.path.isfile(file_path): - os.remove(file_path) - return pages + dataset = self.save_as_dataset( + start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset + ) + + return dataset diff --git a/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py index b6319b054e5876bc88138e5637263282f02f3b1c..836c89e78b014df842328c2eaac8b6f20216ea39 100644 --- a/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py @@ -3,9 +3,12 @@ import os from typing import Any, Dict import fitz +from pdf2image.pdf2image import convert_from_path from PIL import Image, ImageOps, UnidentifiedImageError -from .base_img_loader import BaseImageLoader +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) class FitzPILImageLoader(BaseImageLoader): @@ -20,27 +23,16 @@ class FitzPILImageLoader(BaseImageLoader): ```python import asyncio - import weave - - import wandb from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + loader = FitzPILImageLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=32, - end_page=37, - wandb_artifact_name="grays-anatomy-images-fitzpil", - cleanup=False, - ) - ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) ``` Args: @@ -118,6 +110,14 @@ class FitzPILImageLoader(BaseImageLoader): pdf_document.close() + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + return { "page_idx": page_idx, "document_name": self.document_name, diff --git a/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py index bfc63f8d7a0b5dd6e85c4f61eb51727b21650107..e66cf0af74985c58200099c50ab103d6c8af6250 100644 --- a/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py @@ -5,7 +5,9 @@ from marker.convert import convert_single_pdf from marker.models import load_all_models from pdf2image.pdf2image import convert_from_path -from .base_img_loader import BaseImageLoader +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -22,27 +24,16 @@ class MarkerImageLoader(BaseImageLoader): ```python import asyncio - import weave - - import wandb from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + loader = MarkerImageLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - wandb_artifact_name="grays-anatomy-images-marker", - cleanup=False, - ) - ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) ``` Args: @@ -84,7 +75,7 @@ class MarkerImageLoader(BaseImageLoader): - "file_url": (str) the URL of the PDF file. - "image_file_path": (str) the local file path where the image is stored. """ - _, images, out_meta = convert_single_pdf( + _, images, _ = convert_single_pdf( self.document_file_path, self.model_lst, max_pages=1, @@ -101,14 +92,13 @@ class MarkerImageLoader(BaseImageLoader): image.save(image_file_path, "png") image_file_paths.append(image_file_path) - if self.save_page_image: - page_image = convert_from_path( - self.document_file_path, - first_page=page_idx + 1, - last_page=page_idx + 1, - **kwargs, - )[0] - page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) return { "page_idx": page_idx, @@ -116,7 +106,6 @@ class MarkerImageLoader(BaseImageLoader): "file_path": self.document_file_path, "file_url": self.url, "image_file_paths": os.path.join(image_save_dir, "*.png"), - "meta": out_meta, } def load_data( diff --git a/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py index 7c7d5c6462344c6a4ec16c423d946e468706b54b..bd5abaad407c8e7781275fc30a19a771f902cf74 100644 --- a/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py @@ -3,7 +3,9 @@ from typing import Any, Dict from pdf2image.pdf2image import convert_from_path -from .base_img_loader import BaseImageLoader +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) class PDF2ImageLoader(BaseImageLoader): @@ -19,27 +21,16 @@ class PDF2ImageLoader(BaseImageLoader): ```python import asyncio - import weave - - import wandb from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + loader = PDF2ImageLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - wandb_artifact_name="grays-anatomy-images-pdf2image", - cleanup=False, - ) - ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) ``` Args: diff --git a/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py index 8d96ece41bd1682c8102d731994deb134edb0cb2..2635071c6b58751c8791c07e0c96877e8210f3c1 100644 --- a/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py @@ -2,8 +2,11 @@ import os from typing import Any, Dict import pdfplumber +from pdf2image.pdf2image import convert_from_path -from .base_img_loader import BaseImageLoader +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) class PDFPlumberImageLoader(BaseImageLoader): @@ -18,27 +21,16 @@ class PDFPlumberImageLoader(BaseImageLoader): ```python import asyncio - import weave - - import wandb from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + loader = PDFPlumberImageLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=32, - end_page=37, - wandb_artifact_name="grays-anatomy-images-pdfplumber", - cleanup=False, - ) - ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) ``` Args: @@ -92,6 +84,14 @@ class PDFPlumberImageLoader(BaseImageLoader): extracted_image.save(image_file_path, "png") image_file_paths.append(image_file_path) + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + return { "page_idx": page_idx, "document_name": self.document_name, diff --git a/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py index 4eeedb0fcc3ef80a4e2139be6c7a9d0c53b5f62e..336b8afc01fa421f8b2b7ae4d6beedd3cdf54ace 100644 --- a/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py +++ b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py @@ -3,9 +3,12 @@ import os from typing import Any, Dict import fitz +from pdf2image.pdf2image import convert_from_path from PIL import Image -from .base_img_loader import BaseImageLoader +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) class PyMuPDFImageLoader(BaseImageLoader): @@ -20,27 +23,16 @@ class PyMuPDFImageLoader(BaseImageLoader): ```python import asyncio - import weave - - import wandb from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + loader = PyMuPDFImageLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=32, - end_page=37, - wandb_artifact_name="grays-anatomy-images-pymupdf", - cleanup=False, - ) - ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) ``` Args: @@ -115,6 +107,14 @@ class PyMuPDFImageLoader(BaseImageLoader): pdf_document.close() + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + return { "page_idx": page_idx, "document_name": self.document_name, diff --git a/medrag_multi_modal/document_loader/text_loader/base_text_loader.py b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py index 4137becf1966224bc8d89206d25be85d0bf1d098..b6bc4dc455223eaaf75b61de500bf740d5fe9446 100644 --- a/medrag_multi_modal/document_loader/text_loader/base_text_loader.py +++ b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py @@ -1,12 +1,13 @@ import asyncio import os from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, Optional +import huggingface_hub import PyPDF2 -import rich -import weave +from datasets import Dataset, concatenate_datasets, load_dataset from firerequests import FireRequests +from rich.progress import Progress class BaseTextLoader(ABC): @@ -22,14 +23,22 @@ class BaseTextLoader(ABC): url (str): The URL of the PDF file to download if not present locally. document_name (str): The name of the document for metadata purposes. document_file_path (str): The local file path where the PDF is stored or will be downloaded. + metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset. """ - def __init__(self, url: str, document_name: str, document_file_path: str): + def __init__( + self, + url: str, + document_name: str, + document_file_path: str, + metadata: Optional[dict[str, Any]] = None, + ): self.url = url self.document_name = document_name self.document_file_path = document_file_path + self.metadata = metadata or {} if not os.path.exists(self.document_file_path): - FireRequests().download(url, filename=self.document_file_path) + FireRequests().download(url, filenames=self.document_file_path) with open(self.document_file_path, "rb") as file: pdf_reader = PyPDF2.PdfReader(file) self.page_count = len(pdf_reader.pages) @@ -85,9 +94,11 @@ class BaseTextLoader(ABC): self, start_page: Optional[int] = None, end_page: Optional[int] = None, - weave_dataset_name: Optional[str] = None, + exclude_pages: Optional[list[int]] = None, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, **kwargs, - ) -> List[Dict[str, str]]: + ) -> Dataset: """ Asynchronously loads text from a PDF file specified by a URL or local file path. The overrided processing abstract method then processes the text into markdown format, @@ -102,23 +113,26 @@ class BaseTextLoader(ABC): each page, extract the text from the PDF, and convert it to markdown. It processes pages concurrently using `asyncio` for efficiency. - If a weave_dataset_name is provided, the processed pages are published to a Weave dataset. + If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset. Args: start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page. end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page. - weave_dataset_name (Optional[str]): The name of the Weave dataset to publish the pages to, if provided. + exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing. + dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library. Returns: - List[Dict[str, str]]: A list of dictionaries, each containing the text and metadata for a processed page. - Each dictionary will have the following keys and values: + Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages. + Each entry in the dataset will have the following keys and values: - "text": (str) the processed page data in markdown format. - "page_idx": (int) the index of the page. - "document_name": (str) the name of the document. - "file_path": (str) the local file path where the PDF is stored. - "file_url": (str) the URL of the PDF file. + - "loader_name": (str) the name of the loader class used to process the page. Raises: ValueError: If the specified start_page or end_page is out of bounds of the document's page count. @@ -127,21 +141,45 @@ class BaseTextLoader(ABC): pages = [] processed_pages_counter: int = 1 total_pages = end_page - start_page + exclude_pages = exclude_pages or [] async def process_page(page_idx): nonlocal processed_pages_counter page_data = await self.extract_page_data(page_idx, **kwargs) page_data["loader_name"] = self.__class__.__name__ + for key, value in self.metadata.items(): + if key not in page_data: + page_data[key] = value pages.append(page_data) - rich.print( - f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}" + progress.update( + task_id, + advance=1, + description=f"Loading page {page_idx} using {self.__class__.__name__}", ) processed_pages_counter += 1 - tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)] - for task in asyncio.as_completed(tasks): - await task - - if weave_dataset_name: - weave.publish(weave.Dataset(name=weave_dataset_name, rows=pages)) - return pages + progress = Progress() + with progress: + task_id = progress.add_task("Starting...", total=total_pages) + tasks = [ + process_page(page_idx) + for page_idx in range(start_page, end_page + 1) + if page_idx not in exclude_pages + ] + for task in asyncio.as_completed(tasks): + await task + + pages.sort(key=lambda x: x["page_idx"]) + + dataset = Dataset.from_list(pages) + if dataset_repo_id: + if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"): + print("Dataset already exists") + if not overwrite_dataset: + print("Not overwriting dataset") + dataset = concatenate_datasets( + [dataset, load_dataset(dataset_repo_id, split="corpus")] + ) + dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False) + + return dataset diff --git a/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py index 34a94746391c8700ab7f13fb30feda3af01783a0..16a19e3343bcd2d44f1481def4c7ad031838b815 100644 --- a/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +++ b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py @@ -4,7 +4,9 @@ from typing import Dict from marker.convert import convert_single_pdf from marker.models import load_all_models -from .base_text_loader import BaseTextLoader +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -26,24 +28,16 @@ class MarkerTextLoader(BaseTextLoader): ```python import asyncio - import weave + from medrag_multi_modal.document_loader import MarkerTextLoader - from medrag_multi_modal.document_loader.text_loader import MarkerTextLoader + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" - weave.init(project_name="ml-colabs/medrag-multi-modal") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" loader = MarkerTextLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - weave_dataset_name="grays-anatomy-text", - ) - ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) ``` Args: @@ -76,7 +70,7 @@ class MarkerTextLoader(BaseTextLoader): """ model_lst = load_all_models() - text, _, out_meta = convert_single_pdf( + text, _, _ = convert_single_pdf( self.document_file_path, model_lst, max_pages=1, @@ -92,5 +86,4 @@ class MarkerTextLoader(BaseTextLoader): "document_name": self.document_name, "file_path": self.document_file_path, "file_url": self.url, - "meta": out_meta, } diff --git a/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py index 9fa4a9c39259f079025688fdc19ff9bc63ad50b2..337aed66516170baded9a6faa997c7a2e2503319 100644 --- a/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py +++ b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py @@ -2,7 +2,9 @@ from typing import Dict import pdfplumber -from .base_text_loader import BaseTextLoader +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) class PDFPlumberTextLoader(BaseTextLoader): @@ -22,24 +24,16 @@ class PDFPlumberTextLoader(BaseTextLoader): ```python import asyncio - import weave + from medrag_multi_modal.document_loader import PDFPlumberTextLoader - from medrag_multi_modal.document_loader.text_loader import PDFPlumberTextLoader + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" - weave.init(project_name="ml-colabs/medrag-multi-modal") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" loader = PDFPlumberTextLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - weave_dataset_name="grays-anatomy-text", - ) - ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) ``` Args: diff --git a/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py index a02e817bccd12867c3ee6d31681704e1b9700e4f..05493656683fa7378dff415e3861663667f1340a 100644 --- a/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py +++ b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py @@ -2,7 +2,9 @@ from typing import Dict import pymupdf4llm -from .base_text_loader import BaseTextLoader +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) class PyMuPDF4LLMTextLoader(BaseTextLoader): @@ -20,26 +22,16 @@ class PyMuPDF4LLMTextLoader(BaseTextLoader): ```python import asyncio - import weave + from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader - from medrag_multi_modal.document_loader.text_loader import ( - PyMuPDF4LLMTextLoader - ) + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" - weave.init(project_name="ml-colabs/medrag-multi-modal") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" loader = PyMuPDF4LLMTextLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - weave_dataset_name="grays-anatomy-text", - ) - ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) ``` Args: diff --git a/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py index 7a568b0ee89190ceb5cc92c6ac49dc3e442fe6f1..df6cc011e1a93d9c5f8dd1d25d88f41dff928623 100644 --- a/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py +++ b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py @@ -2,7 +2,9 @@ from typing import Dict import PyPDF2 -from .base_text_loader import BaseTextLoader +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) class PyPDF2TextLoader(BaseTextLoader): @@ -22,24 +24,16 @@ class PyPDF2TextLoader(BaseTextLoader): ```python import asyncio - import weave + from medrag_multi_modal.document_loader import PyPDF2TextLoader - from medrag_multi_modal.document_loader.text_loader import PyPDF2TextLoader + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" - weave.init(project_name="ml-colabs/medrag-multi-modal") - url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" loader = PyPDF2TextLoader( - url=url, + url=URL, document_name="Gray's Anatomy", document_file_path="grays_anatomy.pdf", ) - asyncio.run( - loader.load_data( - start_page=31, - end_page=36, - weave_dataset_name="grays-anatomy-text", - ) - ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) ``` Args: diff --git a/medrag_multi_modal/metrics/__init__.py b/medrag_multi_modal/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b7bf585104b0331b74407570a1d6744c836320 --- /dev/null +++ b/medrag_multi_modal/metrics/__init__.py @@ -0,0 +1,3 @@ +from .mmlu import MMLUOptionAccuracy + +__all__ = ["MMLUOptionAccuracy"] diff --git a/medrag_multi_modal/metrics/base.py b/medrag_multi_modal/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e16cb9219d1f47115fc497f4ddc4f49d376a7c63 --- /dev/null +++ b/medrag_multi_modal/metrics/base.py @@ -0,0 +1,108 @@ +from typing import Optional + +import numpy as np +import weave + + +class BaseAccuracyMetric(weave.Scorer): + """ + BaseAccuracyMetric is a class that extends the + [`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers) + to provide a comprehensive evaluation of accuracy metrics for a given set of score rows. + + This class is designed to process a list of score rows, each containing a + 'correct' key that indicates whether a particular prediction was correct. + The `summarize` method calculates various statistical measures and metrics + based on this data, including: + + - True and false counts: The number of true and false predictions. + - True and false fractions: The proportion of true and false predictions. + - Standard error: The standard error of the mean for the true predictions. + - Precision: The ratio of true positive predictions to the total number of + positive predictions. + - Recall: The ratio of true positive predictions to the total number of + actual positives. + - F1 Score: The harmonic mean of precision and recall, providing a balance + between the two metrics. + + The `summarize` method returns a dictionary containing these metrics, + allowing for a detailed analysis of the model's performance. + + Methods: + summarize(score_rows: list) -> Optional[dict]: + Processes the input score rows to compute and return a dictionary + of accuracy metrics. + """ + @weave.op() + def summarize(self, score_rows: list) -> Optional[dict]: + """ + Summarizes the accuracy metrics from a list of score rows. + + This method processes a list of score rows, each containing a 'correct' key + that indicates whether a particular prediction was correct. It calculates + various statistical measures and metrics based on this data, including: + + - True and false counts: The number of true and false predictions. + - True and false fractions: The proportion of true and false predictions. + - Standard error: The standard error of the mean for the true predictions. + - Precision: The ratio of true positive predictions to the total number of + positive predictions. + - Recall: The ratio of true positive predictions to the total number of + actual positives. + - F1 Score: The harmonic mean of precision and recall, providing a balance + between the two metrics. + + The method returns a dictionary containing these metrics, allowing for a + detailed analysis of the model's performance. + + Args: + score_rows (list): A list of dictionaries, each containing a 'correct' + key with a boolean value indicating the correctness of a prediction. + + Returns: + Optional[dict]: A dictionary containing the calculated accuracy metrics, + or None if the input list is empty. + """ + valid_data = [ + x.get("correct") for x in score_rows if x.get("correct") is not None + ] + count_true = list(valid_data).count(True) + int_data = [int(x) for x in valid_data] + + sample_mean = np.mean(int_data) if int_data else 0 + sample_variance = np.var(int_data) if int_data else 0 + sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0 + + # Calculate precision, recall, and F1 score + true_positives = count_true + false_positives = len(valid_data) - count_true + false_negatives = len(score_rows) - len(valid_data) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1_score = ( + (2 * precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + return { + "correct": { + "true_count": count_true, + "false_count": len(score_rows) - count_true, + "true_fraction": float(sample_mean), + "false_fraction": 1.0 - float(sample_mean), + "stderr": float(sample_error), + "precision": precision, + "recall": recall, + "f1_score": f1_score, + } + } diff --git a/medrag_multi_modal/metrics/mmlu.py b/medrag_multi_modal/metrics/mmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..3e182084fd5cecef8834710d611ae0b5680dfe4b --- /dev/null +++ b/medrag_multi_modal/metrics/mmlu.py @@ -0,0 +1,24 @@ +import weave + +from medrag_multi_modal.assistant.schema import MedQAResponse +from medrag_multi_modal.metrics.base import BaseAccuracyMetric + + +class MMLUOptionAccuracy(BaseAccuracyMetric): + """ + MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`. + + This class is designed to evaluate the accuracy of a multiple-choice question + response by comparing the provided answer with the correct answer from the + given options. It uses the MedQAResponse schema to extract the response + and checks if it matches the correct answer. + + Methods: + -------- + score(output: MedQAResponse, options: list[str], answer: str) -> dict: + Compares the provided answer with the correct answer and returns a + dictionary indicating whether the answer is correct. + """ + @weave.op() + def score(self, output: MedQAResponse, options: list[str], answer: str): + return {"correct": options[answer] == output.response.answer} diff --git a/medrag_multi_modal/retrieval/__init__.py b/medrag_multi_modal/retrieval/__init__.py index 8b052f78129652ce25e8f7a32f166de66a2aa402..20f0a3bdfd27fd93ea9681dc031ead5b885c1909 100644 --- a/medrag_multi_modal/retrieval/__init__.py +++ b/medrag_multi_modal/retrieval/__init__.py @@ -1,15 +1,3 @@ -from .bm25s_retrieval import BM25sRetriever from .colpali_retrieval import CalPaliRetriever -from .common import SimilarityMetric -from .contriever_retrieval import ContrieverRetriever -from .medcpt_retrieval import MedCPTRetriever -from .nv_embed_2 import NVEmbed2Retriever -__all__ = [ - "CalPaliRetriever", - "BM25sRetriever", - "ContrieverRetriever", - "SimilarityMetric", - "MedCPTRetriever", - "NVEmbed2Retriever", -] +__all__ = ["CalPaliRetriever"] diff --git a/medrag_multi_modal/retrieval/colpali_retrieval.py b/medrag_multi_modal/retrieval/colpali_retrieval.py index d82c046f100637d657e545139413d019b3059f79..522d964058abda29649115fa732f5a30ae659607 100644 --- a/medrag_multi_modal/retrieval/colpali_retrieval.py +++ b/medrag_multi_modal/retrieval/colpali_retrieval.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: import wandb from PIL import Image -from ..utils import get_wandb_artifact +from medrag_multi_modal.utils import get_wandb_artifact class CalPaliRetriever(weave.Model): diff --git a/medrag_multi_modal/retrieval/common.py b/medrag_multi_modal/retrieval/common.py index 2170118065cb0e440f605902733326c57ff0719e..9a0f1244bb83c595a5f2427e6ecf86b9334bc4c7 100644 --- a/medrag_multi_modal/retrieval/common.py +++ b/medrag_multi_modal/retrieval/common.py @@ -1,10 +1,5 @@ from enum import Enum -import safetensors -import safetensors.torch -import torch -import wandb - class SimilarityMetric(Enum): COSINE = "cosine" @@ -24,21 +19,3 @@ def argsort_scores(scores: list[float], descending: bool = False): list(enumerate(scores)), key=lambda x: x[1], reverse=descending ) ] - - -def save_vector_index( - vector_index: torch.Tensor, - type: str, - index_name: str, - metadata: dict, - filename: str = "vector_index.safetensors", -): - safetensors.torch.save_file({"vector_index": vector_index.cpu()}, filename) - if wandb.run: - artifact = wandb.Artifact( - name=index_name, - type=type, - metadata=metadata, - ) - artifact.add_file(filename) - artifact.save() diff --git a/medrag_multi_modal/retrieval/text_retrieval/__init__.py b/medrag_multi_modal/retrieval/text_retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8482ab713102ca550af488a830af1aacc7fee3 --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/__init__.py @@ -0,0 +1,11 @@ +from .bm25s_retrieval import BM25sRetriever +from .contriever_retrieval import ContrieverRetriever +from .medcpt_retrieval import MedCPTRetriever +from .nv_embed_2 import NVEmbed2Retriever + +__all__ = [ + "BM25sRetriever", + "ContrieverRetriever", + "MedCPTRetriever", + "NVEmbed2Retriever", +] diff --git a/medrag_multi_modal/retrieval/bm25s_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py similarity index 55% rename from medrag_multi_modal/retrieval/bm25s_retrieval.py rename to medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py index c61ed74b678a618c2f424d95f0512ac79f82cd88..b5b528262a62c1ae3c962a42877d4cd26ef6d75d 100644 --- a/medrag_multi_modal/retrieval/bm25s_retrieval.py +++ b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py @@ -1,12 +1,17 @@ +import json import os -from glob import glob -from typing import Optional +import shutil +from typing import Optional, Union import bm25s -import wandb +import huggingface_hub import weave +from bm25s import BM25 +from datasets import Dataset, load_dataset from Stemmer import Stemmer +from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface + LANGUAGE_DICT = { "english": "en", "french": "fr", @@ -26,49 +31,60 @@ class BM25sRetriever(weave.Model): a new instance is created. """ - language: str - use_stemmer: bool - _retriever: Optional[bm25s.BM25] + language: Optional[str] + use_stemmer: bool = True + _retriever: Optional[BM25] def __init__( self, language: str = "english", use_stemmer: bool = True, - retriever: Optional[bm25s.BM25] = None, + retriever: Optional[BM25] = None, ): super().__init__(language=language, use_stemmer=use_stemmer) - self._retriever = retriever or bm25s.BM25() + self._retriever = retriever or BM25() - def index(self, chunk_dataset_name: str, index_name: Optional[str] = None): + def index( + self, + chunk_dataset: Union[Dataset, str], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + ): """ Indexes a dataset of text chunks using the BM25 algorithm. - This function takes a dataset of text chunks identified by `chunk_dataset_name`, - tokenizes the text using the BM25 tokenizer with optional stemming, and indexes - the tokenized text using the BM25 retriever. If an `index_name` is provided, the - index is saved to disk and logged as a Weights & Biases artifact. + This method retrieves a dataset of text chunks from a specified source, tokenizes + the text using the BM25 tokenizer with optional stemming, and indexes the tokenized + text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved + to disk and optionally logged as a Huggingface artifact. !!! example "Example Usage" ```python import weave from dotenv import load_dotenv - import wandb - from medrag_multi_modal.retrieval import BM25sRetriever + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index") retriever = BM25sRetriever() - retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s") + retriever.index( + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + index_repo_id="geekyrakshit/grays-anatomy-index", + ) ``` Args: - chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed. - index_name (Optional[str]): The name to save the index under. If provided, the index - is saved to disk and logged as a Weights & Biases artifact. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. """ - chunk_dataset = weave.ref(chunk_dataset_name).get().rows + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) corpus = [row["text"] for row in chunk_dataset] corpus_tokens = bm25s.tokenize( corpus, @@ -76,28 +92,40 @@ class BM25sRetriever(weave.Model): stemmer=Stemmer(self.language) if self.use_stemmer else None, ) self._retriever.index(corpus_tokens) - if index_name: + if index_repo_id: + os.makedirs(".huggingface", exist_ok=True) + index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1]) self._retriever.save( - index_name, corpus=[dict(row) for row in chunk_dataset] + index_save_dir, corpus=[dict(row) for row in chunk_dataset] + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" ) - if wandb.run: - artifact = wandb.Artifact( - name=index_name, - type="bm25s-index", - metadata={ + with open(os.path.join(index_save_dir, "config.json"), "w") as config_file: + json.dump( + { "language": self.language, "use_stemmer": self.use_stemmer, }, + config_file, + indent=4, ) - artifact.add_dir(index_name, name=index_name) - artifact.save() + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: BM25s index", + ) + if cleanup: + shutil.rmtree(index_save_dir) @classmethod - def from_wandb_artifact(cls, index_artifact_address: str): + def from_index(cls, index_repo_id: str): """ - Creates an instance of the class from a Weights & Biases artifact. + Creates an instance of the class from a Huggingface repository. - This class method retrieves a BM25 index artifact from Weights & Biases, + This class method retrieves a BM25 index artifact from a Huggingface repository, downloads the artifact, and loads the BM25 retriever with the index and its associated corpus. The method also extracts metadata from the artifact to initialize the class instance with the appropriate language and stemming @@ -108,41 +136,26 @@ class BM25sRetriever(weave.Model): import weave from dotenv import load_dotenv - from medrag_multi_modal.retrieval import BM25sRetriever + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = BM25sRetriever.from_wandb_artifact( - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest" - ) + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") ``` Args: - index_artifact_address (str): The address of the Weights & Biases artifact - containing the BM25 index. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. Returns: An instance of the class initialized with the BM25 retriever and metadata from the artifact. """ - if wandb.run: - artifact = wandb.run.use_artifact( - index_artifact_address, type="bm25s-index" - ) - artifact_dir = artifact.download() - else: - api = wandb.Api() - artifact = api.artifact(index_artifact_address) - artifact_dir = artifact.download() - retriever = bm25s.BM25.load( - glob(os.path.join(artifact_dir, "*"))[0], load_corpus=True - ) - metadata = artifact.metadata - return cls( - language=metadata["language"], - use_stemmer=metadata["use_stemmer"], - retriever=retriever, - ) + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") + retriever = bm25s.BM25.load(index_dir, load_corpus=True) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + config = json.load(config_file) + return cls(retriever=retriever, **config) @weave.op() def retrieve(self, query: str, top_k: int = 2): @@ -155,6 +168,20 @@ class BM25sRetriever(weave.Model): The results are returned as a list of dictionaries, each containing a chunk and its corresponding relevance score. + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + ``` + Args: query (str): The input query string to search for relevant chunks. top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. @@ -192,13 +219,12 @@ class BM25sRetriever(weave.Model): import weave from dotenv import load_dotenv - from medrag_multi_modal.retrieval import BM25sRetriever + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = BM25sRetriever.from_wandb_artifact( - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest" - ) + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") retrieved_chunks = retriever.predict(query="What are Ribosomes?") ``` diff --git a/medrag_multi_modal/retrieval/contriever_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py similarity index 55% rename from medrag_multi_modal/retrieval/contriever_retrieval.py rename to medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py index c2120f17fe0a4546968f8e0dc295928b7cb7b532..77da42e2693828232048723718ca29ae996c788f 100644 --- a/medrag_multi_modal/retrieval/contriever_retrieval.py +++ b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py @@ -1,11 +1,16 @@ +import json import os -from typing import Optional +import shutil +from typing import Optional, Union +import huggingface_hub import safetensors import safetensors.torch import torch import torch.nn.functional as F import weave +from datasets import Dataset, load_dataset +from rich.progress import track from transformers import ( AutoModel, AutoTokenizer, @@ -13,8 +18,16 @@ from transformers import ( PreTrainedTokenizerFast, ) -from ..utils import get_torch_backend, get_wandb_artifact -from .common import SimilarityMetric, argsort_scores, mean_pooling, save_vector_index +from medrag_multi_modal.retrieval.common import ( + SimilarityMetric, + argsort_scores, + mean_pooling, +) +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) class ContrieverRetriever(weave.Model): @@ -45,18 +58,35 @@ class ContrieverRetriever(weave.Model): ): super().__init__(model_name=model_name) self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) - self._model = AutoModel.from_pretrained(self.model_name) + self._model = AutoModel.from_pretrained(self.model_name).to(get_torch_backend()) self._vector_index = vector_index self._chunk_dataset = chunk_dataset - def encode(self, corpus: list[str]) -> torch.Tensor: - inputs = self._tokenizer( - corpus, padding=True, truncation=True, return_tensors="pt" - ) - outputs = self._model(**inputs) - return mean_pooling(outputs[0], inputs["attention_mask"]) - - def index(self, chunk_dataset_name: str, index_name: Optional[str] = None): + def encode(self, corpus: list[str], batch_size: int) -> torch.Tensor: + embeddings = [] + iterable = track( + range(0, len(corpus), batch_size), + description=f"Encoding corpus using {self.model_name}", + ) if batch_size > 1 else range(0, len(corpus), batch_size) + for idx in iterable: + batch = corpus[idx : idx + batch_size] + inputs = self._tokenizer( + batch, padding=True, truncation=True, return_tensors="pt" + ).to(get_torch_backend()) + with torch.no_grad(): + outputs = self._model(**inputs) + batch_embeddings = mean_pooling(outputs[0], inputs["attention_mask"]) + embeddings.append(batch_embeddings) + embeddings = torch.cat(embeddings, dim=0) + return embeddings + + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 32, + ): """ Indexes a dataset of text chunks and optionally saves the vector index to a file. @@ -68,43 +98,64 @@ class ContrieverRetriever(weave.Model): !!! example "Example Usage" ```python - import weave - from dotenv import load_dotenv + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever - import wandb - from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric - - load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="contriever-index") - retriever = ContrieverRetriever(model_name="facebook/contriever") + retriever = ContrieverRetriever() retriever.index( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_name="grays-anatomy-contriever", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", + batch_size=256, ) ``` Args: - chunk_dataset_name (str): The name of the Weave dataset containing the text chunks - to be indexed. - index_name (Optional[str]): The name of the index artifact to be saved. If provided, - the vector index is saved to a file and logged as an artifact to Weave. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. """ - self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) corpus = [row["text"] for row in self._chunk_dataset] with torch.no_grad(): - vector_index = self.encode(corpus) + vector_index = self.encode(corpus, batch_size) self._vector_index = vector_index - if index_name: - save_vector_index( - self._vector_index, - "contriever-index", - index_name, - {"model_name": self.model_name}, + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] + ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + {"model_name": self.model_name}, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", ) + if cleanup: + shutil.rmtree(index_save_dir) @classmethod - def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str): + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): """ Creates an instance of the class from a Weave artifact. @@ -120,35 +171,38 @@ class ContrieverRetriever(weave.Model): import weave from dotenv import load_dotenv - from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = ContrieverRetriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1", + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", ) ``` Args: - chunk_dataset_name (str): The name of the Weave dataset containing the text chunks. - index_artifact_address (str): The address of the Weave artifact containing the - vector index. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. Returns: An instance of the class initialized with the retrieved model name, vector index, and chunk dataset. """ - artifact_dir, metadata = get_wandb_artifact( - index_artifact_address, "contriever-index", get_metadata=True - ) + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") with safetensors.torch.safe_open( - os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt" + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" ) as f: vector_index = f.get_tensor("vector_index") device = torch.device(get_torch_backend()) vector_index = vector_index.to(device) - chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows] + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) return cls( model_name=metadata["model_name"], vector_index=vector_index, @@ -170,6 +224,22 @@ class ContrieverRetriever(weave.Model): cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores are returned as a list of dictionaries, each containing a chunk and its corresponding score. + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + ``` + Args: query (str): The input query string to search for relevant chunks. top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. @@ -181,7 +251,7 @@ class ContrieverRetriever(weave.Model): query = [query] device = torch.device(get_torch_backend()) with torch.no_grad(): - query_embedding = self.encode(query).to(device) + query_embedding = self.encode(query, batch_size=1).to(device) if metric == SimilarityMetric.EUCLIDEAN: scores = torch.squeeze(query_embedding @ self._vector_index.T) else: @@ -218,15 +288,15 @@ class ContrieverRetriever(weave.Model): import weave from dotenv import load_dotenv - from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = ContrieverRetriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1", + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", ) - scores = retriever.predict(query="What are Ribosomes?", metric=SimilarityMetric.COSINE) + retrieved_chunks = retriever.predict(query="What are Ribosomes?") ``` Args: diff --git a/medrag_multi_modal/retrieval/medcpt_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py similarity index 51% rename from medrag_multi_modal/retrieval/medcpt_retrieval.py rename to medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py index ec6adb30173560bf0623a64d3b594b749178e541..ac400dacf329a532e99b3cac7967326b8c3d2c0f 100644 --- a/medrag_multi_modal/retrieval/medcpt_retrieval.py +++ b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py @@ -1,11 +1,16 @@ +import json import os -from typing import Optional +import shutil +from typing import Optional, Union +import huggingface_hub import safetensors import safetensors.torch import torch import torch.nn.functional as F import weave +from datasets import Dataset, load_dataset +from rich.progress import track from transformers import ( AutoModel, AutoTokenizer, @@ -13,8 +18,12 @@ from transformers import ( PreTrainedTokenizerFast, ) -from ..utils import get_torch_backend, get_wandb_artifact -from .common import SimilarityMetric, argsort_scores, save_vector_index +from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) class MedCPTRetriever(weave.Model): @@ -45,8 +54,8 @@ class MedCPTRetriever(weave.Model): def __init__( self, - query_encoder_model_name: str, - article_encoder_model_name: str, + query_encoder_model_name: str = "ncbi/MedCPT-Query-Encoder", + article_encoder_model_name: str = "ncbi/MedCPT-Article-Encoder", chunk_size: Optional[int] = None, vector_index: Optional[torch.Tensor] = None, chunk_dataset: Optional[list[dict]] = None, @@ -64,119 +73,157 @@ class MedCPTRetriever(weave.Model): ) self._query_encoder_model = AutoModel.from_pretrained( self.query_encoder_model_name - ) + ).to(get_torch_backend()) self._article_encoder_model = AutoModel.from_pretrained( self.article_encoder_model_name - ) + ).to(get_torch_backend()) self._chunk_dataset = chunk_dataset self._vector_index = vector_index - def index(self, chunk_dataset_name: str, index_name: Optional[str] = None): + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 32, + ): """ - Indexes a dataset of text chunks and optionally saves the vector index. + Indexes a dataset of text chunks using the MedCPT model and optionally saves the vector index. - This method retrieves a dataset of text chunks from a Weave reference, encodes the text - chunks using the article encoder model, and stores the resulting vector index. If an - index name is provided, the vector index is saved to a file using the `save_vector_index` - function. + This method retrieves a dataset of text chunks from a specified source, encodes the text + chunks into vector representations using the article encoder model, and stores the + resulting vector index. If an `index_repo_id` is provided, the vector index is saved + to disk in the safetensors format and optionally logged as a Huggingface artifact. !!! example "Example Usage" ```python import weave from dotenv import load_dotenv - import wandb - from medrag_multi_modal.retrieval import MedCPTRetriever + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="medcpt-index") - retriever = MedCPTRetriever( - query_encoder_model_name="ncbi/MedCPT-Query-Encoder", - article_encoder_model_name="ncbi/MedCPT-Article-Encoder", - ) + retriever = MedCPTRetriever() retriever.index( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_name="grays-anatomy-medcpt", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + index_repo_id="geekyrakshit/grays-anatomy-index-medcpt", ) ``` Args: - chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed. - index_name (Optional[str]): The name to use when saving the vector index. If not provided, - the vector index is not saved. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. """ - self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) corpus = [row["text"] for row in self._chunk_dataset] + vector_indices = [] with torch.no_grad(): - encoded = self._article_tokenizer( - corpus, - truncation=True, - padding=True, - return_tensors="pt", - max_length=self.chunk_size, - ) - vector_index = ( - self._article_encoder_model(**encoded) - .last_hidden_state[:, 0, :] - .contiguous() - ) + for idx in track( + range(0, len(corpus), batch_size), + description="Encoding corpus using MedCPT", + ): + batch = corpus[idx : idx + batch_size] + encoded = self._article_tokenizer( + batch, + truncation=True, + padding=True, + return_tensors="pt", + max_length=self.chunk_size, + ).to(get_torch_backend()) + batch_vectors = ( + self._article_encoder_model(**encoded) + .last_hidden_state[:, 0, :] + .contiguous() + ) + vector_indices.append(batch_vectors) + + vector_index = torch.cat(vector_indices, dim=0) self._vector_index = vector_index - if index_name: - save_vector_index( - self._vector_index, - "medcpt-index", - index_name, - { - "query_encoder_model_name": self.query_encoder_model_name, - "article_encoder_model_name": self.article_encoder_model_name, - "chunk_size": self.chunk_size, - }, + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": self._vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + { + "query_encoder_model_name": self.query_encoder_model_name, + "article_encoder_model_name": self.article_encoder_model_name, + "chunk_size": self.chunk_size, + }, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", + ) + if cleanup: + shutil.rmtree(index_save_dir) @classmethod - def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str): + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): """ - Initializes an instance of the class from a Weave artifact. + Creates an instance of the class from a Huggingface repository. - This method retrieves a precomputed vector index and its associated metadata from a Weave artifact - stored in Weights & Biases (wandb). It then loads the vector index into memory and initializes an - instance of the class with the retrieved model names, vector index, and chunk dataset. + This method retrieves a vector index and metadata from a Huggingface repository. + It also retrieves a dataset of text chunks from the specified source. The vector + index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). + The method then returns an instance of the class initialized with the retrieved + model names, vector index, and chunk dataset. !!! example "Example Usage" ```python - import weave - from dotenv import load_dotenv - - import wandb - from medrag_multi_modal.retrieval import MedCPTRetriever + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever - load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = MedCPTRetriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", ) ``` Args: - chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed. - index_artifact_address (str): The address of the Weave artifact containing the precomputed vector index. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. Returns: An instance of the class initialized with the retrieved model name, vector index, and chunk dataset. """ - artifact_dir, metadata = get_wandb_artifact( - index_artifact_address, "medcpt-index", get_metadata=True - ) + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") with safetensors.torch.safe_open( - os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt" + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" ) as f: vector_index = f.get_tensor("vector_index") device = torch.device(get_torch_backend()) vector_index = vector_index.to(device) - chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows] + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) return cls( query_encoder_model_name=metadata["query_encoder_model_name"], article_encoder_model_name=metadata["article_encoder_model_name"], @@ -200,6 +247,19 @@ class MedCPTRetriever(weave.Model): cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores are returned as a list of dictionaries, each containing a chunk and its corresponding score. + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.retrieve(query="What is ribosome?") + ``` + Args: query (str): The input query string to search for relevant chunks. top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. @@ -216,7 +276,7 @@ class MedCPTRetriever(weave.Model): truncation=True, padding=True, return_tensors="pt", - ) + ).to(device) query_embedding = self._query_encoder_model(**encoded).last_hidden_state[ :, 0, : ] @@ -254,18 +314,14 @@ class MedCPTRetriever(weave.Model): !!! example "Example Usage" ```python import weave - from dotenv import load_dotenv + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever - import wandb - from medrag_multi_modal.retrieval import MedCPTRetriever - - load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = MedCPTRetriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", ) - retriever.predict(query="What are Ribosomes?") + retriever.predict(query="What is ribosome?") ``` Args: diff --git a/medrag_multi_modal/retrieval/nv_embed_2.py b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py similarity index 56% rename from medrag_multi_modal/retrieval/nv_embed_2.py rename to medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py index 937fc8775addd232aa87b4bf58669801b24054f2..67e61883f8e801f105a2f7e039e3ed16b89dd2e9 100644 --- a/medrag_multi_modal/retrieval/nv_embed_2.py +++ b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py @@ -1,14 +1,23 @@ +import json import os -from typing import Optional +import shutil +from typing import Optional, Union +import huggingface_hub import safetensors import torch import torch.nn.functional as F import weave +from datasets import Dataset, load_dataset +from rich.progress import track from sentence_transformers import SentenceTransformer -from ..utils import get_torch_backend, get_wandb_artifact -from .common import SimilarityMetric, argsort_scores, save_vector_index +from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) class NVEmbed2Retriever(weave.Model): @@ -33,7 +42,7 @@ class NVEmbed2Retriever(weave.Model): def __init__( self, - model_name: str = "sentence-transformers/nvembed2-nli-v1", + model_name: str = "nvidia/NV-Embed-v2", vector_index: Optional[torch.Tensor] = None, chunk_dataset: Optional[list[dict]] = None, ): @@ -56,31 +65,29 @@ class NVEmbed2Retriever(weave.Model): ] return input_examples - def index(self, chunk_dataset_name: str, index_name: Optional[str] = None): + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 8, + ): """ - Indexes a dataset of text chunks and optionally saves the vector index to a file. + Indexes a dataset of text chunks and optionally saves the vector index to a Huggingface repository. - This method retrieves a dataset of text chunks from a Weave reference, encodes the + This method retrieves a dataset of text chunks from a specified source, encodes the text chunks into vector representations using the NV-Embed-v2 model, and stores the - resulting vector index. If an index name is provided, the vector index is saved to - a file in the safetensors format. Additionally, if a Weave run is active, the vector - index file is logged as an artifact to Weave. + resulting vector index. If an index repository ID is provided, the vector index is saved to + a file in the safetensors format within the specified Huggingface repository. !!! example "Example Usage" ```python - import weave - from dotenv import load_dotenv + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever - import wandb - from medrag_multi_modal.retrieval import NVEmbed2Retriever - - load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="nvembed2-index") - retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2") + retriever = NVEmbed2Retriever() retriever.index( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_name="grays-anatomy-nvembed2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", ) ``` @@ -93,31 +100,68 @@ class NVEmbed2Retriever(weave.Model): ``` Args: - chunk_dataset_name (str): The name of the Weave dataset containing the text chunks - to be indexed. - index_name (Optional[str]): The name of the index artifact to be saved. If provided, - the vector index is saved to a file and logged as an artifact to Weave. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. """ - self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows - corpus = [row["text"] for row in self._chunk_dataset] - self._vector_index = self._model.encode( - self.add_eos(corpus), batch_size=len(corpus), normalize_embeddings=True + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset ) + corpus = [row["text"] for row in self._chunk_dataset] + vector_indices = [] + + for idx in track( + range(0, len(corpus), batch_size), + description="Encoding corpus using NV-Embed-v2", + ): + batch = corpus[idx : idx + batch_size] + batch_embeddings = self._model.encode( + self.add_eos(batch), batch_size=len(batch), normalize_embeddings=True + ) + vector_indices.append(torch.tensor(batch_embeddings)) + + self._vector_index = torch.cat(vector_indices, dim=0) with torch.no_grad(): - if index_name: - save_vector_index( - torch.from_numpy(self._vector_index), - "nvembed2-index", - index_name, - {"model_name": self.model_name}, + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] + ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": self._vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + {"model_name": self.model_name}, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", + ) + if cleanup: + shutil.rmtree(index_save_dir) @classmethod - def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str): + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): """ - Creates an instance of the class from a Weave artifact. + Creates an instance of the class from a Huggingface repository. - This method retrieves a vector index and metadata from a Weave artifact stored in + This method retrieves a vector index and metadata from a Huggingface repository. It also retrieves a dataset of text chunks from a Huggingface dataset repository. The vector index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). The text chunks are converted into a list of dictionaries. The method then returns an instance of the class initialized with the retrieved model name, vector index, and chunk dataset. Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave reference. The vector index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). The text chunks are converted into a list of @@ -127,21 +171,11 @@ class NVEmbed2Retriever(weave.Model): !!! example "Example Usage" ```python import weave - from dotenv import load_dotenv - - import wandb - from medrag_multi_modal.retrieval import NVEmbed2Retriever + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever - load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2") - retriever.index( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_name="grays-anatomy-nvembed2", - ) - retriever = NVEmbed2Retriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0", + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", ) ``` @@ -154,24 +188,28 @@ class NVEmbed2Retriever(weave.Model): ``` Args: - chunk_dataset_name (str): The name of the Weave dataset containing the text chunks. - index_artifact_address (str): The address of the Weave artifact containing the - vector index. + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. Returns: An instance of the class initialized with the retrieved model name, vector index, and chunk dataset. """ - artifact_dir, metadata = get_wandb_artifact( - index_artifact_address, "nvembed2-index", get_metadata=True - ) + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") with safetensors.torch.safe_open( - os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt" + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" ) as f: vector_index = f.get_tensor("vector_index") device = torch.device(get_torch_backend()) vector_index = vector_index.to(device) - chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows] + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) return cls( model_name=metadata["model_name"], vector_index=vector_index, @@ -193,6 +231,27 @@ class NVEmbed2Retriever(weave.Model): cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores are returned as a list of dictionaries, each containing a chunk and its corresponding score. + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.retrieve(query="What is ribosome?") + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + Args: query (list[str]): The input query strings to search for relevant chunks. top_k (int, optional): The number of top relevant chunks to retrieve. @@ -240,23 +299,14 @@ class NVEmbed2Retriever(weave.Model): !!! example "Example Usage" ```python import weave - from dotenv import load_dotenv - - import wandb - from medrag_multi_modal.retrieval import NVEmbed2Retriever + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever - load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") - retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2") - retriever.index( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_name="grays-anatomy-nvembed2", - ) - retriever = NVEmbed2Retriever.from_wandb_artifact( - chunk_dataset_name="grays-anatomy-chunks:v0", - index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0", + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", ) - retriever.predict(query="What are Ribosomes?") + retriever.predict(query="What is ribosome?") ``` ??? note "Optional Speedup using Flash Attention" diff --git a/medrag_multi_modal/semantic_chunking.py b/medrag_multi_modal/semantic_chunking.py index 1eb5ac82d8aa7f159e367e60882b3f50371754f5..d7b82b3e7c40f9b80cd127a1646c71330f739457 100644 --- a/medrag_multi_modal/semantic_chunking.py +++ b/medrag_multi_modal/semantic_chunking.py @@ -1,9 +1,11 @@ +import asyncio from typing import Callable, Optional, Union +import huggingface_hub import semchunk import tiktoken import tokenizers -import weave +from datasets import Dataset, concatenate_datasets, load_dataset from rich.progress import track from transformers import PreTrainedTokenizer @@ -28,17 +30,13 @@ class SemanticChunker: !!! example "Example Usage" ```python - import weave - from dotenv import load_dotenv - from medrag_multi_modal.semantic_chunking import SemanticChunker - load_dotenv() - weave.init(project_name="ml-colabs/medrag-multi-modal") + chunker = SemanticChunker(chunk_size=256) - chunker.chunk_and_publish( - document_dataset_name="grays-anatomy-text:v13", - chunk_dataset_name="grays-anatomy-chunks", + chunker.chunk( + document_dataset="geekyrakshit/grays-anatomy-test", + chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test", ) ``` @@ -67,22 +65,71 @@ class SemanticChunker: memoize=memoize, ) - def chunk_and_publish( - self, document_dataset_name: str, chunk_dataset_name: Optional[str] = None - ) -> None: - document_dataset = weave.ref(document_dataset_name).get().rows + def chunk( + self, + document_dataset: Union[Dataset, str], + chunk_dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + ) -> Dataset: + """ + Chunks a document dataset into smaller segments and publishes them as a new dataset. + + This function takes a document dataset, either as a HuggingFace Dataset object or a string + representing the dataset repository ID, and chunks the documents into smaller segments using + the specified chunker. The resulting chunks are then optionally published to a HuggingFace + dataset repository. + + Args: + document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either + a HuggingFace Dataset object or a string representing the dataset repository ID. + chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish + the chunks to, if provided. Defaults to None. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. + + Returns: + Dataset: A HuggingFace Dataset object containing the chunks. + """ + document_dataset = ( + load_dataset(document_dataset, split="corpus") + if isinstance(document_dataset, str) + else document_dataset + ).to_list() + chunks = [] - for idx, document in track( - enumerate(document_dataset), description="Chunking documents" - ): + + async def process_document(idx, document): document_chunks = self.chunker.chunk(str(document["text"])) for chunk in document_chunks: - chunks.append( - { - "document_idx": idx, - "document_name": document["document_name"], - "page_idx": document["page_idx"], - "text": chunk, - } - ) - weave.publish(weave.Dataset(name=chunk_dataset_name, rows=chunks)) + chunk_dict = {"document_idx": idx, "text": chunk} + for key, value in document.items(): + if key not in chunk_dict: + chunk_dict[key] = value + chunks.append(chunk_dict) + + async def process_all_documents(): + tasks = [] + for idx, document in track( + enumerate(document_dataset), + total=len(document_dataset), + description="Chunking documents", + ): + tasks.append(process_document(idx, document)) + await asyncio.gather(*tasks) + + asyncio.run(process_all_documents()) + + chunks.sort(key=lambda x: x["document_idx"]) + + dataset = Dataset.from_list(chunks) + if chunk_dataset_repo_id: + if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"): + if not overwrite_dataset: + dataset = concatenate_datasets( + [ + dataset, + load_dataset(chunk_dataset_repo_id, split="chunks"), + ] + ) + dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks") + + return dataset diff --git a/medrag_multi_modal/utils.py b/medrag_multi_modal/utils.py index 4dcb6caf7604838069feb58a46ff687b53139247..d3b6237a4879199653c77235eb3ee9b0d32c764a 100644 --- a/medrag_multi_modal/utils.py +++ b/medrag_multi_modal/utils.py @@ -4,6 +4,7 @@ import io import jsonlines import torch import wandb +from huggingface_hub import HfApi from PIL import Image @@ -50,3 +51,36 @@ def read_jsonl_file(file_path: str) -> list[dict[str, any]]: with jsonlines.open(file_path) as reader: for obj in reader: return obj + + +def save_to_huggingface( + repo_id: str, local_dir: str, commit_message: str, private: bool = False +): + api = HfApi() + repo_url = api.create_repo( + repo_id=repo_id, + token=api.token, + private=private, + repo_type="model", + exist_ok=True, + ) + repo_id = repo_url.repo_id + api.upload_folder( + repo_id=repo_id, + commit_message=commit_message, + token=api.token, + folder_path=local_dir, + repo_type=repo_url.repo_type, + ) + + +def fetch_from_huggingface(repo_id: str, local_dir: str) -> str: + api = HfApi() + repo_url = api.repo_info(repo_id) + if repo_url is None: + raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") + + snapshot = api.snapshot_download(repo_id, revision=None, local_dir=local_dir) + if snapshot is None: + raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") + return snapshot diff --git a/mkdocs.yml b/mkdocs.yml deleted file mode 100644 index cdd7a9645fe0e539321dd7f85a15ce3598b721ce..0000000000000000000000000000000000000000 --- a/mkdocs.yml +++ /dev/null @@ -1,93 +0,0 @@ -# mkdocs.yml -site_name: Medrag Multi Modal - -theme: - name: material - palette: - # Palette toggle for light mode - - scheme: default - toggle: - icon: material/brightness-7 - name: Switch to dark mode - # Palette toggle for dark mode - - scheme: slate - toggle: - icon: material/brightness-4 - name: Switch to light mode - features: - - content.code.annotate - - content.code.copy - - content.code.select - - content.tabs.link - - content.tooltips - - navigation.tracking - -plugins: - - mkdocstrings - - search - - minify - - glightbox - - mkdocs-jupyter: - include_source: True - - -markdown_extensions: - - attr_list - - pymdownx.emoji: - emoji_index: !!python/name:material.extensions.emoji.twemoji - emoji_generator: !!python/name:material.extensions.emoji.to_svg - - pymdownx.arithmatex: - generic: true - - pymdownx.highlight: - anchor_linenums: true - line_spans: __span - pygments_lang_class: true - - pymdownx.tabbed: - alternate_style: true - - pymdownx.details - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences - - admonition - - attr_list - - md_in_html - -extra_javascript: - - javascripts/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -nav: - - Home: 'index.md' - - Setup: - - Installation: 'installation/install.md' - - Development: 'installation/development.md' - - App: - - MedQA Assistant: 'app.md' - - Document Loader: - - Text Loader: - - Base: 'document_loader/text_loader/base_text_loader.md' - - PyMuPDF4LLM: 'document_loader/text_loader/pymupdf4llm_text_loader.md' - - PyPDF2: 'document_loader/text_loader/pypdf2_text_loader.md' - - PDFPlumber: 'document_loader/text_loader/pdfplumber_text_loader.md' - - Marker: 'document_loader/text_loader/marker_text_loader.md' - - Image Loader: - - Base: 'document_loader/image_loader/base_img_loader.md' - - PDF2Image: 'document_loader/image_loader/pdf2image_img_loader.md' - - Marker: 'document_loader/image_loader/marker_img_loader.md' - - PDFPlumber: 'document_loader/image_loader/pdfplumber_img_loader.md' - - PyMuPDF: 'document_loader/image_loader/pymupdf_img_loader.md' - - FitzPIL: 'document_loader/image_loader/fitzpil_img_loader.md' - - Chunking: 'chunking.md' - - Retrieval: - - BM25-Sparse: 'retreival/bm25s.md' - - ColPali: 'retreival/colpali.md' - - Contriever: 'retreival/contriever.md' - - MedCPT: 'retreival/medcpt.md' - - NV-Embed-v2: 'retreival/nv_embed_2.md' - - Assistant: - - MedQA Assistant: 'assistant/medqa_assistant.md' - - Figure Annotation: 'assistant/figure_annotation.md' - - LLM Client: 'assistant/llm_client.md' - -repo_url: https://github.com/soumik12345/medrag-multi-modal diff --git a/pyproject.toml b/pyproject.toml index c93ad4d31bd42d1574db03d1a0b6e1cda910963f..c92b8110073e43000d923779828818dd796c888a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,16 +5,13 @@ description = "" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "adapters>=1.0.0", "bm25s[full]>=0.2.2", - "datasets>=3.0.1", + "datasets>=3.1.0", "einops>=0.8.0", "firerequests>=0.0.7", - "jax[cpu]>=0.4.34", "pdf2image>=1.17.0", "python-dotenv>=1.0.1", "pymupdf4llm>=0.0.17", - "torch>=2.4.1", "weave>=0.51.14", "pip>=24.2", "uv>=0.4.20", @@ -52,12 +49,10 @@ app = [ "streamlit>=1.39.0", ] core = [ - "adapters>=1.0.0", "bm25s[full]>=0.2.2", - "datasets>=3.0.1", + "datasets>=3.1.0", "einops>=0.8.0", "firerequests>=0.0.7", - "jax[cpu]>=0.4.34", "marker-pdf>=0.2.17", "pdf2image>=1.17.0", "pdfplumber>=0.11.4", @@ -68,8 +63,7 @@ core = [ "safetensors>=0.4.5", "semchunk>=2.2.0", "tiktoken>=0.8.0", - "torch>=2.4.1", - "weave>=0.51.14", + "weave>=0.51.18", "sentence-transformers>=3.2.0", "google-generativeai>=0.8.3", "mistralai>=1.1.0", @@ -100,6 +94,8 @@ medrag = "medrag_multi_modal.cli:main" [tool.pytest.ini_options] pythonpath = "." +testpaths = ["tests"] +filterwarnings = "ignore::DeprecationWarning" [tool.setuptools] -py-modules = ["medrag_multi_modal"] +py-modules = ["medrag_multi_modal"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 79681b5f31c0e36bfe04b4e0e3544cf233b962a0..d8aa85b060cff542af5b7e00093598f992ee01fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,10 @@ -adapters>=1.0.0 bm25s[full]>=0.2.2 -datasets>=3.0.1 +datasets>=3.1.0 einops>=0.8.0 firerequests>=0.0.7 -jax[cpu]>=0.4.34 pdf2image>=1.17.0 python-dotenv>=1.0.1 pymupdf4llm>=0.0.17 -torch>=2.4.1 weave>=0.51.14 pip>=24.2 uv>=0.4.20 @@ -19,6 +16,14 @@ isort>=5.13.2 black>=24.10.0 ruff>=0.6.9 marker-pdf>=0.2.17 +mkdocs>=1.6.1 +mkdocstrings>=0.26.1 +mkdocstrings-python>=1.11.1 +mkdocs-material>=9.5.39 +mkdocs-minify-plugin>=0.8.0 +mkdocs-glightbox>=0.4.0 +mkdocs-jupyter>=0.25.0 +jupyter>=1.1.1 pdfplumber>=0.11.4 semchunk>=2.2.0 tiktoken>=0.8.0 diff --git a/tests/assistant/test_llm_client.py b/tests/assistant/test_llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0ddf830f018d2264e8f550bf28a320270c0165 --- /dev/null +++ b/tests/assistant/test_llm_client.py @@ -0,0 +1,64 @@ +from PIL import Image +from pydantic import BaseModel + +from medrag_multi_modal.assistant.llm_client import ClientType, LLMClient + + +class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + +class ImageDescription(BaseModel): + description: str + + +def test_openai_llm_client(): + llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI) + event = llm_client.predict( + system_prompt="Extract the event information", + user_prompt="Alice and Bob are going to a science fair on Friday.", + schema=CalendarEvent, + ) + assert event.name.lower() == "science fair" + assert event.date.lower() == "friday" + assert [item.lower() for item in event.participants] == ["alice", "bob"] + + +def test_openai_image_description(): + llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI) + description = llm_client.predict( + system_prompt="Describe the image", + user_prompt=[Image.open("./assets/test_image.png")], + schema=ImageDescription, + ) + assert "astronaut" in description.description.lower() + + +def test_google_llm_client(): + llm_client = LLMClient( + model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI + ) + event = llm_client.predict( + system_prompt="Extract the event information", + user_prompt="Alice and Bob are going to a science fair on Friday.", + schema=CalendarEvent, + ) + event = event[0] if isinstance(event, list) else event + assert event["name"].lower() == "science fair" + assert event["date"].lower() == "friday" + assert [item.lower() for item in event["participants"]] == ["alice", "bob"] + + +def test_google_image_client(): + llm_client = LLMClient( + model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI + ) + description = llm_client.predict( + system_prompt="Describe the image", + user_prompt=[Image.open("./assets/test_image.png")], + schema=ImageDescription, + ) + description = description[0] if isinstance(description, list) else description + assert "astronaut" in description["description"].lower() diff --git a/tests/assistant/test_medqa_assistant.py b/tests/assistant/test_medqa_assistant.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0407f60c3756f481f250a1aa965723d863113a --- /dev/null +++ b/tests/assistant/test_medqa_assistant.py @@ -0,0 +1,30 @@ +import pytest + +from medrag_multi_modal.assistant import LLMClient, MedQAAssistant +from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + +@pytest.mark.retry(max_attempts=5) +def test_medqa_assistant(): + retriever = BM25sRetriever().from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" + ) + llm_client = LLMClient(model_name="gemini-1.5-flash") + medqa_assistant = MedQAAssistant( + llm_client=llm_client, + retriever=retriever, + top_k_chunks_for_query=5, + top_k_chunks_for_options=3, + ) + options = [ + "The first pharyngeal arch", + "The first and second pharyngeal arches", + "The second pharyngeal arch", + "The second and third pharyngeal arches", + ] + response = medqa_assistant.predict( + query="What is the embryological origin of the hyoid bone?", + options=options, + ) + assert response.response.answer in options + assert response.response.answer.lower() == options[2].lower() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..d607031a4a4253dd1698ac47327d1e5f5fb20fb2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--model-name", + action="store", + default="gemini-1.5-flash", + help="Model name to use for evaluation", + ) + + +def pytest_configure(config): + # Add model_name to pytest namespace for access in tests + config.addinivalue_line( + "markers", "model_name: mark test to run with specific model name" + ) + + +@pytest.fixture +def model_name(request): + return request.config.getoption("--model-name") diff --git a/tests/document_loader/image_loader.py b/tests/document_loader/image_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d10545713e7f0850b7265f8f06bbd901c2373e --- /dev/null +++ b/tests/document_loader/image_loader.py @@ -0,0 +1,59 @@ +import asyncio + +from medrag_multi_modal.document_loader.image_loader import ( + FitzPILImageLoader, + PDF2ImageLoader, + PDFPlumberImageLoader, + PyMuPDFImageLoader, +) + +URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" +COLUMN_NAMES = ["page_image", "page_figure_images", "document_name", "page_idx"] + + +def test_fitzpil_img_loader(): + loader = FitzPILImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + assert dataset.num_rows == 5 + assert dataset.column_names == COLUMN_NAMES + loader.cleanup_image_dir() + + +def test_pdf2image_img_loader(): + loader = PDF2ImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + assert dataset.num_rows == 5 + assert dataset.column_names == COLUMN_NAMES + loader.cleanup_image_dir() + + +def test_pdfplumber_img_loader(): + loader = PDFPlumberImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + assert dataset.num_rows == 5 + assert dataset.column_names == COLUMN_NAMES + loader.cleanup_image_dir() + + +def test_pymupdf_img_loader(): + loader = PyMuPDFImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + assert dataset.num_rows == 5 + assert dataset.column_names == COLUMN_NAMES + loader.cleanup_image_dir() diff --git a/tests/document_loader/text_loader.py b/tests/document_loader/text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9987af0e9672a20b4bba87582ed0493416634670 --- /dev/null +++ b/tests/document_loader/text_loader.py @@ -0,0 +1,50 @@ +import asyncio + +from medrag_multi_modal.document_loader import ( + PDFPlumberTextLoader, + PyMuPDF4LLMTextLoader, + PyPDF2TextLoader, +) + +URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" +COLUMN_NAMES = [ + "text", + "page_idx", + "document_name", + "file_path", + "file_url", + "loader_name", +] + + +def test_pdfplumber_text_loader(): + loader = PDFPlumberTextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + assert dataset.num_rows == 6 + assert dataset.column_names == COLUMN_NAMES + + +def test_pymupdf_text_loader(): + loader = PyMuPDF4LLMTextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + assert dataset.num_rows == 6 + assert dataset.column_names == COLUMN_NAMES + + +def test_pypdf2_text_loader(): + loader = PyPDF2TextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + assert dataset.num_rows == 6 + assert dataset.column_names == COLUMN_NAMES diff --git a/tests/evals/test_assistant_mmlu_anatomy.py b/tests/evals/test_assistant_mmlu_anatomy.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccefcb0052f3e72af917748e600b07d9db5e019 --- /dev/null +++ b/tests/evals/test_assistant_mmlu_anatomy.py @@ -0,0 +1,147 @@ +import asyncio + +import weave + +from medrag_multi_modal.assistant import LLMClient, MedQAAssistant +from medrag_multi_modal.metrics import MMLUOptionAccuracy +from medrag_multi_modal.retrieval.text_retrieval import ( + BM25sRetriever, + ContrieverRetriever, + MedCPTRetriever, + NVEmbed2Retriever, +) + + +def test_mmlu_correctness_anatomy_bm25s(model_name: str): + weave.init("ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever().from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" + ) + llm_client = LLMClient(model_name=model_name) + medqa_assistant = MedQAAssistant( + llm_client=llm_client, + retriever=retriever, + top_k_chunks_for_query=5, + top_k_chunks_for_options=3, + ) + dataset = weave.ref("mmlu-anatomy-test:v2").get() + with weave.attributes( + {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} + ): + evaluation = weave.Evaluation( + dataset=dataset, + scorers=[MMLUOptionAccuracy()], + name="MMLU-Anatomy-BM25s", + ) + summary = asyncio.run( + evaluation.evaluate( + medqa_assistant, + __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, + ) + ) + assert ( + summary["MMLUOptionAccuracy"]["correct"]["true_count"] + > summary["MMLUOptionAccuracy"]["correct"]["false_count"] + ) + + +def test_mmlu_correctness_anatomy_contriever(model_name: str): + weave.init("ml-colabs/medrag-multi-modal") + retriever = ContrieverRetriever().from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + llm_client = LLMClient(model_name=model_name) + medqa_assistant = MedQAAssistant( + llm_client=llm_client, + retriever=retriever, + top_k_chunks_for_query=5, + top_k_chunks_for_options=3, + ) + dataset = weave.ref("mmlu-anatomy-test:v2").get() + with weave.attributes( + {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} + ): + evaluation = weave.Evaluation( + dataset=dataset, + scorers=[MMLUOptionAccuracy()], + name="MMLU-Anatomy-Contriever", + ) + summary = asyncio.run( + evaluation.evaluate( + medqa_assistant, + __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, + ) + ) + assert ( + summary["MMLUOptionAccuracy"]["correct"]["true_count"] + > summary["MMLUOptionAccuracy"]["correct"]["false_count"] + ) + + +def test_mmlu_correctness_anatomy_medcpt(model_name: str): + weave.init("ml-colabs/medrag-multi-modal") + retriever = MedCPTRetriever().from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + llm_client = LLMClient(model_name=model_name) + medqa_assistant = MedQAAssistant( + llm_client=llm_client, + retriever=retriever, + top_k_chunks_for_query=5, + top_k_chunks_for_options=3, + ) + dataset = weave.ref("mmlu-anatomy-test:v2").get() + with weave.attributes( + {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} + ): + evaluation = weave.Evaluation( + dataset=dataset, + scorers=[MMLUOptionAccuracy()], + name="MMLU-Anatomy-MedCPT", + ) + summary = asyncio.run( + evaluation.evaluate( + medqa_assistant, + __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, + ) + ) + assert ( + summary["MMLUOptionAccuracy"]["correct"]["true_count"] + > summary["MMLUOptionAccuracy"]["correct"]["false_count"] + ) + + +def test_mmlu_correctness_anatomy_nvembed2(model_name: str): + weave.init("ml-colabs/medrag-multi-modal") + retriever = NVEmbed2Retriever().from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + llm_client = LLMClient(model_name=model_name) + medqa_assistant = MedQAAssistant( + llm_client=llm_client, + retriever=retriever, + top_k_chunks_for_query=5, + top_k_chunks_for_options=3, + ) + dataset = weave.ref("mmlu-anatomy-test:v2").get() + with weave.attributes( + {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} + ): + evaluation = weave.Evaluation( + dataset=dataset, + scorers=[MMLUOptionAccuracy()], + name="MMLU-Anatomy-NVEmbed2", + ) + summary = asyncio.run( + evaluation.evaluate( + medqa_assistant, + __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, + ) + ) + assert ( + summary["MMLUOptionAccuracy"]["correct"]["true_count"] + > summary["MMLUOptionAccuracy"]["correct"]["false_count"] + ) diff --git a/tests/retrieval/test_bm25s.py b/tests/retrieval/test_bm25s.py new file mode 100644 index 0000000000000000000000000000000000000000..32b9d1b08234dcbae5f1ff3d47fc704873e15d20 --- /dev/null +++ b/tests/retrieval/test_bm25s.py @@ -0,0 +1,14 @@ +from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + +def test_bm25s_retriever(): + retriever = BM25sRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index" + ) + retrieved_chunks = retriever.predict(query="What are Ribosomes?", top_k=2) + assert len(retrieved_chunks) == 2 + for chunk in retrieved_chunks: + assert "score" in chunk + assert "text" in chunk + assert chunk["score"] > 0 + assert "ribosomes" in chunk["text"].lower() diff --git a/tests/retrieval/test_contriever.py b/tests/retrieval/test_contriever.py new file mode 100644 index 0000000000000000000000000000000000000000..433fc2b2bb25b773a49de04880004954208368e5 --- /dev/null +++ b/tests/retrieval/test_contriever.py @@ -0,0 +1,15 @@ +from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + +def test_contriever_retriever(): + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + assert len(retrieved_chunks) == 2 + for chunk in retrieved_chunks: + assert "score" in chunk + assert "text" in chunk + assert chunk["score"] > 0 + assert "ribosomes" in chunk["text"].lower() diff --git a/tests/retrieval/test_medcpt.py b/tests/retrieval/test_medcpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bfe8f55f4200f8b16af68da33d7733305c3482 --- /dev/null +++ b/tests/retrieval/test_medcpt.py @@ -0,0 +1,15 @@ +from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + +def test_medcpt_retriever(): + retriever = MedCPTRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-medcpt", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + assert len(retrieved_chunks) == 2 + for chunk in retrieved_chunks: + assert "score" in chunk + assert "text" in chunk + assert chunk["score"] > 0 + assert "ribosomes" in chunk["text"].lower() diff --git a/tests/retrieval/test_nv_embed.py b/tests/retrieval/test_nv_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..7387b0d3b699281bf69479dd7f6435965d6cdebe --- /dev/null +++ b/tests/retrieval/test_nv_embed.py @@ -0,0 +1,15 @@ +from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + +def test_contriever_retriever(): + retriever = NVEmbed2Retriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-nvembed2", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + assert len(retrieved_chunks) == 2 + for chunk in retrieved_chunks: + assert "score" in chunk + assert "text" in chunk + assert chunk["score"] > 0 + assert "ribosomes" in chunk["text"].lower() diff --git a/tests/semantic_chunking.py b/tests/semantic_chunking.py new file mode 100644 index 0000000000000000000000000000000000000000..029b6a1d841896898d9e325cb896f8fa468a3b3d --- /dev/null +++ b/tests/semantic_chunking.py @@ -0,0 +1,16 @@ +from medrag_multi_modal.semantic_chunking import SemanticChunker + + +def test_semantic_chunking(): + chunker = SemanticChunker(chunk_size=256) + dataset = chunker.chunk(document_dataset="geekyrakshit/grays-anatomy-test") + assert dataset.num_rows == 49 + assert dataset.column_names == [ + "document_idx", + "text", + "page_idx", + "document_name", + "file_path", + "file_url", + "loader_name", + ]