Spaces:
Runtime error
Runtime error
bug fixes and improvement
Browse files- app.py +96 -17
- requirements.txt +4 -2
- utils.py +7 -9
app.py
CHANGED
@@ -8,11 +8,14 @@ import ocrmypdf
|
|
8 |
import os
|
9 |
import pandas as pd
|
10 |
import pymupdf
|
|
|
11 |
import spaces
|
12 |
import torch
|
13 |
from PIL import Image
|
14 |
from chromadb.utils import embedding_functions
|
15 |
from chromadb.utils.data_loaders import ImageLoader
|
|
|
|
|
16 |
from gradio.themes.utils import sizes
|
17 |
from langchain import PromptTemplate
|
18 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -22,6 +25,29 @@ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
|
|
22 |
from utils import *
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
if torch.cuda.is_available():
|
26 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
27 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
@@ -75,7 +101,6 @@ def get_vectordb(text, images):
|
|
75 |
metadata={"hnsw:space": "cosine"},
|
76 |
)
|
77 |
descs = []
|
78 |
-
print(descs)
|
79 |
for image in images:
|
80 |
try:
|
81 |
descs.append(get_image_description(image)[0])
|
@@ -97,7 +122,9 @@ def get_vectordb(text, images):
|
|
97 |
chunk_overlap=10,
|
98 |
)
|
99 |
|
100 |
-
if len(text)
|
|
|
|
|
101 |
docs = splitter.create_documents([text])
|
102 |
doc_texts = [i.page_content for i in docs]
|
103 |
text_collection.add(
|
@@ -106,7 +133,16 @@ def get_vectordb(text, images):
|
|
106 |
return client
|
107 |
|
108 |
|
109 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
if len(docs) == 0:
|
111 |
raise gr.Error("No documents to process")
|
112 |
progress(0, "Extracting Images")
|
@@ -115,18 +151,20 @@ def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()
|
|
115 |
|
116 |
progress(0.25, "Extracting Text")
|
117 |
|
118 |
-
strategy = "hi_res"
|
119 |
-
model_name = "yolox"
|
120 |
-
all_elements = []
|
121 |
all_text = ""
|
122 |
|
123 |
images = []
|
124 |
for doc in docs:
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
128 |
if include_images == "Include Images":
|
129 |
-
images.extend(extract_images([
|
130 |
|
131 |
progress(
|
132 |
0.6, "Generating image descriptions and inserting everything into vectorDB"
|
@@ -153,20 +191,28 @@ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFuncti
|
|
153 |
|
154 |
|
155 |
def conversation(
|
156 |
-
vectordb_client,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
):
|
158 |
if hf_token.strip() != "" and model_path.strip() != "":
|
159 |
llm = HuggingFaceEndpoint(
|
160 |
repo_id=model_path,
|
161 |
-
temperature=
|
162 |
-
max_new_tokens=
|
163 |
huggingfacehub_api_token=hf_token,
|
164 |
)
|
165 |
else:
|
166 |
llm = HuggingFaceEndpoint(
|
167 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
168 |
-
temperature=
|
169 |
-
max_new_tokens=
|
170 |
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
|
171 |
)
|
172 |
|
@@ -273,6 +319,12 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
273 |
label="Include/ Exclude Images",
|
274 |
interactive=True,
|
275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
with gr.Row(equal_height=True, variant="panel") as row:
|
278 |
selected = gr.Dataframe(
|
@@ -327,6 +379,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
327 |
interactive=True,
|
328 |
value=2,
|
329 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
with gr.Row():
|
331 |
with gr.Column():
|
332 |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
|
@@ -361,7 +430,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
361 |
)
|
362 |
embed.click(
|
363 |
extract_data_from_pdfs,
|
364 |
-
inputs=[doc_collection, session_states, include_images],
|
365 |
outputs=[
|
366 |
vectordb,
|
367 |
session_states,
|
@@ -374,7 +443,17 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
374 |
|
375 |
submit_btn.click(
|
376 |
conversation,
|
377 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
[chatbot, references, ret_images],
|
379 |
)
|
380 |
|
|
|
8 |
import os
|
9 |
import pandas as pd
|
10 |
import pymupdf
|
11 |
+
from pypdf import PdfReader
|
12 |
import spaces
|
13 |
import torch
|
14 |
from PIL import Image
|
15 |
from chromadb.utils import embedding_functions
|
16 |
from chromadb.utils.data_loaders import ImageLoader
|
17 |
+
from doctr.io import DocumentFile
|
18 |
+
from doctr.models import ocr_predictor
|
19 |
from gradio.themes.utils import sizes
|
20 |
from langchain import PromptTemplate
|
21 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
25 |
from utils import *
|
26 |
|
27 |
|
28 |
+
def result_to_text(result, as_text=False) -> str or list:
|
29 |
+
full_doc = []
|
30 |
+
for _, page in enumerate(result.pages, start=1):
|
31 |
+
text = ""
|
32 |
+
for block in page.blocks:
|
33 |
+
text += "\n\t"
|
34 |
+
for line in block.lines:
|
35 |
+
for word in line.words:
|
36 |
+
text += word.value + " "
|
37 |
+
|
38 |
+
full_doc.append(clean_text(text) + "\n\n")
|
39 |
+
|
40 |
+
return "\n".join(full_doc) if as_text else full_doc
|
41 |
+
|
42 |
+
|
43 |
+
ocr_model = ocr_predictor(
|
44 |
+
"db_resnet50",
|
45 |
+
"crnn_mobilenet_v3_large",
|
46 |
+
pretrained=True,
|
47 |
+
assume_straight_pages=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
if torch.cuda.is_available():
|
52 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
53 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
101 |
metadata={"hnsw:space": "cosine"},
|
102 |
)
|
103 |
descs = []
|
|
|
104 |
for image in images:
|
105 |
try:
|
106 |
descs.append(get_image_description(image)[0])
|
|
|
122 |
chunk_overlap=10,
|
123 |
)
|
124 |
|
125 |
+
if len(text.replace(" ", "").replace("\n", "")) == 0:
|
126 |
+
gr.Error("No text found in documents")
|
127 |
+
else:
|
128 |
docs = splitter.create_documents([text])
|
129 |
doc_texts = [i.page_content for i in docs]
|
130 |
text_collection.add(
|
|
|
133 |
return client
|
134 |
|
135 |
|
136 |
+
def extract_only_text(reader):
|
137 |
+
text = ""
|
138 |
+
for _, page in enumerate(reader.pages):
|
139 |
+
text = page.extract_text()
|
140 |
+
return text.strip()
|
141 |
+
|
142 |
+
|
143 |
+
def extract_data_from_pdfs(
|
144 |
+
docs, session, include_images, do_ocr, progress=gr.Progress()
|
145 |
+
):
|
146 |
if len(docs) == 0:
|
147 |
raise gr.Error("No documents to process")
|
148 |
progress(0, "Extracting Images")
|
|
|
151 |
|
152 |
progress(0.25, "Extracting Text")
|
153 |
|
|
|
|
|
|
|
154 |
all_text = ""
|
155 |
|
156 |
images = []
|
157 |
for doc in docs:
|
158 |
+
if do_ocr == "Get Text With OCR":
|
159 |
+
pdf_doc = DocumentFile.from_pdf(doc)
|
160 |
+
result = ocr_model(pdf_doc)
|
161 |
+
all_text += result_to_text(result, as_text=True) + "\n\n"
|
162 |
+
else:
|
163 |
+
reader = PdfReader(doc)
|
164 |
+
all_text += extract_only_text(reader) + "\n\n"
|
165 |
+
|
166 |
if include_images == "Include Images":
|
167 |
+
images.extend(extract_images([doc]))
|
168 |
|
169 |
progress(
|
170 |
0.6, "Generating image descriptions and inserting everything into vectorDB"
|
|
|
191 |
|
192 |
|
193 |
def conversation(
|
194 |
+
vectordb_client,
|
195 |
+
msg,
|
196 |
+
num_context,
|
197 |
+
img_context,
|
198 |
+
history,
|
199 |
+
temperature,
|
200 |
+
max_new_tokens,
|
201 |
+
hf_token,
|
202 |
+
model_path,
|
203 |
):
|
204 |
if hf_token.strip() != "" and model_path.strip() != "":
|
205 |
llm = HuggingFaceEndpoint(
|
206 |
repo_id=model_path,
|
207 |
+
temperature=temperature,
|
208 |
+
max_new_tokens=max_new_tokens,
|
209 |
huggingfacehub_api_token=hf_token,
|
210 |
)
|
211 |
else:
|
212 |
llm = HuggingFaceEndpoint(
|
213 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
214 |
+
temperature=temperature,
|
215 |
+
max_new_tokens=max_new_tokens,
|
216 |
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
|
217 |
)
|
218 |
|
|
|
319 |
label="Include/ Exclude Images",
|
320 |
interactive=True,
|
321 |
)
|
322 |
+
do_ocr = gr.Radio(
|
323 |
+
["Get Text With OCR", "Get Available Text Only"],
|
324 |
+
value="Get Text With OCR",
|
325 |
+
label="OCR/ No OCR",
|
326 |
+
interactive=True,
|
327 |
+
)
|
328 |
|
329 |
with gr.Row(equal_height=True, variant="panel") as row:
|
330 |
selected = gr.Dataframe(
|
|
|
379 |
interactive=True,
|
380 |
value=2,
|
381 |
)
|
382 |
+
with gr.Row(variant="panel", equal_height=True):
|
383 |
+
temp = gr.Slider(
|
384 |
+
label="Temperature",
|
385 |
+
minimum=0.1,
|
386 |
+
maximum=1,
|
387 |
+
step=0.1,
|
388 |
+
interactive=True,
|
389 |
+
value=0.4,
|
390 |
+
)
|
391 |
+
max_tokens = gr.Slider(
|
392 |
+
label="Max Tokens",
|
393 |
+
minimum=10,
|
394 |
+
maximum=2000,
|
395 |
+
step=10,
|
396 |
+
interactive=True,
|
397 |
+
value=500,
|
398 |
+
)
|
399 |
with gr.Row():
|
400 |
with gr.Column():
|
401 |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
|
|
|
430 |
)
|
431 |
embed.click(
|
432 |
extract_data_from_pdfs,
|
433 |
+
inputs=[doc_collection, session_states, include_images, do_ocr],
|
434 |
outputs=[
|
435 |
vectordb,
|
436 |
session_states,
|
|
|
443 |
|
444 |
submit_btn.click(
|
445 |
conversation,
|
446 |
+
[
|
447 |
+
vectordb,
|
448 |
+
msg,
|
449 |
+
num_context,
|
450 |
+
img_context,
|
451 |
+
chatbot,
|
452 |
+
temp,
|
453 |
+
max_tokens,
|
454 |
+
hf_token,
|
455 |
+
model_path,
|
456 |
+
],
|
457 |
[chatbot, references, ret_images],
|
458 |
)
|
459 |
|
requirements.txt
CHANGED
@@ -7,8 +7,10 @@ pandas==2.2.2
|
|
7 |
Pillow==10.3.0
|
8 |
pymupdf==1.24.5
|
9 |
sentence_transformers==3.0.1
|
10 |
-
unstructured[all-docs]
|
11 |
accelerate
|
12 |
bitsandbytes
|
13 |
easyocr
|
14 |
-
ocrmypdf
|
|
|
|
|
|
|
|
7 |
Pillow==10.3.0
|
8 |
pymupdf==1.24.5
|
9 |
sentence_transformers==3.0.1
|
|
|
10 |
accelerate
|
11 |
bitsandbytes
|
12 |
easyocr
|
13 |
+
ocrmypdf
|
14 |
+
tf2onnx
|
15 |
+
clean-text[gpl]
|
16 |
+
python-doctr[torch]
|
utils.py
CHANGED
@@ -27,19 +27,17 @@ def extract_pdfs(docs, doc_collection):
|
|
27 |
def extract_images(docs):
|
28 |
images = []
|
29 |
for doc_path in docs:
|
30 |
-
doc = pymupdf.open(doc_path)
|
31 |
|
32 |
-
for page_index in range(len(doc)):
|
33 |
-
page = doc[page_index]
|
34 |
image_list = page.get_images()
|
35 |
|
36 |
-
for
|
37 |
-
|
38 |
-
|
39 |
-
xref = img[0] # get the XREF of the image
|
40 |
-
pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
|
41 |
|
42 |
-
if pix.n - pix.alpha > 3:
|
43 |
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
44 |
|
45 |
images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
|
|
|
27 |
def extract_images(docs):
|
28 |
images = []
|
29 |
for doc_path in docs:
|
30 |
+
doc = pymupdf.open(doc_path)
|
31 |
|
32 |
+
for page_index in range(len(doc)):
|
33 |
+
page = doc[page_index]
|
34 |
image_list = page.get_images()
|
35 |
|
36 |
+
for _, img in enumerate(image_list, start=1):
|
37 |
+
xref = img[0]
|
38 |
+
pix = pymupdf.Pixmap(doc, xref)
|
|
|
|
|
39 |
|
40 |
+
if pix.n - pix.alpha > 3:
|
41 |
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
42 |
|
43 |
images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
|