anand004 commited on
Commit
3f98f11
1 Parent(s): 243e843

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +434 -0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from unstructured.partition.pdf import partition_pdf
3
+ import pymupdf
4
+ from PIL import Image
5
+ import numpy as np
6
+ import io
7
+ import pandas as pd
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ import gc
10
+ import torch
11
+ import chromadb
12
+ from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
13
+ from chromadb.utils.data_loaders import ImageLoader
14
+ from sentence_transformers import SentenceTransformer
15
+ from chromadb.utils import embedding_functions
16
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
17
+ import base64
18
+ from langchain_community.llms import HuggingFaceEndpoint
19
+ from langchain import PromptTemplate
20
+ import spaces
21
+
22
+ if torch.cuda.is_available():
23
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
24
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
25
+ "llava-hf/llava-v1.6-mistral-7b-hf",
26
+ torch_dtype=torch.float16,
27
+ low_cpu_mem_usage=True,
28
+ load_in_4bit=True,
29
+ )
30
+
31
+
32
+ def image_to_bytes(image):
33
+ img_byte_arr = io.BytesIO()
34
+ image.save(img_byte_arr, format="PNG")
35
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
36
+
37
+
38
+ @spaces.GPU
39
+ def get_image_descriptions(images):
40
+ torch.cuda.empty_cache()
41
+ gc.collect()
42
+
43
+ descriptions = []
44
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
45
+
46
+ for img in images:
47
+ inputs = processor(prompt, img, return_tensors="pt").to("cuda:0")
48
+ output = vision_model.generate(**inputs, max_new_tokens=100)
49
+ descriptions.append(processor.decode(output[0], skip_special_tokens=True))
50
+ return descriptions
51
+
52
+
53
+ CSS = """
54
+ #table_col {background-color: rgb(33, 41, 54);}
55
+ """
56
+
57
+
58
+ def extract_pdfs(docs, doc_collection):
59
+ if docs:
60
+ doc_collection = []
61
+ doc_collection.extend(docs)
62
+ return (
63
+ doc_collection,
64
+ gr.Tabs(selected=1),
65
+ pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]),
66
+ )
67
+
68
+
69
+ def extract_images(docs):
70
+ images = []
71
+ for doc_path in docs:
72
+ doc = pymupdf.open(doc_path) # open a document
73
+
74
+ for page_index in range(len(doc)): # iterate over pdf pages
75
+ page = doc[page_index] # get the page
76
+ image_list = page.get_images()
77
+
78
+ for image_index, img in enumerate(
79
+ image_list, start=1
80
+ ): # enumerate the image list
81
+ xref = img[0] # get the XREF of the image
82
+ pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
83
+
84
+ if pix.n - pix.alpha > 3: # CMYK: convert to RGB first
85
+ pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
86
+
87
+ images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
88
+ return images
89
+
90
+
91
+ # def get_vectordb(text, images, tables):
92
+ def get_vectordb(text, images):
93
+ client = chromadb.EphemeralClient()
94
+ loader = ImageLoader()
95
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
96
+ model_name="multi-qa-mpnet-base-dot-v1"
97
+ )
98
+ if "text_db" in [i.name for i in client.list_collections()]:
99
+ client.delete_collection("text_db")
100
+ if "image_db" in [i.name for i in client.list_collections()]:
101
+ client.delete_collection("image_db")
102
+ text_collection = client.get_or_create_collection(
103
+ name="text_db",
104
+ embedding_function=sentence_transformer_ef,
105
+ data_loader=loader,
106
+ )
107
+ image_collection = client.get_or_create_collection(
108
+ name="image_db",
109
+ embedding_function=sentence_transformer_ef,
110
+ data_loader=loader,
111
+ metadata={"hnsw:space": "cosine"},
112
+ )
113
+
114
+ image_descriptions = get_image_descriptions(images)
115
+ image_dict = [{"image": image_to_bytes(img) for img in images}]
116
+
117
+ image_collection.add(
118
+ ids=[str(i) for i in range(len(images))],
119
+ documents=image_descriptions,
120
+ metadatas=image_dict,
121
+ )
122
+
123
+ splitter = RecursiveCharacterTextSplitter(
124
+ chunk_size=500,
125
+ chunk_overlap=10,
126
+ )
127
+
128
+ docs = splitter.create_documents([text])
129
+ doc_texts = [i.page_content for i in docs]
130
+ text_collection.add(
131
+ ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
132
+ )
133
+ return client
134
+
135
+
136
+ def extract_data_from_pdfs(docs, session, progress=gr.Progress()):
137
+ if len(docs) == 0:
138
+ raise gr.Error("No documents to process")
139
+ progress(0, "Extracting Images")
140
+
141
+ images = extract_images(docs)
142
+
143
+ progress(0.25, "Extracting Text")
144
+
145
+ strategy = "hi_res"
146
+ model_name = "yolox"
147
+ all_elements = []
148
+
149
+ for doc in docs:
150
+ elements = partition_pdf(
151
+ filename=doc,
152
+ strategy=strategy,
153
+ infer_table_structure=True,
154
+ model_name=model_name,
155
+ )
156
+
157
+ all_elements.extend(elements)
158
+
159
+ all_text = ""
160
+
161
+ # tables = []
162
+
163
+ prev = None
164
+ for i in all_elements:
165
+ meta = i.to_dict()
166
+ if meta["type"].lower() not in ["table", "figurecaption"]:
167
+ if meta["type"].lower() in ["listitem", "title"]:
168
+ all_text += "\n\n" + meta["text"] + "\n"
169
+ else:
170
+ all_text += meta["text"]
171
+ elif meta["type"] == "Table":
172
+ continue
173
+ # tables.append(meta["metadata"]["text_as_html"])
174
+
175
+ # html = "<br>".join(tables)
176
+ # display = "<h3>Sample Tables</h3>" + "<br>".join(tables[:2])
177
+ # html = gr.HTML(html)
178
+ # vectordb = get_vectordb(all_text, images, tables)
179
+
180
+ progress(0.5, "Generating image descriptions")
181
+ image_descriptions = "\n".join(get_image_descriptions(images))
182
+
183
+ progress(0.75, "Inserting data into vector database")
184
+ vectordb = get_vectordb(all_text, images)
185
+
186
+ progress(1, "Completed")
187
+ session["processed"] = True
188
+ return (
189
+ vectordb,
190
+ session,
191
+ gr.Row(visible=True),
192
+ all_text[:2000] + "...",
193
+ # display,
194
+ images[:2],
195
+ "<h1 style='text-align: center'>Completed<h1>",
196
+ # image_descriptions
197
+ )
198
+
199
+
200
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
201
+ model_name="multi-qa-mpnet-base-dot-v1"
202
+ )
203
+
204
+
205
+ def conversation(vectordb_client, msg, num_context, img_context, history):
206
+
207
+ text_collection = vectordb_client.get_collection(
208
+ "text_db", embedding_function=sentence_transformer_ef
209
+ )
210
+ image_collection = vectordb_client.get_collection(
211
+ "image_db", embedding_function=sentence_transformer_ef
212
+ )
213
+
214
+ results = text_collection.query(
215
+ query_texts=[msg], include=["documents"], n_results=num_context
216
+ )["documents"][0]
217
+
218
+ similar_images = image_collection.query(
219
+ query_texts=[msg],
220
+ include=["metadatas", "distances", "documents"],
221
+ n_results=img_context,
222
+ )
223
+ img_links = [i["image"] for i in similar_images["metadatas"][0]]
224
+
225
+ images_and_locs = [
226
+ Image.open(io.BytesIO(base64.b64decode(i[1])))
227
+ for i in zip(similar_images["distances"][0], img_links)
228
+ ]
229
+ img_desc = "\n".join(similar_images["documents"][0])
230
+ if len(img_links) == 0:
231
+ img_desc = "No Images Are Provided"
232
+ template = """
233
+ Context:
234
+ {context}
235
+
236
+ Included Images:
237
+ {images}
238
+
239
+ Question:
240
+ {question}
241
+
242
+ Answer:
243
+
244
+ """
245
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
246
+ context = "\n\n".join(results)
247
+ response = llm(prompt.format(context=context, question=msg, images=img_desc))
248
+ return history + [(msg, response)], context, images_and_locs
249
+
250
+
251
+ def check_validity_and_llm(session_states):
252
+ if session_states.get("processed", False) == True:
253
+ return gr.Tabs(selected=2)
254
+ raise gr.Error("Please extract data first")
255
+
256
+
257
+ def get_stats(vectordb):
258
+ eles = vectordb.get()
259
+ # words =
260
+ text_data = [f"Chunks: {len(eles)}", "HIII"]
261
+ return "\n".join(text_data), "", ""
262
+
263
+
264
+ llm = HuggingFaceEndpoint(
265
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
266
+ temperature=0.4,
267
+ max_new_tokens=800,
268
+ )
269
+
270
+ with gr.Blocks(css=CSS) as demo:
271
+
272
+ vectordb = gr.State()
273
+ doc_collection = gr.State(value=[])
274
+ session_states = gr.State(value={})
275
+ gr.Markdown(
276
+ """<h2><center>Multimodal PDF Chatbot</center></h2>
277
+ <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
278
+ )
279
+ gr.Markdown(
280
+ """<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br>
281
+ <center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>"""
282
+ )
283
+ gr.Markdown(
284
+ """
285
+ <center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center>
286
+ """
287
+ )
288
+ with gr.Tabs() as tabs:
289
+ with gr.TabItem("Upload PDFs", id=0) as pdf_tab:
290
+ with gr.Row():
291
+ with gr.Column():
292
+ documents = gr.File(
293
+ file_count="multiple",
294
+ file_types=["pdf"],
295
+ interactive=True,
296
+ label="Upload your PDF file/s",
297
+ )
298
+ pdf_btn = gr.Button(value="Next", elem_id="button1")
299
+
300
+ with gr.TabItem("Extract Data", id=1) as preprocess:
301
+ with gr.Row():
302
+ with gr.Column():
303
+ back_p1 = gr.Button(value="Back")
304
+ with gr.Column():
305
+ embed = gr.Button(value="Extract Data")
306
+ with gr.Column():
307
+ next_p1 = gr.Button(value="Next")
308
+
309
+ with gr.Row() as row:
310
+ with gr.Column():
311
+ selected = gr.Dataframe(
312
+ interactive=False,
313
+ col_count=(1, "fixed"),
314
+ headers=["Selected Files"],
315
+ )
316
+ with gr.Column(variant="panel"):
317
+ prog = gr.HTML(
318
+ value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
319
+ )
320
+
321
+ with gr.Accordion("See Parts of Extracted Data", open=False):
322
+ with gr.Column(visible=True) as sample_data:
323
+ with gr.Row():
324
+ with gr.Column():
325
+ ext_text = gr.Textbox(
326
+ label="Sample Extracted Text", lines=15
327
+ )
328
+ with gr.Column():
329
+ images = gr.Gallery(
330
+ label="Sample Extracted Images", columns=1, rows=2
331
+ )
332
+
333
+ # with gr.Row():
334
+ # image_desc = gr.Textbox(label="Image Descriptions", interactive=False)
335
+ # with gr.Row(variant="panel"):
336
+ # ext_tables = gr.HTML("<h3>Sample Tables</h3>", label="Extracted Tables")
337
+
338
+ # with gr.TabItem("Embeddings", id=3) as embed_tab:
339
+ # with gr.Row():
340
+ # with gr.Column():
341
+ # back_p2 = gr.Button(value="Back")
342
+ # with gr.Column():
343
+ # view_stats = gr.Button(value="View Stats")
344
+ # with gr.Column():
345
+ # next_p2 = gr.Button(value="Next")
346
+
347
+ # with gr.Row():
348
+ # with gr.Column():
349
+ # text_stats = gr.Textbox(label="Text Stats", interactive=False)
350
+ # with gr.Column():
351
+ # table_stats = gr.Textbox(label="Table Stats", interactive=False)
352
+ # with gr.Column():
353
+ # image_stats = gr.Textbox(label="Image Stats", interactive=False)
354
+
355
+ with gr.TabItem("Chat", id=2) as chat_tab:
356
+ with gr.Column():
357
+ choice = gr.Radio(
358
+ ["chromaDB"],
359
+ value="chromaDB",
360
+ label="Vector Database",
361
+ interactive=True,
362
+ )
363
+ num_context = gr.Slider(
364
+ label="Number of text context elements",
365
+ minimum=1,
366
+ maximum=20,
367
+ step=1,
368
+ interactive=True,
369
+ value=3,
370
+ )
371
+ img_context = gr.Slider(
372
+ label="Number of image context elements",
373
+ minimum=1,
374
+ maximum=10,
375
+ step=1,
376
+ interactive=True,
377
+ value=2,
378
+ )
379
+ with gr.Row():
380
+ with gr.Column():
381
+ ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
382
+ with gr.Column():
383
+ chatbot = gr.Chatbot(height=400)
384
+ with gr.Accordion("Text References", open=False):
385
+ with gr.Row():
386
+ text_context = gr.Textbox(interactive=False, lines=10)
387
+
388
+ with gr.Row():
389
+ msg = gr.Textbox(
390
+ placeholder="Type your question here (e.g. 'What is this document about?')",
391
+ interactive=True,
392
+ container=True,
393
+ )
394
+ with gr.Row():
395
+ submit_btn = gr.Button("Submit message")
396
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
397
+
398
+ pdf_btn.click(
399
+ fn=extract_pdfs,
400
+ inputs=[documents, doc_collection],
401
+ outputs=[doc_collection, tabs, selected],
402
+ )
403
+ embed.click(
404
+ extract_data_from_pdfs,
405
+ inputs=[doc_collection, session_states],
406
+ outputs=[
407
+ vectordb,
408
+ session_states,
409
+ sample_data,
410
+ ext_text,
411
+ # ext_tables,
412
+ images,
413
+ prog,
414
+ # image_desc
415
+ ],
416
+ )
417
+
418
+ submit_btn.click(
419
+ conversation,
420
+ [vectordb, msg, num_context, img_context, chatbot],
421
+ [chatbot, text_context, ret_images],
422
+ )
423
+
424
+ # view_stats.click(
425
+ # get_stats, [vectordb], outputs=[text_stats, table_stats, image_stats]
426
+ # )
427
+
428
+ # Page Navigation
429
+
430
+ back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
431
+
432
+ next_p1.click(check_validity_and_llm, session_states, tabs)
433
+ if __name__ == "__main__":
434
+ demo.launch(share=True)