Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
7934a8e
1
Parent(s):
e4a917d
update: fix bug in LLMClient + add FigureAnnotator
Browse files
.gitignore
CHANGED
@@ -17,6 +17,7 @@ wandb/
|
|
17 |
.byaldi/
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
|
|
20 |
uv.lock
|
21 |
grays-anatomy-bm25s/
|
22 |
prompt**.txt
|
|
|
17 |
.byaldi/
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
20 |
+
test.ipynb
|
21 |
uv.lock
|
22 |
grays-anatomy-bm25s/
|
23 |
prompt**.txt
|
medrag_multi_modal/assistant/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
from .
|
|
|
2 |
from .medqa_assistant import MedQAAssistant
|
3 |
|
4 |
-
__all__ = ["LLMClient", "MedQAAssistant"]
|
|
|
1 |
+
from .figure_annotation import FigureAnnotator
|
2 |
+
from .llm_client import ClientType, LLMClient
|
3 |
from .medqa_assistant import MedQAAssistant
|
4 |
|
5 |
+
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotator"]
|
medrag_multi_modal/assistant/figure_annotation.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import weave
|
6 |
+
from PIL import Image
|
7 |
+
from rich.progress import track
|
8 |
+
|
9 |
+
from ..utils import get_wandb_artifact, read_jsonl_file
|
10 |
+
from .llm_client import LLMClient
|
11 |
+
|
12 |
+
|
13 |
+
class FigureAnnotator(weave.Model):
|
14 |
+
llm_client: LLMClient
|
15 |
+
|
16 |
+
@weave.op()
|
17 |
+
def annotate_figures(
|
18 |
+
self, page_image: Image.Image
|
19 |
+
) -> dict[str, Union[Image.Image, str]]:
|
20 |
+
annotation = self.llm_client.predict(
|
21 |
+
system_prompt="""
|
22 |
+
You are an expert in the domain of scientific textbooks, especially medical texts.
|
23 |
+
You are presented with a page from a scientific textbook.
|
24 |
+
You are to first identify the number of figures in the image.
|
25 |
+
Then you are to identify the figure IDs associated with each figure in the image.
|
26 |
+
Then, you are to extract the exact figure descriptions from the image.
|
27 |
+
|
28 |
+
Here are some clues you need to follow:
|
29 |
+
1. Figure IDs are unique identifiers for each figure in the image.
|
30 |
+
2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
|
31 |
+
3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
|
32 |
+
4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
|
33 |
+
5. The text in the image is written in English and is present in a two-column format.
|
34 |
+
6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
|
35 |
+
7. There might be multiple figures present in the image.
|
36 |
+
""",
|
37 |
+
user_prompt=[page_image],
|
38 |
+
)
|
39 |
+
return {"page_image": page_image, "annotations": annotation}
|
40 |
+
|
41 |
+
@weave.op()
|
42 |
+
def predict(self, image_artifact_address: str):
|
43 |
+
artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
44 |
+
metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
|
45 |
+
annotations = []
|
46 |
+
for item in track(metadata, description="Annotating images:"):
|
47 |
+
page_image = cv2.imread(
|
48 |
+
os.path.join(artifact_dir, f"page{item['page_idx']}.png")
|
49 |
+
)
|
50 |
+
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
51 |
+
page_image = Image.fromarray(page_image)
|
52 |
+
annotations.append(self.annotate_figures(page_image=page_image))
|
53 |
+
return annotations
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
|
|
9 |
from ..utils import base64_encode_image
|
10 |
|
11 |
|
12 |
-
class ClientType(
|
13 |
GEMINI = "gemini"
|
14 |
MISTRAL = "mistral"
|
15 |
|
@@ -80,7 +80,7 @@ class LLMClient(weave.Model):
|
|
80 |
]
|
81 |
|
82 |
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
83 |
-
client = instructor.from_mistral(client)
|
84 |
|
85 |
response = (
|
86 |
client.chat.complete(model=self.model_name, messages=messages)
|
|
|
9 |
from ..utils import base64_encode_image
|
10 |
|
11 |
|
12 |
+
class ClientType(str, Enum):
|
13 |
GEMINI = "gemini"
|
14 |
MISTRAL = "mistral"
|
15 |
|
|
|
80 |
]
|
81 |
|
82 |
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
83 |
+
client = instructor.from_mistral(client) if schema is not None else client
|
84 |
|
85 |
response = (
|
86 |
client.chat.complete(model=self.model_name, messages=messages)
|
medrag_multi_modal/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import base64
|
2 |
import io
|
3 |
|
|
|
4 |
import torch
|
5 |
from PIL import Image
|
6 |
|
@@ -36,8 +37,17 @@ def get_torch_backend():
|
|
36 |
|
37 |
|
38 |
def base64_encode_image(image: Image.Image, mimetype: str) -> str:
|
|
|
|
|
|
|
39 |
byte_arr = io.BytesIO()
|
40 |
image.save(byte_arr, format="PNG")
|
41 |
encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
|
42 |
encoded_string = f"data:{mimetype};base64,{encoded_string}"
|
43 |
return str(encoded_string)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
|
4 |
+
import jsonlines
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
|
|
|
37 |
|
38 |
|
39 |
def base64_encode_image(image: Image.Image, mimetype: str) -> str:
|
40 |
+
image.load()
|
41 |
+
if image.mode not in ("RGB", "RGBA"):
|
42 |
+
image = image.convert("RGB")
|
43 |
byte_arr = io.BytesIO()
|
44 |
image.save(byte_arr, format="PNG")
|
45 |
encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
|
46 |
encoded_string = f"data:{mimetype};base64,{encoded_string}"
|
47 |
return str(encoded_string)
|
48 |
+
|
49 |
+
|
50 |
+
def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
|
51 |
+
with jsonlines.open(file_path) as reader:
|
52 |
+
for obj in reader:
|
53 |
+
return obj
|
pyproject.toml
CHANGED
@@ -42,6 +42,7 @@ dependencies = [
|
|
42 |
"mistralai>=1.1.0",
|
43 |
"instructor>=1.6.3",
|
44 |
"jsonlines>=4.0.0",
|
|
|
45 |
]
|
46 |
|
47 |
[project.optional-dependencies]
|
@@ -69,6 +70,7 @@ core = [
|
|
69 |
"mistralai>=1.1.0",
|
70 |
"instructor>=1.6.3",
|
71 |
"jsonlines>=4.0.0",
|
|
|
72 |
]
|
73 |
|
74 |
dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
|
|
|
42 |
"mistralai>=1.1.0",
|
43 |
"instructor>=1.6.3",
|
44 |
"jsonlines>=4.0.0",
|
45 |
+
"opencv-python>=4.10.0.84",
|
46 |
]
|
47 |
|
48 |
[project.optional-dependencies]
|
|
|
70 |
"mistralai>=1.1.0",
|
71 |
"instructor>=1.6.3",
|
72 |
"jsonlines>=4.0.0",
|
73 |
+
"opencv-python>=4.10.0.84",
|
74 |
]
|
75 |
|
76 |
dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
|