Spaces:
Running
Running
geekyrakshit
commited on
add: files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +7 -7
- app.py +114 -0
- medrag_multi_modal/__init__.py +0 -0
- medrag_multi_modal/__pycache__/__init__.cpython-310.pyc +0 -0
- medrag_multi_modal/__pycache__/__init__.cpython-39.pyc +0 -0
- medrag_multi_modal/__pycache__/cli.cpython-310.pyc +0 -0
- medrag_multi_modal/__pycache__/utils.cpython-310.pyc +0 -0
- medrag_multi_modal/__pycache__/utils.cpython-39.pyc +0 -0
- medrag_multi_modal/assistant/__init__.py +5 -0
- medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc +0 -0
- medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc +0 -0
- medrag_multi_modal/assistant/figure_annotation.py +147 -0
- medrag_multi_modal/assistant/llm_client.py +245 -0
- medrag_multi_modal/assistant/medqa_assistant.py +174 -0
- medrag_multi_modal/assistant/schema.py +27 -0
- medrag_multi_modal/cli.py +68 -0
- medrag_multi_modal/document_loader/__init__.py +25 -0
- medrag_multi_modal/document_loader/image_loader/__init__.py +13 -0
- medrag_multi_modal/document_loader/image_loader/base_img_loader.py +180 -0
- medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py +127 -0
- medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +131 -0
- medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py +83 -0
- medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py +101 -0
- medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py +124 -0
- medrag_multi_modal/document_loader/text_loader/__init__.py +11 -0
- medrag_multi_modal/document_loader/text_loader/base_text_loader.py +185 -0
- medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +89 -0
- medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py +76 -0
- medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py +73 -0
- medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py +77 -0
- medrag_multi_modal/metrics/__init__.py +3 -0
- medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc +0 -0
- medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc +0 -0
- medrag_multi_modal/metrics/base.py +108 -0
- medrag_multi_modal/metrics/mmlu.py +24 -0
- medrag_multi_modal/retrieval/__init__.py +3 -0
- medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
- medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc +0 -0
- medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc +0 -0
- medrag_multi_modal/retrieval/colpali_retrieval.py +255 -0
- medrag_multi_modal/retrieval/common.py +21 -0
- medrag_multi_modal/retrieval/text_retrieval/__init__.py +11 -0
- medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
short_description: Multi-modal assistant for medical professionals
|
11 |
---
|
|
|
12 |
|
13 |
-
|
|
|
1 |
---
|
2 |
+
title: MedRAG Multi-Modal
|
3 |
+
emoji: 🩺
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: "1.39.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
+
# MedRAG Multi-Modal
|
12 |
|
13 |
+
Multi-modal RAG for medical docmain.
|
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
|
4 |
+
from medrag_multi_modal.retrieval.text_retrieval import (
|
5 |
+
BM25sRetriever,
|
6 |
+
ContrieverRetriever,
|
7 |
+
MedCPTRetriever,
|
8 |
+
NVEmbed2Retriever,
|
9 |
+
)
|
10 |
+
|
11 |
+
# Define constants
|
12 |
+
ALL_AVAILABLE_MODELS = [
|
13 |
+
"gemini-1.5-flash-latest",
|
14 |
+
"gemini-1.5-pro-latest",
|
15 |
+
"gpt-4o",
|
16 |
+
"gpt-4o-mini",
|
17 |
+
]
|
18 |
+
|
19 |
+
# Sidebar for configuration settings
|
20 |
+
st.sidebar.title("Configuration Settings")
|
21 |
+
project_name = st.sidebar.text_input(
|
22 |
+
label="Project Name",
|
23 |
+
value="ml-colabs/medrag-multi-modal",
|
24 |
+
placeholder="wandb project name",
|
25 |
+
help="format: wandb_username/wandb_project_name",
|
26 |
+
)
|
27 |
+
chunk_dataset_id = st.sidebar.selectbox(
|
28 |
+
label="Chunk Dataset ID",
|
29 |
+
options=["ashwiniai/medrag-text-corpus-chunks"],
|
30 |
+
)
|
31 |
+
llm_model = st.sidebar.selectbox(
|
32 |
+
label="LLM Model",
|
33 |
+
options=ALL_AVAILABLE_MODELS,
|
34 |
+
)
|
35 |
+
top_k_chunks_for_query = st.sidebar.slider(
|
36 |
+
label="Top K Chunks for Query",
|
37 |
+
min_value=1,
|
38 |
+
max_value=20,
|
39 |
+
value=5,
|
40 |
+
)
|
41 |
+
top_k_chunks_for_options = st.sidebar.slider(
|
42 |
+
label="Top K Chunks for Options",
|
43 |
+
min_value=1,
|
44 |
+
max_value=20,
|
45 |
+
value=3,
|
46 |
+
)
|
47 |
+
rely_only_on_context = st.sidebar.checkbox(
|
48 |
+
label="Rely Only on Context",
|
49 |
+
value=False,
|
50 |
+
)
|
51 |
+
retriever_type = st.sidebar.selectbox(
|
52 |
+
label="Retriever Type",
|
53 |
+
options=[
|
54 |
+
"",
|
55 |
+
"BM25S",
|
56 |
+
"Contriever",
|
57 |
+
"MedCPT",
|
58 |
+
"NV-Embed-v2",
|
59 |
+
],
|
60 |
+
)
|
61 |
+
|
62 |
+
if retriever_type != "":
|
63 |
+
|
64 |
+
llm_model = LLMClient(model_name=llm_model)
|
65 |
+
|
66 |
+
retriever = None
|
67 |
+
|
68 |
+
if retriever_type == "BM25S":
|
69 |
+
retriever = BM25sRetriever.from_index(
|
70 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
|
71 |
+
)
|
72 |
+
elif retriever_type == "Contriever":
|
73 |
+
retriever = ContrieverRetriever.from_index(
|
74 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
|
75 |
+
chunk_dataset_id=chunk_dataset_id,
|
76 |
+
)
|
77 |
+
elif retriever_type == "MedCPT":
|
78 |
+
retriever = MedCPTRetriever.from_index(
|
79 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
|
80 |
+
chunk_dataset_id=chunk_dataset_id,
|
81 |
+
)
|
82 |
+
elif retriever_type == "NV-Embed-v2":
|
83 |
+
retriever = NVEmbed2Retriever.from_index(
|
84 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
|
85 |
+
chunk_dataset_id=chunk_dataset_id,
|
86 |
+
)
|
87 |
+
|
88 |
+
medqa_assistant = MedQAAssistant(
|
89 |
+
llm_client=llm_model,
|
90 |
+
retriever=retriever,
|
91 |
+
top_k_chunks_for_query=top_k_chunks_for_query,
|
92 |
+
top_k_chunks_for_options=top_k_chunks_for_options,
|
93 |
+
)
|
94 |
+
|
95 |
+
with st.chat_message("assistant"):
|
96 |
+
st.markdown(
|
97 |
+
"""
|
98 |
+
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
|
99 |
+
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
|
100 |
+
|
101 |
+
**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
|
102 |
+
Please consult a medical professional for any medical advice.
|
103 |
+
|
104 |
+
In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
|
105 |
+
""",
|
106 |
+
unsafe_allow_html=True,
|
107 |
+
)
|
108 |
+
query = st.chat_input("Enter your question here")
|
109 |
+
if query:
|
110 |
+
with st.chat_message("user"):
|
111 |
+
st.markdown(query)
|
112 |
+
response = medqa_assistant.predict(query=query)
|
113 |
+
with st.chat_message("assistant"):
|
114 |
+
st.markdown(response.response)
|
medrag_multi_modal/__init__.py
ADDED
File without changes
|
medrag_multi_modal/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (186 Bytes). View file
|
|
medrag_multi_modal/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (184 Bytes). View file
|
|
medrag_multi_modal/__pycache__/cli.cpython-310.pyc
ADDED
Binary file (1.5 kB). View file
|
|
medrag_multi_modal/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.49 kB). View file
|
|
medrag_multi_modal/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.47 kB). View file
|
|
medrag_multi_modal/assistant/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
2 |
+
from .llm_client import ClientType, LLMClient
|
3 |
+
from .medqa_assistant import MedQAAssistant
|
4 |
+
|
5 |
+
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
|
medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (426 Bytes). View file
|
|
medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (424 Bytes). View file
|
|
medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc
ADDED
Binary file (6.68 kB). View file
|
|
medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc
ADDED
Binary file (6.67 kB). View file
|
|
medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc
ADDED
Binary file (7.3 kB). View file
|
|
medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc
ADDED
Binary file (7.06 kB). View file
|
|
medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc
ADDED
Binary file (7.45 kB). View file
|
|
medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc
ADDED
Binary file (1.27 kB). View file
|
|
medrag_multi_modal/assistant/figure_annotation.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import weave
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from medrag_multi_modal.assistant.llm_client import LLMClient
|
10 |
+
from medrag_multi_modal.assistant.schema import FigureAnnotations
|
11 |
+
from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file
|
12 |
+
|
13 |
+
|
14 |
+
class FigureAnnotatorFromPageImage(weave.Model):
|
15 |
+
"""
|
16 |
+
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
17 |
+
figures from a page image of a scientific textbook.
|
18 |
+
|
19 |
+
!!! example "Example Usage"
|
20 |
+
```python
|
21 |
+
import weave
|
22 |
+
from dotenv import load_dotenv
|
23 |
+
|
24 |
+
from medrag_multi_modal.assistant import (
|
25 |
+
FigureAnnotatorFromPageImage, LLMClient
|
26 |
+
)
|
27 |
+
|
28 |
+
load_dotenv()
|
29 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
30 |
+
figure_annotator = FigureAnnotatorFromPageImage(
|
31 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
32 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
33 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
34 |
+
)
|
35 |
+
annotations = figure_annotator.predict(page_idx=34)
|
36 |
+
```
|
37 |
+
|
38 |
+
Args:
|
39 |
+
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
40 |
+
from the page image.
|
41 |
+
structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
|
42 |
+
annotations into a structured format.
|
43 |
+
image_artifact_address (Optional[str]): The address of the image artifact containing the
|
44 |
+
page images.
|
45 |
+
"""
|
46 |
+
|
47 |
+
figure_extraction_llm_client: LLMClient
|
48 |
+
structured_output_llm_client: LLMClient
|
49 |
+
_artifact_dir: str
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
figure_extraction_llm_client: LLMClient,
|
54 |
+
structured_output_llm_client: LLMClient,
|
55 |
+
image_artifact_address: Optional[str] = None,
|
56 |
+
):
|
57 |
+
super().__init__(
|
58 |
+
figure_extraction_llm_client=figure_extraction_llm_client,
|
59 |
+
structured_output_llm_client=structured_output_llm_client,
|
60 |
+
)
|
61 |
+
self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
62 |
+
|
63 |
+
@weave.op()
|
64 |
+
def annotate_figures(
|
65 |
+
self, page_image: Image.Image
|
66 |
+
) -> dict[str, Union[Image.Image, str]]:
|
67 |
+
annotation = self.figure_extraction_llm_client.predict(
|
68 |
+
system_prompt="""
|
69 |
+
You are an expert in the domain of scientific textbooks, especially medical texts.
|
70 |
+
You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
|
71 |
+
You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc.
|
72 |
+
Then you are to identify the figure IDs associated with each figure in the page image.
|
73 |
+
Then, you are to extract only the exact figure descriptions from the page image.
|
74 |
+
You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object.
|
75 |
+
|
76 |
+
Here are some clues you need to follow:
|
77 |
+
1. Figure IDs are unique identifiers for each figure in the page image.
|
78 |
+
2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
|
79 |
+
3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
|
80 |
+
4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
|
81 |
+
5. The text in the page image is written in English and is present in a two-column format.
|
82 |
+
6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space.
|
83 |
+
You are to carefully identify all the figures in the page image.
|
84 |
+
7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side
|
85 |
+
or one above the other.
|
86 |
+
8. The figures may or may not have a distinct border against a white background.
|
87 |
+
10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is.
|
88 |
+
""",
|
89 |
+
user_prompt=[page_image],
|
90 |
+
)
|
91 |
+
return {"page_image": page_image, "annotations": annotation}
|
92 |
+
|
93 |
+
@weave.op
|
94 |
+
def extract_structured_output(self, annotations: str) -> FigureAnnotations:
|
95 |
+
return self.structured_output_llm_client.predict(
|
96 |
+
system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
|
97 |
+
user_prompt=[annotations],
|
98 |
+
schema=FigureAnnotations,
|
99 |
+
)
|
100 |
+
|
101 |
+
@weave.op()
|
102 |
+
def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]:
|
103 |
+
"""
|
104 |
+
Predicts figure annotations for a specific page in a document.
|
105 |
+
|
106 |
+
This function retrieves the artifact directory from the given image artifact address,
|
107 |
+
reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
|
108 |
+
to find the specified page index. If the page index matches, it reads the page image
|
109 |
+
and associated figure images, and then uses the `annotate_figures` method to extract
|
110 |
+
figure annotations from the page image. The extracted annotations are then structured
|
111 |
+
using the `extract_structured_output` method and returned as a dictionary.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
page_idx (int): The index of the page to annotate.
|
115 |
+
image_artifact_address (str): The address of the image artifact containing the
|
116 |
+
page images.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
dict: A dictionary containing the page index as the key and the extracted figure
|
120 |
+
annotations as the value.
|
121 |
+
"""
|
122 |
+
|
123 |
+
metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
|
124 |
+
annotations = {}
|
125 |
+
for item in metadata:
|
126 |
+
if item["page_idx"] == page_idx:
|
127 |
+
page_image_file = os.path.join(
|
128 |
+
self._artifact_dir, f"page{item['page_idx']}.png"
|
129 |
+
)
|
130 |
+
figure_image_files = glob(
|
131 |
+
os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
|
132 |
+
)
|
133 |
+
if len(figure_image_files) > 0:
|
134 |
+
page_image = cv2.imread(page_image_file)
|
135 |
+
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
136 |
+
page_image = Image.fromarray(page_image)
|
137 |
+
figure_extracted_annotations = self.annotate_figures(
|
138 |
+
page_image=page_image
|
139 |
+
)
|
140 |
+
figure_extracted_annotations = self.extract_structured_output(
|
141 |
+
figure_extracted_annotations["annotations"]
|
142 |
+
).model_dump()
|
143 |
+
annotations[item["page_idx"]] = figure_extracted_annotations[
|
144 |
+
"annotations"
|
145 |
+
]
|
146 |
+
break
|
147 |
+
return annotations
|
medrag_multi_modal/assistant/llm_client.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Any, Optional, Union
|
5 |
+
|
6 |
+
import instructor
|
7 |
+
import weave
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from ..utils import base64_encode_image
|
11 |
+
|
12 |
+
|
13 |
+
class ClientType(str, Enum):
|
14 |
+
GEMINI = "gemini"
|
15 |
+
MISTRAL = "mistral"
|
16 |
+
OPENAI = "openai"
|
17 |
+
|
18 |
+
|
19 |
+
GOOGLE_MODELS = [
|
20 |
+
"gemini-1.0-pro-latest",
|
21 |
+
"gemini-1.0-pro",
|
22 |
+
"gemini-pro",
|
23 |
+
"gemini-1.0-pro-001",
|
24 |
+
"gemini-1.0-pro-vision-latest",
|
25 |
+
"gemini-pro-vision",
|
26 |
+
"gemini-1.5-pro-latest",
|
27 |
+
"gemini-1.5-pro-001",
|
28 |
+
"gemini-1.5-pro-002",
|
29 |
+
"gemini-1.5-pro",
|
30 |
+
"gemini-1.5-pro-exp-0801",
|
31 |
+
"gemini-1.5-pro-exp-0827",
|
32 |
+
"gemini-1.5-flash-latest",
|
33 |
+
"gemini-1.5-flash-001",
|
34 |
+
"gemini-1.5-flash-001-tuning",
|
35 |
+
"gemini-1.5-flash",
|
36 |
+
"gemini-1.5-flash-exp-0827",
|
37 |
+
"gemini-1.5-flash-002",
|
38 |
+
"gemini-1.5-flash-8b",
|
39 |
+
"gemini-1.5-flash-8b-001",
|
40 |
+
"gemini-1.5-flash-8b-latest",
|
41 |
+
"gemini-1.5-flash-8b-exp-0827",
|
42 |
+
"gemini-1.5-flash-8b-exp-0924",
|
43 |
+
]
|
44 |
+
|
45 |
+
MISTRAL_MODELS = [
|
46 |
+
"ministral-3b-latest",
|
47 |
+
"ministral-8b-latest",
|
48 |
+
"mistral-large-latest",
|
49 |
+
"mistral-small-latest",
|
50 |
+
"codestral-latest",
|
51 |
+
"pixtral-12b-2409",
|
52 |
+
"open-mistral-nemo",
|
53 |
+
"open-codestral-mamba",
|
54 |
+
"open-mistral-7b",
|
55 |
+
"open-mixtral-8x7b",
|
56 |
+
"open-mixtral-8x22b",
|
57 |
+
]
|
58 |
+
|
59 |
+
OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
|
60 |
+
|
61 |
+
|
62 |
+
class LLMClient(weave.Model):
|
63 |
+
"""
|
64 |
+
LLMClient is a class that interfaces with different large language model (LLM) providers
|
65 |
+
such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
|
66 |
+
these different APIs and provides a unified interface for making predictions.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
model_name (str): The name of the model to be used for predictions.
|
70 |
+
client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
|
71 |
+
If not provided, it is inferred from the model_name.
|
72 |
+
"""
|
73 |
+
|
74 |
+
model_name: str
|
75 |
+
client_type: Optional[ClientType]
|
76 |
+
|
77 |
+
def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
|
78 |
+
if client_type is None:
|
79 |
+
if model_name in GOOGLE_MODELS:
|
80 |
+
client_type = ClientType.GEMINI
|
81 |
+
elif model_name in MISTRAL_MODELS:
|
82 |
+
client_type = ClientType.MISTRAL
|
83 |
+
elif model_name in OPENAI_MODELS:
|
84 |
+
client_type = ClientType.OPENAI
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Invalid model name: {model_name}")
|
87 |
+
super().__init__(model_name=model_name, client_type=client_type)
|
88 |
+
|
89 |
+
@weave.op()
|
90 |
+
def execute_gemini_sdk(
|
91 |
+
self,
|
92 |
+
user_prompt: Union[str, list[str]],
|
93 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
94 |
+
schema: Optional[Any] = None,
|
95 |
+
) -> Union[str, Any]:
|
96 |
+
import google.generativeai as genai
|
97 |
+
from google.generativeai.types import HarmBlockThreshold, HarmCategory
|
98 |
+
|
99 |
+
system_prompt = (
|
100 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
101 |
+
)
|
102 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
103 |
+
|
104 |
+
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
|
105 |
+
model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
|
106 |
+
generation_config = (
|
107 |
+
None
|
108 |
+
if schema is None
|
109 |
+
else genai.GenerationConfig(
|
110 |
+
response_mime_type="application/json", response_schema=schema
|
111 |
+
)
|
112 |
+
)
|
113 |
+
response = model.generate_content(
|
114 |
+
user_prompt,
|
115 |
+
generation_config=generation_config,
|
116 |
+
# This is necessary in order to answer questions about anatomy, sexual diseases,
|
117 |
+
# medical devices, medicines, etc.
|
118 |
+
safety_settings={
|
119 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
120 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
121 |
+
},
|
122 |
+
)
|
123 |
+
return response.text if schema is None else json.loads(response.text)
|
124 |
+
|
125 |
+
@weave.op()
|
126 |
+
def execute_mistral_sdk(
|
127 |
+
self,
|
128 |
+
user_prompt: Union[str, list[str]],
|
129 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
130 |
+
schema: Optional[Any] = None,
|
131 |
+
) -> Union[str, Any]:
|
132 |
+
from mistralai import Mistral
|
133 |
+
|
134 |
+
system_prompt = (
|
135 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
136 |
+
)
|
137 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
138 |
+
system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
|
139 |
+
user_messages = []
|
140 |
+
for prompt in user_prompt:
|
141 |
+
if isinstance(prompt, Image.Image):
|
142 |
+
user_messages.append(
|
143 |
+
{
|
144 |
+
"type": "image_url",
|
145 |
+
"image_url": base64_encode_image(prompt, "image/png"),
|
146 |
+
}
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
user_messages.append({"type": "text", "text": prompt})
|
150 |
+
messages = [
|
151 |
+
{"role": "system", "content": system_messages},
|
152 |
+
{"role": "user", "content": user_messages},
|
153 |
+
]
|
154 |
+
|
155 |
+
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
156 |
+
client = instructor.from_mistral(client) if schema is not None else client
|
157 |
+
|
158 |
+
if schema is None:
|
159 |
+
raise NotImplementedError(
|
160 |
+
"Mistral does not support structured output using a schema"
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
response = client.chat.complete(model=self.model_name, messages=messages)
|
164 |
+
return response.choices[0].message.content
|
165 |
+
|
166 |
+
@weave.op()
|
167 |
+
def execute_openai_sdk(
|
168 |
+
self,
|
169 |
+
user_prompt: Union[str, list[str]],
|
170 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
171 |
+
schema: Optional[Any] = None,
|
172 |
+
) -> Union[str, Any]:
|
173 |
+
from openai import OpenAI
|
174 |
+
|
175 |
+
system_prompt = (
|
176 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
177 |
+
)
|
178 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
179 |
+
|
180 |
+
system_messages = [
|
181 |
+
{"role": "system", "content": prompt} for prompt in system_prompt
|
182 |
+
]
|
183 |
+
user_messages = []
|
184 |
+
for prompt in user_prompt:
|
185 |
+
if isinstance(prompt, Image.Image):
|
186 |
+
user_messages.append(
|
187 |
+
{
|
188 |
+
"type": "image_url",
|
189 |
+
"image_url": {
|
190 |
+
"url": base64_encode_image(prompt, "image/png"),
|
191 |
+
},
|
192 |
+
},
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
user_messages.append({"type": "text", "text": prompt})
|
196 |
+
messages = system_messages + [{"role": "user", "content": user_messages}]
|
197 |
+
|
198 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
199 |
+
|
200 |
+
if schema is None:
|
201 |
+
completion = client.chat.completions.create(
|
202 |
+
model=self.model_name, messages=messages
|
203 |
+
)
|
204 |
+
return completion.choices[0].message.content
|
205 |
+
|
206 |
+
completion = weave.op()(client.beta.chat.completions.parse)(
|
207 |
+
model=self.model_name, messages=messages, response_format=schema
|
208 |
+
)
|
209 |
+
return completion.choices[0].message.parsed
|
210 |
+
|
211 |
+
@weave.op()
|
212 |
+
def predict(
|
213 |
+
self,
|
214 |
+
user_prompt: Union[str, list[str]],
|
215 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
216 |
+
schema: Optional[Any] = None,
|
217 |
+
) -> Union[str, Any]:
|
218 |
+
"""
|
219 |
+
Predicts the response from a language model based on the provided prompts and schema.
|
220 |
+
|
221 |
+
This function determines the client type and calls the appropriate SDK execution function
|
222 |
+
to get the response from the language model. It supports multiple client types including
|
223 |
+
GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
|
224 |
+
execution function with the provided user and system prompts, and an optional schema.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
|
228 |
+
system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
|
229 |
+
schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
Union[str, Any]: The response from the language model, which could be a string or any other type
|
233 |
+
depending on the schema provided.
|
234 |
+
|
235 |
+
Raises:
|
236 |
+
ValueError: If the client type is invalid.
|
237 |
+
"""
|
238 |
+
if self.client_type == ClientType.GEMINI:
|
239 |
+
return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
|
240 |
+
elif self.client_type == ClientType.MISTRAL:
|
241 |
+
return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
|
242 |
+
elif self.client_type == ClientType.OPENAI:
|
243 |
+
return self.execute_openai_sdk(user_prompt, system_prompt, schema)
|
244 |
+
else:
|
245 |
+
raise ValueError(f"Invalid client type: {self.client_type}")
|
medrag_multi_modal/assistant/medqa_assistant.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import weave
|
4 |
+
|
5 |
+
from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
|
6 |
+
from medrag_multi_modal.assistant.llm_client import LLMClient
|
7 |
+
from medrag_multi_modal.assistant.schema import (
|
8 |
+
MedQACitation,
|
9 |
+
MedQAMCQResponse,
|
10 |
+
MedQAResponse,
|
11 |
+
)
|
12 |
+
from medrag_multi_modal.retrieval.common import SimilarityMetric
|
13 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
14 |
+
|
15 |
+
|
16 |
+
class MedQAAssistant(weave.Model):
|
17 |
+
"""
|
18 |
+
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
|
19 |
+
language model client, a retriever model, and a figure annotator.
|
20 |
+
|
21 |
+
!!! example "Usage Example"
|
22 |
+
```python
|
23 |
+
import weave
|
24 |
+
from dotenv import load_dotenv
|
25 |
+
|
26 |
+
from medrag_multi_modal.assistant import (
|
27 |
+
FigureAnnotatorFromPageImage,
|
28 |
+
LLMClient,
|
29 |
+
MedQAAssistant,
|
30 |
+
)
|
31 |
+
from medrag_multi_modal.retrieval import MedCPTRetriever
|
32 |
+
|
33 |
+
load_dotenv()
|
34 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
35 |
+
|
36 |
+
llm_client = LLMClient(model_name="gemini-1.5-flash")
|
37 |
+
|
38 |
+
retriever=MedCPTRetriever.from_wandb_artifact(
|
39 |
+
chunk_dataset_name="grays-anatomy-chunks:v0",
|
40 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
|
41 |
+
)
|
42 |
+
|
43 |
+
figure_annotator=FigureAnnotatorFromPageImage(
|
44 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
45 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
46 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
47 |
+
)
|
48 |
+
medqa_assistant = MedQAAssistant(
|
49 |
+
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
|
50 |
+
)
|
51 |
+
medqa_assistant.predict(query="What is ribosome?")
|
52 |
+
```
|
53 |
+
|
54 |
+
Args:
|
55 |
+
llm_client (LLMClient): The language model client used to generate responses.
|
56 |
+
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
57 |
+
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
58 |
+
top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
|
59 |
+
top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
|
60 |
+
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
61 |
+
"""
|
62 |
+
|
63 |
+
llm_client: LLMClient
|
64 |
+
retriever: weave.Model
|
65 |
+
figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
|
66 |
+
top_k_chunks_for_query: int = 2
|
67 |
+
top_k_chunks_for_options: int = 2
|
68 |
+
rely_only_on_context: bool = True
|
69 |
+
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
70 |
+
|
71 |
+
@weave.op()
|
72 |
+
def retrieve_chunks_for_query(self, query: str) -> list[dict]:
|
73 |
+
retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
|
74 |
+
if not isinstance(self.retriever, BM25sRetriever):
|
75 |
+
retriever_kwargs["metric"] = self.retrieval_similarity_metric
|
76 |
+
return self.retriever.predict(query, **retriever_kwargs)
|
77 |
+
|
78 |
+
@weave.op()
|
79 |
+
def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
|
80 |
+
retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
|
81 |
+
if not isinstance(self.retriever, BM25sRetriever):
|
82 |
+
retriever_kwargs["metric"] = self.retrieval_similarity_metric
|
83 |
+
retrieved_chunks = []
|
84 |
+
for option in options:
|
85 |
+
retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
|
86 |
+
return retrieved_chunks
|
87 |
+
|
88 |
+
@weave.op()
|
89 |
+
def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
|
90 |
+
"""
|
91 |
+
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
92 |
+
from a medical document and using a language model to generate the final response.
|
93 |
+
|
94 |
+
This function performs the following steps:
|
95 |
+
1. Retrieves relevant text chunks from the medical document based on the query and any provided options
|
96 |
+
using the retriever model.
|
97 |
+
2. Extracts the text and page indices from the retrieved chunks.
|
98 |
+
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
99 |
+
4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
|
100 |
+
and figure descriptions.
|
101 |
+
5. Uses the language model client to generate a response based on the constructed prompts, either choosing
|
102 |
+
from provided options or generating a free-form response.
|
103 |
+
6. Returns the generated response, which includes the answer and explanation if options were provided.
|
104 |
+
|
105 |
+
The function can operate in two modes:
|
106 |
+
- Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
|
107 |
+
- Free response: When no options are provided, it generates a comprehensive response based on the context
|
108 |
+
|
109 |
+
Args:
|
110 |
+
query (str): The medical query to be answered.
|
111 |
+
options (Optional[list[str]]): The list of options to choose from.
|
112 |
+
rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
MedQAResponse: The generated response to the query, including source information.
|
116 |
+
"""
|
117 |
+
retrieved_chunks = self.retrieve_chunks_for_query(query)
|
118 |
+
options = options or []
|
119 |
+
retrieved_chunks += self.retrieve_chunks_for_options(options)
|
120 |
+
|
121 |
+
retrieved_chunk_texts = []
|
122 |
+
page_indices = set()
|
123 |
+
for chunk in retrieved_chunks:
|
124 |
+
retrieved_chunk_texts.append(chunk["text"])
|
125 |
+
page_indices.add(int(chunk["page_idx"]))
|
126 |
+
|
127 |
+
figure_descriptions = []
|
128 |
+
if self.figure_annotator is not None:
|
129 |
+
for page_idx in page_indices:
|
130 |
+
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
|
131 |
+
page_idx
|
132 |
+
]
|
133 |
+
figure_descriptions += [
|
134 |
+
item["figure_description"] for item in figure_annotations
|
135 |
+
]
|
136 |
+
|
137 |
+
system_prompt = """You are an expert in medical science. You are given a question
|
138 |
+
and a list of excerpts from various medical documents.
|
139 |
+
"""
|
140 |
+
query = f"""# Question
|
141 |
+
{query}
|
142 |
+
"""
|
143 |
+
|
144 |
+
if len(options) > 0:
|
145 |
+
system_prompt += """\nYou are also given a list of options to choose your answer from.
|
146 |
+
You are supposed to choose the best possible option based on the context provided. You should also
|
147 |
+
explain your answer to justify why you chose that option.
|
148 |
+
"""
|
149 |
+
query += "## Options\n"
|
150 |
+
for option in options:
|
151 |
+
query += f"- {option}\n"
|
152 |
+
else:
|
153 |
+
system_prompt += "\nYou are supposed to answer the question based on the context provided."
|
154 |
+
|
155 |
+
if self.rely_only_on_context:
|
156 |
+
system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
|
157 |
+
You are not allowed to use any external knowledge to answer the question.
|
158 |
+
"""
|
159 |
+
|
160 |
+
response = self.llm_client.predict(
|
161 |
+
system_prompt=system_prompt,
|
162 |
+
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
163 |
+
schema=MedQAMCQResponse if len(options) > 0 else None,
|
164 |
+
)
|
165 |
+
|
166 |
+
# TODO: Add figure citations
|
167 |
+
# TODO: Add source document name from retrieved chunks as citations
|
168 |
+
citations = []
|
169 |
+
for page_idx in page_indices:
|
170 |
+
citations.append(
|
171 |
+
MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
|
172 |
+
)
|
173 |
+
|
174 |
+
return MedQAResponse(response=response, citations=citations)
|
medrag_multi_modal/assistant/schema.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class FigureAnnotation(BaseModel):
|
7 |
+
figure_id: str
|
8 |
+
figure_description: str
|
9 |
+
|
10 |
+
|
11 |
+
class FigureAnnotations(BaseModel):
|
12 |
+
annotations: list[FigureAnnotation]
|
13 |
+
|
14 |
+
|
15 |
+
class MedQAMCQResponse(BaseModel):
|
16 |
+
answer: str
|
17 |
+
explanation: str
|
18 |
+
|
19 |
+
|
20 |
+
class MedQACitation(BaseModel):
|
21 |
+
page_number: int
|
22 |
+
document_name: str
|
23 |
+
|
24 |
+
|
25 |
+
class MedQAResponse(BaseModel):
|
26 |
+
response: Union[str, MedQAMCQResponse]
|
27 |
+
citations: list[MedQACitation]
|
medrag_multi_modal/cli.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
|
9 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
10 |
+
|
11 |
+
# Run subcommand
|
12 |
+
run_parser = subparsers.add_parser("run", help="Run the Streamlit application")
|
13 |
+
run_parser.add_argument(
|
14 |
+
"--port", type=int, default=8501, help="Port to run Streamlit on"
|
15 |
+
)
|
16 |
+
|
17 |
+
# Evaluate subcommand
|
18 |
+
eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests")
|
19 |
+
eval_parser.add_argument(
|
20 |
+
"--test-file",
|
21 |
+
default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"),
|
22 |
+
help="Path to test file",
|
23 |
+
)
|
24 |
+
eval_parser.add_argument(
|
25 |
+
"--test-case",
|
26 |
+
type=str,
|
27 |
+
help="Only run tests which match the given substring expression",
|
28 |
+
)
|
29 |
+
eval_parser.add_argument(
|
30 |
+
"--model-name",
|
31 |
+
type=str,
|
32 |
+
default="gemini-1.5-flash",
|
33 |
+
help="Model name to use for evaluation",
|
34 |
+
)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
if args.command == "run":
|
39 |
+
subprocess.run(
|
40 |
+
[
|
41 |
+
sys.executable,
|
42 |
+
"-m",
|
43 |
+
"streamlit",
|
44 |
+
"run",
|
45 |
+
"app.py",
|
46 |
+
"--server.port",
|
47 |
+
str(args.port),
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
elif args.command == "evaluate":
|
52 |
+
test_file = (
|
53 |
+
args.test_file + "::" + args.test_case if args.test_case else args.test_file
|
54 |
+
)
|
55 |
+
cmd = [
|
56 |
+
sys.executable,
|
57 |
+
"-m",
|
58 |
+
"pytest",
|
59 |
+
"-s",
|
60 |
+
test_file,
|
61 |
+
"-v",
|
62 |
+
f"--model-name={args.model_name}",
|
63 |
+
]
|
64 |
+
subprocess.run(cmd)
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main()
|
medrag_multi_modal/document_loader/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_loader import (
|
2 |
+
FitzPILImageLoader,
|
3 |
+
MarkerImageLoader,
|
4 |
+
PDF2ImageLoader,
|
5 |
+
PDFPlumberImageLoader,
|
6 |
+
PyMuPDFImageLoader,
|
7 |
+
)
|
8 |
+
from .text_loader import (
|
9 |
+
MarkerTextLoader,
|
10 |
+
PDFPlumberTextLoader,
|
11 |
+
PyMuPDF4LLMTextLoader,
|
12 |
+
PyPDF2TextLoader,
|
13 |
+
)
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"PyMuPDF4LLMTextLoader",
|
17 |
+
"PyPDF2TextLoader",
|
18 |
+
"PDFPlumberTextLoader",
|
19 |
+
"MarkerTextLoader",
|
20 |
+
"PDF2ImageLoader",
|
21 |
+
"MarkerImageLoader",
|
22 |
+
"PDFPlumberImageLoader",
|
23 |
+
"PyMuPDFImageLoader",
|
24 |
+
"FitzPILImageLoader",
|
25 |
+
]
|
medrag_multi_modal/document_loader/image_loader/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fitzpil_img_loader import FitzPILImageLoader
|
2 |
+
from .marker_img_loader import MarkerImageLoader
|
3 |
+
from .pdf2image_img_loader import PDF2ImageLoader
|
4 |
+
from .pdfplumber_img_loader import PDFPlumberImageLoader
|
5 |
+
from .pymupdf_img_loader import PyMuPDFImageLoader
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"PDF2ImageLoader",
|
9 |
+
"MarkerImageLoader",
|
10 |
+
"PDFPlumberImageLoader",
|
11 |
+
"PyMuPDFImageLoader",
|
12 |
+
"FitzPILImageLoader",
|
13 |
+
]
|
medrag_multi_modal/document_loader/image_loader/base_img_loader.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
from abc import abstractmethod
|
4 |
+
from glob import glob
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
|
7 |
+
import huggingface_hub
|
8 |
+
import jsonlines
|
9 |
+
import rich
|
10 |
+
from datasets import (
|
11 |
+
Dataset,
|
12 |
+
Features,
|
13 |
+
Image,
|
14 |
+
Sequence,
|
15 |
+
Value,
|
16 |
+
concatenate_datasets,
|
17 |
+
load_dataset,
|
18 |
+
)
|
19 |
+
|
20 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
21 |
+
BaseTextLoader,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class BaseImageLoader(BaseTextLoader):
|
26 |
+
def __init__(self, url: str, document_name: str, document_file_path: str):
|
27 |
+
super().__init__(url, document_name, document_file_path)
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
async def extract_page_data(
|
31 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
32 |
+
) -> Dict[str, str]:
|
33 |
+
"""
|
34 |
+
Abstract method to process a single page of the PDF and extract the image data.
|
35 |
+
|
36 |
+
Overwrite this method in the subclass to provide the actual implementation and
|
37 |
+
processing logic for each page of the PDF using various PDF processing libraries.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
page_idx (int): The index of the page to process.
|
41 |
+
image_save_dir (str): The directory to save the extracted images.
|
42 |
+
**kwargs: Additional keyword arguments that may be used by underlying libraries.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
46 |
+
"""
|
47 |
+
pass
|
48 |
+
|
49 |
+
def save_as_dataset(
|
50 |
+
self,
|
51 |
+
start_page: int,
|
52 |
+
end_page: int,
|
53 |
+
image_save_dir: str,
|
54 |
+
dataset_repo_id: Optional[str] = None,
|
55 |
+
overwrite_dataset: bool = False,
|
56 |
+
):
|
57 |
+
features = Features(
|
58 |
+
{
|
59 |
+
"page_image": Image(decode=True),
|
60 |
+
"page_figure_images": Sequence(Image(decode=True)),
|
61 |
+
"document_name": Value(dtype="string"),
|
62 |
+
"page_idx": Value(dtype="int32"),
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
all_examples = []
|
67 |
+
for page_idx in range(start_page, end_page):
|
68 |
+
page_image_file_paths = glob(
|
69 |
+
os.path.join(image_save_dir, f"page{page_idx}*.png")
|
70 |
+
)
|
71 |
+
if len(page_image_file_paths) > 0:
|
72 |
+
page_image_path = page_image_file_paths[0]
|
73 |
+
figure_image_paths = [
|
74 |
+
image_file_path
|
75 |
+
for image_file_path in glob(
|
76 |
+
os.path.join(image_save_dir, f"page{page_idx}*_fig*.png")
|
77 |
+
)
|
78 |
+
]
|
79 |
+
|
80 |
+
example = {
|
81 |
+
"page_image": page_image_path,
|
82 |
+
"page_figure_images": figure_image_paths,
|
83 |
+
"document_name": self.document_name,
|
84 |
+
"page_idx": page_idx,
|
85 |
+
}
|
86 |
+
all_examples.append(example)
|
87 |
+
|
88 |
+
dataset = Dataset.from_list(all_examples, features=features)
|
89 |
+
|
90 |
+
if dataset_repo_id:
|
91 |
+
if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
|
92 |
+
if not overwrite_dataset:
|
93 |
+
dataset = concatenate_datasets(
|
94 |
+
[dataset, load_dataset(dataset_repo_id)["corpus"]]
|
95 |
+
)
|
96 |
+
|
97 |
+
dataset.push_to_hub(dataset_repo_id, split="corpus")
|
98 |
+
|
99 |
+
return dataset
|
100 |
+
|
101 |
+
def cleanup_image_dir(self, image_save_dir: str = "./images"):
|
102 |
+
for file in os.listdir(image_save_dir):
|
103 |
+
file_path = os.path.join(image_save_dir, file)
|
104 |
+
if os.path.isfile(file_path):
|
105 |
+
os.remove(file_path)
|
106 |
+
|
107 |
+
async def load_data(
|
108 |
+
self,
|
109 |
+
start_page: Optional[int] = None,
|
110 |
+
end_page: Optional[int] = None,
|
111 |
+
dataset_repo_id: Optional[str] = None,
|
112 |
+
overwrite_dataset: bool = False,
|
113 |
+
image_save_dir: str = "./images",
|
114 |
+
exclude_file_extensions: list[str] = [],
|
115 |
+
**kwargs,
|
116 |
+
) -> List[Dict[str, str]]:
|
117 |
+
"""
|
118 |
+
Asynchronously loads images from a PDF file specified by a URL or local file path.
|
119 |
+
The overrided processing abstract method then processes the images,
|
120 |
+
and optionally publishes it to a WandB artifact.
|
121 |
+
|
122 |
+
This function downloads a PDF from a given URL if it does not already exist locally,
|
123 |
+
reads the specified range of pages, scans each page's content to extract images, and
|
124 |
+
returns a list of Page objects containing the images and metadata.
|
125 |
+
|
126 |
+
It uses `PyPDF2` to calculate the number of pages in the PDF and the
|
127 |
+
overriden `extract_page_data` method provides the actual implementation to process
|
128 |
+
each page, extract the image content from the PDF, and convert it to png format.
|
129 |
+
It processes pages concurrently using `asyncio` for efficiency.
|
130 |
+
|
131 |
+
If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
start_page (Optional[int]): The starting page index (0-based) to process.
|
135 |
+
end_page (Optional[int]): The ending page index (0-based) to process.
|
136 |
+
dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
|
137 |
+
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
|
138 |
+
image_save_dir (str): The directory to save the extracted images.
|
139 |
+
exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
|
140 |
+
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Dataset: A HuggingFace dataset containing the processed pages.
|
144 |
+
|
145 |
+
Raises:
|
146 |
+
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
147 |
+
"""
|
148 |
+
os.makedirs(image_save_dir, exist_ok=True)
|
149 |
+
start_page, end_page = self.get_page_indices(start_page, end_page)
|
150 |
+
pages = []
|
151 |
+
processed_pages_counter: int = 1
|
152 |
+
total_pages = end_page - start_page
|
153 |
+
|
154 |
+
async def process_page(page_idx):
|
155 |
+
nonlocal processed_pages_counter
|
156 |
+
page_data = await self.extract_page_data(page_idx, image_save_dir, **kwargs)
|
157 |
+
pages.append(page_data)
|
158 |
+
rich.print(
|
159 |
+
f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
|
160 |
+
)
|
161 |
+
processed_pages_counter += 1
|
162 |
+
|
163 |
+
tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)]
|
164 |
+
for task in asyncio.as_completed(tasks):
|
165 |
+
await task
|
166 |
+
|
167 |
+
with jsonlines.open(
|
168 |
+
os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
|
169 |
+
) as writer:
|
170 |
+
writer.write(pages)
|
171 |
+
|
172 |
+
for file in os.listdir(image_save_dir):
|
173 |
+
if file.endswith(tuple(exclude_file_extensions)):
|
174 |
+
os.remove(os.path.join(image_save_dir, file))
|
175 |
+
|
176 |
+
dataset = self.save_as_dataset(
|
177 |
+
start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset
|
178 |
+
)
|
179 |
+
|
180 |
+
return dataset
|
medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import fitz
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
+
from PIL import Image, ImageOps, UnidentifiedImageError
|
8 |
+
|
9 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
10 |
+
BaseImageLoader,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class FitzPILImageLoader(BaseImageLoader):
|
15 |
+
"""
|
16 |
+
`FitzPILImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
|
17 |
+
loading of pages from a PDF file as images using the fitz and PIL libraries.
|
18 |
+
|
19 |
+
This class provides functionality to extract images from a PDF file using fitz and PIL libraries,
|
20 |
+
and optionally publish these images to a WandB artifact.
|
21 |
+
|
22 |
+
!!! example "Example Usage"
|
23 |
+
```python
|
24 |
+
import asyncio
|
25 |
+
|
26 |
+
from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
|
27 |
+
|
28 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
29 |
+
|
30 |
+
loader = FitzPILImageLoader(
|
31 |
+
url=URL,
|
32 |
+
document_name="Gray's Anatomy",
|
33 |
+
document_file_path="grays_anatomy.pdf",
|
34 |
+
)
|
35 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
36 |
+
```
|
37 |
+
|
38 |
+
Args:
|
39 |
+
url (str): The URL of the PDF document.
|
40 |
+
document_name (str): The name of the document.
|
41 |
+
document_file_path (str): The path to the PDF file.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, url: str, document_name: str, document_file_path: str):
|
45 |
+
super().__init__(url, document_name, document_file_path)
|
46 |
+
|
47 |
+
async def extract_page_data(
|
48 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
49 |
+
) -> Dict[str, Any]:
|
50 |
+
"""
|
51 |
+
Extracts a single page from the PDF as an image using fitz and PIL libraries.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
page_idx (int): The index of the page to process.
|
55 |
+
image_save_dir (str): The directory to save the extracted image.
|
56 |
+
**kwargs: Additional keyword arguments that may be used by fitz and PIL.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Dict[str, Any]: A dictionary containing the processed page data.
|
60 |
+
The dictionary will have the following keys and values:
|
61 |
+
|
62 |
+
- "page_idx": (int) the index of the page.
|
63 |
+
- "document_name": (str) the name of the document.
|
64 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
65 |
+
- "file_url": (str) the URL of the PDF file.
|
66 |
+
- "image_file_paths": (list) the local file paths where the images are stored.
|
67 |
+
"""
|
68 |
+
image_file_paths = []
|
69 |
+
|
70 |
+
pdf_document = fitz.open(self.document_file_path)
|
71 |
+
page = pdf_document.load_page(page_idx)
|
72 |
+
|
73 |
+
images = page.get_images(full=True)
|
74 |
+
for img_idx, image in enumerate(images):
|
75 |
+
xref = image[0]
|
76 |
+
base_image = pdf_document.extract_image(xref)
|
77 |
+
image_bytes = base_image["image"]
|
78 |
+
image_ext = base_image["ext"]
|
79 |
+
|
80 |
+
try:
|
81 |
+
img = Image.open(io.BytesIO(image_bytes))
|
82 |
+
|
83 |
+
if img.mode in ["L"]:
|
84 |
+
# images in greyscale looks inverted, need to test on other PDFs
|
85 |
+
img = ImageOps.invert(img)
|
86 |
+
|
87 |
+
if img.mode == "CMYK":
|
88 |
+
img = img.convert("RGB")
|
89 |
+
|
90 |
+
if image_ext not in ["png", "jpg", "jpeg"]:
|
91 |
+
image_ext = "png"
|
92 |
+
image_file_name = f"page{page_idx}_fig{img_idx}.png"
|
93 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
94 |
+
|
95 |
+
img.save(image_file_path, format="PNG")
|
96 |
+
else:
|
97 |
+
image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
|
98 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
99 |
+
|
100 |
+
with open(image_file_path, "wb") as image_file:
|
101 |
+
image_file.write(image_bytes)
|
102 |
+
|
103 |
+
image_file_paths.append(image_file_path)
|
104 |
+
|
105 |
+
except (UnidentifiedImageError, OSError) as e:
|
106 |
+
print(
|
107 |
+
f"Skipping image at page {page_idx}, fig {img_idx} due to an error: {e}"
|
108 |
+
)
|
109 |
+
continue
|
110 |
+
|
111 |
+
pdf_document.close()
|
112 |
+
|
113 |
+
page_image = convert_from_path(
|
114 |
+
self.document_file_path,
|
115 |
+
first_page=page_idx + 1,
|
116 |
+
last_page=page_idx + 1,
|
117 |
+
**kwargs,
|
118 |
+
)[0]
|
119 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
120 |
+
|
121 |
+
return {
|
122 |
+
"page_idx": page_idx,
|
123 |
+
"document_name": self.document_name,
|
124 |
+
"file_path": self.document_file_path,
|
125 |
+
"file_url": self.url,
|
126 |
+
"image_file_paths": image_file_paths,
|
127 |
+
}
|
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Coroutine, Dict, List
|
3 |
+
|
4 |
+
from marker.convert import convert_single_pdf
|
5 |
+
from marker.models import load_all_models
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
+
|
8 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
9 |
+
BaseImageLoader,
|
10 |
+
)
|
11 |
+
|
12 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
13 |
+
|
14 |
+
|
15 |
+
class MarkerImageLoader(BaseImageLoader):
|
16 |
+
"""
|
17 |
+
`MarkerImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
|
18 |
+
loading of pages from a PDF file as images using the marker library.
|
19 |
+
|
20 |
+
This class provides functionality to extract images from a PDF file using marker library,
|
21 |
+
and optionally publish these images to a WandB artifact.
|
22 |
+
|
23 |
+
!!! example "Example Usage"
|
24 |
+
```python
|
25 |
+
import asyncio
|
26 |
+
|
27 |
+
from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
|
28 |
+
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
+
|
31 |
+
loader = MarkerImageLoader(
|
32 |
+
url=URL,
|
33 |
+
document_name="Gray's Anatomy",
|
34 |
+
document_file_path="grays_anatomy.pdf",
|
35 |
+
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
37 |
+
```
|
38 |
+
|
39 |
+
Args:
|
40 |
+
url (str): The URL of the PDF document.
|
41 |
+
document_name (str): The name of the document.
|
42 |
+
document_file_path (str): The path to the PDF file.
|
43 |
+
save_page_image (bool): Whether to additionally save the image of the entire page.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
url: str,
|
49 |
+
document_name: str,
|
50 |
+
document_file_path: str,
|
51 |
+
save_page_image: bool = False,
|
52 |
+
):
|
53 |
+
super().__init__(url, document_name, document_file_path)
|
54 |
+
self.save_page_image = save_page_image
|
55 |
+
self.model_lst = load_all_models()
|
56 |
+
|
57 |
+
async def extract_page_data(
|
58 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
59 |
+
) -> Dict[str, Any]:
|
60 |
+
"""
|
61 |
+
Extracts a single page from the PDF as an image using marker library.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
page_idx (int): The index of the page to process.
|
65 |
+
image_save_dir (str): The directory to save the extracted image.
|
66 |
+
**kwargs: Additional keyword arguments that may be used by marker.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Dict[str, Any]: A dictionary containing the processed page data.
|
70 |
+
The dictionary will have the following keys and values:
|
71 |
+
|
72 |
+
- "page_idx": (int) the index of the page.
|
73 |
+
- "document_name": (str) the name of the document.
|
74 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
75 |
+
- "file_url": (str) the URL of the PDF file.
|
76 |
+
- "image_file_path": (str) the local file path where the image is stored.
|
77 |
+
"""
|
78 |
+
_, images, _ = convert_single_pdf(
|
79 |
+
self.document_file_path,
|
80 |
+
self.model_lst,
|
81 |
+
max_pages=1,
|
82 |
+
batch_multiplier=1,
|
83 |
+
start_page=page_idx,
|
84 |
+
ocr_all_pages=True,
|
85 |
+
**kwargs,
|
86 |
+
)
|
87 |
+
|
88 |
+
image_file_paths = []
|
89 |
+
for img_idx, (_, image) in enumerate(images.items()):
|
90 |
+
image_file_name = f"page{page_idx}_fig{img_idx}.png"
|
91 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
92 |
+
image.save(image_file_path, "png")
|
93 |
+
image_file_paths.append(image_file_path)
|
94 |
+
|
95 |
+
page_image = convert_from_path(
|
96 |
+
self.document_file_path,
|
97 |
+
first_page=page_idx + 1,
|
98 |
+
last_page=page_idx + 1,
|
99 |
+
**kwargs,
|
100 |
+
)[0]
|
101 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
102 |
+
|
103 |
+
return {
|
104 |
+
"page_idx": page_idx,
|
105 |
+
"document_name": self.document_name,
|
106 |
+
"file_path": self.document_file_path,
|
107 |
+
"file_url": self.url,
|
108 |
+
"image_file_paths": os.path.join(image_save_dir, "*.png"),
|
109 |
+
}
|
110 |
+
|
111 |
+
def load_data(
|
112 |
+
self,
|
113 |
+
start_page: int | None = None,
|
114 |
+
end_page: int | None = None,
|
115 |
+
wandb_artifact_name: str | None = None,
|
116 |
+
image_save_dir: str = "./images",
|
117 |
+
exclude_file_extensions: list[str] = [],
|
118 |
+
cleanup: bool = False,
|
119 |
+
**kwargs,
|
120 |
+
) -> Coroutine[Any, Any, List[Dict[str, str]]]:
|
121 |
+
start_page = start_page - 1 if start_page is not None else None
|
122 |
+
end_page = end_page - 1 if end_page is not None else None
|
123 |
+
return super().load_data(
|
124 |
+
start_page,
|
125 |
+
end_page,
|
126 |
+
wandb_artifact_name,
|
127 |
+
image_save_dir,
|
128 |
+
exclude_file_extensions,
|
129 |
+
cleanup,
|
130 |
+
**kwargs,
|
131 |
+
)
|
medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
+
from pdf2image.pdf2image import convert_from_path
|
5 |
+
|
6 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
7 |
+
BaseImageLoader,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class PDF2ImageLoader(BaseImageLoader):
|
12 |
+
"""
|
13 |
+
`PDF2ImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
|
14 |
+
loading of pages from a PDF file as images using the pdf2image library.
|
15 |
+
|
16 |
+
This class provides functionality to convert specific pages of a PDF document into images
|
17 |
+
and optionally publish these images to a WandB artifact.
|
18 |
+
It is like a snapshot image version of each of the pages from the PDF.
|
19 |
+
|
20 |
+
!!! example "Example Usage"
|
21 |
+
```python
|
22 |
+
import asyncio
|
23 |
+
|
24 |
+
from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
|
25 |
+
|
26 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
27 |
+
|
28 |
+
loader = PDF2ImageLoader(
|
29 |
+
url=URL,
|
30 |
+
document_name="Gray's Anatomy",
|
31 |
+
document_file_path="grays_anatomy.pdf",
|
32 |
+
)
|
33 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
34 |
+
```
|
35 |
+
|
36 |
+
Args:
|
37 |
+
url (str): The URL of the PDF document.
|
38 |
+
document_name (str): The name of the document.
|
39 |
+
document_file_path (str): The path to the PDF file.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, url: str, document_name: str, document_file_path: str):
|
43 |
+
super().__init__(url, document_name, document_file_path)
|
44 |
+
|
45 |
+
async def extract_page_data(
|
46 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
47 |
+
) -> Dict[str, Any]:
|
48 |
+
"""
|
49 |
+
Extracts a single page from the PDF as an image using pdf2image library.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
page_idx (int): The index of the page to process.
|
53 |
+
image_save_dir (str): The directory to save the extracted image.
|
54 |
+
**kwargs: Additional keyword arguments that may be used by pdf2image.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Dict[str, Any]: A dictionary containing the processed page data.
|
58 |
+
The dictionary will have the following keys and values:
|
59 |
+
|
60 |
+
- "page_idx": (int) the index of the page.
|
61 |
+
- "document_name": (str) the name of the document.
|
62 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
63 |
+
- "file_url": (str) the URL of the PDF file.
|
64 |
+
- "image_file_path": (str) the local file path where the image is stored.
|
65 |
+
"""
|
66 |
+
image = convert_from_path(
|
67 |
+
self.document_file_path,
|
68 |
+
first_page=page_idx + 1,
|
69 |
+
last_page=page_idx + 1,
|
70 |
+
**kwargs,
|
71 |
+
)[0]
|
72 |
+
|
73 |
+
image_file_name = f"page{page_idx}.png"
|
74 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
75 |
+
image.save(image_file_path)
|
76 |
+
|
77 |
+
return {
|
78 |
+
"page_idx": page_idx,
|
79 |
+
"document_name": self.document_name,
|
80 |
+
"file_path": self.document_file_path,
|
81 |
+
"file_url": self.url,
|
82 |
+
"image_file_path": image_file_path,
|
83 |
+
}
|
medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
+
import pdfplumber
|
5 |
+
from pdf2image.pdf2image import convert_from_path
|
6 |
+
|
7 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
8 |
+
BaseImageLoader,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class PDFPlumberImageLoader(BaseImageLoader):
|
13 |
+
"""
|
14 |
+
`PDFPlumberImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
|
15 |
+
loading of pages from a PDF file as images using the pdfplumber library.
|
16 |
+
|
17 |
+
This class provides functionality to extract images from a PDF file using pdfplumber library,
|
18 |
+
and optionally publish these images to a WandB artifact.
|
19 |
+
|
20 |
+
!!! example "Example Usage"
|
21 |
+
```python
|
22 |
+
import asyncio
|
23 |
+
|
24 |
+
from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
|
25 |
+
|
26 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
27 |
+
|
28 |
+
loader = PDFPlumberImageLoader(
|
29 |
+
url=URL,
|
30 |
+
document_name="Gray's Anatomy",
|
31 |
+
document_file_path="grays_anatomy.pdf",
|
32 |
+
)
|
33 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
34 |
+
```
|
35 |
+
|
36 |
+
Args:
|
37 |
+
url (str): The URL of the PDF document.
|
38 |
+
document_name (str): The name of the document.
|
39 |
+
document_file_path (str): The path to the PDF file.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, url: str, document_name: str, document_file_path: str):
|
43 |
+
super().__init__(url, document_name, document_file_path)
|
44 |
+
|
45 |
+
async def extract_page_data(
|
46 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
47 |
+
) -> Dict[str, Any]:
|
48 |
+
"""
|
49 |
+
Extracts a single page from the PDF as an image using pdfplumber library.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
page_idx (int): The index of the page to process.
|
53 |
+
image_save_dir (str): The directory to save the extracted image.
|
54 |
+
**kwargs: Additional keyword arguments that may be used by pdfplumber.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Dict[str, Any]: A dictionary containing the processed page data.
|
58 |
+
The dictionary will have the following keys and values:
|
59 |
+
|
60 |
+
- "page_idx": (int) the index of the page.
|
61 |
+
- "document_name": (str) the name of the document.
|
62 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
63 |
+
- "file_url": (str) the URL of the PDF file.
|
64 |
+
- "image_file_path": (str) the local file path where the image is stored.
|
65 |
+
"""
|
66 |
+
with pdfplumber.open(self.document_file_path) as pdf:
|
67 |
+
page = pdf.pages[page_idx]
|
68 |
+
images = page.images
|
69 |
+
|
70 |
+
image_file_paths = []
|
71 |
+
for img_idx, image in enumerate(images):
|
72 |
+
extracted_image = page.crop(
|
73 |
+
(
|
74 |
+
image["x0"],
|
75 |
+
image["top"],
|
76 |
+
image["x1"],
|
77 |
+
image["bottom"],
|
78 |
+
)
|
79 |
+
).to_image(resolution=300)
|
80 |
+
|
81 |
+
image_file_name = f"page{page_idx}_fig{img_idx}.png"
|
82 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
83 |
+
|
84 |
+
extracted_image.save(image_file_path, "png")
|
85 |
+
image_file_paths.append(image_file_path)
|
86 |
+
|
87 |
+
page_image = convert_from_path(
|
88 |
+
self.document_file_path,
|
89 |
+
first_page=page_idx + 1,
|
90 |
+
last_page=page_idx + 1,
|
91 |
+
**kwargs,
|
92 |
+
)[0]
|
93 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
94 |
+
|
95 |
+
return {
|
96 |
+
"page_idx": page_idx,
|
97 |
+
"document_name": self.document_name,
|
98 |
+
"file_path": self.document_file_path,
|
99 |
+
"file_url": self.url,
|
100 |
+
"image_file_paths": image_file_paths,
|
101 |
+
}
|
medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import fitz
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
10 |
+
BaseImageLoader,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class PyMuPDFImageLoader(BaseImageLoader):
|
15 |
+
"""
|
16 |
+
`PyMuPDFImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
|
17 |
+
loading of pages from a PDF file as images using the pymupdf library.
|
18 |
+
|
19 |
+
This class provides functionality to extract images from a PDF file using pymupdf library,
|
20 |
+
and optionally publish these images to a WandB artifact.
|
21 |
+
|
22 |
+
!!! example "Example Usage"
|
23 |
+
```python
|
24 |
+
import asyncio
|
25 |
+
|
26 |
+
from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
|
27 |
+
|
28 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
29 |
+
|
30 |
+
loader = PyMuPDFImageLoader(
|
31 |
+
url=URL,
|
32 |
+
document_name="Gray's Anatomy",
|
33 |
+
document_file_path="grays_anatomy.pdf",
|
34 |
+
)
|
35 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
36 |
+
```
|
37 |
+
|
38 |
+
Args:
|
39 |
+
url (str): The URL of the PDF document.
|
40 |
+
document_name (str): The name of the document.
|
41 |
+
document_file_path (str): The path to the PDF file.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, url: str, document_name: str, document_file_path: str):
|
45 |
+
super().__init__(url, document_name, document_file_path)
|
46 |
+
|
47 |
+
async def extract_page_data(
|
48 |
+
self, page_idx: int, image_save_dir: str, **kwargs
|
49 |
+
) -> Dict[str, Any]:
|
50 |
+
"""
|
51 |
+
Extracts a single page from the PDF as an image using pymupdf library.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
page_idx (int): The index of the page to process.
|
55 |
+
image_save_dir (str): The directory to save the extracted image.
|
56 |
+
**kwargs: Additional keyword arguments that may be used by pymupdf.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Dict[str, Any]: A dictionary containing the processed page data.
|
60 |
+
The dictionary will have the following keys and values:
|
61 |
+
|
62 |
+
- "page_idx": (int) the index of the page.
|
63 |
+
- "document_name": (str) the name of the document.
|
64 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
65 |
+
- "file_url": (str) the URL of the PDF file.
|
66 |
+
- "image_file_paths": (list) the local file paths where the images are stored.
|
67 |
+
"""
|
68 |
+
image_file_paths = []
|
69 |
+
|
70 |
+
pdf_document = fitz.open(self.document_file_path)
|
71 |
+
page = pdf_document[page_idx]
|
72 |
+
|
73 |
+
images = page.get_images(full=True)
|
74 |
+
for img_idx, image in enumerate(images):
|
75 |
+
xref = image[0]
|
76 |
+
base_image = pdf_document.extract_image(xref)
|
77 |
+
image_bytes = base_image["image"]
|
78 |
+
image_ext = base_image["ext"]
|
79 |
+
|
80 |
+
if image_ext == "jb2":
|
81 |
+
image_ext = "png"
|
82 |
+
elif image_ext == "jpx":
|
83 |
+
image_ext = "jpg"
|
84 |
+
|
85 |
+
image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
|
86 |
+
image_file_path = os.path.join(image_save_dir, image_file_name)
|
87 |
+
|
88 |
+
# For JBIG2 and JPEG2000, we need to convert the image
|
89 |
+
if base_image["ext"] in ["jb2", "jpx"]:
|
90 |
+
try:
|
91 |
+
pix = fitz.Pixmap(image_bytes)
|
92 |
+
pix.save(image_file_path)
|
93 |
+
except Exception as err_fitz:
|
94 |
+
print(f"Error processing image with fitz: {err_fitz}")
|
95 |
+
# Fallback to using PIL for image conversion
|
96 |
+
try:
|
97 |
+
img = Image.open(io.BytesIO(image_bytes))
|
98 |
+
img.save(image_file_path)
|
99 |
+
except Exception as err_pil:
|
100 |
+
print(f"Failed to process image with PIL: {err_pil}")
|
101 |
+
continue # Skip this image if both methods fail
|
102 |
+
else:
|
103 |
+
with open(image_file_path, "wb") as image_file:
|
104 |
+
image_file.write(image_bytes)
|
105 |
+
|
106 |
+
image_file_paths.append(image_file_path)
|
107 |
+
|
108 |
+
pdf_document.close()
|
109 |
+
|
110 |
+
page_image = convert_from_path(
|
111 |
+
self.document_file_path,
|
112 |
+
first_page=page_idx + 1,
|
113 |
+
last_page=page_idx + 1,
|
114 |
+
**kwargs,
|
115 |
+
)[0]
|
116 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
117 |
+
|
118 |
+
return {
|
119 |
+
"page_idx": page_idx,
|
120 |
+
"document_name": self.document_name,
|
121 |
+
"file_path": self.document_file_path,
|
122 |
+
"file_url": self.url,
|
123 |
+
"image_file_paths": image_file_paths,
|
124 |
+
}
|
medrag_multi_modal/document_loader/text_loader/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .marker_text_loader import MarkerTextLoader
|
2 |
+
from .pdfplumber_text_loader import PDFPlumberTextLoader
|
3 |
+
from .pymupdf4llm_text_loader import PyMuPDF4LLMTextLoader
|
4 |
+
from .pypdf2_text_loader import PyPDF2TextLoader
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"PyMuPDF4LLMTextLoader",
|
8 |
+
"PyPDF2TextLoader",
|
9 |
+
"PDFPlumberTextLoader",
|
10 |
+
"MarkerTextLoader",
|
11 |
+
]
|
medrag_multi_modal/document_loader/text_loader/base_text_loader.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
import huggingface_hub
|
7 |
+
import PyPDF2
|
8 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
9 |
+
from firerequests import FireRequests
|
10 |
+
from rich.progress import Progress
|
11 |
+
|
12 |
+
|
13 |
+
class BaseTextLoader(ABC):
|
14 |
+
"""
|
15 |
+
An abstract base class for loading text from a PDF file, processing it into markdown, and optionally publishing it to a Weave dataset.
|
16 |
+
|
17 |
+
This class handles the downloading of a PDF file from a given URL if it does not already exist locally.
|
18 |
+
Subclasses should implement the specific PDF reading, text extraction, and markdown conversion methods.
|
19 |
+
|
20 |
+
The processed pages are finally stored in a list of Page objects, which can be optionally published to a Weave dataset.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
url (str): The URL of the PDF file to download if not present locally.
|
24 |
+
document_name (str): The name of the document for metadata purposes.
|
25 |
+
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
26 |
+
metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
url: str,
|
32 |
+
document_name: str,
|
33 |
+
document_file_path: str,
|
34 |
+
metadata: Optional[dict[str, Any]] = None,
|
35 |
+
):
|
36 |
+
self.url = url
|
37 |
+
self.document_name = document_name
|
38 |
+
self.document_file_path = document_file_path
|
39 |
+
self.metadata = metadata or {}
|
40 |
+
if not os.path.exists(self.document_file_path):
|
41 |
+
FireRequests().download(url, filenames=self.document_file_path)
|
42 |
+
with open(self.document_file_path, "rb") as file:
|
43 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
44 |
+
self.page_count = len(pdf_reader.pages)
|
45 |
+
|
46 |
+
def get_page_indices(
|
47 |
+
self, start_page: Optional[int] = None, end_page: Optional[int] = None
|
48 |
+
) -> tuple[int, int]:
|
49 |
+
"""
|
50 |
+
Get the start and end page indices for processing.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
|
54 |
+
end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
tuple[int, int]: A tuple containing the start and end page indices.
|
58 |
+
"""
|
59 |
+
|
60 |
+
if start_page:
|
61 |
+
if start_page > self.page_count:
|
62 |
+
raise ValueError(
|
63 |
+
f"Start page {start_page} is greater than the total page count {self.page_count}"
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
start_page = 0
|
67 |
+
if end_page:
|
68 |
+
if end_page > self.page_count:
|
69 |
+
raise ValueError(
|
70 |
+
f"End page {end_page} is greater than the total page count {self.page_count}"
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
end_page = self.page_count - 1
|
74 |
+
return start_page, end_page
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
|
78 |
+
"""
|
79 |
+
Abstract method to process a single page of the PDF and extract the text data.
|
80 |
+
|
81 |
+
Overwrite this method in the subclass to provide the actual implementation and
|
82 |
+
processing logic for each page of the PDF using various PDF processing libraries.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
page_idx (int): The index of the page to process.
|
86 |
+
**kwargs: Additional keyword arguments that may be used by underlying libraries.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
90 |
+
"""
|
91 |
+
pass
|
92 |
+
|
93 |
+
async def load_data(
|
94 |
+
self,
|
95 |
+
start_page: Optional[int] = None,
|
96 |
+
end_page: Optional[int] = None,
|
97 |
+
exclude_pages: Optional[list[int]] = None,
|
98 |
+
dataset_repo_id: Optional[str] = None,
|
99 |
+
overwrite_dataset: bool = False,
|
100 |
+
**kwargs,
|
101 |
+
) -> Dataset:
|
102 |
+
"""
|
103 |
+
Asynchronously loads text from a PDF file specified by a URL or local file path.
|
104 |
+
The overrided processing abstract method then processes the text into markdown format,
|
105 |
+
and optionally publishes it to a Weave dataset.
|
106 |
+
|
107 |
+
This function downloads a PDF from a given URL if it does not already exist locally,
|
108 |
+
reads the specified range of pages, converts each page's content to markdown, and
|
109 |
+
returns a list of Page objects containing the text and metadata.
|
110 |
+
|
111 |
+
It uses `PyPDF2` to calculate the number of pages in the PDF and the
|
112 |
+
overriden `extract_page_data` method provides the actual implementation to process
|
113 |
+
each page, extract the text from the PDF, and convert it to markdown.
|
114 |
+
It processes pages concurrently using `asyncio` for efficiency.
|
115 |
+
|
116 |
+
If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
|
120 |
+
end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
|
121 |
+
exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing.
|
122 |
+
dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
|
123 |
+
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
|
124 |
+
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages.
|
128 |
+
Each entry in the dataset will have the following keys and values:
|
129 |
+
|
130 |
+
- "text": (str) the processed page data in markdown format.
|
131 |
+
- "page_idx": (int) the index of the page.
|
132 |
+
- "document_name": (str) the name of the document.
|
133 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
134 |
+
- "file_url": (str) the URL of the PDF file.
|
135 |
+
- "loader_name": (str) the name of the loader class used to process the page.
|
136 |
+
|
137 |
+
Raises:
|
138 |
+
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
139 |
+
"""
|
140 |
+
start_page, end_page = self.get_page_indices(start_page, end_page)
|
141 |
+
pages = []
|
142 |
+
processed_pages_counter: int = 1
|
143 |
+
total_pages = end_page - start_page
|
144 |
+
exclude_pages = exclude_pages or []
|
145 |
+
|
146 |
+
async def process_page(page_idx):
|
147 |
+
nonlocal processed_pages_counter
|
148 |
+
page_data = await self.extract_page_data(page_idx, **kwargs)
|
149 |
+
page_data["loader_name"] = self.__class__.__name__
|
150 |
+
for key, value in self.metadata.items():
|
151 |
+
if key not in page_data:
|
152 |
+
page_data[key] = value
|
153 |
+
pages.append(page_data)
|
154 |
+
progress.update(
|
155 |
+
task_id,
|
156 |
+
advance=1,
|
157 |
+
description=f"Loading page {page_idx} using {self.__class__.__name__}",
|
158 |
+
)
|
159 |
+
processed_pages_counter += 1
|
160 |
+
|
161 |
+
progress = Progress()
|
162 |
+
with progress:
|
163 |
+
task_id = progress.add_task("Starting...", total=total_pages)
|
164 |
+
tasks = [
|
165 |
+
process_page(page_idx)
|
166 |
+
for page_idx in range(start_page, end_page + 1)
|
167 |
+
if page_idx not in exclude_pages
|
168 |
+
]
|
169 |
+
for task in asyncio.as_completed(tasks):
|
170 |
+
await task
|
171 |
+
|
172 |
+
pages.sort(key=lambda x: x["page_idx"])
|
173 |
+
|
174 |
+
dataset = Dataset.from_list(pages)
|
175 |
+
if dataset_repo_id:
|
176 |
+
if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
|
177 |
+
print("Dataset already exists")
|
178 |
+
if not overwrite_dataset:
|
179 |
+
print("Not overwriting dataset")
|
180 |
+
dataset = concatenate_datasets(
|
181 |
+
[dataset, load_dataset(dataset_repo_id, split="corpus")]
|
182 |
+
)
|
183 |
+
dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False)
|
184 |
+
|
185 |
+
return dataset
|
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
from marker.convert import convert_single_pdf
|
5 |
+
from marker.models import load_all_models
|
6 |
+
|
7 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
8 |
+
BaseTextLoader,
|
9 |
+
)
|
10 |
+
|
11 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
12 |
+
|
13 |
+
|
14 |
+
class MarkerTextLoader(BaseTextLoader):
|
15 |
+
"""
|
16 |
+
A concrete implementation of the BaseTextLoader for loading text from a PDF file
|
17 |
+
using `marker-pdf`, processing it into a structured text format, and optionally publishing
|
18 |
+
it to a Weave dataset.
|
19 |
+
|
20 |
+
This class extends the BaseTextLoader and implements the abstract methods to
|
21 |
+
load and process pages from a PDF file using marker-pdf, which is a pipeline of deep learning models.
|
22 |
+
|
23 |
+
This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
|
24 |
+
It uses marker-pdf to read the PDF and extract structured text from each page. The processed pages are stored
|
25 |
+
in a list of Page objects, which can be optionally published to a Weave dataset.
|
26 |
+
|
27 |
+
!!! example "Example Usage"
|
28 |
+
```python
|
29 |
+
import asyncio
|
30 |
+
|
31 |
+
from medrag_multi_modal.document_loader import MarkerTextLoader
|
32 |
+
|
33 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
34 |
+
|
35 |
+
loader = MarkerTextLoader(
|
36 |
+
url=URL,
|
37 |
+
document_name="Gray's Anatomy",
|
38 |
+
document_file_path="grays_anatomy.pdf",
|
39 |
+
)
|
40 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
41 |
+
```
|
42 |
+
|
43 |
+
Args:
|
44 |
+
url (str): The URL of the PDF file to download if not present locally.
|
45 |
+
document_name (str): The name of the document for metadata purposes.
|
46 |
+
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
47 |
+
"""
|
48 |
+
|
49 |
+
async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
|
50 |
+
"""
|
51 |
+
Process a single page of the PDF and extract its structured text using marker-pdf.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Dict[str, str]: A dictionary with the processed page data.
|
55 |
+
The dictionary will have the following keys and values:
|
56 |
+
|
57 |
+
- "text": (str) the extracted structured text from the page.
|
58 |
+
- "page_idx": (int) the index of the page.
|
59 |
+
- "document_name": (str) the name of the document.
|
60 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
61 |
+
- "file_url": (str) the URL of the PDF file.
|
62 |
+
- "meta": (dict) the metadata extracted from the page by marker-pdf.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
page_idx (int): The index of the page to process.
|
66 |
+
**kwargs: Additional keyword arguments to be passed to `marker.convert.convert_single_pdf`.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
70 |
+
"""
|
71 |
+
model_lst = load_all_models()
|
72 |
+
|
73 |
+
text, _, _ = convert_single_pdf(
|
74 |
+
self.document_file_path,
|
75 |
+
model_lst,
|
76 |
+
max_pages=1,
|
77 |
+
batch_multiplier=1,
|
78 |
+
start_page=page_idx,
|
79 |
+
ocr_all_pages=True,
|
80 |
+
**kwargs,
|
81 |
+
)
|
82 |
+
|
83 |
+
return {
|
84 |
+
"text": text,
|
85 |
+
"page_idx": page_idx,
|
86 |
+
"document_name": self.document_name,
|
87 |
+
"file_path": self.document_file_path,
|
88 |
+
"file_url": self.url,
|
89 |
+
}
|
medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import pdfplumber
|
4 |
+
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class PDFPlumberTextLoader(BaseTextLoader):
|
11 |
+
"""
|
12 |
+
A concrete implementation of the BaseTextLoader for loading text from a PDF file
|
13 |
+
using `pdfplumber`, processing it into a simple text format, and optionally publishing
|
14 |
+
it to a Weave dataset.
|
15 |
+
|
16 |
+
This class extends the BaseTextLoader and implements the abstract methods to
|
17 |
+
load and process pages from a PDF file.
|
18 |
+
|
19 |
+
This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
|
20 |
+
It uses pdfplumber to read the PDF and extract text from each page. The processed pages are stored in a list
|
21 |
+
of Page objects, which can be optionally published to a Weave dataset.
|
22 |
+
|
23 |
+
!!! example "Example Usage"
|
24 |
+
```python
|
25 |
+
import asyncio
|
26 |
+
|
27 |
+
from medrag_multi_modal.document_loader import PDFPlumberTextLoader
|
28 |
+
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
+
|
31 |
+
loader = PDFPlumberTextLoader(
|
32 |
+
url=URL,
|
33 |
+
document_name="Gray's Anatomy",
|
34 |
+
document_file_path="grays_anatomy.pdf",
|
35 |
+
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
37 |
+
```
|
38 |
+
|
39 |
+
Args:
|
40 |
+
url (str): The URL of the PDF file to download if not present locally.
|
41 |
+
document_name (str): The name of the document for metadata purposes.
|
42 |
+
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
43 |
+
"""
|
44 |
+
|
45 |
+
async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
|
46 |
+
"""
|
47 |
+
Process a single page of the PDF and extract its text using pdfplumber.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Dict[str, str]: A dictionary with the processed page data.
|
51 |
+
The dictionary will have the following keys and values:
|
52 |
+
|
53 |
+
- "text": (str) the extracted text from the page.
|
54 |
+
- "page_idx": (int) the index of the page.
|
55 |
+
- "document_name": (str) the name of the document.
|
56 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
57 |
+
- "file_url": (str) the URL of the PDF file.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
page_idx (int): The index of the page to process.
|
61 |
+
**kwargs: Additional keyword arguments to be passed to `pdfplumber.Page.extract_text`.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
65 |
+
"""
|
66 |
+
with pdfplumber.open(self.document_file_path) as pdf:
|
67 |
+
page = pdf.pages[page_idx]
|
68 |
+
text = page.extract_text(**kwargs)
|
69 |
+
|
70 |
+
return {
|
71 |
+
"text": text,
|
72 |
+
"page_idx": page_idx,
|
73 |
+
"document_name": self.document_name,
|
74 |
+
"file_path": self.document_file_path,
|
75 |
+
"file_url": self.url,
|
76 |
+
}
|
medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import pymupdf4llm
|
4 |
+
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class PyMuPDF4LLMTextLoader(BaseTextLoader):
|
11 |
+
"""
|
12 |
+
A concrete implementation of the BaseTextLoader for loading text from a PDF file,
|
13 |
+
processing it into markdown using `pymupdf4llm`, and optionally publishing it to a Weave dataset.
|
14 |
+
|
15 |
+
This class extends the BaseTextLoader and implements the abstract methods to load and process pages from a PDF file.
|
16 |
+
|
17 |
+
This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
|
18 |
+
It uses PyPDF2 to read the PDF and pymupdf4llm to convert pages to markdown. The processed pages are stored in a list
|
19 |
+
of Page objects, which can be optionally published to a Weave dataset.
|
20 |
+
|
21 |
+
!!! example "Example Usage"
|
22 |
+
```python
|
23 |
+
import asyncio
|
24 |
+
|
25 |
+
from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader
|
26 |
+
|
27 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
28 |
+
|
29 |
+
loader = PyMuPDF4LLMTextLoader(
|
30 |
+
url=URL,
|
31 |
+
document_name="Gray's Anatomy",
|
32 |
+
document_file_path="grays_anatomy.pdf",
|
33 |
+
)
|
34 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
35 |
+
```
|
36 |
+
|
37 |
+
Args:
|
38 |
+
url (str): The URL of the PDF file to download if not present locally.
|
39 |
+
document_name (str): The name of the document for metadata purposes.
|
40 |
+
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
41 |
+
"""
|
42 |
+
|
43 |
+
async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
|
44 |
+
"""
|
45 |
+
Process a single page of the PDF and convert it to markdown using `pymupdf4llm`.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Dict[str, str]: A dictionary with the processed page data.
|
49 |
+
The dictionary will have the following keys and values:
|
50 |
+
|
51 |
+
- "text": (str) the processed page data in markdown format.
|
52 |
+
- "page_idx": (int) the index of the page.
|
53 |
+
- "document_name": (str) the name of the document.
|
54 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
55 |
+
- "file_url": (str) the URL of the PDF file.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
page_idx (int): The index of the page to process.
|
59 |
+
**kwargs: Additional keyword arguments to be passed to `pymupdf4llm.to_markdown`.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
63 |
+
"""
|
64 |
+
text = pymupdf4llm.to_markdown(
|
65 |
+
doc=self.document_file_path, pages=[page_idx], show_progress=False, **kwargs
|
66 |
+
)
|
67 |
+
return {
|
68 |
+
"text": text,
|
69 |
+
"page_idx": page_idx,
|
70 |
+
"document_name": self.document_name,
|
71 |
+
"file_path": self.document_file_path,
|
72 |
+
"file_url": self.url,
|
73 |
+
}
|
medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import PyPDF2
|
4 |
+
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class PyPDF2TextLoader(BaseTextLoader):
|
11 |
+
"""
|
12 |
+
A concrete implementation of the BaseTextLoader for loading text from a PDF file
|
13 |
+
using `PyPDF2`, processing it into a simple text format, and optionally publishing
|
14 |
+
it to a Weave dataset.
|
15 |
+
|
16 |
+
This class extends the BaseTextLoader and implements the abstract methods to
|
17 |
+
load and process pages from a PDF file.
|
18 |
+
|
19 |
+
This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
|
20 |
+
It uses PyPDF2 to read the PDF and extract text from each page. The processed pages are stored in a list
|
21 |
+
of Page objects, which can be optionally published to a Weave dataset.
|
22 |
+
|
23 |
+
!!! example "Example Usage"
|
24 |
+
```python
|
25 |
+
import asyncio
|
26 |
+
|
27 |
+
from medrag_multi_modal.document_loader import PyPDF2TextLoader
|
28 |
+
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
+
|
31 |
+
loader = PyPDF2TextLoader(
|
32 |
+
url=URL,
|
33 |
+
document_name="Gray's Anatomy",
|
34 |
+
document_file_path="grays_anatomy.pdf",
|
35 |
+
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
37 |
+
```
|
38 |
+
|
39 |
+
Args:
|
40 |
+
url (str): The URL of the PDF file to download if not present locally.
|
41 |
+
document_name (str): The name of the document for metadata purposes.
|
42 |
+
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
43 |
+
"""
|
44 |
+
|
45 |
+
async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
|
46 |
+
"""
|
47 |
+
Process a single page of the PDF and extract its text using PyPDF2.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Dict[str, str]: A dictionary with the processed page data.
|
51 |
+
The dictionary will have the following keys and values:
|
52 |
+
|
53 |
+
- "text": (str) the extracted text from the page.
|
54 |
+
- "page_idx": (int) the index of the page.
|
55 |
+
- "document_name": (str) the name of the document.
|
56 |
+
- "file_path": (str) the local file path where the PDF is stored.
|
57 |
+
- "file_url": (str) the URL of the PDF file.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
page_idx (int): The index of the page to process.
|
61 |
+
**kwargs: Additional keyword arguments to be passed to `PyPDF2.PdfReader.pages[0].extract_text`.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dict[str, str]: A dictionary containing the processed page data.
|
65 |
+
"""
|
66 |
+
with open(self.document_file_path, "rb") as file:
|
67 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
68 |
+
page = pdf_reader.pages[page_idx]
|
69 |
+
text = page.extract_text(**kwargs)
|
70 |
+
|
71 |
+
return {
|
72 |
+
"text": text,
|
73 |
+
"page_idx": page_idx,
|
74 |
+
"document_name": self.document_name,
|
75 |
+
"file_path": self.document_file_path,
|
76 |
+
"file_url": self.url,
|
77 |
+
}
|
medrag_multi_modal/metrics/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .mmlu import MMLUOptionAccuracy
|
2 |
+
|
3 |
+
__all__ = ["MMLUOptionAccuracy"]
|
medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (273 Bytes). View file
|
|
medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc
ADDED
Binary file (775 Bytes). View file
|
|
medrag_multi_modal/metrics/base.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import weave
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAccuracyMetric(weave.Scorer):
|
8 |
+
"""
|
9 |
+
BaseAccuracyMetric is a class that extends the
|
10 |
+
[`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
|
11 |
+
to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
|
12 |
+
|
13 |
+
This class is designed to process a list of score rows, each containing a
|
14 |
+
'correct' key that indicates whether a particular prediction was correct.
|
15 |
+
The `summarize` method calculates various statistical measures and metrics
|
16 |
+
based on this data, including:
|
17 |
+
|
18 |
+
- True and false counts: The number of true and false predictions.
|
19 |
+
- True and false fractions: The proportion of true and false predictions.
|
20 |
+
- Standard error: The standard error of the mean for the true predictions.
|
21 |
+
- Precision: The ratio of true positive predictions to the total number of
|
22 |
+
positive predictions.
|
23 |
+
- Recall: The ratio of true positive predictions to the total number of
|
24 |
+
actual positives.
|
25 |
+
- F1 Score: The harmonic mean of precision and recall, providing a balance
|
26 |
+
between the two metrics.
|
27 |
+
|
28 |
+
The `summarize` method returns a dictionary containing these metrics,
|
29 |
+
allowing for a detailed analysis of the model's performance.
|
30 |
+
|
31 |
+
Methods:
|
32 |
+
summarize(score_rows: list) -> Optional[dict]:
|
33 |
+
Processes the input score rows to compute and return a dictionary
|
34 |
+
of accuracy metrics.
|
35 |
+
"""
|
36 |
+
@weave.op()
|
37 |
+
def summarize(self, score_rows: list) -> Optional[dict]:
|
38 |
+
"""
|
39 |
+
Summarizes the accuracy metrics from a list of score rows.
|
40 |
+
|
41 |
+
This method processes a list of score rows, each containing a 'correct' key
|
42 |
+
that indicates whether a particular prediction was correct. It calculates
|
43 |
+
various statistical measures and metrics based on this data, including:
|
44 |
+
|
45 |
+
- True and false counts: The number of true and false predictions.
|
46 |
+
- True and false fractions: The proportion of true and false predictions.
|
47 |
+
- Standard error: The standard error of the mean for the true predictions.
|
48 |
+
- Precision: The ratio of true positive predictions to the total number of
|
49 |
+
positive predictions.
|
50 |
+
- Recall: The ratio of true positive predictions to the total number of
|
51 |
+
actual positives.
|
52 |
+
- F1 Score: The harmonic mean of precision and recall, providing a balance
|
53 |
+
between the two metrics.
|
54 |
+
|
55 |
+
The method returns a dictionary containing these metrics, allowing for a
|
56 |
+
detailed analysis of the model's performance.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
score_rows (list): A list of dictionaries, each containing a 'correct'
|
60 |
+
key with a boolean value indicating the correctness of a prediction.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Optional[dict]: A dictionary containing the calculated accuracy metrics,
|
64 |
+
or None if the input list is empty.
|
65 |
+
"""
|
66 |
+
valid_data = [
|
67 |
+
x.get("correct") for x in score_rows if x.get("correct") is not None
|
68 |
+
]
|
69 |
+
count_true = list(valid_data).count(True)
|
70 |
+
int_data = [int(x) for x in valid_data]
|
71 |
+
|
72 |
+
sample_mean = np.mean(int_data) if int_data else 0
|
73 |
+
sample_variance = np.var(int_data) if int_data else 0
|
74 |
+
sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
|
75 |
+
|
76 |
+
# Calculate precision, recall, and F1 score
|
77 |
+
true_positives = count_true
|
78 |
+
false_positives = len(valid_data) - count_true
|
79 |
+
false_negatives = len(score_rows) - len(valid_data)
|
80 |
+
|
81 |
+
precision = (
|
82 |
+
true_positives / (true_positives + false_positives)
|
83 |
+
if (true_positives + false_positives) > 0
|
84 |
+
else 0
|
85 |
+
)
|
86 |
+
recall = (
|
87 |
+
true_positives / (true_positives + false_negatives)
|
88 |
+
if (true_positives + false_negatives) > 0
|
89 |
+
else 0
|
90 |
+
)
|
91 |
+
f1_score = (
|
92 |
+
(2 * precision * recall) / (precision + recall)
|
93 |
+
if (precision + recall) > 0
|
94 |
+
else 0
|
95 |
+
)
|
96 |
+
|
97 |
+
return {
|
98 |
+
"correct": {
|
99 |
+
"true_count": count_true,
|
100 |
+
"false_count": len(score_rows) - count_true,
|
101 |
+
"true_fraction": float(sample_mean),
|
102 |
+
"false_fraction": 1.0 - float(sample_mean),
|
103 |
+
"stderr": float(sample_error),
|
104 |
+
"precision": precision,
|
105 |
+
"recall": recall,
|
106 |
+
"f1_score": f1_score,
|
107 |
+
}
|
108 |
+
}
|
medrag_multi_modal/metrics/mmlu.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
+
from medrag_multi_modal.assistant.schema import MedQAResponse
|
4 |
+
from medrag_multi_modal.metrics.base import BaseAccuracyMetric
|
5 |
+
|
6 |
+
|
7 |
+
class MMLUOptionAccuracy(BaseAccuracyMetric):
|
8 |
+
"""
|
9 |
+
MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`.
|
10 |
+
|
11 |
+
This class is designed to evaluate the accuracy of a multiple-choice question
|
12 |
+
response by comparing the provided answer with the correct answer from the
|
13 |
+
given options. It uses the MedQAResponse schema to extract the response
|
14 |
+
and checks if it matches the correct answer.
|
15 |
+
|
16 |
+
Methods:
|
17 |
+
--------
|
18 |
+
score(output: MedQAResponse, options: list[str], answer: str) -> dict:
|
19 |
+
Compares the provided answer with the correct answer and returns a
|
20 |
+
dictionary indicating whether the answer is correct.
|
21 |
+
"""
|
22 |
+
@weave.op()
|
23 |
+
def score(self, output: MedQAResponse, options: list[str], answer: str):
|
24 |
+
return {"correct": options[answer] == output.response.answer}
|
medrag_multi_modal/retrieval/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .colpali_retrieval import CalPaliRetriever
|
2 |
+
|
3 |
+
__all__ = ["CalPaliRetriever"]
|
medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (276 Bytes). View file
|
|
medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc
ADDED
Binary file (9.94 kB). View file
|
|
medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc
ADDED
Binary file (1.25 kB). View file
|
|
medrag_multi_modal/retrieval/colpali_retrieval.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import TYPE_CHECKING, Any, Optional
|
3 |
+
|
4 |
+
import weave
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from byaldi import RAGMultiModalModel
|
8 |
+
|
9 |
+
import wandb
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from medrag_multi_modal.utils import get_wandb_artifact
|
13 |
+
|
14 |
+
|
15 |
+
class CalPaliRetriever(weave.Model):
|
16 |
+
"""
|
17 |
+
CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.
|
18 |
+
|
19 |
+
This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
|
20 |
+
It can be initialized with a pre-trained model or from a specified W&B artifact. The class
|
21 |
+
also provides methods to index new data and to predict/retrieve documents based on a query.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
model_name (str): The name of the model to be used for retrieval.
|
25 |
+
"""
|
26 |
+
|
27 |
+
model_name: str
|
28 |
+
_docs_retrieval_model: Optional["RAGMultiModalModel"] = None
|
29 |
+
_metadata: Optional[dict] = None
|
30 |
+
_data_artifact_dir: Optional[str] = None
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
model_name: str = "vidore/colpali-v1.2",
|
35 |
+
docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
|
36 |
+
data_artifact_dir: Optional[str] = None,
|
37 |
+
metadata_dataset_name: Optional[str] = None,
|
38 |
+
):
|
39 |
+
super().__init__(model_name=model_name)
|
40 |
+
from byaldi import RAGMultiModalModel
|
41 |
+
|
42 |
+
self._docs_retrieval_model = (
|
43 |
+
docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
|
44 |
+
)
|
45 |
+
self._data_artifact_dir = data_artifact_dir
|
46 |
+
self._metadata = (
|
47 |
+
[dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
|
48 |
+
if metadata_dataset_name
|
49 |
+
else None
|
50 |
+
)
|
51 |
+
|
52 |
+
def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
|
53 |
+
"""
|
54 |
+
Indexes a dataset of documents and saves the index as a Weave artifact.
|
55 |
+
|
56 |
+
This method retrieves a dataset of documents from a Weave artifact using the provided
|
57 |
+
data artifact name. It then indexes the documents using the document retrieval model
|
58 |
+
and assigns the specified index name. The index is stored locally without storing the
|
59 |
+
collection with the index and overwrites any existing index with the same name.
|
60 |
+
|
61 |
+
If a Weave run is active, the method creates a new Weave artifact with the specified
|
62 |
+
index name and type "colpali-index". It adds the local index directory to the artifact
|
63 |
+
and saves it to Weave, including metadata with the provided Weave dataset name.
|
64 |
+
|
65 |
+
!!! example "Indexing Data"
|
66 |
+
First you need to install `Byaldi` library by Answer.ai.
|
67 |
+
|
68 |
+
```bash
|
69 |
+
uv pip install Byaldi>=0.0.5
|
70 |
+
```
|
71 |
+
|
72 |
+
Next, you can index the data by running the following code:
|
73 |
+
|
74 |
+
```python
|
75 |
+
import wandb
|
76 |
+
from medrag_multi_modal.retrieval import CalPaliRetriever
|
77 |
+
|
78 |
+
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
|
79 |
+
retriever = CalPaliRetriever()
|
80 |
+
retriever.index(
|
81 |
+
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
|
82 |
+
weave_dataset_name="grays-anatomy-images:v0",
|
83 |
+
index_name="grays-anatomy",
|
84 |
+
)
|
85 |
+
```
|
86 |
+
|
87 |
+
??? note "Optional Speedup using Flash Attention"
|
88 |
+
If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
|
89 |
+
installing the `flash-attn` package.
|
90 |
+
|
91 |
+
```bash
|
92 |
+
uv pip install flash-attn --no-build-isolation
|
93 |
+
```
|
94 |
+
|
95 |
+
Args:
|
96 |
+
data_artifact_name (str): The name of the Weave artifact containing the dataset.
|
97 |
+
weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata.
|
98 |
+
index_name (str): The name to assign to the created index.
|
99 |
+
"""
|
100 |
+
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
|
101 |
+
self._docs_retrieval_model.index(
|
102 |
+
input_path=data_artifact_dir,
|
103 |
+
index_name=index_name,
|
104 |
+
store_collection_with_index=False,
|
105 |
+
overwrite=True,
|
106 |
+
)
|
107 |
+
if wandb.run:
|
108 |
+
artifact = wandb.Artifact(
|
109 |
+
name=index_name,
|
110 |
+
type="colpali-index",
|
111 |
+
metadata={"weave_dataset_name": weave_dataset_name},
|
112 |
+
)
|
113 |
+
artifact.add_dir(
|
114 |
+
local_path=os.path.join(".byaldi", index_name), name="index"
|
115 |
+
)
|
116 |
+
artifact.save()
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def from_wandb_artifact(
|
120 |
+
cls,
|
121 |
+
index_artifact_name: str,
|
122 |
+
metadata_dataset_name: str,
|
123 |
+
data_artifact_name: str,
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
Creates an instance of the class from Weights & Biases (wandb) artifacts.
|
127 |
+
|
128 |
+
This method retrieves the necessary artifacts from wandb to initialize the
|
129 |
+
ColPaliRetriever. It fetches the index artifact directory and the data artifact
|
130 |
+
directory using the provided artifact names. It then loads the document retrieval
|
131 |
+
model from the index path within the index artifact directory. Finally, it returns
|
132 |
+
an instance of the class initialized with the retrieved document retrieval model,
|
133 |
+
metadata dataset name, and data artifact directory.
|
134 |
+
|
135 |
+
!!! example "Retrieving Documents"
|
136 |
+
First you need to install `Byaldi` library by Answer.ai.
|
137 |
+
|
138 |
+
```bash
|
139 |
+
uv pip install Byaldi>=0.0.5
|
140 |
+
```
|
141 |
+
|
142 |
+
Next, you can retrieve the documents by running the following code:
|
143 |
+
|
144 |
+
```python
|
145 |
+
import weave
|
146 |
+
|
147 |
+
import wandb
|
148 |
+
from medrag_multi_modal.retrieval import CalPaliRetriever
|
149 |
+
|
150 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
151 |
+
retriever = CalPaliRetriever.from_wandb_artifact(
|
152 |
+
index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
|
153 |
+
metadata_dataset_name="grays-anatomy-images:v0",
|
154 |
+
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
|
155 |
+
)
|
156 |
+
```
|
157 |
+
|
158 |
+
??? note "Optional Speedup using Flash Attention"
|
159 |
+
If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
|
160 |
+
installing the `flash-attn` package.
|
161 |
+
|
162 |
+
```bash
|
163 |
+
uv pip install flash-attn --no-build-isolation
|
164 |
+
```
|
165 |
+
|
166 |
+
Args:
|
167 |
+
index_artifact_name (str): The name of the wandb artifact containing the index.
|
168 |
+
metadata_dataset_name (str): The name of the dataset containing metadata.
|
169 |
+
data_artifact_name (str): The name of the wandb artifact containing the data.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
An instance of the class initialized with the retrieved document retrieval model,
|
173 |
+
metadata dataset name, and data artifact directory.
|
174 |
+
"""
|
175 |
+
from byaldi import RAGMultiModalModel
|
176 |
+
|
177 |
+
index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
|
178 |
+
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
|
179 |
+
docs_retrieval_model = RAGMultiModalModel.from_index(
|
180 |
+
index_path=os.path.join(index_artifact_dir, "index")
|
181 |
+
)
|
182 |
+
return cls(
|
183 |
+
docs_retrieval_model=docs_retrieval_model,
|
184 |
+
metadata_dataset_name=metadata_dataset_name,
|
185 |
+
data_artifact_dir=data_artifact_dir,
|
186 |
+
)
|
187 |
+
|
188 |
+
@weave.op()
|
189 |
+
def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
|
190 |
+
"""
|
191 |
+
Predicts and retrieves the top-k most relevant documents/images for a given query
|
192 |
+
using ColPali.
|
193 |
+
|
194 |
+
This function uses the document retrieval model to search for the most relevant
|
195 |
+
documents based on the provided query. It returns a list of dictionaries, each
|
196 |
+
containing the document image, document ID, and the relevance score.
|
197 |
+
|
198 |
+
!!! example "Retrieving Documents"
|
199 |
+
First you need to install `Byaldi` library by Answer.ai.
|
200 |
+
|
201 |
+
```bash
|
202 |
+
uv pip install Byaldi>=0.0.5
|
203 |
+
```
|
204 |
+
|
205 |
+
Next, you can retrieve the documents by running the following code:
|
206 |
+
|
207 |
+
```python
|
208 |
+
import weave
|
209 |
+
|
210 |
+
import wandb
|
211 |
+
from medrag_multi_modal.retrieval import CalPaliRetriever
|
212 |
+
|
213 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
214 |
+
retriever = CalPaliRetriever.from_wandb_artifact(
|
215 |
+
index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
|
216 |
+
metadata_dataset_name="grays-anatomy-images:v0",
|
217 |
+
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
|
218 |
+
)
|
219 |
+
retriever.predict(
|
220 |
+
query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
|
221 |
+
top_k=3,
|
222 |
+
)
|
223 |
+
```
|
224 |
+
|
225 |
+
??? note "Optional Speedup using Flash Attention"
|
226 |
+
If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
|
227 |
+
installing the `flash-attn` package.
|
228 |
+
|
229 |
+
```bash
|
230 |
+
uv pip install flash-attn --no-build-isolation
|
231 |
+
```
|
232 |
+
|
233 |
+
Args:
|
234 |
+
query (str): The search query string.
|
235 |
+
top_k (int, optional): The number of top results to retrieve. Defaults to 10.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
|
239 |
+
- "doc_image" (PIL.Image.Image): The image of the document.
|
240 |
+
- "doc_id" (str): The ID of the document.
|
241 |
+
- "score" (float): The relevance score of the document.
|
242 |
+
"""
|
243 |
+
results = self._docs_retrieval_model.search(query=query, k=top_k)
|
244 |
+
retrieved_results = []
|
245 |
+
for result in results:
|
246 |
+
retrieved_results.append(
|
247 |
+
{
|
248 |
+
"doc_image": Image.open(
|
249 |
+
os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
|
250 |
+
),
|
251 |
+
"doc_id": result["doc_id"],
|
252 |
+
"score": result["score"],
|
253 |
+
}
|
254 |
+
)
|
255 |
+
return retrieved_results
|
medrag_multi_modal/retrieval/common.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class SimilarityMetric(Enum):
|
5 |
+
COSINE = "cosine"
|
6 |
+
EUCLIDEAN = "euclidean"
|
7 |
+
|
8 |
+
|
9 |
+
def mean_pooling(token_embeddings, mask):
|
10 |
+
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
|
11 |
+
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
|
12 |
+
return sentence_embeddings
|
13 |
+
|
14 |
+
|
15 |
+
def argsort_scores(scores: list[float], descending: bool = False):
|
16 |
+
return [
|
17 |
+
{"item": item, "original_index": idx}
|
18 |
+
for idx, item in sorted(
|
19 |
+
list(enumerate(scores)), key=lambda x: x[1], reverse=descending
|
20 |
+
)
|
21 |
+
]
|
medrag_multi_modal/retrieval/text_retrieval/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bm25s_retrieval import BM25sRetriever
|
2 |
+
from .contriever_retrieval import ContrieverRetriever
|
3 |
+
from .medcpt_retrieval import MedCPTRetriever
|
4 |
+
from .nv_embed_2 import NVEmbed2Retriever
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"BM25sRetriever",
|
8 |
+
"ContrieverRetriever",
|
9 |
+
"MedCPTRetriever",
|
10 |
+
"NVEmbed2Retriever",
|
11 |
+
]
|
medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (478 Bytes). View file
|
|