anand004 commited on
Commit
e70cddd
1 Parent(s): 7afeb0e

bug fixes, faster ocr and restructure

Browse files
Files changed (3) hide show
  1. app.py +125 -158
  2. requirements.txt +4 -1
  3. utils.py +53 -0
app.py CHANGED
@@ -1,23 +1,23 @@
 
 
 
1
  import gradio as gr
2
- from unstructured.partition.pdf import partition_pdf
3
- import pymupdf
4
- from PIL import Image
5
- import numpy as np
6
  import io
 
 
 
7
  import pandas as pd
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter
9
- import gc
10
  import torch
11
- import chromadb
12
- from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
13
- from chromadb.utils.data_loaders import ImageLoader
14
- from sentence_transformers import SentenceTransformer
15
  from chromadb.utils import embedding_functions
16
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
17
- import base64
18
- from langchain_community.llms import HuggingFaceEndpoint
19
  from langchain import PromptTemplate
20
- import spaces
 
 
 
 
21
 
22
  if torch.cuda.is_available():
23
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
@@ -29,24 +29,17 @@ if torch.cuda.is_available():
29
  )
30
 
31
 
32
- def image_to_bytes(image):
33
- img_byte_arr = io.BytesIO()
34
- image.save(img_byte_arr, format="PNG")
35
- return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
36
-
37
-
38
- @spaces.GPU(duration=60*4)
39
- def get_image_descriptions(images):
40
  torch.cuda.empty_cache()
41
  gc.collect()
42
 
43
  descriptions = []
44
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
45
 
46
- for img in images:
47
- inputs = processor(prompt, img, return_tensors="pt").to("cuda:0")
48
- output = vision_model.generate(**inputs, max_new_tokens=100)
49
- descriptions.append(processor.decode(output[0], skip_special_tokens=True))
50
  return descriptions
51
 
52
 
@@ -55,39 +48,6 @@ CSS = """
55
  """
56
 
57
 
58
- def extract_pdfs(docs, doc_collection):
59
- if docs:
60
- doc_collection = []
61
- doc_collection.extend(docs)
62
- return (
63
- doc_collection,
64
- gr.Tabs(selected=1),
65
- pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]),
66
- )
67
-
68
-
69
- def extract_images(docs):
70
- images = []
71
- for doc_path in docs:
72
- doc = pymupdf.open(doc_path) # open a document
73
-
74
- for page_index in range(len(doc)): # iterate over pdf pages
75
- page = doc[page_index] # get the page
76
- image_list = page.get_images()
77
-
78
- for image_index, img in enumerate(
79
- image_list, start=1
80
- ): # enumerate the image list
81
- xref = img[0] # get the XREF of the image
82
- pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
83
-
84
- if pix.n - pix.alpha > 3: # CMYK: convert to RGB first
85
- pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
86
-
87
- images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
88
- return images
89
-
90
-
91
  # def get_vectordb(text, images, tables):
92
  def get_vectordb(text, images):
93
  client = chromadb.EphemeralClient()
@@ -99,7 +59,7 @@ def get_vectordb(text, images):
99
  client.delete_collection("text_db")
100
  if "image_db" in [i.name for i in client.list_collections()]:
101
  client.delete_collection("image_db")
102
-
103
  text_collection = client.get_or_create_collection(
104
  name="text_db",
105
  embedding_function=sentence_transformer_ef,
@@ -111,14 +71,21 @@ def get_vectordb(text, images):
111
  data_loader=loader,
112
  metadata={"hnsw:space": "cosine"},
113
  )
114
-
115
- image_descriptions = get_image_descriptions(images)
116
- image_dict = [{"image": image_to_bytes(img) for img in images}]
117
-
118
- if len(images)>0:
 
 
 
 
 
 
 
119
  image_collection.add(
120
  ids=[str(i) for i in range(len(images))],
121
- documents=image_descriptions,
122
  metadatas=image_dict,
123
  )
124
 
@@ -127,7 +94,7 @@ def get_vectordb(text, images):
127
  chunk_overlap=10,
128
  )
129
 
130
- if len(text)>0:
131
  docs = splitter.create_documents([text])
132
  doc_texts = [i.page_content for i in docs]
