geekyrakshit commited on
Commit
39b7b6a
1 Parent(s): 63b409f

add: files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +7 -7
  2. app.py +114 -0
  3. medrag_multi_modal/__init__.py +0 -0
  4. medrag_multi_modal/__pycache__/__init__.cpython-310.pyc +0 -0
  5. medrag_multi_modal/__pycache__/__init__.cpython-39.pyc +0 -0
  6. medrag_multi_modal/__pycache__/cli.cpython-310.pyc +0 -0
  7. medrag_multi_modal/__pycache__/utils.cpython-310.pyc +0 -0
  8. medrag_multi_modal/__pycache__/utils.cpython-39.pyc +0 -0
  9. medrag_multi_modal/assistant/__init__.py +5 -0
  10. medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc +0 -0
  11. medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc +0 -0
  12. medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc +0 -0
  13. medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc +0 -0
  14. medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc +0 -0
  15. medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc +0 -0
  16. medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc +0 -0
  17. medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc +0 -0
  18. medrag_multi_modal/assistant/figure_annotation.py +147 -0
  19. medrag_multi_modal/assistant/llm_client.py +245 -0
  20. medrag_multi_modal/assistant/medqa_assistant.py +174 -0
  21. medrag_multi_modal/assistant/schema.py +27 -0
  22. medrag_multi_modal/cli.py +68 -0
  23. medrag_multi_modal/document_loader/__init__.py +25 -0
  24. medrag_multi_modal/document_loader/image_loader/__init__.py +13 -0
  25. medrag_multi_modal/document_loader/image_loader/base_img_loader.py +180 -0
  26. medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py +127 -0
  27. medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +131 -0
  28. medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py +83 -0
  29. medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py +101 -0
  30. medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py +124 -0
  31. medrag_multi_modal/document_loader/text_loader/__init__.py +11 -0
  32. medrag_multi_modal/document_loader/text_loader/base_text_loader.py +185 -0
  33. medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +89 -0
  34. medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py +76 -0
  35. medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py +73 -0
  36. medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py +77 -0
  37. medrag_multi_modal/metrics/__init__.py +3 -0
  38. medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
  39. medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc +0 -0
  40. medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc +0 -0
  41. medrag_multi_modal/metrics/base.py +108 -0
  42. medrag_multi_modal/metrics/mmlu.py +24 -0
  43. medrag_multi_modal/retrieval/__init__.py +3 -0
  44. medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
  45. medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc +0 -0
  46. medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc +0 -0
  47. medrag_multi_modal/retrieval/colpali_retrieval.py +255 -0
  48. medrag_multi_modal/retrieval/common.py +21 -0
  49. medrag_multi_modal/retrieval/text_retrieval/__init__.py +11 -0
  50. medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Medrag Multi Modal
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: streamlit
7
- sdk_version: 1.40.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Multi-modal assistant for medical professionals
11
  ---
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MedRAG Multi-Modal
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: streamlit
7
+ sdk_version: "1.39.0"
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
+ # MedRAG Multi-Modal
12
 
13
+ Multi-modal RAG for medical docmain.
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
4
+ from medrag_multi_modal.retrieval.text_retrieval import (
5
+ BM25sRetriever,
6
+ ContrieverRetriever,
7
+ MedCPTRetriever,
8
+ NVEmbed2Retriever,
9
+ )
10
+
11
+ # Define constants
12
+ ALL_AVAILABLE_MODELS = [
13
+ "gemini-1.5-flash-latest",
14
+ "gemini-1.5-pro-latest",
15
+ "gpt-4o",
16
+ "gpt-4o-mini",
17
+ ]
18
+
19
+ # Sidebar for configuration settings
20
+ st.sidebar.title("Configuration Settings")
21
+ project_name = st.sidebar.text_input(
22
+ label="Project Name",
23
+ value="ml-colabs/medrag-multi-modal",
24
+ placeholder="wandb project name",
25
+ help="format: wandb_username/wandb_project_name",
26
+ )
27
+ chunk_dataset_id = st.sidebar.selectbox(
28
+ label="Chunk Dataset ID",
29
+ options=["ashwiniai/medrag-text-corpus-chunks"],
30
+ )
31
+ llm_model = st.sidebar.selectbox(
32
+ label="LLM Model",
33
+ options=ALL_AVAILABLE_MODELS,
34
+ )
35
+ top_k_chunks_for_query = st.sidebar.slider(
36
+ label="Top K Chunks for Query",
37
+ min_value=1,
38
+ max_value=20,
39
+ value=5,
40
+ )
41
+ top_k_chunks_for_options = st.sidebar.slider(
42
+ label="Top K Chunks for Options",
43
+ min_value=1,
44
+ max_value=20,
45
+ value=3,
46
+ )
47
+ rely_only_on_context = st.sidebar.checkbox(
48
+ label="Rely Only on Context",
49
+ value=False,
50
+ )
51
+ retriever_type = st.sidebar.selectbox(
52
+ label="Retriever Type",
53
+ options=[
54
+ "",
55
+ "BM25S",
56
+ "Contriever",
57
+ "MedCPT",
58
+ "NV-Embed-v2",
59
+ ],
60
+ )
61
+
62
+ if retriever_type != "":
63
+
64
+ llm_model = LLMClient(model_name=llm_model)
65
+
66
+ retriever = None
67
+
68
+ if retriever_type == "BM25S":
69
+ retriever = BM25sRetriever.from_index(
70
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
71
+ )
72
+ elif retriever_type == "Contriever":
73
+ retriever = ContrieverRetriever.from_index(
74
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
75
+ chunk_dataset_id=chunk_dataset_id,
76
+ )
77
+ elif retriever_type == "MedCPT":
78
+ retriever = MedCPTRetriever.from_index(
79
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
80
+ chunk_dataset_id=chunk_dataset_id,
81
+ )
82
+ elif retriever_type == "NV-Embed-v2":
83
+ retriever = NVEmbed2Retriever.from_index(
84
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
85
+ chunk_dataset_id=chunk_dataset_id,
86
+ )
87
+
88
+ medqa_assistant = MedQAAssistant(
89
+ llm_client=llm_model,
90
+ retriever=retriever,
91
+ top_k_chunks_for_query=top_k_chunks_for_query,
92
+ top_k_chunks_for_options=top_k_chunks_for_options,
93
+ )
94
+
95
+ with st.chat_message("assistant"):
96
+ st.markdown(
97
+ """
98
+ Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
99
+ I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
100
+
101
+ **Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
102
+ Please consult a medical professional for any medical advice.
103
+
104
+ In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
105
+ """,
106
+ unsafe_allow_html=True,
107
+ )
108
+ query = st.chat_input("Enter your question here")
109
+ if query:
110
+ with st.chat_message("user"):
111
+ st.markdown(query)
112
+ response = medqa_assistant.predict(query=query)
113
+ with st.chat_message("assistant"):
114
+ st.markdown(response.response)
medrag_multi_modal/__init__.py ADDED
File without changes
medrag_multi_modal/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
medrag_multi_modal/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (184 Bytes). View file
 
medrag_multi_modal/__pycache__/cli.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
medrag_multi_modal/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
medrag_multi_modal/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.47 kB). View file
 
medrag_multi_modal/assistant/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .figure_annotation import FigureAnnotatorFromPageImage
2
+ from .llm_client import ClientType, LLMClient
3
+ from .medqa_assistant import MedQAAssistant
4
+
5
+ __all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (426 Bytes). View file
 
medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (424 Bytes). View file
 
medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc ADDED
Binary file (6.68 kB). View file
 
medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc ADDED
Binary file (6.67 kB). View file
 
medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc ADDED
Binary file (7.3 kB). View file
 
medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc ADDED
Binary file (7.06 kB). View file
 
medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
medrag_multi_modal/assistant/figure_annotation.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from typing import Optional, Union
4
+
5
+ import cv2
6
+ import weave
7
+ from PIL import Image
8
+
9
+ from medrag_multi_modal.assistant.llm_client import LLMClient
10
+ from medrag_multi_modal.assistant.schema import FigureAnnotations
11
+ from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file
12
+
13
+
14
+ class FigureAnnotatorFromPageImage(weave.Model):
15
+ """
16
+ `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
17
+ figures from a page image of a scientific textbook.
18
+
19
+ !!! example "Example Usage"
20
+ ```python
21
+ import weave
22
+ from dotenv import load_dotenv
23
+
24
+ from medrag_multi_modal.assistant import (
25
+ FigureAnnotatorFromPageImage, LLMClient
26
+ )
27
+
28
+ load_dotenv()
29
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
30
+ figure_annotator = FigureAnnotatorFromPageImage(
31
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
32
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
33
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
34
+ )
35
+ annotations = figure_annotator.predict(page_idx=34)
36
+ ```
37
+
38
+ Args:
39
+ figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
40
+ from the page image.
41
+ structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
42
+ annotations into a structured format.
43
+ image_artifact_address (Optional[str]): The address of the image artifact containing the
44
+ page images.
45
+ """
46
+
47
+ figure_extraction_llm_client: LLMClient
48
+ structured_output_llm_client: LLMClient
49
+ _artifact_dir: str
50
+
51
+ def __init__(
52
+ self,
53
+ figure_extraction_llm_client: LLMClient,
54
+ structured_output_llm_client: LLMClient,
55
+ image_artifact_address: Optional[str] = None,
56
+ ):
57
+ super().__init__(
58
+ figure_extraction_llm_client=figure_extraction_llm_client,
59
+ structured_output_llm_client=structured_output_llm_client,
60
+ )
61
+ self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
62
+
63
+ @weave.op()
64
+ def annotate_figures(
65
+ self, page_image: Image.Image
66
+ ) -> dict[str, Union[Image.Image, str]]:
67
+ annotation = self.figure_extraction_llm_client.predict(
68
+ system_prompt="""
69
+ You are an expert in the domain of scientific textbooks, especially medical texts.
70
+ You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
71
+ You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc.
72
+ Then you are to identify the figure IDs associated with each figure in the page image.
73
+ Then, you are to extract only the exact figure descriptions from the page image.
74
+ You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object.
75
+
76
+ Here are some clues you need to follow:
77
+ 1. Figure IDs are unique identifiers for each figure in the page image.
78
+ 2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
79
+ 3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
80
+ 4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
81
+ 5. The text in the page image is written in English and is present in a two-column format.
82
+ 6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space.
83
+ You are to carefully identify all the figures in the page image.
84
+ 7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side
85
+ or one above the other.
86
+ 8. The figures may or may not have a distinct border against a white background.
87
+ 10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is.
88
+ """,
89
+ user_prompt=[page_image],
90
+ )
91
+ return {"page_image": page_image, "annotations": annotation}
92
+
93
+ @weave.op
94
+ def extract_structured_output(self, annotations: str) -> FigureAnnotations:
95
+ return self.structured_output_llm_client.predict(
96
+ system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
97
+ user_prompt=[annotations],
98
+ schema=FigureAnnotations,
99
+ )
100
+
101
+ @weave.op()
102
+ def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]:
103
+ """
104
+ Predicts figure annotations for a specific page in a document.
105
+
106
+ This function retrieves the artifact directory from the given image artifact address,
107
+ reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
108
+ to find the specified page index. If the page index matches, it reads the page image
109
+ and associated figure images, and then uses the `annotate_figures` method to extract
110
+ figure annotations from the page image. The extracted annotations are then structured
111
+ using the `extract_structured_output` method and returned as a dictionary.
112
+
113
+ Args:
114
+ page_idx (int): The index of the page to annotate.
115
+ image_artifact_address (str): The address of the image artifact containing the
116
+ page images.
117
+
118
+ Returns:
119
+ dict: A dictionary containing the page index as the key and the extracted figure
120
+ annotations as the value.
121
+ """
122
+
123
+ metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
124
+ annotations = {}
125
+ for item in metadata:
126
+ if item["page_idx"] == page_idx:
127
+ page_image_file = os.path.join(
128
+ self._artifact_dir, f"page{item['page_idx']}.png"
129
+ )
130
+ figure_image_files = glob(
131
+ os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
132
+ )
133
+ if len(figure_image_files) > 0:
134
+ page_image = cv2.imread(page_image_file)
135
+ page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
136
+ page_image = Image.fromarray(page_image)
137
+ figure_extracted_annotations = self.annotate_figures(
138
+ page_image=page_image
139
+ )
140
+ figure_extracted_annotations = self.extract_structured_output(
141
+ figure_extracted_annotations["annotations"]
142
+ ).model_dump()
143
+ annotations[item["page_idx"]] = figure_extracted_annotations[
144
+ "annotations"
145
+ ]
146
+ break
147
+ return annotations
medrag_multi_modal/assistant/llm_client.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from enum import Enum
4
+ from typing import Any, Optional, Union
5
+
6
+ import instructor
7
+ import weave
8
+ from PIL import Image
9
+
10
+ from ..utils import base64_encode_image
11
+
12
+
13
+ class ClientType(str, Enum):
14
+ GEMINI = "gemini"
15
+ MISTRAL = "mistral"
16
+ OPENAI = "openai"
17
+
18
+
19
+ GOOGLE_MODELS = [
20
+ "gemini-1.0-pro-latest",
21
+ "gemini-1.0-pro",
22
+ "gemini-pro",
23
+ "gemini-1.0-pro-001",
24
+ "gemini-1.0-pro-vision-latest",
25
+ "gemini-pro-vision",
26
+ "gemini-1.5-pro-latest",
27
+ "gemini-1.5-pro-001",
28
+ "gemini-1.5-pro-002",
29
+ "gemini-1.5-pro",
30
+ "gemini-1.5-pro-exp-0801",
31
+ "gemini-1.5-pro-exp-0827",
32
+ "gemini-1.5-flash-latest",
33
+ "gemini-1.5-flash-001",
34
+ "gemini-1.5-flash-001-tuning",
35
+ "gemini-1.5-flash",
36
+ "gemini-1.5-flash-exp-0827",
37
+ "gemini-1.5-flash-002",
38
+ "gemini-1.5-flash-8b",
39
+ "gemini-1.5-flash-8b-001",
40
+ "gemini-1.5-flash-8b-latest",
41
+ "gemini-1.5-flash-8b-exp-0827",
42
+ "gemini-1.5-flash-8b-exp-0924",
43
+ ]
44
+
45
+ MISTRAL_MODELS = [
46
+ "ministral-3b-latest",
47
+ "ministral-8b-latest",
48
+ "mistral-large-latest",
49
+ "mistral-small-latest",
50
+ "codestral-latest",
51
+ "pixtral-12b-2409",
52
+ "open-mistral-nemo",
53
+ "open-codestral-mamba",
54
+ "open-mistral-7b",
55
+ "open-mixtral-8x7b",
56
+ "open-mixtral-8x22b",
57
+ ]
58
+
59
+ OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
60
+
61
+
62
+ class LLMClient(weave.Model):
63
+ """
64
+ LLMClient is a class that interfaces with different large language model (LLM) providers
65
+ such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
66
+ these different APIs and provides a unified interface for making predictions.
67
+
68
+ Args:
69
+ model_name (str): The name of the model to be used for predictions.
70
+ client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
71
+ If not provided, it is inferred from the model_name.
72
+ """
73
+
74
+ model_name: str
75
+ client_type: Optional[ClientType]
76
+
77
+ def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
78
+ if client_type is None:
79
+ if model_name in GOOGLE_MODELS:
80
+ client_type = ClientType.GEMINI
81
+ elif model_name in MISTRAL_MODELS:
82
+ client_type = ClientType.MISTRAL
83
+ elif model_name in OPENAI_MODELS:
84
+ client_type = ClientType.OPENAI
85
+ else:
86
+ raise ValueError(f"Invalid model name: {model_name}")
87
+ super().__init__(model_name=model_name, client_type=client_type)
88
+
89
+ @weave.op()
90
+ def execute_gemini_sdk(
91
+ self,
92
+ user_prompt: Union[str, list[str]],
93
+ system_prompt: Optional[Union[str, list[str]]] = None,
94
+ schema: Optional[Any] = None,
95
+ ) -> Union[str, Any]:
96
+ import google.generativeai as genai
97
+ from google.generativeai.types import HarmBlockThreshold, HarmCategory
98
+
99
+ system_prompt = (
100
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
101
+ )
102
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
103
+
104
+ genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
105
+ model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
106
+ generation_config = (
107
+ None
108
+ if schema is None
109
+ else genai.GenerationConfig(
110
+ response_mime_type="application/json", response_schema=schema
111
+ )
112
+ )
113
+ response = model.generate_content(
114
+ user_prompt,
115
+ generation_config=generation_config,
116
+ # This is necessary in order to answer questions about anatomy, sexual diseases,
117
+ # medical devices, medicines, etc.
118
+ safety_settings={
119
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
120
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
121
+ },
122
+ )
123
+ return response.text if schema is None else json.loads(response.text)
124
+
125
+ @weave.op()
126
+ def execute_mistral_sdk(
127
+ self,
128
+ user_prompt: Union[str, list[str]],
129
+ system_prompt: Optional[Union[str, list[str]]] = None,
130
+ schema: Optional[Any] = None,
131
+ ) -> Union[str, Any]:
132
+ from mistralai import Mistral
133
+
134
+ system_prompt = (
135
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
136
+ )
137
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
138
+ system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
139
+ user_messages = []
140
+ for prompt in user_prompt:
141
+ if isinstance(prompt, Image.Image):
142
+ user_messages.append(
143
+ {
144
+ "type": "image_url",
145
+ "image_url": base64_encode_image(prompt, "image/png"),
146
+ }
147
+ )
148
+ else:
149
+ user_messages.append({"type": "text", "text": prompt})
150
+ messages = [
151
+ {"role": "system", "content": system_messages},
152
+ {"role": "user", "content": user_messages},
153
+ ]
154
+
155
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
156
+ client = instructor.from_mistral(client) if schema is not None else client
157
+
158
+ if schema is None:
159
+ raise NotImplementedError(
160
+ "Mistral does not support structured output using a schema"
161
+ )
162
+ else:
163
+ response = client.chat.complete(model=self.model_name, messages=messages)
164
+ return response.choices[0].message.content
165
+
166
+ @weave.op()
167
+ def execute_openai_sdk(
168
+ self,
169
+ user_prompt: Union[str, list[str]],
170
+ system_prompt: Optional[Union[str, list[str]]] = None,
171
+ schema: Optional[Any] = None,
172
+ ) -> Union[str, Any]:
173
+ from openai import OpenAI
174
+
175
+ system_prompt = (
176
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
177
+ )
178
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
179
+
180
+ system_messages = [
181
+ {"role": "system", "content": prompt} for prompt in system_prompt
182
+ ]
183
+ user_messages = []
184
+ for prompt in user_prompt:
185
+ if isinstance(prompt, Image.Image):
186
+ user_messages.append(
187
+ {
188
+ "type": "image_url",
189
+ "image_url": {
190
+ "url": base64_encode_image(prompt, "image/png"),
191
+ },
192
+ },
193
+ )
194
+ else:
195
+ user_messages.append({"type": "text", "text": prompt})
196
+ messages = system_messages + [{"role": "user", "content": user_messages}]
197
+
198
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
199
+
200
+ if schema is None:
201
+ completion = client.chat.completions.create(
202
+ model=self.model_name, messages=messages
203
+ )
204
+ return completion.choices[0].message.content
205
+
206
+ completion = weave.op()(client.beta.chat.completions.parse)(
207
+ model=self.model_name, messages=messages, response_format=schema
208
+ )
209
+ return completion.choices[0].message.parsed
210
+
211
+ @weave.op()
212
+ def predict(
213
+ self,
214
+ user_prompt: Union[str, list[str]],
215
+ system_prompt: Optional[Union[str, list[str]]] = None,
216
+ schema: Optional[Any] = None,
217
+ ) -> Union[str, Any]:
218
+ """
219
+ Predicts the response from a language model based on the provided prompts and schema.
220
+
221
+ This function determines the client type and calls the appropriate SDK execution function
222
+ to get the response from the language model. It supports multiple client types including
223
+ GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
224
+ execution function with the provided user and system prompts, and an optional schema.
225
+
226
+ Args:
227
+ user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
228
+ system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
229
+ schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
230
+
231
+ Returns:
232
+ Union[str, Any]: The response from the language model, which could be a string or any other type
233
+ depending on the schema provided.
234
+
235
+ Raises:
236
+ ValueError: If the client type is invalid.
237
+ """
238
+ if self.client_type == ClientType.GEMINI:
239
+ return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
240
+ elif self.client_type == ClientType.MISTRAL:
241
+ return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
242
+ elif self.client_type == ClientType.OPENAI:
243
+ return self.execute_openai_sdk(user_prompt, system_prompt, schema)
244
+ else:
245
+ raise ValueError(f"Invalid client type: {self.client_type}")
medrag_multi_modal/assistant/medqa_assistant.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import weave
4
+
5
+ from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
6
+ from medrag_multi_modal.assistant.llm_client import LLMClient
7
+ from medrag_multi_modal.assistant.schema import (
8
+ MedQACitation,
9
+ MedQAMCQResponse,
10
+ MedQAResponse,
11
+ )
12
+ from medrag_multi_modal.retrieval.common import SimilarityMetric
13
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
14
+
15
+
16
+ class MedQAAssistant(weave.Model):
17
+ """
18
+ `MedQAAssistant` is a class designed to assist with medical queries by leveraging a
19
+ language model client, a retriever model, and a figure annotator.
20
+
21
+ !!! example "Usage Example"
22
+ ```python
23
+ import weave
24
+ from dotenv import load_dotenv
25
+
26
+ from medrag_multi_modal.assistant import (
27
+ FigureAnnotatorFromPageImage,
28
+ LLMClient,
29
+ MedQAAssistant,
30
+ )
31
+ from medrag_multi_modal.retrieval import MedCPTRetriever
32
+
33
+ load_dotenv()
34
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
35
+
36
+ llm_client = LLMClient(model_name="gemini-1.5-flash")
37
+
38
+ retriever=MedCPTRetriever.from_wandb_artifact(
39
+ chunk_dataset_name="grays-anatomy-chunks:v0",
40
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
41
+ )
42
+
43
+ figure_annotator=FigureAnnotatorFromPageImage(
44
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
45
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
46
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
47
+ )
48
+ medqa_assistant = MedQAAssistant(
49
+ llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
50
+ )
51
+ medqa_assistant.predict(query="What is ribosome?")
52
+ ```
53
+
54
+ Args:
55
+ llm_client (LLMClient): The language model client used to generate responses.
56
+ retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
57
+ figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
58
+ top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
59
+ top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
60
+ retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
61
+ """
62
+
63
+ llm_client: LLMClient
64
+ retriever: weave.Model
65
+ figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
66
+ top_k_chunks_for_query: int = 2
67
+ top_k_chunks_for_options: int = 2
68
+ rely_only_on_context: bool = True
69
+ retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
70
+
71
+ @weave.op()
72
+ def retrieve_chunks_for_query(self, query: str) -> list[dict]:
73
+ retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
74
+ if not isinstance(self.retriever, BM25sRetriever):
75
+ retriever_kwargs["metric"] = self.retrieval_similarity_metric
76
+ return self.retriever.predict(query, **retriever_kwargs)
77
+
78
+ @weave.op()
79
+ def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
80
+ retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
81
+ if not isinstance(self.retriever, BM25sRetriever):
82
+ retriever_kwargs["metric"] = self.retrieval_similarity_metric
83
+ retrieved_chunks = []
84
+ for option in options:
85
+ retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
86
+ return retrieved_chunks
87
+
88
+ @weave.op()
89
+ def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
90
+ """
91
+ Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
92
+ from a medical document and using a language model to generate the final response.
93
+
94
+ This function performs the following steps:
95
+ 1. Retrieves relevant text chunks from the medical document based on the query and any provided options
96
+ using the retriever model.
97
+ 2. Extracts the text and page indices from the retrieved chunks.
98
+ 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
99
+ 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
100
+ and figure descriptions.
101
+ 5. Uses the language model client to generate a response based on the constructed prompts, either choosing
102
+ from provided options or generating a free-form response.
103
+ 6. Returns the generated response, which includes the answer and explanation if options were provided.
104
+
105
+ The function can operate in two modes:
106
+ - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
107
+ - Free response: When no options are provided, it generates a comprehensive response based on the context
108
+
109
+ Args:
110
+ query (str): The medical query to be answered.
111
+ options (Optional[list[str]]): The list of options to choose from.
112
+ rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
113
+
114
+ Returns:
115
+ MedQAResponse: The generated response to the query, including source information.
116
+ """
117
+ retrieved_chunks = self.retrieve_chunks_for_query(query)
118
+ options = options or []
119
+ retrieved_chunks += self.retrieve_chunks_for_options(options)
120
+
121
+ retrieved_chunk_texts = []
122
+ page_indices = set()
123
+ for chunk in retrieved_chunks:
124
+ retrieved_chunk_texts.append(chunk["text"])
125
+ page_indices.add(int(chunk["page_idx"]))
126
+
127
+ figure_descriptions = []
128
+ if self.figure_annotator is not None:
129
+ for page_idx in page_indices:
130
+ figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
131
+ page_idx
132
+ ]
133
+ figure_descriptions += [
134
+ item["figure_description"] for item in figure_annotations
135
+ ]
136
+
137
+ system_prompt = """You are an expert in medical science. You are given a question
138
+ and a list of excerpts from various medical documents.
139
+ """
140
+ query = f"""# Question
141
+ {query}
142
+ """
143
+
144
+ if len(options) > 0:
145
+ system_prompt += """\nYou are also given a list of options to choose your answer from.
146
+ You are supposed to choose the best possible option based on the context provided. You should also
147
+ explain your answer to justify why you chose that option.
148
+ """
149
+ query += "## Options\n"
150
+ for option in options:
151
+ query += f"- {option}\n"
152
+ else:
153
+ system_prompt += "\nYou are supposed to answer the question based on the context provided."
154
+
155
+ if self.rely_only_on_context:
156
+ system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
157
+ You are not allowed to use any external knowledge to answer the question.
158
+ """
159
+
160
+ response = self.llm_client.predict(
161
+ system_prompt=system_prompt,
162
+ user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
163
+ schema=MedQAMCQResponse if len(options) > 0 else None,
164
+ )
165
+
166
+ # TODO: Add figure citations
167
+ # TODO: Add source document name from retrieved chunks as citations
168
+ citations = []
169
+ for page_idx in page_indices:
170
+ citations.append(
171
+ MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
172
+ )
173
+
174
+ return MedQAResponse(response=response, citations=citations)
medrag_multi_modal/assistant/schema.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class FigureAnnotation(BaseModel):
7
+ figure_id: str
8
+ figure_description: str
9
+
10
+
11
+ class FigureAnnotations(BaseModel):
12
+ annotations: list[FigureAnnotation]
13
+
14
+
15
+ class MedQAMCQResponse(BaseModel):
16
+ answer: str
17
+ explanation: str
18
+
19
+
20
+ class MedQACitation(BaseModel):
21
+ page_number: int
22
+ document_name: str
23
+
24
+
25
+ class MedQAResponse(BaseModel):
26
+ response: Union[str, MedQAMCQResponse]
27
+ citations: list[MedQACitation]
medrag_multi_modal/cli.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+ import sys
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
9
+ subparsers = parser.add_subparsers(dest="command", required=True)
10
+
11
+ # Run subcommand
12
+ run_parser = subparsers.add_parser("run", help="Run the Streamlit application")
13
+ run_parser.add_argument(
14
+ "--port", type=int, default=8501, help="Port to run Streamlit on"
15
+ )
16
+
17
+ # Evaluate subcommand
18
+ eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests")
19
+ eval_parser.add_argument(
20
+ "--test-file",
21
+ default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"),
22
+ help="Path to test file",
23
+ )
24
+ eval_parser.add_argument(
25
+ "--test-case",
26
+ type=str,
27
+ help="Only run tests which match the given substring expression",
28
+ )
29
+ eval_parser.add_argument(
30
+ "--model-name",
31
+ type=str,
32
+ default="gemini-1.5-flash",
33
+ help="Model name to use for evaluation",
34
+ )
35
+
36
+ args = parser.parse_args()
37
+
38
+ if args.command == "run":
39
+ subprocess.run(
40
+ [
41
+ sys.executable,
42
+ "-m",
43
+ "streamlit",
44
+ "run",
45
+ "app.py",
46
+ "--server.port",
47
+ str(args.port),
48
+ ]
49
+ )
50
+
51
+ elif args.command == "evaluate":
52
+ test_file = (
53
+ args.test_file + "::" + args.test_case if args.test_case else args.test_file
54
+ )
55
+ cmd = [
56
+ sys.executable,
57
+ "-m",
58
+ "pytest",
59
+ "-s",
60
+ test_file,
61
+ "-v",
62
+ f"--model-name={args.model_name}",
63
+ ]
64
+ subprocess.run(cmd)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
medrag_multi_modal/document_loader/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_loader import (
2
+ FitzPILImageLoader,
3
+ MarkerImageLoader,
4
+ PDF2ImageLoader,
5
+ PDFPlumberImageLoader,
6
+ PyMuPDFImageLoader,
7
+ )
8
+ from .text_loader import (
9
+ MarkerTextLoader,
10
+ PDFPlumberTextLoader,
11
+ PyMuPDF4LLMTextLoader,
12
+ PyPDF2TextLoader,
13
+ )
14
+
15
+ __all__ = [
16
+ "PyMuPDF4LLMTextLoader",
17
+ "PyPDF2TextLoader",
18
+ "PDFPlumberTextLoader",
19
+ "MarkerTextLoader",
20
+ "PDF2ImageLoader",
21
+ "MarkerImageLoader",
22
+ "PDFPlumberImageLoader",
23
+ "PyMuPDFImageLoader",
24
+ "FitzPILImageLoader",
25
+ ]
medrag_multi_modal/document_loader/image_loader/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fitzpil_img_loader import FitzPILImageLoader
2
+ from .marker_img_loader import MarkerImageLoader
3
+ from .pdf2image_img_loader import PDF2ImageLoader
4
+ from .pdfplumber_img_loader import PDFPlumberImageLoader
5
+ from .pymupdf_img_loader import PyMuPDFImageLoader
6
+
7
+ __all__ = [
8
+ "PDF2ImageLoader",
9
+ "MarkerImageLoader",
10
+ "PDFPlumberImageLoader",
11
+ "PyMuPDFImageLoader",
12
+ "FitzPILImageLoader",
13
+ ]
medrag_multi_modal/document_loader/image_loader/base_img_loader.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from abc import abstractmethod
4
+ from glob import glob
5
+ from typing import Dict, List, Optional
6
+
7
+ import huggingface_hub
8
+ import jsonlines
9
+ import rich
10
+ from datasets import (
11
+ Dataset,
12
+ Features,
13
+ Image,
14
+ Sequence,
15
+ Value,
16
+ concatenate_datasets,
17
+ load_dataset,
18
+ )
19
+
20
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
21
+ BaseTextLoader,
22
+ )
23
+
24
+
25
+ class BaseImageLoader(BaseTextLoader):
26
+ def __init__(self, url: str, document_name: str, document_file_path: str):
27
+ super().__init__(url, document_name, document_file_path)
28
+
29
+ @abstractmethod
30
+ async def extract_page_data(
31
+ self, page_idx: int, image_save_dir: str, **kwargs
32
+ ) -> Dict[str, str]:
33
+ """
34
+ Abstract method to process a single page of the PDF and extract the image data.
35
+
36
+ Overwrite this method in the subclass to provide the actual implementation and
37
+ processing logic for each page of the PDF using various PDF processing libraries.
38
+
39
+ Args:
40
+ page_idx (int): The index of the page to process.
41
+ image_save_dir (str): The directory to save the extracted images.
42
+ **kwargs: Additional keyword arguments that may be used by underlying libraries.
43
+
44
+ Returns:
45
+ Dict[str, str]: A dictionary containing the processed page data.
46
+ """
47
+ pass
48
+
49
+ def save_as_dataset(
50
+ self,
51
+ start_page: int,
52
+ end_page: int,
53
+ image_save_dir: str,
54
+ dataset_repo_id: Optional[str] = None,
55
+ overwrite_dataset: bool = False,
56
+ ):
57
+ features = Features(
58
+ {
59
+ "page_image": Image(decode=True),
60
+ "page_figure_images": Sequence(Image(decode=True)),
61
+ "document_name": Value(dtype="string"),
62
+ "page_idx": Value(dtype="int32"),
63
+ }
64
+ )
65
+
66
+ all_examples = []
67
+ for page_idx in range(start_page, end_page):
68
+ page_image_file_paths = glob(
69
+ os.path.join(image_save_dir, f"page{page_idx}*.png")
70
+ )
71
+ if len(page_image_file_paths) > 0:
72
+ page_image_path = page_image_file_paths[0]
73
+ figure_image_paths = [
74
+ image_file_path
75
+ for image_file_path in glob(
76
+ os.path.join(image_save_dir, f"page{page_idx}*_fig*.png")
77
+ )
78
+ ]
79
+
80
+ example = {
81
+ "page_image": page_image_path,
82
+ "page_figure_images": figure_image_paths,
83
+ "document_name": self.document_name,
84
+ "page_idx": page_idx,
85
+ }
86
+ all_examples.append(example)
87
+
88
+ dataset = Dataset.from_list(all_examples, features=features)
89
+
90
+ if dataset_repo_id:
91
+ if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
92
+ if not overwrite_dataset:
93
+ dataset = concatenate_datasets(
94
+ [dataset, load_dataset(dataset_repo_id)["corpus"]]
95
+ )
96
+
97
+ dataset.push_to_hub(dataset_repo_id, split="corpus")
98
+
99
+ return dataset
100
+
101
+ def cleanup_image_dir(self, image_save_dir: str = "./images"):
102
+ for file in os.listdir(image_save_dir):
103
+ file_path = os.path.join(image_save_dir, file)
104
+ if os.path.isfile(file_path):
105
+ os.remove(file_path)
106
+
107
+ async def load_data(
108
+ self,
109
+ start_page: Optional[int] = None,
110
+ end_page: Optional[int] = None,
111
+ dataset_repo_id: Optional[str] = None,
112
+ overwrite_dataset: bool = False,
113
+ image_save_dir: str = "./images",
114
+ exclude_file_extensions: list[str] = [],
115
+ **kwargs,
116
+ ) -> List[Dict[str, str]]:
117
+ """
118
+ Asynchronously loads images from a PDF file specified by a URL or local file path.
119
+ The overrided processing abstract method then processes the images,
120
+ and optionally publishes it to a WandB artifact.
121
+
122
+ This function downloads a PDF from a given URL if it does not already exist locally,
123
+ reads the specified range of pages, scans each page's content to extract images, and
124
+ returns a list of Page objects containing the images and metadata.
125
+
126
+ It uses `PyPDF2` to calculate the number of pages in the PDF and the
127
+ overriden `extract_page_data` method provides the actual implementation to process
128
+ each page, extract the image content from the PDF, and convert it to png format.
129
+ It processes pages concurrently using `asyncio` for efficiency.
130
+
131
+ If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
132
+
133
+ Args:
134
+ start_page (Optional[int]): The starting page index (0-based) to process.
135
+ end_page (Optional[int]): The ending page index (0-based) to process.
136
+ dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
137
+ overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
138
+ image_save_dir (str): The directory to save the extracted images.
139
+ exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
140
+ **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
141
+
142
+ Returns:
143
+ Dataset: A HuggingFace dataset containing the processed pages.
144
+
145
+ Raises:
146
+ ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
147
+ """
148
+ os.makedirs(image_save_dir, exist_ok=True)
149
+ start_page, end_page = self.get_page_indices(start_page, end_page)
150
+ pages = []
151
+ processed_pages_counter: int = 1
152
+ total_pages = end_page - start_page
153
+
154
+ async def process_page(page_idx):
155
+ nonlocal processed_pages_counter
156
+ page_data = await self.extract_page_data(page_idx, image_save_dir, **kwargs)
157
+ pages.append(page_data)
158
+ rich.print(
159
+ f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
160
+ )
161
+ processed_pages_counter += 1
162
+
163
+ tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)]
164
+ for task in asyncio.as_completed(tasks):
165
+ await task
166
+
167
+ with jsonlines.open(
168
+ os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
169
+ ) as writer:
170
+ writer.write(pages)
171
+
172
+ for file in os.listdir(image_save_dir):
173
+ if file.endswith(tuple(exclude_file_extensions)):
174
+ os.remove(os.path.join(image_save_dir, file))
175
+
176
+ dataset = self.save_as_dataset(
177
+ start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset
178
+ )
179
+
180
+ return dataset
medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import fitz
6
+ from pdf2image.pdf2image import convert_from_path
7
+ from PIL import Image, ImageOps, UnidentifiedImageError
8
+
9
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
10
+ BaseImageLoader,
11
+ )
12
+
13
+
14
+ class FitzPILImageLoader(BaseImageLoader):
15
+ """
16
+ `FitzPILImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
17
+ loading of pages from a PDF file as images using the fitz and PIL libraries.
18
+
19
+ This class provides functionality to extract images from a PDF file using fitz and PIL libraries,
20
+ and optionally publish these images to a WandB artifact.
21
+
22
+ !!! example "Example Usage"
23
+ ```python
24
+ import asyncio
25
+
26
+ from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
27
+
28
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
29
+
30
+ loader = FitzPILImageLoader(
31
+ url=URL,
32
+ document_name="Gray's Anatomy",
33
+ document_file_path="grays_anatomy.pdf",
34
+ )
35
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
36
+ ```
37
+
38
+ Args:
39
+ url (str): The URL of the PDF document.
40
+ document_name (str): The name of the document.
41
+ document_file_path (str): The path to the PDF file.
42
+ """
43
+
44
+ def __init__(self, url: str, document_name: str, document_file_path: str):
45
+ super().__init__(url, document_name, document_file_path)
46
+
47
+ async def extract_page_data(
48
+ self, page_idx: int, image_save_dir: str, **kwargs
49
+ ) -> Dict[str, Any]:
50
+ """
51
+ Extracts a single page from the PDF as an image using fitz and PIL libraries.
52
+
53
+ Args:
54
+ page_idx (int): The index of the page to process.
55
+ image_save_dir (str): The directory to save the extracted image.
56
+ **kwargs: Additional keyword arguments that may be used by fitz and PIL.
57
+
58
+ Returns:
59
+ Dict[str, Any]: A dictionary containing the processed page data.
60
+ The dictionary will have the following keys and values:
61
+
62
+ - "page_idx": (int) the index of the page.
63
+ - "document_name": (str) the name of the document.
64
+ - "file_path": (str) the local file path where the PDF is stored.
65
+ - "file_url": (str) the URL of the PDF file.
66
+ - "image_file_paths": (list) the local file paths where the images are stored.
67
+ """
68
+ image_file_paths = []
69
+
70
+ pdf_document = fitz.open(self.document_file_path)
71
+ page = pdf_document.load_page(page_idx)
72
+
73
+ images = page.get_images(full=True)
74
+ for img_idx, image in enumerate(images):
75
+ xref = image[0]
76
+ base_image = pdf_document.extract_image(xref)
77
+ image_bytes = base_image["image"]
78
+ image_ext = base_image["ext"]
79
+
80
+ try:
81
+ img = Image.open(io.BytesIO(image_bytes))
82
+
83
+ if img.mode in ["L"]:
84
+ # images in greyscale looks inverted, need to test on other PDFs
85
+ img = ImageOps.invert(img)
86
+
87
+ if img.mode == "CMYK":
88
+ img = img.convert("RGB")
89
+
90
+ if image_ext not in ["png", "jpg", "jpeg"]:
91
+ image_ext = "png"
92
+ image_file_name = f"page{page_idx}_fig{img_idx}.png"
93
+ image_file_path = os.path.join(image_save_dir, image_file_name)
94
+
95
+ img.save(image_file_path, format="PNG")
96
+ else:
97
+ image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
98
+ image_file_path = os.path.join(image_save_dir, image_file_name)
99
+
100
+ with open(image_file_path, "wb") as image_file:
101
+ image_file.write(image_bytes)
102
+
103
+ image_file_paths.append(image_file_path)
104
+
105
+ except (UnidentifiedImageError, OSError) as e:
106
+ print(
107
+ f"Skipping image at page {page_idx}, fig {img_idx} due to an error: {e}"
108
+ )
109
+ continue
110
+
111
+ pdf_document.close()
112
+
113
+ page_image = convert_from_path(
114
+ self.document_file_path,
115
+ first_page=page_idx + 1,
116
+ last_page=page_idx + 1,
117
+ **kwargs,
118
+ )[0]
119
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
120
+
121
+ return {
122
+ "page_idx": page_idx,
123
+ "document_name": self.document_name,
124
+ "file_path": self.document_file_path,
125
+ "file_url": self.url,
126
+ "image_file_paths": image_file_paths,
127
+ }
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Coroutine, Dict, List
3
+
4
+ from marker.convert import convert_single_pdf
5
+ from marker.models import load_all_models
6
+ from pdf2image.pdf2image import convert_from_path
7
+
8
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
9
+ BaseImageLoader,
10
+ )
11
+
12
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
13
+
14
+
15
+ class MarkerImageLoader(BaseImageLoader):
16
+ """
17
+ `MarkerImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
18
+ loading of pages from a PDF file as images using the marker library.
19
+
20
+ This class provides functionality to extract images from a PDF file using marker library,
21
+ and optionally publish these images to a WandB artifact.
22
+
23
+ !!! example "Example Usage"
24
+ ```python
25
+ import asyncio
26
+
27
+ from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
28
+
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
+
31
+ loader = MarkerImageLoader(
32
+ url=URL,
33
+ document_name="Gray's Anatomy",
34
+ document_file_path="grays_anatomy.pdf",
35
+ )
36
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
37
+ ```
38
+
39
+ Args:
40
+ url (str): The URL of the PDF document.
41
+ document_name (str): The name of the document.
42
+ document_file_path (str): The path to the PDF file.
43
+ save_page_image (bool): Whether to additionally save the image of the entire page.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ url: str,
49
+ document_name: str,
50
+ document_file_path: str,
51
+ save_page_image: bool = False,
52
+ ):
53
+ super().__init__(url, document_name, document_file_path)
54
+ self.save_page_image = save_page_image
55
+ self.model_lst = load_all_models()
56
+
57
+ async def extract_page_data(
58
+ self, page_idx: int, image_save_dir: str, **kwargs
59
+ ) -> Dict[str, Any]:
60
+ """
61
+ Extracts a single page from the PDF as an image using marker library.
62
+
63
+ Args:
64
+ page_idx (int): The index of the page to process.
65
+ image_save_dir (str): The directory to save the extracted image.
66
+ **kwargs: Additional keyword arguments that may be used by marker.
67
+
68
+ Returns:
69
+ Dict[str, Any]: A dictionary containing the processed page data.
70
+ The dictionary will have the following keys and values:
71
+
72
+ - "page_idx": (int) the index of the page.
73
+ - "document_name": (str) the name of the document.
74
+ - "file_path": (str) the local file path where the PDF is stored.
75
+ - "file_url": (str) the URL of the PDF file.
76
+ - "image_file_path": (str) the local file path where the image is stored.
77
+ """
78
+ _, images, _ = convert_single_pdf(
79
+ self.document_file_path,
80
+ self.model_lst,
81
+ max_pages=1,
82
+ batch_multiplier=1,
83
+ start_page=page_idx,
84
+ ocr_all_pages=True,
85
+ **kwargs,
86
+ )
87
+
88
+ image_file_paths = []
89
+ for img_idx, (_, image) in enumerate(images.items()):
90
+ image_file_name = f"page{page_idx}_fig{img_idx}.png"
91
+ image_file_path = os.path.join(image_save_dir, image_file_name)
92
+ image.save(image_file_path, "png")
93
+ image_file_paths.append(image_file_path)
94
+
95
+ page_image = convert_from_path(
96
+ self.document_file_path,
97
+ first_page=page_idx + 1,
98
+ last_page=page_idx + 1,
99
+ **kwargs,
100
+ )[0]
101
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
102
+
103
+ return {
104
+ "page_idx": page_idx,
105
+ "document_name": self.document_name,
106
+ "file_path": self.document_file_path,
107
+ "file_url": self.url,
108
+ "image_file_paths": os.path.join(image_save_dir, "*.png"),
109
+ }
110
+
111
+ def load_data(
112
+ self,
113
+ start_page: int | None = None,
114
+ end_page: int | None = None,
115
+ wandb_artifact_name: str | None = None,
116
+ image_save_dir: str = "./images",
117
+ exclude_file_extensions: list[str] = [],
118
+ cleanup: bool = False,
119
+ **kwargs,
120
+ ) -> Coroutine[Any, Any, List[Dict[str, str]]]:
121
+ start_page = start_page - 1 if start_page is not None else None
122
+ end_page = end_page - 1 if end_page is not None else None
123
+ return super().load_data(
124
+ start_page,
125
+ end_page,
126
+ wandb_artifact_name,
127
+ image_save_dir,
128
+ exclude_file_extensions,
129
+ cleanup,
130
+ **kwargs,
131
+ )
medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from pdf2image.pdf2image import convert_from_path
5
+
6
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
7
+ BaseImageLoader,
8
+ )
9
+
10
+
11
+ class PDF2ImageLoader(BaseImageLoader):
12
+ """
13
+ `PDF2ImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
14
+ loading of pages from a PDF file as images using the pdf2image library.
15
+
16
+ This class provides functionality to convert specific pages of a PDF document into images
17
+ and optionally publish these images to a WandB artifact.
18
+ It is like a snapshot image version of each of the pages from the PDF.
19
+
20
+ !!! example "Example Usage"
21
+ ```python
22
+ import asyncio
23
+
24
+ from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
25
+
26
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
27
+
28
+ loader = PDF2ImageLoader(
29
+ url=URL,
30
+ document_name="Gray's Anatomy",
31
+ document_file_path="grays_anatomy.pdf",
32
+ )
33
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
34
+ ```
35
+
36
+ Args:
37
+ url (str): The URL of the PDF document.
38
+ document_name (str): The name of the document.
39
+ document_file_path (str): The path to the PDF file.
40
+ """
41
+
42
+ def __init__(self, url: str, document_name: str, document_file_path: str):
43
+ super().__init__(url, document_name, document_file_path)
44
+
45
+ async def extract_page_data(
46
+ self, page_idx: int, image_save_dir: str, **kwargs
47
+ ) -> Dict[str, Any]:
48
+ """
49
+ Extracts a single page from the PDF as an image using pdf2image library.
50
+
51
+ Args:
52
+ page_idx (int): The index of the page to process.
53
+ image_save_dir (str): The directory to save the extracted image.
54
+ **kwargs: Additional keyword arguments that may be used by pdf2image.
55
+
56
+ Returns:
57
+ Dict[str, Any]: A dictionary containing the processed page data.
58
+ The dictionary will have the following keys and values:
59
+
60
+ - "page_idx": (int) the index of the page.
61
+ - "document_name": (str) the name of the document.
62
+ - "file_path": (str) the local file path where the PDF is stored.
63
+ - "file_url": (str) the URL of the PDF file.
64
+ - "image_file_path": (str) the local file path where the image is stored.
65
+ """
66
+ image = convert_from_path(
67
+ self.document_file_path,
68
+ first_page=page_idx + 1,
69
+ last_page=page_idx + 1,
70
+ **kwargs,
71
+ )[0]
72
+
73
+ image_file_name = f"page{page_idx}.png"
74
+ image_file_path = os.path.join(image_save_dir, image_file_name)
75
+ image.save(image_file_path)
76
+
77
+ return {
78
+ "page_idx": page_idx,
79
+ "document_name": self.document_name,
80
+ "file_path": self.document_file_path,
81
+ "file_url": self.url,
82
+ "image_file_path": image_file_path,
83
+ }
medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ import pdfplumber
5
+ from pdf2image.pdf2image import convert_from_path
6
+
7
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
8
+ BaseImageLoader,
9
+ )
10
+
11
+
12
+ class PDFPlumberImageLoader(BaseImageLoader):
13
+ """
14
+ `PDFPlumberImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
15
+ loading of pages from a PDF file as images using the pdfplumber library.
16
+
17
+ This class provides functionality to extract images from a PDF file using pdfplumber library,
18
+ and optionally publish these images to a WandB artifact.
19
+
20
+ !!! example "Example Usage"
21
+ ```python
22
+ import asyncio
23
+
24
+ from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
25
+
26
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
27
+
28
+ loader = PDFPlumberImageLoader(
29
+ url=URL,
30
+ document_name="Gray's Anatomy",
31
+ document_file_path="grays_anatomy.pdf",
32
+ )
33
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
34
+ ```
35
+
36
+ Args:
37
+ url (str): The URL of the PDF document.
38
+ document_name (str): The name of the document.
39
+ document_file_path (str): The path to the PDF file.
40
+ """
41
+
42
+ def __init__(self, url: str, document_name: str, document_file_path: str):
43
+ super().__init__(url, document_name, document_file_path)
44
+
45
+ async def extract_page_data(
46
+ self, page_idx: int, image_save_dir: str, **kwargs
47
+ ) -> Dict[str, Any]:
48
+ """
49
+ Extracts a single page from the PDF as an image using pdfplumber library.
50
+
51
+ Args:
52
+ page_idx (int): The index of the page to process.
53
+ image_save_dir (str): The directory to save the extracted image.
54
+ **kwargs: Additional keyword arguments that may be used by pdfplumber.
55
+
56
+ Returns:
57
+ Dict[str, Any]: A dictionary containing the processed page data.
58
+ The dictionary will have the following keys and values:
59
+
60
+ - "page_idx": (int) the index of the page.
61
+ - "document_name": (str) the name of the document.
62
+ - "file_path": (str) the local file path where the PDF is stored.
63
+ - "file_url": (str) the URL of the PDF file.
64
+ - "image_file_path": (str) the local file path where the image is stored.
65
+ """
66
+ with pdfplumber.open(self.document_file_path) as pdf:
67
+ page = pdf.pages[page_idx]
68
+ images = page.images
69
+
70
+ image_file_paths = []
71
+ for img_idx, image in enumerate(images):
72
+ extracted_image = page.crop(
73
+ (
74
+ image["x0"],
75
+ image["top"],
76
+ image["x1"],
77
+ image["bottom"],
78
+ )
79
+ ).to_image(resolution=300)
80
+
81
+ image_file_name = f"page{page_idx}_fig{img_idx}.png"
82
+ image_file_path = os.path.join(image_save_dir, image_file_name)
83
+
84
+ extracted_image.save(image_file_path, "png")
85
+ image_file_paths.append(image_file_path)
86
+
87
+ page_image = convert_from_path(
88
+ self.document_file_path,
89
+ first_page=page_idx + 1,
90
+ last_page=page_idx + 1,
91
+ **kwargs,
92
+ )[0]
93
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
94
+
95
+ return {
96
+ "page_idx": page_idx,
97
+ "document_name": self.document_name,
98
+ "file_path": self.document_file_path,
99
+ "file_url": self.url,
100
+ "image_file_paths": image_file_paths,
101
+ }
medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import fitz
6
+ from pdf2image.pdf2image import convert_from_path
7
+ from PIL import Image
8
+
9
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
10
+ BaseImageLoader,
11
+ )
12
+
13
+
14
+ class PyMuPDFImageLoader(BaseImageLoader):
15
+ """
16
+ `PyMuPDFImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
17
+ loading of pages from a PDF file as images using the pymupdf library.
18
+
19
+ This class provides functionality to extract images from a PDF file using pymupdf library,
20
+ and optionally publish these images to a WandB artifact.
21
+
22
+ !!! example "Example Usage"
23
+ ```python
24
+ import asyncio
25
+
26
+ from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
27
+
28
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
29
+
30
+ loader = PyMuPDFImageLoader(
31
+ url=URL,
32
+ document_name="Gray's Anatomy",
33
+ document_file_path="grays_anatomy.pdf",
34
+ )
35
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
36
+ ```
37
+
38
+ Args:
39
+ url (str): The URL of the PDF document.
40
+ document_name (str): The name of the document.
41
+ document_file_path (str): The path to the PDF file.
42
+ """
43
+
44
+ def __init__(self, url: str, document_name: str, document_file_path: str):
45
+ super().__init__(url, document_name, document_file_path)
46
+
47
+ async def extract_page_data(
48
+ self, page_idx: int, image_save_dir: str, **kwargs
49
+ ) -> Dict[str, Any]:
50
+ """
51
+ Extracts a single page from the PDF as an image using pymupdf library.
52
+
53
+ Args:
54
+ page_idx (int): The index of the page to process.
55
+ image_save_dir (str): The directory to save the extracted image.
56
+ **kwargs: Additional keyword arguments that may be used by pymupdf.
57
+
58
+ Returns:
59
+ Dict[str, Any]: A dictionary containing the processed page data.
60
+ The dictionary will have the following keys and values:
61
+
62
+ - "page_idx": (int) the index of the page.
63
+ - "document_name": (str) the name of the document.
64
+ - "file_path": (str) the local file path where the PDF is stored.
65
+ - "file_url": (str) the URL of the PDF file.
66
+ - "image_file_paths": (list) the local file paths where the images are stored.
67
+ """
68
+ image_file_paths = []
69
+
70
+ pdf_document = fitz.open(self.document_file_path)
71
+ page = pdf_document[page_idx]
72
+
73
+ images = page.get_images(full=True)
74
+ for img_idx, image in enumerate(images):
75
+ xref = image[0]
76
+ base_image = pdf_document.extract_image(xref)
77
+ image_bytes = base_image["image"]
78
+ image_ext = base_image["ext"]
79
+
80
+ if image_ext == "jb2":
81
+ image_ext = "png"
82
+ elif image_ext == "jpx":
83
+ image_ext = "jpg"
84
+
85
+ image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
86
+ image_file_path = os.path.join(image_save_dir, image_file_name)
87
+
88
+ # For JBIG2 and JPEG2000, we need to convert the image
89
+ if base_image["ext"] in ["jb2", "jpx"]:
90
+ try:
91
+ pix = fitz.Pixmap(image_bytes)
92
+ pix.save(image_file_path)
93
+ except Exception as err_fitz:
94
+ print(f"Error processing image with fitz: {err_fitz}")
95
+ # Fallback to using PIL for image conversion
96
+ try:
97
+ img = Image.open(io.BytesIO(image_bytes))
98
+ img.save(image_file_path)
99
+ except Exception as err_pil:
100
+ print(f"Failed to process image with PIL: {err_pil}")
101
+ continue # Skip this image if both methods fail
102
+ else:
103
+ with open(image_file_path, "wb") as image_file:
104
+ image_file.write(image_bytes)
105
+
106
+ image_file_paths.append(image_file_path)
107
+
108
+ pdf_document.close()
109
+
110
+ page_image = convert_from_path(
111
+ self.document_file_path,
112
+ first_page=page_idx + 1,
113
+ last_page=page_idx + 1,
114
+ **kwargs,
115
+ )[0]
116
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
117
+
118
+ return {
119
+ "page_idx": page_idx,
120
+ "document_name": self.document_name,
121
+ "file_path": self.document_file_path,
122
+ "file_url": self.url,
123
+ "image_file_paths": image_file_paths,
124
+ }
medrag_multi_modal/document_loader/text_loader/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .marker_text_loader import MarkerTextLoader
2
+ from .pdfplumber_text_loader import PDFPlumberTextLoader
3
+ from .pymupdf4llm_text_loader import PyMuPDF4LLMTextLoader
4
+ from .pypdf2_text_loader import PyPDF2TextLoader
5
+
6
+ __all__ = [
7
+ "PyMuPDF4LLMTextLoader",
8
+ "PyPDF2TextLoader",
9
+ "PDFPlumberTextLoader",
10
+ "MarkerTextLoader",
11
+ ]
medrag_multi_modal/document_loader/text_loader/base_text_loader.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, Optional
5
+
6
+ import huggingface_hub
7
+ import PyPDF2
8
+ from datasets import Dataset, concatenate_datasets, load_dataset
9
+ from firerequests import FireRequests
10
+ from rich.progress import Progress
11
+
12
+
13
+ class BaseTextLoader(ABC):
14
+ """
15
+ An abstract base class for loading text from a PDF file, processing it into markdown, and optionally publishing it to a Weave dataset.
16
+
17
+ This class handles the downloading of a PDF file from a given URL if it does not already exist locally.
18
+ Subclasses should implement the specific PDF reading, text extraction, and markdown conversion methods.
19
+
20
+ The processed pages are finally stored in a list of Page objects, which can be optionally published to a Weave dataset.
21
+
22
+ Args:
23
+ url (str): The URL of the PDF file to download if not present locally.
24
+ document_name (str): The name of the document for metadata purposes.
25
+ document_file_path (str): The local file path where the PDF is stored or will be downloaded.
26
+ metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ url: str,
32
+ document_name: str,
33
+ document_file_path: str,
34
+ metadata: Optional[dict[str, Any]] = None,
35
+ ):
36
+ self.url = url
37
+ self.document_name = document_name
38
+ self.document_file_path = document_file_path
39
+ self.metadata = metadata or {}
40
+ if not os.path.exists(self.document_file_path):
41
+ FireRequests().download(url, filenames=self.document_file_path)
42
+ with open(self.document_file_path, "rb") as file:
43
+ pdf_reader = PyPDF2.PdfReader(file)
44
+ self.page_count = len(pdf_reader.pages)
45
+
46
+ def get_page_indices(
47
+ self, start_page: Optional[int] = None, end_page: Optional[int] = None
48
+ ) -> tuple[int, int]:
49
+ """
50
+ Get the start and end page indices for processing.
51
+
52
+ Args:
53
+ start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
54
+ end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
55
+
56
+ Returns:
57
+ tuple[int, int]: A tuple containing the start and end page indices.
58
+ """
59
+
60
+ if start_page:
61
+ if start_page > self.page_count:
62
+ raise ValueError(
63
+ f"Start page {start_page} is greater than the total page count {self.page_count}"
64
+ )
65
+ else:
66
+ start_page = 0
67
+ if end_page:
68
+ if end_page > self.page_count:
69
+ raise ValueError(
70
+ f"End page {end_page} is greater than the total page count {self.page_count}"
71
+ )
72
+ else:
73
+ end_page = self.page_count - 1
74
+ return start_page, end_page
75
+
76
+ @abstractmethod
77
+ async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
78
+ """
79
+ Abstract method to process a single page of the PDF and extract the text data.
80
+
81
+ Overwrite this method in the subclass to provide the actual implementation and
82
+ processing logic for each page of the PDF using various PDF processing libraries.
83
+
84
+ Args:
85
+ page_idx (int): The index of the page to process.
86
+ **kwargs: Additional keyword arguments that may be used by underlying libraries.
87
+
88
+ Returns:
89
+ Dict[str, str]: A dictionary containing the processed page data.
90
+ """
91
+ pass
92
+
93
+ async def load_data(
94
+ self,
95
+ start_page: Optional[int] = None,
96
+ end_page: Optional[int] = None,
97
+ exclude_pages: Optional[list[int]] = None,
98
+ dataset_repo_id: Optional[str] = None,
99
+ overwrite_dataset: bool = False,
100
+ **kwargs,
101
+ ) -> Dataset:
102
+ """
103
+ Asynchronously loads text from a PDF file specified by a URL or local file path.
104
+ The overrided processing abstract method then processes the text into markdown format,
105
+ and optionally publishes it to a Weave dataset.
106
+
107
+ This function downloads a PDF from a given URL if it does not already exist locally,
108
+ reads the specified range of pages, converts each page's content to markdown, and
109
+ returns a list of Page objects containing the text and metadata.
110
+
111
+ It uses `PyPDF2` to calculate the number of pages in the PDF and the
112
+ overriden `extract_page_data` method provides the actual implementation to process
113
+ each page, extract the text from the PDF, and convert it to markdown.
114
+ It processes pages concurrently using `asyncio` for efficiency.
115
+
116
+ If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset.
117
+
118
+ Args:
119
+ start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
120
+ end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
121
+ exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing.
122
+ dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
123
+ overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
124
+ **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
125
+
126
+ Returns:
127
+ Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages.
128
+ Each entry in the dataset will have the following keys and values:
129
+
130
+ - "text": (str) the processed page data in markdown format.
131
+ - "page_idx": (int) the index of the page.
132
+ - "document_name": (str) the name of the document.
133
+ - "file_path": (str) the local file path where the PDF is stored.
134
+ - "file_url": (str) the URL of the PDF file.
135
+ - "loader_name": (str) the name of the loader class used to process the page.
136
+
137
+ Raises:
138
+ ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
139
+ """
140
+ start_page, end_page = self.get_page_indices(start_page, end_page)
141
+ pages = []
142
+ processed_pages_counter: int = 1
143
+ total_pages = end_page - start_page
144
+ exclude_pages = exclude_pages or []
145
+
146
+ async def process_page(page_idx):
147
+ nonlocal processed_pages_counter
148
+ page_data = await self.extract_page_data(page_idx, **kwargs)
149
+ page_data["loader_name"] = self.__class__.__name__
150
+ for key, value in self.metadata.items():
151
+ if key not in page_data:
152
+ page_data[key] = value
153
+ pages.append(page_data)
154
+ progress.update(
155
+ task_id,
156
+ advance=1,
157
+ description=f"Loading page {page_idx} using {self.__class__.__name__}",
158
+ )
159
+ processed_pages_counter += 1
160
+
161
+ progress = Progress()
162
+ with progress:
163
+ task_id = progress.add_task("Starting...", total=total_pages)
164
+ tasks = [
165
+ process_page(page_idx)
166
+ for page_idx in range(start_page, end_page + 1)
167
+ if page_idx not in exclude_pages
168
+ ]
169
+ for task in asyncio.as_completed(tasks):
170
+ await task
171
+
172
+ pages.sort(key=lambda x: x["page_idx"])
173
+
174
+ dataset = Dataset.from_list(pages)
175
+ if dataset_repo_id:
176
+ if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
177
+ print("Dataset already exists")
178
+ if not overwrite_dataset:
179
+ print("Not overwriting dataset")
180
+ dataset = concatenate_datasets(
181
+ [dataset, load_dataset(dataset_repo_id, split="corpus")]
182
+ )
183
+ dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False)
184
+
185
+ return dataset
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ from marker.convert import convert_single_pdf
5
+ from marker.models import load_all_models
6
+
7
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
8
+ BaseTextLoader,
9
+ )
10
+
11
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
12
+
13
+
14
+ class MarkerTextLoader(BaseTextLoader):
15
+ """
16
+ A concrete implementation of the BaseTextLoader for loading text from a PDF file
17
+ using `marker-pdf`, processing it into a structured text format, and optionally publishing
18
+ it to a Weave dataset.
19
+
20
+ This class extends the BaseTextLoader and implements the abstract methods to
21
+ load and process pages from a PDF file using marker-pdf, which is a pipeline of deep learning models.
22
+
23
+ This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
24
+ It uses marker-pdf to read the PDF and extract structured text from each page. The processed pages are stored
25
+ in a list of Page objects, which can be optionally published to a Weave dataset.
26
+
27
+ !!! example "Example Usage"
28
+ ```python
29
+ import asyncio
30
+
31
+ from medrag_multi_modal.document_loader import MarkerTextLoader
32
+
33
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
34
+
35
+ loader = MarkerTextLoader(
36
+ url=URL,
37
+ document_name="Gray's Anatomy",
38
+ document_file_path="grays_anatomy.pdf",
39
+ )
40
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
41
+ ```
42
+
43
+ Args:
44
+ url (str): The URL of the PDF file to download if not present locally.
45
+ document_name (str): The name of the document for metadata purposes.
46
+ document_file_path (str): The local file path where the PDF is stored or will be downloaded.
47
+ """
48
+
49
+ async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
50
+ """
51
+ Process a single page of the PDF and extract its structured text using marker-pdf.
52
+
53
+ Returns:
54
+ Dict[str, str]: A dictionary with the processed page data.
55
+ The dictionary will have the following keys and values:
56
+
57
+ - "text": (str) the extracted structured text from the page.
58
+ - "page_idx": (int) the index of the page.
59
+ - "document_name": (str) the name of the document.
60
+ - "file_path": (str) the local file path where the PDF is stored.
61
+ - "file_url": (str) the URL of the PDF file.
62
+ - "meta": (dict) the metadata extracted from the page by marker-pdf.
63
+
64
+ Args:
65
+ page_idx (int): The index of the page to process.
66
+ **kwargs: Additional keyword arguments to be passed to `marker.convert.convert_single_pdf`.
67
+
68
+ Returns:
69
+ Dict[str, str]: A dictionary containing the processed page data.
70
+ """
71
+ model_lst = load_all_models()
72
+
73
+ text, _, _ = convert_single_pdf(
74
+ self.document_file_path,
75
+ model_lst,
76
+ max_pages=1,
77
+ batch_multiplier=1,
78
+ start_page=page_idx,
79
+ ocr_all_pages=True,
80
+ **kwargs,
81
+ )
82
+
83
+ return {
84
+ "text": text,
85
+ "page_idx": page_idx,
86
+ "document_name": self.document_name,
87
+ "file_path": self.document_file_path,
88
+ "file_url": self.url,
89
+ }
medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import pdfplumber
4
+
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
+
9
+
10
+ class PDFPlumberTextLoader(BaseTextLoader):
11
+ """
12
+ A concrete implementation of the BaseTextLoader for loading text from a PDF file
13
+ using `pdfplumber`, processing it into a simple text format, and optionally publishing
14
+ it to a Weave dataset.
15
+
16
+ This class extends the BaseTextLoader and implements the abstract methods to
17
+ load and process pages from a PDF file.
18
+
19
+ This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
20
+ It uses pdfplumber to read the PDF and extract text from each page. The processed pages are stored in a list
21
+ of Page objects, which can be optionally published to a Weave dataset.
22
+
23
+ !!! example "Example Usage"
24
+ ```python
25
+ import asyncio
26
+
27
+ from medrag_multi_modal.document_loader import PDFPlumberTextLoader
28
+
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
+
31
+ loader = PDFPlumberTextLoader(
32
+ url=URL,
33
+ document_name="Gray's Anatomy",
34
+ document_file_path="grays_anatomy.pdf",
35
+ )
36
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
37
+ ```
38
+
39
+ Args:
40
+ url (str): The URL of the PDF file to download if not present locally.
41
+ document_name (str): The name of the document for metadata purposes.
42
+ document_file_path (str): The local file path where the PDF is stored or will be downloaded.
43
+ """
44
+
45
+ async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
46
+ """
47
+ Process a single page of the PDF and extract its text using pdfplumber.
48
+
49
+ Returns:
50
+ Dict[str, str]: A dictionary with the processed page data.
51
+ The dictionary will have the following keys and values:
52
+
53
+ - "text": (str) the extracted text from the page.
54
+ - "page_idx": (int) the index of the page.
55
+ - "document_name": (str) the name of the document.
56
+ - "file_path": (str) the local file path where the PDF is stored.
57
+ - "file_url": (str) the URL of the PDF file.
58
+
59
+ Args:
60
+ page_idx (int): The index of the page to process.
61
+ **kwargs: Additional keyword arguments to be passed to `pdfplumber.Page.extract_text`.
62
+
63
+ Returns:
64
+ Dict[str, str]: A dictionary containing the processed page data.
65
+ """
66
+ with pdfplumber.open(self.document_file_path) as pdf:
67
+ page = pdf.pages[page_idx]
68
+ text = page.extract_text(**kwargs)
69
+
70
+ return {
71
+ "text": text,
72
+ "page_idx": page_idx,
73
+ "document_name": self.document_name,
74
+ "file_path": self.document_file_path,
75
+ "file_url": self.url,
76
+ }
medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import pymupdf4llm
4
+
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
+
9
+
10
+ class PyMuPDF4LLMTextLoader(BaseTextLoader):
11
+ """
12
+ A concrete implementation of the BaseTextLoader for loading text from a PDF file,
13
+ processing it into markdown using `pymupdf4llm`, and optionally publishing it to a Weave dataset.
14
+
15
+ This class extends the BaseTextLoader and implements the abstract methods to load and process pages from a PDF file.
16
+
17
+ This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
18
+ It uses PyPDF2 to read the PDF and pymupdf4llm to convert pages to markdown. The processed pages are stored in a list
19
+ of Page objects, which can be optionally published to a Weave dataset.
20
+
21
+ !!! example "Example Usage"
22
+ ```python
23
+ import asyncio
24
+
25
+ from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader
26
+
27
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
28
+
29
+ loader = PyMuPDF4LLMTextLoader(
30
+ url=URL,
31
+ document_name="Gray's Anatomy",
32
+ document_file_path="grays_anatomy.pdf",
33
+ )
34
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
35
+ ```
36
+
37
+ Args:
38
+ url (str): The URL of the PDF file to download if not present locally.
39
+ document_name (str): The name of the document for metadata purposes.
40
+ document_file_path (str): The local file path where the PDF is stored or will be downloaded.
41
+ """
42
+
43
+ async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
44
+ """
45
+ Process a single page of the PDF and convert it to markdown using `pymupdf4llm`.
46
+
47
+ Returns:
48
+ Dict[str, str]: A dictionary with the processed page data.
49
+ The dictionary will have the following keys and values:
50
+
51
+ - "text": (str) the processed page data in markdown format.
52
+ - "page_idx": (int) the index of the page.
53
+ - "document_name": (str) the name of the document.
54
+ - "file_path": (str) the local file path where the PDF is stored.
55
+ - "file_url": (str) the URL of the PDF file.
56
+
57
+ Args:
58
+ page_idx (int): The index of the page to process.
59
+ **kwargs: Additional keyword arguments to be passed to `pymupdf4llm.to_markdown`.
60
+
61
+ Returns:
62
+ Dict[str, str]: A dictionary containing the processed page data.
63
+ """
64
+ text = pymupdf4llm.to_markdown(
65
+ doc=self.document_file_path, pages=[page_idx], show_progress=False, **kwargs
66
+ )
67
+ return {
68
+ "text": text,
69
+ "page_idx": page_idx,
70
+ "document_name": self.document_name,
71
+ "file_path": self.document_file_path,
72
+ "file_url": self.url,
73
+ }
medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import PyPDF2
4
+
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
+
9
+
10
+ class PyPDF2TextLoader(BaseTextLoader):
11
+ """
12
+ A concrete implementation of the BaseTextLoader for loading text from a PDF file
13
+ using `PyPDF2`, processing it into a simple text format, and optionally publishing
14
+ it to a Weave dataset.
15
+
16
+ This class extends the BaseTextLoader and implements the abstract methods to
17
+ load and process pages from a PDF file.
18
+
19
+ This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
20
+ It uses PyPDF2 to read the PDF and extract text from each page. The processed pages are stored in a list
21
+ of Page objects, which can be optionally published to a Weave dataset.
22
+
23
+ !!! example "Example Usage"
24
+ ```python
25
+ import asyncio
26
+
27
+ from medrag_multi_modal.document_loader import PyPDF2TextLoader
28
+
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
+
31
+ loader = PyPDF2TextLoader(
32
+ url=URL,
33
+ document_name="Gray's Anatomy",
34
+ document_file_path="grays_anatomy.pdf",
35
+ )
36
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
37
+ ```
38
+
39
+ Args:
40
+ url (str): The URL of the PDF file to download if not present locally.
41
+ document_name (str): The name of the document for metadata purposes.
42
+ document_file_path (str): The local file path where the PDF is stored or will be downloaded.
43
+ """
44
+
45
+ async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
46
+ """
47
+ Process a single page of the PDF and extract its text using PyPDF2.
48
+
49
+ Returns:
50
+ Dict[str, str]: A dictionary with the processed page data.
51
+ The dictionary will have the following keys and values:
52
+
53
+ - "text": (str) the extracted text from the page.
54
+ - "page_idx": (int) the index of the page.
55
+ - "document_name": (str) the name of the document.
56
+ - "file_path": (str) the local file path where the PDF is stored.
57
+ - "file_url": (str) the URL of the PDF file.
58
+
59
+ Args:
60
+ page_idx (int): The index of the page to process.
61
+ **kwargs: Additional keyword arguments to be passed to `PyPDF2.PdfReader.pages[0].extract_text`.
62
+
63
+ Returns:
64
+ Dict[str, str]: A dictionary containing the processed page data.
65
+ """
66
+ with open(self.document_file_path, "rb") as file:
67
+ pdf_reader = PyPDF2.PdfReader(file)
68
+ page = pdf_reader.pages[page_idx]
69
+ text = page.extract_text(**kwargs)
70
+
71
+ return {
72
+ "text": text,
73
+ "page_idx": page_idx,
74
+ "document_name": self.document_name,
75
+ "file_path": self.document_file_path,
76
+ "file_url": self.url,
77
+ }
medrag_multi_modal/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mmlu import MMLUOptionAccuracy
2
+
3
+ __all__ = ["MMLUOptionAccuracy"]
medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (273 Bytes). View file
 
medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc ADDED
Binary file (775 Bytes). View file
 
medrag_multi_modal/metrics/base.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import weave
5
+
6
+
7
+ class BaseAccuracyMetric(weave.Scorer):
8
+ """
9
+ BaseAccuracyMetric is a class that extends the
10
+ [`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
11
+ to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
12
+
13
+ This class is designed to process a list of score rows, each containing a
14
+ 'correct' key that indicates whether a particular prediction was correct.
15
+ The `summarize` method calculates various statistical measures and metrics
16
+ based on this data, including:
17
+
18
+ - True and false counts: The number of true and false predictions.
19
+ - True and false fractions: The proportion of true and false predictions.
20
+ - Standard error: The standard error of the mean for the true predictions.
21
+ - Precision: The ratio of true positive predictions to the total number of
22
+ positive predictions.
23
+ - Recall: The ratio of true positive predictions to the total number of
24
+ actual positives.
25
+ - F1 Score: The harmonic mean of precision and recall, providing a balance
26
+ between the two metrics.
27
+
28
+ The `summarize` method returns a dictionary containing these metrics,
29
+ allowing for a detailed analysis of the model's performance.
30
+
31
+ Methods:
32
+ summarize(score_rows: list) -> Optional[dict]:
33
+ Processes the input score rows to compute and return a dictionary
34
+ of accuracy metrics.
35
+ """
36
+ @weave.op()
37
+ def summarize(self, score_rows: list) -> Optional[dict]:
38
+ """
39
+ Summarizes the accuracy metrics from a list of score rows.
40
+
41
+ This method processes a list of score rows, each containing a 'correct' key
42
+ that indicates whether a particular prediction was correct. It calculates
43
+ various statistical measures and metrics based on this data, including:
44
+
45
+ - True and false counts: The number of true and false predictions.
46
+ - True and false fractions: The proportion of true and false predictions.
47
+ - Standard error: The standard error of the mean for the true predictions.
48
+ - Precision: The ratio of true positive predictions to the total number of
49
+ positive predictions.
50
+ - Recall: The ratio of true positive predictions to the total number of
51
+ actual positives.
52
+ - F1 Score: The harmonic mean of precision and recall, providing a balance
53
+ between the two metrics.
54
+
55
+ The method returns a dictionary containing these metrics, allowing for a
56
+ detailed analysis of the model's performance.
57
+
58
+ Args:
59
+ score_rows (list): A list of dictionaries, each containing a 'correct'
60
+ key with a boolean value indicating the correctness of a prediction.
61
+
62
+ Returns:
63
+ Optional[dict]: A dictionary containing the calculated accuracy metrics,
64
+ or None if the input list is empty.
65
+ """
66
+ valid_data = [
67
+ x.get("correct") for x in score_rows if x.get("correct") is not None
68
+ ]
69
+ count_true = list(valid_data).count(True)
70
+ int_data = [int(x) for x in valid_data]
71
+
72
+ sample_mean = np.mean(int_data) if int_data else 0
73
+ sample_variance = np.var(int_data) if int_data else 0
74
+ sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
75
+
76
+ # Calculate precision, recall, and F1 score
77
+ true_positives = count_true
78
+ false_positives = len(valid_data) - count_true
79
+ false_negatives = len(score_rows) - len(valid_data)
80
+
81
+ precision = (
82
+ true_positives / (true_positives + false_positives)
83
+ if (true_positives + false_positives) > 0
84
+ else 0
85
+ )
86
+ recall = (
87
+ true_positives / (true_positives + false_negatives)
88
+ if (true_positives + false_negatives) > 0
89
+ else 0
90
+ )
91
+ f1_score = (
92
+ (2 * precision * recall) / (precision + recall)
93
+ if (precision + recall) > 0
94
+ else 0
95
+ )
96
+
97
+ return {
98
+ "correct": {
99
+ "true_count": count_true,
100
+ "false_count": len(score_rows) - count_true,
101
+ "true_fraction": float(sample_mean),
102
+ "false_fraction": 1.0 - float(sample_mean),
103
+ "stderr": float(sample_error),
104
+ "precision": precision,
105
+ "recall": recall,
106
+ "f1_score": f1_score,
107
+ }
108
+ }
medrag_multi_modal/metrics/mmlu.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weave
2
+
3
+ from medrag_multi_modal.assistant.schema import MedQAResponse
4
+ from medrag_multi_modal.metrics.base import BaseAccuracyMetric
5
+
6
+
7
+ class MMLUOptionAccuracy(BaseAccuracyMetric):
8
+ """
9
+ MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`.
10
+
11
+ This class is designed to evaluate the accuracy of a multiple-choice question
12
+ response by comparing the provided answer with the correct answer from the
13
+ given options. It uses the MedQAResponse schema to extract the response
14
+ and checks if it matches the correct answer.
15
+
16
+ Methods:
17
+ --------
18
+ score(output: MedQAResponse, options: list[str], answer: str) -> dict:
19
+ Compares the provided answer with the correct answer and returns a
20
+ dictionary indicating whether the answer is correct.
21
+ """
22
+ @weave.op()
23
+ def score(self, output: MedQAResponse, options: list[str], answer: str):
24
+ return {"correct": options[answer] == output.response.answer}
medrag_multi_modal/retrieval/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .colpali_retrieval import CalPaliRetriever
2
+
3
+ __all__ = ["CalPaliRetriever"]
medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (276 Bytes). View file
 
medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc ADDED
Binary file (9.94 kB). View file
 
medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
medrag_multi_modal/retrieval/colpali_retrieval.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+
4
+ import weave
5
+
6
+ if TYPE_CHECKING:
7
+ from byaldi import RAGMultiModalModel
8
+
9
+ import wandb
10
+ from PIL import Image
11
+
12
+ from medrag_multi_modal.utils import get_wandb_artifact
13
+
14
+
15
+ class CalPaliRetriever(weave.Model):
16
+ """
17
+ CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.
18
+
19
+ This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
20
+ It can be initialized with a pre-trained model or from a specified W&B artifact. The class
21
+ also provides methods to index new data and to predict/retrieve documents based on a query.
22
+
23
+ Attributes:
24
+ model_name (str): The name of the model to be used for retrieval.
25
+ """
26
+
27
+ model_name: str
28
+ _docs_retrieval_model: Optional["RAGMultiModalModel"] = None
29
+ _metadata: Optional[dict] = None
30
+ _data_artifact_dir: Optional[str] = None
31
+
32
+ def __init__(
33
+ self,
34
+ model_name: str = "vidore/colpali-v1.2",
35
+ docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
36
+ data_artifact_dir: Optional[str] = None,
37
+ metadata_dataset_name: Optional[str] = None,
38
+ ):
39
+ super().__init__(model_name=model_name)
40
+ from byaldi import RAGMultiModalModel
41
+
42
+ self._docs_retrieval_model = (
43
+ docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
44
+ )
45
+ self._data_artifact_dir = data_artifact_dir
46
+ self._metadata = (
47
+ [dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
48
+ if metadata_dataset_name
49
+ else None
50
+ )
51
+
52
+ def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
53
+ """
54
+ Indexes a dataset of documents and saves the index as a Weave artifact.
55
+
56
+ This method retrieves a dataset of documents from a Weave artifact using the provided
57
+ data artifact name. It then indexes the documents using the document retrieval model
58
+ and assigns the specified index name. The index is stored locally without storing the
59
+ collection with the index and overwrites any existing index with the same name.
60
+
61
+ If a Weave run is active, the method creates a new Weave artifact with the specified
62
+ index name and type "colpali-index". It adds the local index directory to the artifact
63
+ and saves it to Weave, including metadata with the provided Weave dataset name.
64
+
65
+ !!! example "Indexing Data"
66
+ First you need to install `Byaldi` library by Answer.ai.
67
+
68
+ ```bash
69
+ uv pip install Byaldi>=0.0.5
70
+ ```
71
+
72
+ Next, you can index the data by running the following code:
73
+
74
+ ```python
75
+ import wandb
76
+ from medrag_multi_modal.retrieval import CalPaliRetriever
77
+
78
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
79
+ retriever = CalPaliRetriever()
80
+ retriever.index(
81
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
82
+ weave_dataset_name="grays-anatomy-images:v0",
83
+ index_name="grays-anatomy",
84
+ )
85
+ ```
86
+
87
+ ??? note "Optional Speedup using Flash Attention"
88
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
89
+ installing the `flash-attn` package.
90
+
91
+ ```bash
92
+ uv pip install flash-attn --no-build-isolation
93
+ ```
94
+
95
+ Args:
96
+ data_artifact_name (str): The name of the Weave artifact containing the dataset.
97
+ weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata.
98
+ index_name (str): The name to assign to the created index.
99
+ """
100
+ data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
101
+ self._docs_retrieval_model.index(
102
+ input_path=data_artifact_dir,
103
+ index_name=index_name,
104
+ store_collection_with_index=False,
105
+ overwrite=True,
106
+ )
107
+ if wandb.run:
108
+ artifact = wandb.Artifact(
109
+ name=index_name,
110
+ type="colpali-index",
111
+ metadata={"weave_dataset_name": weave_dataset_name},
112
+ )
113
+ artifact.add_dir(
114
+ local_path=os.path.join(".byaldi", index_name), name="index"
115
+ )
116
+ artifact.save()
117
+
118
+ @classmethod
119
+ def from_wandb_artifact(
120
+ cls,
121
+ index_artifact_name: str,
122
+ metadata_dataset_name: str,
123
+ data_artifact_name: str,
124
+ ):
125
+ """
126
+ Creates an instance of the class from Weights & Biases (wandb) artifacts.
127
+
128
+ This method retrieves the necessary artifacts from wandb to initialize the
129
+ ColPaliRetriever. It fetches the index artifact directory and the data artifact
130
+ directory using the provided artifact names. It then loads the document retrieval
131
+ model from the index path within the index artifact directory. Finally, it returns
132
+ an instance of the class initialized with the retrieved document retrieval model,
133
+ metadata dataset name, and data artifact directory.
134
+
135
+ !!! example "Retrieving Documents"
136
+ First you need to install `Byaldi` library by Answer.ai.
137
+
138
+ ```bash
139
+ uv pip install Byaldi>=0.0.5
140
+ ```
141
+
142
+ Next, you can retrieve the documents by running the following code:
143
+
144
+ ```python
145
+ import weave
146
+
147
+ import wandb
148
+ from medrag_multi_modal.retrieval import CalPaliRetriever
149
+
150
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
151
+ retriever = CalPaliRetriever.from_wandb_artifact(
152
+ index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
153
+ metadata_dataset_name="grays-anatomy-images:v0",
154
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
155
+ )
156
+ ```
157
+
158
+ ??? note "Optional Speedup using Flash Attention"
159
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
160
+ installing the `flash-attn` package.
161
+
162
+ ```bash
163
+ uv pip install flash-attn --no-build-isolation
164
+ ```
165
+
166
+ Args:
167
+ index_artifact_name (str): The name of the wandb artifact containing the index.
168
+ metadata_dataset_name (str): The name of the dataset containing metadata.
169
+ data_artifact_name (str): The name of the wandb artifact containing the data.
170
+
171
+ Returns:
172
+ An instance of the class initialized with the retrieved document retrieval model,
173
+ metadata dataset name, and data artifact directory.
174
+ """
175
+ from byaldi import RAGMultiModalModel
176
+
177
+ index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
178
+ data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
179
+ docs_retrieval_model = RAGMultiModalModel.from_index(
180
+ index_path=os.path.join(index_artifact_dir, "index")
181
+ )
182
+ return cls(
183
+ docs_retrieval_model=docs_retrieval_model,
184
+ metadata_dataset_name=metadata_dataset_name,
185
+ data_artifact_dir=data_artifact_dir,
186
+ )
187
+
188
+ @weave.op()
189
+ def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
190
+ """
191
+ Predicts and retrieves the top-k most relevant documents/images for a given query
192
+ using ColPali.
193
+
194
+ This function uses the document retrieval model to search for the most relevant
195
+ documents based on the provided query. It returns a list of dictionaries, each
196
+ containing the document image, document ID, and the relevance score.
197
+
198
+ !!! example "Retrieving Documents"
199
+ First you need to install `Byaldi` library by Answer.ai.
200
+
201
+ ```bash
202
+ uv pip install Byaldi>=0.0.5
203
+ ```
204
+
205
+ Next, you can retrieve the documents by running the following code:
206
+
207
+ ```python
208
+ import weave
209
+
210
+ import wandb
211
+ from medrag_multi_modal.retrieval import CalPaliRetriever
212
+
213
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
214
+ retriever = CalPaliRetriever.from_wandb_artifact(
215
+ index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
216
+ metadata_dataset_name="grays-anatomy-images:v0",
217
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
218
+ )
219
+ retriever.predict(
220
+ query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
221
+ top_k=3,
222
+ )
223
+ ```
224
+
225
+ ??? note "Optional Speedup using Flash Attention"
226
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
227
+ installing the `flash-attn` package.
228
+
229
+ ```bash
230
+ uv pip install flash-attn --no-build-isolation
231
+ ```
232
+
233
+ Args:
234
+ query (str): The search query string.
235
+ top_k (int, optional): The number of top results to retrieve. Defaults to 10.
236
+
237
+ Returns:
238
+ list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
239
+ - "doc_image" (PIL.Image.Image): The image of the document.
240
+ - "doc_id" (str): The ID of the document.
241
+ - "score" (float): The relevance score of the document.
242
+ """
243
+ results = self._docs_retrieval_model.search(query=query, k=top_k)
244
+ retrieved_results = []
245
+ for result in results:
246
+ retrieved_results.append(
247
+ {
248
+ "doc_image": Image.open(
249
+ os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
250
+ ),
251
+ "doc_id": result["doc_id"],
252
+ "score": result["score"],
253
+ }
254
+ )
255
+ return retrieved_results
medrag_multi_modal/retrieval/common.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class SimilarityMetric(Enum):
5
+ COSINE = "cosine"
6
+ EUCLIDEAN = "euclidean"
7
+
8
+
9
+ def mean_pooling(token_embeddings, mask):
10
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
11
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
12
+ return sentence_embeddings
13
+
14
+
15
+ def argsort_scores(scores: list[float], descending: bool = False):
16
+ return [
17
+ {"item": item, "original_index": idx}
18
+ for idx, item in sorted(
19
+ list(enumerate(scores)), key=lambda x: x[1], reverse=descending
20
+ )
21
+ ]
medrag_multi_modal/retrieval/text_retrieval/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bm25s_retrieval import BM25sRetriever
2
+ from .contriever_retrieval import ContrieverRetriever
3
+ from .medcpt_retrieval import MedCPTRetriever
4
+ from .nv_embed_2 import NVEmbed2Retriever
5
+
6
+ __all__ = [
7
+ "BM25sRetriever",
8
+ "ContrieverRetriever",
9
+ "MedCPTRetriever",
10
+ "NVEmbed2Retriever",
11
+ ]
medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (478 Bytes). View file