anand004 commited on
Commit
e014b81
1 Parent(s): 65aad38

bug fixes and improvement

Browse files
Files changed (3) hide show
  1. app.py +96 -17
  2. requirements.txt +4 -2
  3. utils.py +7 -9
app.py CHANGED
@@ -8,11 +8,14 @@ import ocrmypdf
8
  import os
9
  import pandas as pd
10
  import pymupdf
 
11
  import spaces
12
  import torch
13
  from PIL import Image
14
  from chromadb.utils import embedding_functions
15
  from chromadb.utils.data_loaders import ImageLoader
 
 
16
  from gradio.themes.utils import sizes
17
  from langchain import PromptTemplate
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -22,6 +25,29 @@ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
22
  from utils import *
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if torch.cuda.is_available():
26
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
27
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
@@ -75,7 +101,6 @@ def get_vectordb(text, images):
75
  metadata={"hnsw:space": "cosine"},
76
  )
77
  descs = []
78
- print(descs)
79
  for image in images:
80
  try:
81
  descs.append(get_image_description(image)[0])
@@ -97,7 +122,9 @@ def get_vectordb(text, images):
97
  chunk_overlap=10,
98
  )
99
 
100
- if len(text) > 0:
 
 
101
  docs = splitter.create_documents([text])
102
  doc_texts = [i.page_content for i in docs]
