Spaces:
Running
Running
########################################################################################### | |
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB | |
# Author: Andreas Fischer | |
# Date: October 10th, 2024 | |
# Last update: October 22th, 2024 | |
########################################################################################## | |
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModel # chromaDB | |
from datetime import datetime, date #add_doc, | |
import chromadb #chromaDB | |
from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB | |
from chromadb.utils import embedding_functions #chromaDB | |
import ocrmypdf #convertPDF | |
from pypdf import PdfReader #convertPDF | |
import re #format_prompt | |
import gradio as gr # multimodal_response | |
from huggingface_hub import InferenceClient #multimodal_response | |
#--------------------------------------------------- | |
# Specify models for text generation and embeddings | |
#--------------------------------------------------- | |
myModel="mistralai/Mixtral-8x7b-instruct-v0.1" | |
#mod="mistralai/Mixtral-8x7b-instruct-v0.1" | |
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...") | |
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}] | |
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}] | |
#res=tok.apply_chat_template(cha) | |
#print(tok.decode(res)) | |
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) | |
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") | |
device='cuda:0' if torch.cuda.is_available() else 'cpu' | |
jina.to(device) #cuda:0 | |
print(device) | |
#----------------- | |
# ChromaDB-client | |
#----------------- | |
class JinaEmbeddingFunction(EmbeddingFunction): | |
def __call__(self, input: Documents) -> Embeddings: | |
embeddings = jina.encode(input) #max_length=2048 | |
return(embeddings.tolist()) | |
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/" | |
onPrem = True if(os.path.exists(dbPath)) else False | |
if(onPrem==False): dbPath="/home/user/app/db/" | |
print(dbPath) | |
client = chromadb.PersistentClient(path=dbPath) | |
print(client.heartbeat()) | |
print(client.get_version()) | |
print(client.list_collections()) | |
jina_ef=JinaEmbeddingFunction() | |
embeddingModel=jina_ef | |
databases=[(date.today(),"0")] # start a list of databases | |
#--------------------------------------------------------------------- | |
# Function for formatting single message according to prompt template | |
#--------------------------------------------------------------------- | |
def format_prompt0(message, history): | |
prompt = "<s>" | |
#for user_prompt, bot_response in history: | |
# prompt += f"[INST] {user_prompt} [/INST]" | |
# prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
#------------------------------------------------------------------------- | |
# Function for formatting multiturn-dialogue according to prompt template | |
#------------------------------------------------------------------------- | |
def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False): | |
if zeichenlimit is None: zeichenlimit=1000000000 # :-) | |
startOfString="<s>" | |
template0=" [INST] {system} [/INST] </s>" #" [INST] {system} [/INST] </s>" vs " [INST]{system}\n [/INST] </s>" | |
template1=" [INST] {message} [/INST]" | |
template2=" {response}</s>" | |
prompt = "" | |
if RAGAddon is not None: | |
system += RAGAddon | |
if system is not None: | |
prompt += template0.format(system=system) #"<s>" | |
if history is not None: | |
for user_message, bot_response in history[-historylimit:]: | |
if user_message is None: user_message = "" | |
if bot_response is None: bot_response = "" | |
bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents | |
if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering) | |
if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit]) | |
if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit]) | |
if message is not None: prompt += template1.format(message=message[:zeichenlimit]) | |
if system2 is not None: | |
prompt += system2 | |
return startOfString+prompt | |
#-------------------------------------------- | |
# Function for converting pdf-files to text | |
#-------------------------------------------- | |
def convertPDF(pdf_file, allow_ocr=False): | |
reader = PdfReader(pdf_file) | |
full_text = "" | |
page_list = [] | |
def extract_text_from_pdf(reader): | |
full_text = "" | |
page_list = [] | |
page_count = 1 | |
for idx, page in enumerate(reader.pages): | |
text = page.extract_text() | |
if len(text) > 0: | |
page_list.append(text) | |
#full_text += f"---- Page {idx} ----\n" + text + "\n\n" | |
page_count += 1 | |
return full_text.strip(), page_count, page_list | |
# Check if there are any images | |
image_count = sum(len(page.images) for page in reader.pages) | |
# If there are images and not much content, you may want to perform OCR on the document | |
if allow_ocr: | |
print(f"{image_count} Images") | |
if image_count > 0 and len(full_text) < 1000: | |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") | |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) | |
reader = PdfReader(out_pdf_file) | |
# Extract text: | |
full_text, page_count, page_list = extract_text_from_pdf(reader) | |
l = len(page_list) | |
print(f"{l} Pages") | |
# Extract metadata | |
metadata = { | |
"author": reader.metadata.author, | |
"creator": reader.metadata.creator, | |
"producer": reader.metadata.producer, | |
"subject": reader.metadata.subject, | |
"title": reader.metadata.title, | |
"image_count": image_count, | |
"page_count": page_count, | |
"char_count": len(full_text), | |
} | |
return page_list, full_text, metadata | |
#------------------------------------------ | |
# Function for splitting text with overlap | |
#------------------------------------------ | |
def split_with_overlap0(text,chunk_size=3500, overlap=700): | |
""" Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)""" | |
chunks=[] | |
step=max(1,chunk_size-overlap) | |
for i in range(0,len(text),step): | |
end=min(i+chunk_size,len(text)) | |
chunks.append(text[i:end]) | |
return chunks | |
import re | |
def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False): | |
""" Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap). | |
By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3). | |
""" | |
chunks = [] | |
overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size | |
step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap | |
def find_pattern(text): # Funktion zur Suche nach dem Muster | |
return re.search(pattern, text) | |
i, lastEnd = 0,0 | |
while i<len(text): | |
print("i="+str(i)) | |
end = min(i + chunk_size, len(text)) | |
pattern_match = find_pattern(text[i:end]) # erstes Vorkommnis (if any) | |
matchesStart = [x.start() for x in re.finditer(pattern, text[i:end])] # start aller matches | |
matchesEnd = [x.start() for x in re.finditer(pattern, text[i:end])] # end aller matches | |
step = max(1, chunk_size - overlap) # Normalerweise beträgt ein Step chunk_size - overlap | |
if pattern_match: # Wenn (mindestens) ein Satzzeichen gefunden wurde | |
for s in matchesStart: # gehe jedes Satzzeichen durch | |
if ((variant<=2 and s>=overlap) or (variant==3 and s>=overlap and s>(chunk_size-overlap))): # wenn das Satzzeichen nicht im Overlap links liegt (1) oder zusätzlich im reechten Overlap liegt (2) - wobei letzteres unvollständige Sätze bedeuten kann | |
end=s+i+1 # Setze end auf den Start des Patterns/Satzzeichens im gesamten Text | |
if(verbose==True): print("***move end:"+str(end)+"; step="+str(step)) | |
if(s<(chunk_size-overlap)):step=min(step,max(1,s-overlap)) # Springe mit step höchstens zum Ende des Satzzeichens (nur erforderlich, wenn end nicht im Overlap) | |
if ((variant==1 and i>0) or (variant>=2 and pattern_match.start()<overlap and i>0)): # wenn das erste Satzzeichen im Overlap liegt | |
i=i+pattern_match.start()+1 # Verzichte auf Textteile vor dem ersten Satzzeichen | |
if(verbose==True): print("i="+str(i)+"; end="+str(end)+"; step="+str(step)+"; len="+str(len(text))+"; match="+str(pattern_match)+"; text="+text[i:end]+"; rest="+text[end:]) | |
if(end>lastEnd): # wenn das Ende sich verschoben hat (und nicht nur den Satzbeginn zu einem bereits bekannten Satz abschneidet) | |
chunks.append(text[i:end]) | |
lastEnd=end | |
if(verbose==True): print("Text at position "+str(i)+": "+text[i:end]) | |
i += step | |
if(len(text[end:])>0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest | |
return chunks | |
fiveChars= "(?<![ \n\(]bspw|[ \n]inkl)" | |
fourChars= "(?<![ \n\(]sog|[ \n]Mio|[ \n]Mrd|[ \n]Tsd|[ \n]Tel)" | |
threeChars= "(?<!www|bzw|etc|ggf|[ \n\(]al|[ \n\(]St|[ \n\(]dh|[ \n\(]va|[ \n\(]ca|[ \n\(]Dr|[ \n\(]Hr|[ \n\(]Fr|[0-9]ff)" | |
twoChars= "(?<![ \n\(][A-Za-zΆ-Ωά-ωäöüß])" | |
oneChars= "(?<![0-9.])" | |
sentenceRegex="(?<=[^.]{4})"+fiveChars+fourChars+threeChars+twoChars+oneChars+"[.?!](?![A-Za-zΆ-Ωά-ωäöüß0-9.!?'\"])" | |
sectionRegex="\n[ ]*\n[\n ]*" | |
splitRegex="("+sentenceRegex+"|"+sectionRegex+")" | |
#--------------------------------------------------------------- | |
# Function for adding docs to ChromaDB and/or return collection | |
#--------------------------------------------------------------- | |
def add_doc(path, session): | |
global device | |
print("def add_doc!") | |
print(path) | |
anhang=False | |
if(str.lower(path).endswith(".pdf") and os.path.exists(path)): | |
doc=convertPDF(path) | |
if(len(doc[0])>5): | |
if(not "cuda" in device): | |
doc="\n\n".join(doc[0][0:5]) | |
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (first 5 pages on CPU setups)!") | |
else: | |
doc="\n\n".join(doc[0]) | |
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") | |
else: | |
doc="\n\n".join(doc[0]) | |
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") | |
anhang=True | |
else: | |
gr.Info("No PDF attached - answer based on DB_"+str(session)+".") | |
client = chromadb.PersistentClient(path=dbPath) | |
print(str(client.list_collections())) | |
print(str(session)) | |
dbName="DB_"+str(session) | |
if(not "name="+dbName in str(client.list_collections())): | |
# client.delete_collection(name=dbName) | |
collection = client.create_collection( | |
name=dbName, | |
embedding_function=embeddingModel, | |
metadata={"hnsw:space": "cosine"}) | |
else: | |
collection = client.get_collection( | |
name=dbName, embedding_function=embeddingModel) | |
if(anhang==True): | |
corpus=split_with_overlap(doc,3500,700,pattern=splitRegex) | |
print("Length of corpus: "+str(len(corpus))) | |
print("Corpus:"+str(corpus)) | |
then = datetime.now() | |
x=collection.get(include=[])["ids"] | |
print(len(x)) | |
if(len(x)==0): | |
chunkSize=40000 | |
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts | |
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) | |
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) | |
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] | |
textIDs=[str(id) for id in ids[0:len(batch)]] | |
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID | |
collection.add(documents=batch, ids=ids, | |
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, | |
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) | |
now = datetime.now() | |
gr.Info(f"Indexing complete!") | |
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks | |
return(collection) | |
#-------------------------------------------------------- | |
# Function for response to user queries and pot. addenda | |
#-------------------------------------------------------- | |
def multimodal_response(message, history, dropdown, hfToken, request: gr.Request): | |
print("def multimodal response!") | |
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided | |
inferenceClient = InferenceClient(model=myModel, token=hfToken) | |
else: | |
inferenceClient = InferenceClient(myModel) | |
global databases | |
if request: | |
session=request.session_hash | |
else: | |
session="0" | |
length=str(len(history)) | |
print(databases) | |
if(not databases[-1][1]==session): | |
databases.append((date.today(),session)) | |
#print(databases) | |
query=message["text"] | |
if(len(message["files"])>0): # is there at least one file attached? | |
collection=add_doc(message["files"][0], session) | |
else: # otherwise, you still want to get the collection with the session-based db | |
collection=add_doc(message["text"], session) | |
client = chromadb.PersistentClient(path=dbPath) | |
print(str(client.list_collections())) | |
x=collection.get(include=[])["ids"] | |
context=collection.query(query_texts=[query], n_results=1) | |
context=["<Kontext "+str(i)+"> "+str(c)+"</Kontext "+str(i)+">" for i,c in enumerate(context["documents"][0])] | |
gr.Info("Kontext:\n"+str(context)) | |
generate_kwargs = dict( | |
temperature=float(0.9), | |
max_new_tokens=5000, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
do_sample=True, | |
seed=42, | |
) | |
system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\ | |
"Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\ | |
str("\n\n".join(context)) | |
#"Given the following conversation, relevant context, and a follow up question, "+\ | |
#"reply with an answer to the current question the user is asking. "+\ | |
#"Return only your response to the question given the above information "+\ | |
#"following the users instructions as needed.\n\nContext:"+\ | |
print(system) | |
#formatted_prompt = format_prompt0(system+"\n"+query, history) | |
formatted_prompt = format_prompt(query, history,system=system) | |
print(formatted_prompt) | |
output = "" | |
try: | |
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
for response in stream: | |
output += response.token.text | |
yield output | |
except Exception as e: | |
output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an." | |
if(len(context)>0): | |
output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:" | |
yield output | |
print(str(e)) | |
if(len(context)>0): | |
output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>" | |
yield output | |
#------------------------------ | |
# Launch Gradio-ChatInterface | |
#------------------------------ | |
i=gr.ChatInterface(multimodal_response, | |
title="Frag dein PDF", | |
multimodal=True, | |
additional_inputs=[ | |
gr.Dropdown( | |
info="Wähle eine Variante", | |
choices=["1","2","3"], | |
value="1", | |
label="Variante"), | |
gr.Textbox( | |
value="", | |
label="HF_token"), | |
]) | |
i.launch() #allowed_paths=["."]) | |