Frag-dein-PDF / run.py
AFischer1985's picture
Update run.py
95ef3f6 verified
raw
history blame
7.57 kB
###########################################################################################
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
# Author: Andreas Fischer
# Date: October 10th, 2024
# Last update: October 10th, 2024
##########################################################################################
import os
import chromadb
from datetime import datetime
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.utils import embedding_functions
from transformers import AutoTokenizer, AutoModel
import torch
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
device='cuda' if torch.cuda.is_available() else 'cpu'
#device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
jina.to(device) #cuda:0
print(device)
class JinaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
embeddings = jina.encode(input) #max_length=2048
return(embeddings.tolist())
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db"
onPrem = True if(os.path.exists(dbPath)) else False
if(onPrem==False): dbPath="/home/user/app/db"
#onPrem=True # uncomment to override automatic detection
print(dbPath)
path=dbPath
client = chromadb.PersistentClient(path=path)
print(client.heartbeat())
print(client.get_version())
print(client.list_collections())
jina_ef=JinaEmbeddingFunction()
embeddingModel=jina_ef
from huggingface_hub import InferenceClient
import gradio as gr
import json
inferenceClient = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
#"mistralai/Mistral-7B-Instruct-v0.1"
)
def format_prompt(message, history):
prompt = "<s>"
#for user_prompt, bot_response in history:
# prompt += f"[INST] {user_prompt} [/INST]"
# prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
from pypdf import PdfReader
import ocrmypdf
def convertPDF(pdf_file, allow_ocr=False):
reader = PdfReader(pdf_file)
full_text = ""
page_list = []
def extract_text_from_pdf(reader):
full_text = ""
page_list = []
page_count = 1
for idx, page in enumerate(reader.pages):
text = page.extract_text()
if len(text) > 0:
page_list.append(text)
#full_text += f"---- Page {idx} ----\n" + text + "\n\n"
page_count += 1
return full_text.strip(), page_count, page_list
# Check if there are any images
image_count = sum(len(page.images) for page in reader.pages)
# If there are images and not much content, perform OCR on the document
if allow_ocr:
print(f"{image_count} Images")
if image_count > 0 and len(full_text) < 1000:
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
reader = PdfReader(out_pdf_file)
# Extract text:
full_text, page_count, page_list = extract_text_from_pdf(reader)
l = len(page_list)
print(f"{l} Pages")
# Extract metadata
metadata = {
"author": reader.metadata.author,
"creator": reader.metadata.creator,
"producer": reader.metadata.producer,
"subject": reader.metadata.subject,
"title": reader.metadata.title,
"image_count": image_count,
"page_count": page_count,
"char_count": len(full_text),
}
return page_list, full_text, metadata
def split_with_overlap(text,chunk_size=3500, overlap=700):
chunks=[]
step=max(1,chunk_size-overlap)
for i in range(0,len(text),step):
end=min(i+chunk_size,len(text))
#chunk = text[i:i+chunk_size]
chunks.append(text[i:end])
return chunks
def add_doc(path):
print("def add_doc!")
print(path)
if(str.lower(path).endswith(".pdf")):
doc=convertPDF(path)
doc="\n\n".join(doc[0])
gr.Info("PDF uploaded, start Indexing!")
else:
gr.Info("Error: Only pdfs are accepted!")
client = chromadb.PersistentClient(path="output/general_knowledge")
print(str(client.list_collections()))
#global collection
dbName="test"
if("name="+dbName in str(client.list_collections())):
client.delete_collection(name=dbName)
collection = client.create_collection(
dbName,
embedding_function=embeddingModel,
metadata={"hnsw:space": "cosine"})
corpus=split_with_overlap(doc,3500,700)
print(len(corpus))
then = datetime.now()
x=collection.get(include=[])["ids"]
print(len(x))
if(len(x)==0):
chunkSize=40000
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
textIDs=[str(id) for id in ids[0:len(batch)]]
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID
collection.add(documents=batch, ids=ids,
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids,
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))
now = datetime.now()
gr.Info(f"Indexing complete!")
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
return(collection)
#split_with_overlap("test me if you can",2,1)
import gradio as gr
import re
def multimodalResponse(message,history,dropdown):
print("def multimodal response!")
length=str(len(history))
query=message["text"]
if(len(message["files"])>0): # is there at least one file attached?
collection=add_doc(message["files"][0])
else:
collection=add_doc(message["text"])
client = chromadb.PersistentClient(path="output/general_knowledge")
print(str(client.list_collections()))
x=collection.get(include=[])["ids"]
context=collection.query(query_texts=[query], n_results=1)
print(str(context))
#context=["<context "+str(i+1)+">\n"+c+"\n</context "+str(i+1)+">" for i, c in enumerate(retrievedTexts)]
#context="\n\n".join(context)
#return context
generate_kwargs = dict(
temperature=float(0.9),
max_new_tokens=5000,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
system="Given the following conversation, relevant context, and a follow up question, "+\
"reply with an answer to the current question the user is asking. "+\
"Return only your response to the question given the above information "+\
"following the users instructions as needed.\n\nContext:"+\
str(context)
print(system)
formatted_prompt = format_prompt(system+"\n"+prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
#output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
yield output
i=gr.ChatInterface(multimodalResponse,
title="Frag dein PDF",
multimodal=True,
additional_inputs=[
gr.Dropdown(
info="select retrieval version",
choices=["1","2","3"],
value=["1"],
label="Retrieval Version")])
i.launch() #allowed_paths=["."])