103
  text_collection.add(
@@ -106,7 +133,16 @@ def get_vectordb(text, images):
106
  return client
107
 
108
 
109
- def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
110
  if len(docs) == 0:
111
  raise gr.Error("No documents to process")
112
  progress(0, "Extracting Images")
@@ -115,18 +151,20 @@ def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()
115
 
116
  progress(0.25, "Extracting Text")
117
 
118
- strategy = "hi_res"
119
- model_name = "yolox"
120
- all_elements = []
121
  all_text = ""
122
 
123
  images = []
124
  for doc in docs:
125
- ocrmypdf.ocr(doc, "ocr.pdf", deskew=True, force_ocr=True)
126
- text = extract_text("ocr.pdf")
127
- all_text += clean_text(text) + "\n\n"
 
 
 
 
 
128
  if include_images == "Include Images":
129
- images.extend(extract_images(["ocr.pdf"]))
130
 
131
  progress(
132
  0.6, "Generating image descriptions and inserting everything into vectorDB"
@@ -153,20 +191,28 @@ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFuncti
153
 
154
 
155
  def conversation(
156
- vectordb_client, msg, num_context, img_context, history, hf_token, model_path
 
 
 
 
 
 
 
 
157
  ):
158
  if hf_token.strip() != "" and model_path.strip() != "":
159
  llm = HuggingFaceEndpoint(
160
  repo_id=model_path,
161
- temperature=0.4,
162
- max_new_tokens=800,
163
  huggingfacehub_api_token=hf_token,
164
  )
165
  else:
166
  llm = HuggingFaceEndpoint(
167
  repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
168
- temperature=0.4,
169
- max_new_tokens=800,
170
  huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
171
  )
172
 
@@ -273,6 +319,12 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
273
  label="Include/ Exclude Images",
274
  interactive=True,
275
  )
 
 
 
 
 
 
276
 
277
  with gr.Row(equal_height=True, variant="panel") as row:
278
  selected = gr.Dataframe(
@@ -327,6 +379,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
327
  interactive=True,
328
  value=2,
329
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  with gr.Row():
331
  with gr.Column():
332
  ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
@@ -361,7 +430,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
361
  )
362
  embed.click(
363
  extract_data_from_pdfs,
364
- inputs=[doc_collection, session_states, include_images],
365
  outputs=[
366
  vectordb,
367
  session_states,
@@ -374,7 +443,17 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
374
 
375
  submit_btn.click(
376
  conversation,
377
- [vectordb, msg, num_context, img_context, chatbot, hf_token, model_path],
 
 
 
 
 
 
 
 
 
 
378
  [chatbot, references, ret_images],
379
  )
380
 
 
8
  import os
9
  import pandas as pd
10
  import pymupdf
11
+ from pypdf import PdfReader
12
  import spaces
13
  import torch
14
  from PIL import Image
15
  from chromadb.utils import embedding_functions
16
  from chromadb.utils.data_loaders import ImageLoader
17
+ from doctr.io import DocumentFile
18
+ from doctr.models import ocr_predictor
19
  from gradio.themes.utils import sizes
20
  from langchain import PromptTemplate
21
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
25
  from utils import *
26
 
27
 
28
+ def result_to_text(result, as_text=False) -> str or list:
29
+ full_doc = []
30
+ for _, page in enumerate(result.pages, start=1):
31
+ text = ""
32
+ for block in page.blocks:
33
+ text += "\n\t"
34
+ for line in block.lines:
35
+ for word in line.words:
36
+ text += word.value + " "
37
+
38
+ full_doc.append(clean_text(text) + "\n\n")
39
+
40
+ return "\n".join(full_doc) if as_text else full_doc
41
+
42
+
43
+ ocr_model = ocr_predictor(
44
+ "db_resnet50",
45
+ "crnn_mobilenet_v3_large",
46
+ pretrained=True,
47
+ assume_straight_pages=True,
48
+ )
49
+
50
+
51
  if torch.cuda.is_available():
52
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
53
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
 
101
  metadata={"hnsw:space": "cosine"},
102
  )
103
  descs = []
 
104
  for image in images:
105
  try:
106
  descs.append(get_image_description(image)[0])
 
122
  chunk_overlap=10,
123
  )
124
 
125
+ if len(text.replace(" ", "").replace("\n", "")) == 0:
126
+ gr.Error("No text found in documents")
127
+ else:
128
  docs = splitter.create_documents([text])
129
  doc_texts = [i.page_content for i in docs]
130
  text_collection.add(
 
133
  return client
134
 
135
 
136
+ def extract_only_text(reader):
137
+ text = ""
138
+ for _, page in enumerate(reader.pages):
139
+ text = page.extract_text()
140
+ return text.strip()
141
+
142
+
143
+ def extract_data_from_pdfs(
144
+ docs, session, include_images, do_ocr, progress=gr.Progress()
145
+ ):
146
  if len(docs) == 0:
147
  raise gr.Error("No documents to process")
148
  progress(0, "Extracting Images")
 
151
 
152
  progress(0.25, "Extracting Text")
153
 
 
 
 
154
  all_text = ""
155
 
156
  images = []
157
  for doc in docs:
158
+ if do_ocr == "Get Text With OCR":
159
+ pdf_doc = DocumentFile.from_pdf(doc)
160
+ result = ocr_model(pdf_doc)
161
+ all_text += result_to_text(result, as_text=True) + "\n\n"
162
+ else:
163
+ reader = PdfReader(doc)
164
+ all_text += extract_only_text(reader) + "\n\n"
165
+
166
  if include_images == "Include Images":
167
+ images.extend(extract_images([doc]))
168
 
169
  progress(
170
  0.6, "Generating image descriptions and inserting everything into vectorDB"
 
191
 
192
 
193
  def conversation(
194
+ vectordb_client,
195
+ msg,
196
+ num_context,
197
+ img_context,
198
+ history,
199
+ temperature,
200
+ max_new_tokens,
201
+ hf_token,
202
+ model_path,
203
  ):
204
  if hf_token.strip() != "" and model_path.strip() != "":
205
  llm = HuggingFaceEndpoint(
206
  repo_id=model_path,
207
+ temperature=temperature,
208
+ max_new_tokens=max_new_tokens,
209
  huggingfacehub_api_token=hf_token,
210
  )
211
  else:
212
  llm = HuggingFaceEndpoint(
213
  repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
214
+ temperature=temperature,
215
+ max_new_tokens=max_new_tokens,
216
  huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
217
  )
218
 
 
319
  label="Include/ Exclude Images",
320
  interactive=True,
321
  )
322
+ do_ocr = gr.Radio(
323
+ ["Get Text With OCR", "Get Available Text Only"],
324
+ value="Get Text With OCR",
325
+ label="OCR/ No OCR",
326
+ interactive=True,
327
+ )
328
 
329
  with gr.Row(equal_height=True, variant="panel") as row:
330
  selected = gr.Dataframe(
 
379
  interactive=True,
380
  value=2,
381
  )
382
+ with gr.Row(variant="panel", equal_height=True):
383
+ temp = gr.Slider(
384
+ label="Temperature",
385
+ minimum=0.1,
386
+ maximum=1,
387
+ step=0.1,
388
+ interactive=True,
389
+ value=0.4,
390
+ )
391
+ max_tokens = gr.Slider(
392
+ label="Max Tokens",
393
+ minimum=10,
394
+ maximum=2000,
395
+ step=10,
396
+ interactive=True,
397
+ value=500,
398
+ )
399
  with gr.Row():
400
  with gr.Column():
401
  ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
 
430
  )
431
  embed.click(
432
  extract_data_from_pdfs,
433
+ inputs=[doc_collection, session_states, include_images, do_ocr],
434
  outputs=[
435
  vectordb,
436
  session_states,
 
443
 
444
  submit_btn.click(
445
  conversation,
446
+ [
447
+ vectordb,
448
+ msg,
449
+ num_context,
450
+ img_context,
451
+ chatbot,
452
+ temp,
453
+ max_tokens,
454
+ hf_token,
455
+ model_path,
456
+ ],
457
  [chatbot, references, ret_images],
458
  )
459
 
requirements.txt CHANGED
@@ -7,8 +7,10 @@ pandas==2.2.2
7
  Pillow==10.3.0
8
  pymupdf==1.24.5
9
  sentence_transformers==3.0.1
10
- unstructured[all-docs]
11
  accelerate
12
  bitsandbytes
13
  easyocr
14
- ocrmypdf
 
 
 
 
7
  Pillow==10.3.0
8
  pymupdf==1.24.5
9
  sentence_transformers==3.0.1
 
10
  accelerate
11
  bitsandbytes
12
  easyocr
13
+ ocrmypdf
14
+ tf2onnx
15
+ clean-text[gpl]
16
+ python-doctr[torch]
utils.py CHANGED
@@ -27,19 +27,17 @@ def extract_pdfs(docs, doc_collection):
27
  def extract_images(docs):
28
  images = []
29
  for doc_path in docs:
30
- doc = pymupdf.open(doc_path) # open a document
31
 
32
- for page_index in range(len(doc)): # iterate over pdf pages
33
- page = doc[page_index] # get the page
34
  image_list = page.get_images()
35
 
36
- for image_index, img in enumerate(
37
- image_list, start=1
38
- ): # enumerate the image list
39
- xref = img[0] # get the XREF of the image
40
- pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
41
 
42
- if pix.n - pix.alpha > 3: # CMYK: convert to RGB first
43
  pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
44
 
45
  images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
 
27
  def extract_images(docs):
28
  images = []
29
  for doc_path in docs:
30
+ doc = pymupdf.open(doc_path)
31
 
32
+ for page_index in range(len(doc)):
33
+ page = doc[page_index]
34
  image_list = page.get_images()
35
 
36
+ for _, img in enumerate(image_list, start=1):
37
+ xref = img[0]
38
+ pix = pymupdf.Pixmap(doc, xref)
 
 
39
 
40
+ if pix.n - pix.alpha > 3:
41
  pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
42
 
43
  images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))