anand004's picture
Update app.py
c0f02eb verified
raw
history blame
13.4 kB
import base64
import chromadb
import gc
import gradio as gr
import io
import numpy as np
import ocrmypdf
import os
import pandas as pd
import pymupdf
import spaces
import torch
from PIL import Image
from chromadb.utils import embedding_functions
from chromadb.utils.data_loaders import ImageLoader
from gradio.themes.utils import sizes
from langchain import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import HuggingFaceEndpoint
from pdfminer.high_level import extract_text
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
from utils import *
if torch.cuda.is_available():
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True,
)
@spaces.GPU()
def get_image_description(image):
torch.cuda.empty_cache()
gc.collect()
descriptions = []
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
output = vision_model.generate(**inputs, max_new_tokens=100)
descriptions.append(processor.decode(output[0], skip_special_tokens=True))
return descriptions
CSS = """
#table_col {background-color: rgb(33, 41, 54);}
"""
# def get_vectordb(text, images, tables):
def get_vectordb(text, images):
client = chromadb.EphemeralClient()
loader = ImageLoader()
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="multi-qa-mpnet-base-dot-v1"
)
if "text_db" in [i.name for i in client.list_collections()]:
client.delete_collection("text_db")
if "image_db" in [i.name for i in client.list_collections()]:
client.delete_collection("image_db")
text_collection = client.get_or_create_collection(
name="text_db",
embedding_function=sentence_transformer_ef,
data_loader=loader,
)
image_collection = client.get_or_create_collection(
name="image_db",
embedding_function=sentence_transformer_ef,
data_loader=loader,
metadata={"hnsw:space": "cosine"},
)
descs = []
print(descs)
for image in images:
try:
descs.append(get_image_description(image)[0])
except:
descs.append("Could not generate image description due to some error")
# image_descriptions = get_image_descriptions(images)
image_dict = [{"image": image_to_bytes(img)} for img in images]
if len(images) > 0:
image_collection.add(
ids=[str(i) for i in range(len(images))],
documents=descs,
metadatas=image_dict,
)
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=10,
)
if len(text) > 0:
docs = splitter.create_documents([text])
doc_texts = [i.page_content for i in docs]
text_collection.add(
ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
)
return client
def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()):
if len(docs) == 0:
raise gr.Error("No documents to process")
progress(0, "Extracting Images")
# images = extract_images(docs)
progress(0.25, "Extracting Text")
strategy = "hi_res"
model_name = "yolox"
all_elements = []
all_text = ""
images = []
for doc in docs:
ocrmypdf.ocr(doc, "ocr.pdf", deskew=True, skip_text=True)
text = extract_text("ocr.pdf")
all_text += clean_text(text) + "\n\n"
if include_images == "Include Images":
images.extend(extract_images(["ocr.pdf"]))
progress(
0.6, "Generating image descriptions and inserting everything into vectorDB"
)
vectordb = get_vectordb(all_text, images)
progress(1, "Completed")
session["processed"] = True
return (
vectordb,
session,
gr.Row(visible=True),
all_text[:2000] + "...",
# display,
images[:2],
"<h1 style='text-align: center'>Completed<h1>",
# image_descriptions
)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="multi-qa-mpnet-base-dot-v1"
)
def conversation(
vectordb_client, msg, num_context, img_context, history, hf_token, model_path
):
if hf_token.strip() != "" and model_path.strip() != "":
llm = HuggingFaceEndpoint(
repo_id=model_path,
temperature=0.4,
max_new_tokens=800,
huggingfacehub_api_token=hf_token,
)
else:
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
temperature=0.4,
max_new_tokens=800,
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
)
text_collection = vectordb_client.get_collection(
"text_db", embedding_function=sentence_transformer_ef
)
image_collection = vectordb_client.get_collection(
"image_db", embedding_function=sentence_transformer_ef
)
results = text_collection.query(
query_texts=[msg], include=["documents"], n_results=num_context
)["documents"][0]
similar_images = image_collection.query(
query_texts=[msg],
include=["metadatas", "distances", "documents"],
n_results=img_context,
)
img_links = [i["image"] for i in similar_images["metadatas"][0]]
images_and_locs = [
Image.open(io.BytesIO(base64.b64decode(i[1])))
for i in zip(similar_images["distances"][0], img_links)
]
img_desc = "\n".join(similar_images["documents"][0])
if len(img_links) == 0:
img_desc = "No Images Are Provided"
template = """
Context:
{context}
Included Images:
{images}
Question:
{question}
Answer:
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
context = "\n\n".join(results)
# references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
response = llm(prompt.format(context=context, question=msg, images=img_desc))
return history + [(msg, response)], results, images_and_locs
def check_validity_and_llm(session_states):
if session_states.get("processed", False) == True:
return gr.Tabs(selected=2)
raise gr.Error("Please extract data first")
def get_stats(vectordb):
eles = vectordb.get()
# words =
text_data = [f"Chunks: {len(eles)}", "HIII"]
return "\n".join(text_data), "", ""
with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
vectordb = gr.State()
doc_collection = gr.State(value=[])
session_states = gr.State(value={})
references = gr.State(value=[])
gr.Markdown(
"""<h2><center>Multimodal PDF Chatbot</center></h2>
<h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
)
gr.Markdown(
"""<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br>
<center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>"""
)
gr.Markdown(
"""
<center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center>
"""
)
with gr.Tabs() as tabs:
with gr.TabItem("Upload PDFs", id=0) as pdf_tab:
with gr.Row():
with gr.Column():
documents = gr.File(
file_count="multiple",
file_types=["pdf"],
interactive=True,
label="Upload your PDF file/s",
)
pdf_btn = gr.Button(value="Next", elem_id="button1")
with gr.TabItem("Extract Data", id=1) as preprocess:
with gr.Row():
with gr.Column():
back_p1 = gr.Button(value="Back")
with gr.Column():
embed = gr.Button(value="Extract Data")
with gr.Column():
next_p1 = gr.Button(value="Next")
with gr.Row():
include_images = gr.Radio(
["Include Images", "Exclude Images"],
value="Include Images",
label="Include/ Exclude Images",
interactive=True,
)
with gr.Row(equal_height=True, variant="panel") as row:
selected = gr.Dataframe(
interactive=False,
col_count=(1, "fixed"),
headers=["Selected Files"],
)
prog = gr.HTML(
value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
)
with gr.Accordion("See Parts of Extracted Data", open=False):
with gr.Column(visible=True) as sample_data:
with gr.Row():
with gr.Column():
ext_text = gr.Textbox(
label="Sample Extracted Text", lines=15
)
with gr.Column():
images = gr.Gallery(
label="Sample Extracted Images", columns=1, rows=2
)
with gr.TabItem("Chat", id=2) as chat_tab:
with gr.Accordion("Config (Advanced) (Optional)", open=False):
with gr.Row(variant="panel", equal_height=True):
choice = gr.Radio(
["chromaDB"],
value="chromaDB",
label="Vector Database",
interactive=True,
)
with gr.Accordion("Use your own model (optional)", open=False):
hf_token = gr.Textbox(
label="HuggingFace Token", interactive=True
)
model_path = gr.Textbox(label="Model Path", interactive=True)
with gr.Row(variant="panel", equal_height=True):
num_context = gr.Slider(
label="Number of text context elements",
minimum=1,
maximum=20,
step=1,
interactive=True,
value=3,
)
img_context = gr.Slider(
label="Number of image context elements",
minimum=1,
maximum=10,
step=1,
interactive=True,
value=2,
)
with gr.Row():
with gr.Column():
ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
with gr.Column():
chatbot = gr.Chatbot(height=400)
with gr.Accordion("Text References", open=False):
# text_context = gr.Row()
@gr.render(inputs=references)
def gen_refs(references):
# print(references)
n = len(references)
for i in range(n):
gr.Textbox(
label=f"Reference-{i+1}", value=references[i], lines=3
)
with gr.Row():
msg = gr.Textbox(
placeholder="Type your question here (e.g. 'What is this document about?')",
interactive=True,
container=True,
)
with gr.Row():
submit_btn = gr.Button("Submit message")
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
pdf_btn.click(
fn=extract_pdfs,
inputs=[documents, doc_collection],
outputs=[doc_collection, tabs, selected],
)
embed.click(
extract_data_from_pdfs,
inputs=[doc_collection, session_states, include_images],
outputs=[
vectordb,
session_states,
sample_data,
ext_text,
images,
prog,
],
)
submit_btn.click(
conversation,
[vectordb, msg, num_context, img_context, chatbot, hf_token, model_path],
[chatbot, references, ret_images],
)
msg.submit(
conversation,
[vectordb, msg, num_context, img_context, chatbot, hf_token, model_path],
[chatbot, references, ret_images],
)
back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
next_p1.click(check_validity_and_llm, session_states, tabs)
if __name__ == "__main__":
demo.launch()