Learn / src /pdfchatbot.py
Yoxas's picture
Update src/pdfchatbot.py
124c404 verified
raw
history blame contribute delete
No virus
6.16 kB
import yaml
import fitz
import torch
import gradio as gr
from PIL import Image
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.document_loaders import PyPDFLoader
from langchain.prompts import PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import spaces
from langchain_text_splitters import CharacterTextSplitter,RecursiveCharacterTextSplitter
class PDFChatBot:
def __init__(self, config_path="config.yaml"):
"""
Initialize the PDFChatBot instance.
Parameters:
config_path (str): Path to the configuration file (default is "../config.yaml").
"""
self.processed = False
self.page = 0
self.chat_history = []
# Initialize other attributes to None
self.prompt = None
self.documents = None
self.embeddings = None
self.vectordb = None
self.tokenizer = None
self.model = None
self.pipeline = None
self.chain = None
self.chunk_size = 2048
self.overlap_percentage = 50
self.max_chunks_in_context = 2
self.current_context = None
self.model_temperatue = 0.5
self.format_seperator="""\n\n--\n\n"""
self.pipe = None
#self.chunk_size_slider = chunk_size_slider
def load_embeddings(self):
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
print("Embedding model loaded")
def load_vectordb(self):
overlap = int((self.overlap_percentage/100) * self.chunk_size)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=overlap,
length_function=len,
add_start_index=True,
)
docs = text_splitter.split_documents(self.documents)
self.vectordb = Chroma.from_documents(docs, self.embeddings)
print("Vector store created")
@spaces.GPU(duration=120)
def load_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained("gradientai/Llama-3-8B-Instruct-Gradient-1048k")
@spaces.GPU(duration=120)
def create_organic_pipeline(self):
self.pipe = pipeline(
"text-generation",
model="gradientai/Llama-3-8B-Instruct-Gradient-1048k",
model_kwargs={"torch_dtype": torch.bfloat16},
device="cuda",
)
print("Model pipeline loaded")
def get_organic_context(self, query):
documents = self.vectordb.similarity_search_with_relevance_scores(query, k=self.max_chunks_in_context)
context = self.format_seperator.join([doc.page_content for doc, score in documents])
self.current_context = context
print("Context Ready")
print(self.current_context)
@spaces.GPU(duration=120)
def create_organic_response(self, history, query):
self.get_organic_context(query)
"""
pipe = pipeline(
"text-generation",
model="gradientai/Llama-3-8B-Instruct-Gradient-1048k",
model_kwargs={"torch_dtype": torch.bfloat16},
device="cuda",
)
"""
messages = [
{"role": "system", "content": "From the the contained given below, answer the question of user \n " + self.current_context},
{"role": "user", "content": query},
]
prompt = self.pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
temp = 0.1
outputs = self.pipe(
prompt,
max_new_tokens=65536,
do_sample=True,
temperature=temp,
top_p=0.9,
)
print(outputs)
return outputs[0]["generated_text"][len(prompt):]
def process_file(self, file):
"""
Process the uploaded PDF file and initialize necessary components: Tokenizer, VectorDB and LLM.
Parameters:
file (FileStorage): The uploaded PDF file.
"""
self.documents = PyPDFLoader(file.name).load()
self.load_embeddings()
self.load_vectordb()
self.create_organic_pipeline()
#self.create_chain()
@spaces.GPU(duration=120)
def generate_response(self, history, query, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context):
self.chunk_size = chunk_size
self.overlap_percentage = chunk_overlap_percentage
self.model_temperatue = model_temperature
self.max_chunks_in_context = max_chunks_in_context
if not query:
raise gr.Error(message='Submit a question')
if not file:
raise gr.Error(message='Upload a PDF')
if not self.processed:
self.process_file(file)
self.processed = True
result = self.create_organic_response(history="",query=query)
for char in result:
history[-1][-1] += char
return history,""
def render_file(self, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context):
print(chunk_size)
doc = fitz.open(file.name)
page = doc[self.page]
self.chunk_size = chunk_size
self.overlap_percentage = chunk_overlap_percentage
self.model_temperatue = model_temperature
self.max_chunks_in_context = max_chunks_in_context
pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
return image
def add_text(self, history, text):
"""
Add user-entered text to the chat history.
Parameters:
history (list): List of chat history tuples.
text (str): User-entered text.
Returns:
list: Updated chat history.
"""
if not text:
raise gr.Error('Enter text')
history.append((text, ''))
return history