Frag-dein-PDF / run.py
AFischer1985's picture
Update run.py
ed60378 verified
raw
history blame
16.2 kB
###########################################################################################
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
# Author: Andreas Fischer
# Date: October 10th, 2024
# Last update: October 22th, 2024
##########################################################################################
import os
import torch
from transformers import AutoTokenizer, AutoModel # chromaDB
from datetime import datetime, date #add_doc,
import chromadb #chromaDB
from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB
from chromadb.utils import embedding_functions #chromaDB
import ocrmypdf #convertPDF
from pypdf import PdfReader #convertPDF
import re #format_prompt
import gradio as gr # multimodal_response
from huggingface_hub import InferenceClient #multimodal_response
#---------------------------------------------------
# Specify models for text generation and embeddings
#---------------------------------------------------
myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
#mod="mistralai/Mixtral-8x7b-instruct-v0.1"
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
#res=tok.apply_chat_template(cha)
#print(tok.decode(res))
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
device='cuda:0' if torch.cuda.is_available() else 'cpu'
jina.to(device) #cuda:0
print(device)
#-----------------
# ChromaDB-client
#-----------------
class JinaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
embeddings = jina.encode(input) #max_length=2048
return(embeddings.tolist())
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/"
onPrem = True if(os.path.exists(dbPath)) else False
if(onPrem==False): dbPath="/home/user/app/db/"
print(dbPath)
client = chromadb.PersistentClient(path=dbPath)
print(client.heartbeat())
print(client.get_version())
print(client.list_collections())
jina_ef=JinaEmbeddingFunction()
embeddingModel=jina_ef
databases=[(date.today(),"0")] # start a list of databases
#---------------------------------------------------------------------
# Function for formatting single message according to prompt template
#---------------------------------------------------------------------
def format_prompt0(message, history):
prompt = "<s>"
#for user_prompt, bot_response in history:
# prompt += f"[INST] {user_prompt} [/INST]"
# prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
#-------------------------------------------------------------------------
# Function for formatting multiturn-dialogue according to prompt template
#-------------------------------------------------------------------------
def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
if zeichenlimit is None: zeichenlimit=1000000000 # :-)
startOfString="<s>"
template0=" [INST] {system} [/INST] </s>" #" [INST] {system} [/INST] </s>" vs " [INST]{system}\n [/INST] </s>"
template1=" [INST] {message} [/INST]"
template2=" {response}</s>"
prompt = ""
if RAGAddon is not None:
system += RAGAddon
if system is not None:
prompt += template0.format(system=system) #"<s>"
if history is not None:
for user_message, bot_response in history[-historylimit:]:
if user_message is None: user_message = ""
if bot_response is None: bot_response = ""
bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering)
if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
if message is not None: prompt += template1.format(message=message[:zeichenlimit])
if system2 is not None:
prompt += system2
return startOfString+prompt
#--------------------------------------------
# Function for converting pdf-files to text
#--------------------------------------------
def convertPDF(pdf_file, allow_ocr=False):
reader = PdfReader(pdf_file)
full_text = ""
page_list = []
def extract_text_from_pdf(reader):
full_text = ""
page_list = []
page_count = 1
for idx, page in enumerate(reader.pages):
text = page.extract_text()
if len(text) > 0:
page_list.append(text)
#full_text += f"---- Page {idx} ----\n" + text + "\n\n"
page_count += 1
return full_text.strip(), page_count, page_list
# Check if there are any images
image_count = sum(len(page.images) for page in reader.pages)
# If there are images and not much content, you may want to perform OCR on the document
if allow_ocr:
print(f"{image_count} Images")
if image_count > 0 and len(full_text) < 1000:
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
reader = PdfReader(out_pdf_file)
# Extract text:
full_text, page_count, page_list = extract_text_from_pdf(reader)
l = len(page_list)
print(f"{l} Pages")
# Extract metadata
metadata = {
"author": reader.metadata.author,
"creator": reader.metadata.creator,
"producer": reader.metadata.producer,
"subject": reader.metadata.subject,
"title": reader.metadata.title,
"image_count": image_count,
"page_count": page_count,
"char_count": len(full_text),
}
return page_list, full_text, metadata
#------------------------------------------
# Function for splitting text with overlap
#------------------------------------------
def split_with_overlap0(text,chunk_size=3500, overlap=700):
""" Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)"""
chunks=[]
step=max(1,chunk_size-overlap)
for i in range(0,len(text),step):
end=min(i+chunk_size,len(text))
chunks.append(text[i:end])
return chunks
import re
def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False):
""" Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap).
By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3).
"""
chunks = []
overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size
step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap
def find_pattern(text): # Funktion zur Suche nach dem Muster
return re.search(pattern, text)
i, lastEnd = 0,0
while i<len(text):
print("i="+str(i))
end = min(i + chunk_size, len(text))
pattern_match = find_pattern(text[i:end]) # erstes Vorkommnis (if any)
matchesStart = [x.start() for x in re.finditer(pattern, text[i:end])] # start aller matches
matchesEnd = [x.start() for x in re.finditer(pattern, text[i:end])] # end aller matches
step = max(1, chunk_size - overlap) # Normalerweise beträgt ein Step chunk_size - overlap
if pattern_match: # Wenn (mindestens) ein Satzzeichen gefunden wurde
for s in matchesStart: # gehe jedes Satzzeichen durch
if ((variant<=2 and s>=overlap) or (variant==3 and s>=overlap and s>(chunk_size-overlap))): # wenn das Satzzeichen nicht im Overlap links liegt (1) oder zusätzlich im reechten Overlap liegt (2) - wobei letzteres unvollständige Sätze bedeuten kann
end=s+i+1 # Setze end auf den Start des Patterns/Satzzeichens im gesamten Text
if(verbose==True): print("***move end:"+str(end)+"; step="+str(step))
if(s<(chunk_size-overlap)):step=min(step,max(1,s-overlap)) # Springe mit step höchstens zum Ende des Satzzeichens (nur erforderlich, wenn end nicht im Overlap)
if ((variant==1 and i>0) or (variant>=2 and pattern_match.start()<overlap and i>0)): # wenn das erste Satzzeichen im Overlap liegt
i=i+pattern_match.start()+1 # Verzichte auf Textteile vor dem ersten Satzzeichen
if(verbose==True): print("i="+str(i)+"; end="+str(end)+"; step="+str(step)+"; len="+str(len(text))+"; match="+str(pattern_match)+"; text="+text[i:end]+"; rest="+text[end:])
if(end>lastEnd): # wenn das Ende sich verschoben hat (und nicht nur den Satzbeginn zu einem bereits bekannten Satz abschneidet)
chunks.append(text[i:end])
lastEnd=end
if(verbose==True): print("Text at position "+str(i)+": "+text[i:end])
i += step
if(len(text[end:])>0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest
return chunks
fiveChars= "(?<![ \n\(]bspw|[ \n]inkl)"
fourChars= "(?<![ \n\(]sog|[ \n]Mio|[ \n]Mrd|[ \n]Tsd|[ \n]Tel)"
threeChars= "(?<!www|bzw|etc|ggf|[ \n\(]al|[ \n\(]St|[ \n\(]dh|[ \n\(]va|[ \n\(]ca|[ \n\(]Dr|[ \n\(]Hr|[ \n\(]Fr|[0-9]ff)"
twoChars= "(?<![ \n\(][A-Za-zΆ-Ωά-ωäöüß])"
oneChars= "(?<![0-9.])"
sentenceRegex="(?<=[^.]{4})"+fiveChars+fourChars+threeChars+twoChars+oneChars+"[.?!](?![A-Za-zΆ-Ωά-ωäöüß0-9.!?'\"])"
sectionRegex="\n[ ]*\n[\n ]*"
splitRegex="("+sentenceRegex+"|"+sectionRegex+")"
#---------------------------------------------------------------
# Function for adding docs to ChromaDB and/or return collection
#---------------------------------------------------------------
def add_doc(path, session):
global device
print("def add_doc!")
print(path)
anhang=False
if(str.lower(path).endswith(".pdf") and os.path.exists(path)):
doc=convertPDF(path)
if(len(doc[0])>5):
if(not "cuda" in device):
doc="\n\n".join(doc[0][0:5])
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (first 5 pages on CPU setups)!")
else:
doc="\n\n".join(doc[0])
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!")
else:
doc="\n\n".join(doc[0])
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!")
anhang=True
else:
gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
client = chromadb.PersistentClient(path=dbPath)
print(str(client.list_collections()))
print(str(session))
dbName="DB_"+str(session)
if(not "name="+dbName in str(client.list_collections())):
# client.delete_collection(name=dbName)
collection = client.create_collection(
name=dbName,
embedding_function=embeddingModel,
metadata={"hnsw:space": "cosine"})
else:
collection = client.get_collection(
name=dbName, embedding_function=embeddingModel)
if(anhang==True):
corpus=split_with_overlap(doc,3500,700,pattern=splitRegex)
print("Length of corpus: "+str(len(corpus)))
print("Corpus:"+str(corpus))
then = datetime.now()
x=collection.get(include=[])["ids"]
print(len(x))
if(len(x)==0):
chunkSize=40000
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
textIDs=[str(id) for id in ids[0:len(batch)]]
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID
collection.add(documents=batch, ids=ids,
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids,
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))
now = datetime.now()
gr.Info(f"Indexing complete!")
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
return(collection)
#--------------------------------------------------------
# Function for response to user queries and pot. addenda
#--------------------------------------------------------
def multimodal_response(message, history, dropdown, hfToken, request: gr.Request):
print("def multimodal response!")
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
inferenceClient = InferenceClient(model=myModel, token=hfToken)
else:
inferenceClient = InferenceClient(myModel)
global databases
if request:
session=request.session_hash
else:
session="0"
length=str(len(history))
print(databases)
if(not databases[-1][1]==session):
databases.append((date.today(),session))
#print(databases)
query=message["text"]
if(len(message["files"])>0): # is there at least one file attached?
collection=add_doc(message["files"][0], session)
else: # otherwise, you still want to get the collection with the session-based db
collection=add_doc(message["text"], session)
client = chromadb.PersistentClient(path=dbPath)
print(str(client.list_collections()))
x=collection.get(include=[])["ids"]
context=collection.query(query_texts=[query], n_results=1)
context=["<Kontext "+str(i)+"> "+str(c)+"</Kontext "+str(i)+">" for i,c in enumerate(context["documents"][0])]
gr.Info("Kontext:\n"+str(context))
generate_kwargs = dict(
temperature=float(0.9),
max_new_tokens=5000,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\
"Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\
str("\n\n".join(context))
#"Given the following conversation, relevant context, and a follow up question, "+\
#"reply with an answer to the current question the user is asking. "+\
#"Return only your response to the question given the above information "+\
#"following the users instructions as needed.\n\nContext:"+\
print(system)
#formatted_prompt = format_prompt0(system+"\n"+query, history)
formatted_prompt = format_prompt(query, history,system=system)
print(formatted_prompt)
output = ""
try:
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
output += response.token.text
yield output
except Exception as e:
output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
if(len(context)>0):
output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
yield output
print(str(e))
if(len(context)>0):
output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>"
yield output
#------------------------------
# Launch Gradio-ChatInterface
#------------------------------
i=gr.ChatInterface(multimodal_response,
title="Frag dein PDF",
multimodal=True,
additional_inputs=[
gr.Dropdown(
info="Wähle eine Variante",
choices=["1","2","3"],
value="1",
label="Variante"),
gr.Textbox(
value="",
label="HF_token"),
])
i.launch() #allowed_paths=["."])