chat_with_pdf / app.py
AamirAli123's picture
Update app.py
c6c42c2 verified
import gradio as gr
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
# from doctr.models import ocr_predictor
# from doctr.io import DocumentFile
from pathlib import Path
import chromadb
# Later Packages
from getpass import getpass
import weasyprint
import matplotlib.pyplot as plt
from langchain.document_loaders import PyPDFDirectoryLoader
load_dotenv()
# model = ocr_predictor(pretrained = True)
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
openai_key = os.getenv("OPEN_API_KEY")
# default_persist_directory = './chroma_HF/'
list_llm = ["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
"google/gemma-7b-it","google/gemma-2b-it", \
"HuggingFaceH4/zephyr-7b-beta", \
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct", \
"google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
#Extract text data from doctr reaponse
def extract_value_from_response(response):
value = ''
for page in response.pages:
for block in page.blocks:
for line in block.lines:
for word in line.words:
value += " "+word.value
return value
# Craete PDf from URL
def create_pdf_from_url(url):
pdf = weasyprint.HTML(url).write_pdf()
output_dir = "pdfDir"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_path = os.path.join(output_dir,'url_pdf.pdf')
with open(file_path,'wb') as f:
f.write(pdf)
return file_path
# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
# Processing for one document only
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = chunk_size,
chunk_overlap = chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
# if len(doc_splits) == 0:
# doc = DocumentFile.from_pdf(list_file_path[0])
# result = model(doc)
# response = extract_value_from_response(result)
# doc_splits = text_splitter.split_documents(response)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents = splits,
embedding = embedding,
client = new_client,
collection_name = collection_name,
# persist_directory=default_persist_directory
)
return vectordb
# Load vector database
def load_db():
embedding = HuggingFaceEmbeddings()
vectordb = Chroma( embedding_function = embedding)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Initializing HF tokenizer...")
# HuggingFaceHub uses HF inference endpoints
progress(0.5, desc="Initializing HF Hub...")
# Use of trust_remote_code as model_kwargs
# Warning: langchain issue
# URL: https://github.com/langchain-ai/langchain/issues/6080
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
)
elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
)
else:
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
progress(0.8, desc="Defining retrieval chain...")
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever = retriever,
chain_type = "stuff",
memory = memory,
# combine_docs_chain_kwargs={"prompt": your_prompt})
return_source_documents=True,
#return_generated_question=False,
verbose = False,
)
progress(0.9, desc="Done!")
return qa_chain
# Initialize database
def initialize_database(list_file_obj, chunk_size, chunk_overlap, vector_db, url, progress = gr.Progress()):
if url != "":
file_path = create_pdf_from_url(url)
list_file_obj = []
list_file_obj.append(file_path)
list_file_path = list_file_obj
else:
# Create list of documents (when valid)
list_file_path = [x.name for x in list_file_obj if x is not None]
# Create collection_name for vector database
progress(0.1, desc="Creating collection name...")
collection_name = Path(list_file_path[0]).stem
# Fix potential issues from naming convention
## Remove space
collection_name = collection_name.replace(" ","-")
## Limit lenght to 50 characters
collection_name = collection_name[:50]
## Enforce start and end as alphanumeric character
if not collection_name[0].isalnum():
collection_name[0] = 'A'
if not collection_name[-1].isalnum():
collection_name[-1] = 'Z'
# print('list_file_path: ', list_file_path)
print('Collection name: ', collection_name)
progress(0.25, desc="Loading document...")
# Load document and create splits
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
# Create or load vector database
progress(0.7, desc="Generating vector database...")
# global vector_db
vector_db = create_db(doc_splits, collection_name)
return vector_db, collection_name, gr.update(value = ""), "Complete!"
def re_initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
llm_name = list_llm[llm_option]
print("llm_name: ",llm_name)
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history, llm_option):
formatted_chat_history = format_chat_history(message, history)
# Generate response using QA chain
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value = ""), new_history
def upload_file(file_obj):
list_file_path = []
for idx, file in enumerate(file_obj):
file_path = file_obj.name
list_file_path.append(file_path)
# print(file_path)
return list_file_path
def demo():
with gr.Blocks(theme = "base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown(
'''
<div style="text-align:center;">
<span style="font-size:3em; font-weight:bold;">PDF Document Chatbot</span>
</div>
''')
with gr.Row():
with gr.Row():
with gr.Column():
document = gr.Files(file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
with gr.Row():
gr.Markdown(
'''
<div style="text-align:center;">
<span style="font-size:2em; font-weight:bold;">OR</span>
</div>
''')
with gr.Row():
url = gr.Textbox(placeholder = "Enter your URL Here")
with gr.Row():
db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database", visible = False)
with gr.Accordion("Advanced options - Document text splitter", open=False, visible = False):
with gr.Row():
slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True, visible = False)
with gr.Row():
slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True, visible = False)
llm_btn = gr.Radio(list_llm_simple, label = "LLM models", type = "index", info = "Choose your LLM model")
db_progres = gr.Textbox(label="Vector database initialization", value="None")
with gr.Row():
submit_file = gr.Button("Submit File")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder = "Type Your Message")
with gr.Accordion("Advanced options - LLM model", open = False):
with gr.Row():
slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
with gr.Row():
slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
with gr.Row():
slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
with gr.Row():
submit_btn = gr.Button("Submit")
# clear_btn = gr.ClearButton([msg2, chatbot])
# Preprocessing events
#upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
submit_file.click(initialize_database, \
inputs=[document, slider_chunk_size, slider_chunk_overlap, vector_db, url], \
outputs = [vector_db, collection_name, url, db_progres])
llm_btn.change(
re_initialize_LLM, \
inputs = [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
outputs = [qa_chain]
)
msg.submit(conversation, \
inputs=[qa_chain, msg, chatbot, llm_btn], \
outputs=[qa_chain, msg, chatbot], \
queue=False)
submit_btn.click(conversation, \
inputs=[qa_chain, msg, chatbot, llm_btn], \
outputs=[qa_chain, msg, chatbot], \
queue=False)
demo.queue().launch(share = True, debug = True)
if __name__ == "__main__":
demo()