133
  text_collection.add(
@@ -136,54 +103,31 @@ def get_vectordb(text, images):
136
  return client
137
 
138
 
139
- def extract_data_from_pdfs(docs, session, progress=gr.Progress()):
140
  if len(docs) == 0:
141
  raise gr.Error("No documents to process")
142
  progress(0, "Extracting Images")
143
 
144
- images = extract_images(docs)
145
 
146
  progress(0.25, "Extracting Text")
147
 
148
  strategy = "hi_res"
149
  model_name = "yolox"
150
  all_elements = []
151
-
152
- for doc in docs:
153
- elements = partition_pdf(
154
- filename=doc,
155
- strategy=strategy,
156
- infer_table_structure=True,
157
- model_name=model_name,
158
- )
159
-
160
- all_elements.extend(elements)
161
-
162
  all_text = ""
163
 
164
- # tables = []
165
-
166
- prev = None
167
- for i in all_elements:
168
- meta = i.to_dict()
169
- if meta["type"].lower() not in ["table", "figurecaption"]:
170
- if meta["type"].lower() in ["listitem", "title"]:
171
- all_text += "\n\n" + meta["text"] + "\n"
172
- else:
173
- all_text += meta["text"]
174
- elif meta["type"] == "Table":
175
- continue
176
- # tables.append(meta["metadata"]["text_as_html"])
177
-
178
- # html = "<br>".join(tables)
179
- # display = "<h3>Sample Tables</h3>" + "<br>".join(tables[:2])
180
- # html = gr.HTML(html)
181
- # vectordb = get_vectordb(all_text, images, tables)
182
-
183
- progress(0.5, "Generating image descriptions")
184
- image_descriptions = "\n".join(get_image_descriptions(images))
185
-
186
- progress(0.75, "Inserting data into vector database")
187
  vectordb = get_vectordb(all_text, images)
188
 
189
  progress(1, "Completed")
@@ -205,7 +149,23 @@ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFuncti
205
  )
206
 
207
 
