Spaces:
Runtime error
Runtime error
import gradio as gr | |
from unstructured.partition.pdf import partition_pdf | |
import pymupdf | |
from PIL import Image | |
import numpy as np | |
import io | |
import pandas as pd | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import gc | |
import torch | |
import chromadb | |
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
from chromadb.utils.data_loaders import ImageLoader | |
from sentence_transformers import SentenceTransformer | |
from chromadb.utils import embedding_functions | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
import base64 | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain import PromptTemplate | |
import spaces | |
if torch.cuda.is_available(): | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
load_in_4bit=True, | |
) | |
def image_to_bytes(image): | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
def get_image_descriptions(images): | |
torch.cuda.empty_cache() | |
gc.collect() | |
descriptions = [] | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
for img in images: | |
inputs = processor(prompt, img, return_tensors="pt").to("cuda:0") | |
output = vision_model.generate(**inputs, max_new_tokens=100) | |
descriptions.append(processor.decode(output[0], skip_special_tokens=True)) | |
return descriptions | |
CSS = """ | |
#table_col {background-color: rgb(33, 41, 54);} | |
""" | |
def extract_pdfs(docs, doc_collection): | |
if docs: | |
doc_collection = [] | |
doc_collection.extend(docs) | |
return ( | |
doc_collection, | |
gr.Tabs(selected=1), | |
pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]), | |
) | |
def extract_images(docs): | |
images = [] | |
for doc_path in docs: | |
doc = pymupdf.open(doc_path) # open a document | |
for page_index in range(len(doc)): # iterate over pdf pages | |
page = doc[page_index] # get the page | |
image_list = page.get_images() | |
for image_index, img in enumerate( | |
image_list, start=1 | |
): # enumerate the image list | |
xref = img[0] # get the XREF of the image | |
pix = pymupdf.Pixmap(doc, xref) # create a Pixmap | |
if pix.n - pix.alpha > 3: # CMYK: convert to RGB first | |
pix = pymupdf.Pixmap(pymupdf.csRGB, pix) | |
images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG")))) | |
return images | |
# def get_vectordb(text, images, tables): | |
def get_vectordb(text, images): | |
client = chromadb.EphemeralClient() | |
loader = ImageLoader() | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="multi-qa-mpnet-base-dot-v1" | |
) | |
if "text_db" in [i.name for i in client.list_collections()]: | |
client.delete_collection("text_db") | |
if "image_db" in [i.name for i in client.list_collections()]: | |
client.delete_collection("image_db") | |
text_collection = client.get_or_create_collection( | |
name="text_db", | |
embedding_function=sentence_transformer_ef, | |
data_loader=loader, | |
) | |
image_collection = client.get_or_create_collection( | |
name="image_db", | |
embedding_function=sentence_transformer_ef, | |
data_loader=loader, | |
metadata={"hnsw:space": "cosine"}, | |
) | |
image_descriptions = get_image_descriptions(images) | |
image_dict = [{"image": image_to_bytes(img) for img in images}] | |
if len(images)>0: | |
image_collection.add( | |
ids=[str(i) for i in range(len(images))], | |
documents=image_descriptions, | |
metadatas=image_dict, | |
) | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=10, | |
) | |
if len(text)>0: | |
docs = splitter.create_documents([text]) | |
doc_texts = [i.page_content for i in docs] | |
text_collection.add( | |
ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts | |
) | |
return client | |
def extract_data_from_pdfs(docs, session, progress=gr.Progress()): | |
if len(docs) == 0: | |
raise gr.Error("No documents to process") | |
progress(0, "Extracting Images") | |
images = extract_images(docs) | |
progress(0.25, "Extracting Text") | |
strategy = "hi_res" | |
model_name = "yolox" | |
all_elements = [] | |
for doc in docs: | |
elements = partition_pdf( | |
filename=doc, | |
strategy=strategy, | |
infer_table_structure=True, | |
model_name=model_name, | |
) | |
all_elements.extend(elements) | |
all_text = "" | |
# tables = [] | |
prev = None | |
for i in all_elements: | |
meta = i.to_dict() | |
if meta["type"].lower() not in ["table", "figurecaption"]: | |
if meta["type"].lower() in ["listitem", "title"]: | |
all_text += "\n\n" + meta["text"] + "\n" | |
else: | |
all_text += meta["text"] | |
elif meta["type"] == "Table": | |
continue | |
# tables.append(meta["metadata"]["text_as_html"]) | |
# html = "<br>".join(tables) | |
# display = "<h3>Sample Tables</h3>" + "<br>".join(tables[:2]) | |
# html = gr.HTML(html) | |
# vectordb = get_vectordb(all_text, images, tables) | |
progress(0.5, "Generating image descriptions") | |
image_descriptions = "\n".join(get_image_descriptions(images)) | |
progress(0.75, "Inserting data into vector database") | |
vectordb = get_vectordb(all_text, images) | |
progress(1, "Completed") | |
session["processed"] = True | |
return ( | |
vectordb, | |
session, | |
gr.Row(visible=True), | |
all_text[:2000] + "...", | |
# display, | |
images[:2], | |
"<h1 style='text-align: center'>Completed<h1>", | |
# image_descriptions | |
) | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="multi-qa-mpnet-base-dot-v1" | |
) | |
def conversation(vectordb_client, msg, num_context, img_context, history): | |
text_collection = vectordb_client.get_collection( | |
"text_db", embedding_function=sentence_transformer_ef | |
) | |
image_collection = vectordb_client.get_collection( | |
"image_db", embedding_function=sentence_transformer_ef | |
) | |
results = text_collection.query( | |
query_texts=[msg], include=["documents"], n_results=num_context | |
)["documents"][0] | |
similar_images = image_collection.query( | |
query_texts=[msg], | |
include=["metadatas", "distances", "documents"], | |
n_results=img_context, | |
) | |
img_links = [i["image"] for i in similar_images["metadatas"][0]] | |
images_and_locs = [ | |
Image.open(io.BytesIO(base64.b64decode(i[1]))) | |
for i in zip(similar_images["distances"][0], img_links) | |
] | |
img_desc = "\n".join(similar_images["documents"][0]) | |
if len(img_links) == 0: | |
img_desc = "No Images Are Provided" | |
template = """ | |
Context: | |
{context} | |
Included Images: | |
{images} | |
Question: | |
{question} | |
Answer: | |
""" | |
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
context = "\n\n".join(results) | |
# references = [gr.Textbox(i, visible=True, interactive=False) for i in results] | |
response = llm(prompt.format(context=context, question=msg, images=img_desc)) | |
yield history + [(msg, response)], results, images_and_locs | |
def check_validity_and_llm(session_states): | |
if session_states.get("processed", False) == True: | |
return gr.Tabs(selected=2) | |
raise gr.Error("Please extract data first") | |
def get_stats(vectordb): | |
eles = vectordb.get() | |
# words = | |
text_data = [f"Chunks: {len(eles)}", "HIII"] | |
return "\n".join(text_data), "", "" | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
temperature=0.4, | |
max_new_tokens=800, | |
) | |
with gr.Blocks(css=CSS) as demo: | |
vectordb = gr.State() | |
doc_collection = gr.State(value=[]) | |
session_states = gr.State(value={}) | |
references = gr.State(value=[]) | |
gr.Markdown( | |
"""<h2><center>Multimodal PDF Chatbot</center></h2> | |
<h3><center><b>Interact With Your PDF Documents</b></center></h3>""" | |
) | |
gr.Markdown( | |
"""<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br> | |
<center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>""" | |
) | |
gr.Markdown( | |
""" | |
<center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center> | |
""" | |
) | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Upload PDFs", id=0) as pdf_tab: | |
with gr.Row(): | |
with gr.Column(): | |
documents = gr.File( | |
file_count="multiple", | |
file_types=["pdf"], | |
interactive=True, | |
label="Upload your PDF file/s", | |
) | |
pdf_btn = gr.Button(value="Next", elem_id="button1") | |
with gr.TabItem("Extract Data", id=1) as preprocess: | |
with gr.Row(): | |
with gr.Column(): | |
back_p1 = gr.Button(value="Back") | |
with gr.Column(): | |
embed = gr.Button(value="Extract Data") | |
with gr.Column(): | |
next_p1 = gr.Button(value="Next") | |
with gr.Row() as row: | |
with gr.Column(): | |
selected = gr.Dataframe( | |
interactive=False, | |
col_count=(1, "fixed"), | |
headers=["Selected Files"], | |
) | |
with gr.Column(variant="panel"): | |
prog = gr.HTML( | |
value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>" | |
) | |
with gr.Accordion("See Parts of Extracted Data", open=False): | |
with gr.Column(visible=True) as sample_data: | |
with gr.Row(): | |
with gr.Column(): | |
ext_text = gr.Textbox( | |
label="Sample Extracted Text", lines=15 | |
) | |
with gr.Column(): | |
images = gr.Gallery( | |
label="Sample Extracted Images", columns=1, rows=2 | |
) | |
with gr.TabItem("Chat", id=2) as chat_tab: | |
with gr.Column(): | |
choice = gr.Radio( | |
["chromaDB"], | |
value="chromaDB", | |
label="Vector Database", | |
interactive=True, | |
) | |
num_context = gr.Slider( | |
label="Number of text context elements", | |
minimum=1, | |
maximum=20, | |
step=1, | |
interactive=True, | |
value=3, | |
) | |
img_context = gr.Slider( | |
label="Number of image context elements", | |
minimum=1, | |
maximum=10, | |
step=1, | |
interactive=True, | |
value=2, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2) | |
with gr.Column(): | |
chatbot = gr.Chatbot(height=400) | |
with gr.Accordion("Text References", open=False): | |
# text_context = gr.Row() | |
def gen_refs(refs): | |
n = len(refs) | |
for i in range(n): | |
gr.Textbox(label=f"Ref-{i+1}", value=refs[i], lines=3) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Type your question here (e.g. 'What is this document about?')", | |
interactive=True, | |
container=True, | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit message") | |
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation") | |
pdf_btn.click( | |
fn=extract_pdfs, | |
inputs=[documents, doc_collection], | |
outputs=[doc_collection, tabs, selected], | |
) | |
embed.click( | |
extract_data_from_pdfs, | |
inputs=[doc_collection, session_states], | |
outputs=[ | |
vectordb, | |
session_states, | |
sample_data, | |
ext_text, | |
images, | |
prog, | |
], | |
) | |
submit_btn.click( | |
conversation, | |
[vectordb, msg, num_context, img_context, chatbot], | |
[chatbot, references, ret_images], | |
) | |
back_p1.click(lambda: gr.Tabs(selected=0), None, tabs) | |
next_p1.click(check_validity_and_llm, session_states, tabs) | |
if __name__ == "__main__": | |
demo.launch(share=True) |