208
- def conversation(vectordb_client, msg, num_context, img_context, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  text_collection = vectordb_client.get_collection(
211
  "text_db", embedding_function=sentence_transformer_ef
@@ -217,8 +177,6 @@ def conversation(vectordb_client, msg, num_context, img_context, history):
217
  results = text_collection.query(
218
  query_texts=[msg], include=["documents"], n_results=num_context
219
  )["documents"][0]
220
- # print(results)
221
- # print("R"*100)
222
  similar_images = image_collection.query(
223
  query_texts=[msg],
224
  include=["metadatas", "distances", "documents"],
@@ -266,19 +224,12 @@ def get_stats(vectordb):
266
  return "\n".join(text_data), "", ""
267
 
268
 
269
- llm = HuggingFaceEndpoint(
270
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
271
- temperature=0.4,
272
- max_new_tokens=800,
273
- )
274
-
275
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
276
-
277
  vectordb = gr.State()
278
  doc_collection = gr.State(value=[])
279
  session_states = gr.State(value={})
280
  references = gr.State(value=[])
281
-
282
  gr.Markdown(
283
  """<h2><center>Multimodal PDF Chatbot</center></h2>
284
  <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
@@ -312,18 +263,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
312
  embed = gr.Button(value="Extract Data")
313
  with gr.Column():
314
  next_p1 = gr.Button(value="Next")
 
 
 
 
 
 
 
315
 
316
- with gr.Row() as row:
317
- with gr.Column():
318
- selected = gr.Dataframe(
319
- interactive=False,
320
- col_count=(1, "fixed"),
321
- headers=["Selected Files"],
322
- )
323
- with gr.Column(variant="panel"):
324
- prog = gr.HTML(
325
- value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
326
- )
327
 
328
  with gr.Accordion("See Parts of Extracted Data", open=False):
329
  with gr.Column(visible=True) as sample_data:
@@ -337,32 +293,37 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
337
  label="Sample Extracted Images", columns=1, rows=2
338
  )
339
 
340
-
341
-
342
  with gr.TabItem("Chat", id=2) as chat_tab:
343
- with gr.Column():
344
- choice = gr.Radio(
345
- ["chromaDB"],
346
- value="chromaDB",
347
- label="Vector Database",
348
- interactive=True,
349
- )
350
- num_context = gr.Slider(
351
- label="Number of text context elements",
352
- minimum=1,
353
- maximum=20,
354
- step=1,
355
- interactive=True,
356
- value=3,
357
- )
358
- img_context = gr.Slider(
359
- label="Number of image context elements",
360
- minimum=1,
361
- maximum=10,
362
- step=1,
363
- interactive=True,
364
- value=2,
365
- )
 
 
 
 
 
 
 
366
  with gr.Row():
367
  with gr.Column():
368
  ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
@@ -370,14 +331,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
370
  chatbot = gr.Chatbot(height=400)
371
  with gr.Accordion("Text References", open=False):
372
  # text_context = gr.Row()
373
-
374
  @gr.render(inputs=references)
375
  def gen_refs(references):
376
  # print(references)
377
  n = len(references)
378
  for i in range(n):
379
- gr.Textbox(label=f"Reference-{i+1}", value=references[i], lines=3)
380
-
 
381
 
382
  with gr.Row():
383
  msg = gr.Textbox(
@@ -396,7 +358,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
396
  )
397
  embed.click(
398
  extract_data_from_pdfs,
399
- inputs=[doc_collection, session_states],
400
  outputs=[
401
  vectordb,
402
  session_states,
@@ -409,13 +371,18 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
409
 
410
  submit_btn.click(
411
  conversation,
412
- [vectordb, msg, num_context, img_context, chatbot],
413
- [chatbot,references ,ret_images],
414
  )
415
 
 
 
 
 
 
416
 
417
  back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
418
 
419
  next_p1.click(check_validity_and_llm, session_states, tabs)
420
  if __name__ == "__main__":
421
- demo.launch()
 
1
+ import base64
2
+ import chromadb
3
+ import gc
4
  import gradio as gr
 
 
 
 
5
  import io
6
+ import numpy as np
7
+ import ocrmypdf
8
+ import os
9
  import pandas as pd
10
+ import pymupdf
 
11
  import torch
12
+ from PIL import Image
 
 
 
13
  from chromadb.utils import embedding_functions
14
+ from chromadb.utils.data_loaders import ImageLoader
 
 
15
  from langchain import PromptTemplate
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain_community.llms import HuggingFaceEndpoint
18
+ from pdfminer.high_level import extract_text
19
+ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
20
+ from utils import *
21
 
22
  if torch.cuda.is_available():
23
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
 
29
  )
30
 
31
 
32
+ @spaces.GPU()
33
+ def get_image_description(image):
 
 
 
 
 
 
34
  torch.cuda.empty_cache()
35
  gc.collect()
36
 
37
  descriptions = []
38
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
39
 
40
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
41
+ output = vision_model.generate(**inputs, max_new_tokens=100)
42
+ descriptions.append(processor.decode(output[0], skip_special_tokens=True))
 
43
  return descriptions
44
 
45
 
 
48
  """
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # def get_vectordb(text, images, tables):
52
  def get_vectordb(text, images):
53
  client = chromadb.EphemeralClient()
 
59
  client.delete_collection("text_db")
60
  if "image_db" in [i.name for i in client.list_collections()]:
61
  client.delete_collection("image_db")
62
+
63
  text_collection = client.get_or_create_collection(
64
  name="text_db",
65
  embedding_function=sentence_transformer_ef,
 
71
  data_loader=loader,
72
  metadata={"hnsw:space": "cosine"},
73
  )
74
+ descs = []
75
+ print(descs)
76
+ for image in images:
77
+ try:
78
+ descs.append(get_image_description(image)[0])
79
+ except:
80
+ descs.append("Could not generate image description due to some error")
81
+
82
+ # image_descriptions = get_image_descriptions(images)
83
+ image_dict = [{"image": image_to_bytes(img)} for img in images]
84
+
85
+ if len(images) > 0:
86
  image_collection.add(
87
  ids=[str(i) for i in range(len(images))],
88
+ documents=descs,
89
  metadatas=image_dict,
90
  )
91
 
 
94
  chunk_overlap=10,
95
  )
96
 
97
+ if len(text) > 0:
98
  docs = splitter.create_documents([text])
99
  doc_texts = [i.page_content for i in docs]
100
  text_collection.add(
 
103
  return client
104
 
105
 
106
+ def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()):
107
  if len(docs) == 0:
108
  raise gr.Error("No documents to process")
109
  progress(0, "Extracting Images")
110
 
111
+ # images = extract_images(docs)
112
 
113
  progress(0.25, "Extracting Text")
114
 
115
  strategy = "hi_res"
116
  model_name = "yolox"
117
  all_elements = []
 
 
 
 
 
 
 
 
 
 
 
118
  all_text = ""
119
 
120
+ images = []
121
+ for doc in docs:
122
+ ocrmypdf.ocr(doc, "ocr.pdf", deskew=True, skip_text=True)
123
+ text = extract_text("ocr.pdf")
124
+ all_text += clean_text(text) + "\n\n"
125
+ if include_images == "Include Images":
126
+ images.extend(extract_images(["ocr.pdf"]))
127
+
128
+ progress(
129
+ 0.6, "Generating image descriptions and inserting everything into vectorDB"
130
+ )
 
 
 
 
 
 
 
 
 
 
 
 
131
  vectordb = get_vectordb(all_text, images)
132
 
133
  progress(1, "Completed")
 
149
  )
150
 
151
 
152
+ def conversation(
153
+ vectordb_client, msg, num_context, img_context, history, hf_token, model_path
154
+ ):
155
+ if hf_token.strip() != "" and model_path.strip() != "":
156
+ llm = HuggingFaceEndpoint(
157
+ repo_id=model_path,
158
+ temperature=0.4,
159
+ max_new_tokens=800,
160
+ huggingfacehub_api_token=hf_token,
161
+ )
162
+ else:
163
+ llm = HuggingFaceEndpoint(
164
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
165
+ temperature=0.4,
166
+ max_new_tokens=800,
167
+ huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
168
+ )
169
 
170
  text_collection = vectordb_client.get_collection(
171
  "text_db", embedding_function=sentence_transformer_ef
 
177
  results = text_collection.query(
178
  query_texts=[msg], include=["documents"], n_results=num_context
179
  )["documents"][0]
 
 
180
  similar_images = image_collection.query(
181
  query_texts=[msg],
182
  include=["metadatas", "distances", "documents"],
 
224
  return "\n".join(text_data), "", ""
225
 
226
 
227
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
 
 
 
 
 
 
 
228
  vectordb = gr.State()
229
  doc_collection = gr.State(value=[])
230
  session_states = gr.State(value={})
231
  references = gr.State(value=[])
232
+
233
  gr.Markdown(
234
  """<h2><center>Multimodal PDF Chatbot</center></h2>
235
  <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
 
263
  embed = gr.Button(value="Extract Data")
264
  with gr.Column():
265
  next_p1 = gr.Button(value="Next")
266
+ with gr.Row():
267
+ include_images = gr.Radio(
268
+ ["Include Images", "Exclude Images"],
269
+ value="Include Images",
270
+ label="Include/ Exclude Images",
271
+ interactive=True,
272
+ )
273
 
274
+ with gr.Row(equal_height=True, variant="panel") as row:
275
+ selected = gr.Dataframe(
276
+ interactive=False,
277
+ col_count=(1, "fixed"),
278
+ headers=["Selected Files"],
279
+ )
280
+ prog = gr.HTML(
281
+ value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
282
+ )
 
 
283
 
284
  with gr.Accordion("See Parts of Extracted Data", open=False):
285
  with gr.Column(visible=True) as sample_data:
 
293
  label="Sample Extracted Images", columns=1, rows=2
294
  )
295
 
 
 
296
  with gr.TabItem("Chat", id=2) as chat_tab:
297
+ with gr.Accordion("Config (Advanced) (Optional)", open=False):
298
+ with gr.Row(variant="panel", equal_height=True):
299
+ choice = gr.Radio(
300
+ ["chromaDB"],
301
+ value="chromaDB",
302
+ label="Vector Database",
303
+ interactive=True,
304
+ )
305
+ with gr.Accordion("Use your own model (optional)", open=False):
306
+ hf_token = gr.Textbox(
307
+ label="HuggingFace Token", interactive=True
308
+ )
309
+ model_path = gr.Textbox(label="Model Path", interactive=True)
310
+ with gr.Row(variant="panel", equal_height=True):
311
+ num_context = gr.Slider(
312
+ label="Number of text context elements",
313
+ minimum=1,
314
+ maximum=20,
315
+ step=1,
316
+ interactive=True,
317
+ value=3,
318
+ )
319
+ img_context = gr.Slider(
320
+ label="Number of image context elements",
321
+ minimum=1,
322
+ maximum=10,
323
+ step=1,
324
+ interactive=True,
325
+ value=2,
326
+ )
327
  with gr.Row():
328
  with gr.Column():
329
  ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
 
331
  chatbot = gr.Chatbot(height=400)
332
  with gr.Accordion("Text References", open=False):
333
  # text_context = gr.Row()
334
+
335
  @gr.render(inputs=references)
336
  def gen_refs(references):
337
  # print(references)
338
  n = len(references)
339
  for i in range(n):
340
+ gr.Textbox(
341
+ label=f"Reference-{i+1}", value=references[i], lines=3
342
+ )
343
 
344
  with gr.Row():
345
  msg = gr.Textbox(
 
358
  )
359
  embed.click(
360
  extract_data_from_pdfs,
361
+ inputs=[doc_collection, session_states, include_images],
362
  outputs=[
363
  vectordb,
364
  session_states,
 
371
 
372
  submit_btn.click(
373
  conversation,
374
+ [vectordb, msg, num_context, img_context, chatbot, hf_token, model_path],
375
+ [chatbot, references, ret_images],
376
  )
377
 
378
+ msg.submit(
379
+ conversation,
380
+ [vectordb, msg, num_context, img_context, chatbot, hf_token, model_path],
381
+ [chatbot, references, ret_images],
382
+ )
383
 
384
  back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
385
 
386
  next_p1.click(check_validity_and_llm, session_states, tabs)
387
  if __name__ == "__main__":
388
+ demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  chromadb==0.5.3
2
  langchain==0.2.5
3
  langchain_community==0.2.5
 
4
  numpy<2.0.0
5
  pandas==2.2.2
6
  Pillow==10.3.0
@@ -8,4 +9,6 @@ pymupdf==1.24.5
8
  sentence_transformers==3.0.1
9
  unstructured[all-docs]
10
  accelerate
11
- bitsandbytes
 
 
 
1
  chromadb==0.5.3
2
  langchain==0.2.5
3
  langchain_community==0.2.5
4
+ langchain-huggingface
5
  numpy<2.0.0
6
  pandas==2.2.2
7
  Pillow==10.3.0
 
9
  sentence_transformers==3.0.1
10
  unstructured[all-docs]
11
  accelerate
12
+ bitsandbytes
13
+ easyocr
14
+ ocrmypdf
utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymupdf
2
+ from PIL import Image
3
+ import io
4
+ import gradio as gr
5
+ import pandas as pd
6
+
7
+
8
+ def image_to_bytes(image):
9
+ img_byte_arr = io.BytesIO()
10
+ image.save(img_byte_arr, format="PNG")
11
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
12
+
13
+
14
+ def extract_pdfs(docs, doc_collection):
15
+ if docs:
16
+ doc_collection = []
17
+ doc_collection.extend(docs)
18
+ return (
19
+ doc_collection,
20
+ gr.Tabs(selected=1),
21
+ pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]),
22
+ )
23
+
24
+
25
+ def extract_images(docs):
26
+ images = []
27
+ for doc_path in docs:
28
+ doc = pymupdf.open(doc_path) # open a document
29
+
30
+ for page_index in range(len(doc)): # iterate over pdf pages
31
+ page = doc[page_index] # get the page
32
+ image_list = page.get_images()
33
+
34
+ for image_index, img in enumerate(
35
+ image_list, start=1
36
+ ): # enumerate the image list
37
+ xref = img[0] # get the XREF of the image
38
+ pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
39
+
40
+ if pix.n - pix.alpha > 3: # CMYK: convert to RGB first
41
+ pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
42
+
43
+ images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
44
+ return images
45
+
46
+
47
+ def clean_text(text):
48
+ text = text.strip()
49
+ cleaned_text = text.replace("\n", " ")
50
+ cleaned_text = cleaned_text.replace("\t", " ")
51
+ cleaned_text = cleaned_text.replace(" ", " ")
52
+ cleaned_text = cleaned_text.strip()
53
+ return cleaned_text