Spaces:
Running
Running
AFischer1985
commited on
Commit
•
0ad705b
1
Parent(s):
e00a35e
Update run.py
Browse files
run.py
CHANGED
@@ -2,50 +2,70 @@
|
|
2 |
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
|
3 |
# Author: Andreas Fischer
|
4 |
# Date: October 10th, 2024
|
5 |
-
# Last update: October
|
6 |
##########################################################################################
|
7 |
|
8 |
import os
|
9 |
-
|
10 |
-
from datetime import datetime
|
11 |
-
from chromadb import Documents, EmbeddingFunction, Embeddings
|
12 |
-
from chromadb.utils import embedding_functions
|
13 |
-
from transformers import AutoTokenizer, AutoModel
|
14 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
|
16 |
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
|
17 |
-
device='cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
-
#device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
jina.to(device) #cuda:0
|
20 |
print(device)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
class JinaEmbeddingFunction(EmbeddingFunction):
|
23 |
def __call__(self, input: Documents) -> Embeddings:
|
24 |
embeddings = jina.encode(input) #max_length=2048
|
25 |
return(embeddings.tolist())
|
26 |
|
27 |
-
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db"
|
28 |
onPrem = True if(os.path.exists(dbPath)) else False
|
29 |
-
if(onPrem==False): dbPath="/home/user/app/db"
|
30 |
|
31 |
-
#onPrem=True # uncomment to override automatic detection
|
32 |
print(dbPath)
|
33 |
-
path=dbPath
|
34 |
-
client = chromadb.PersistentClient(path=path)
|
35 |
print(client.heartbeat())
|
36 |
print(client.get_version())
|
37 |
print(client.list_collections())
|
|
|
38 |
jina_ef=JinaEmbeddingFunction()
|
39 |
embeddingModel=jina_ef
|
|
|
40 |
|
41 |
-
myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
|
42 |
-
#mod="mistralai/Mixtral-8x7b-instruct-v0.1"
|
43 |
-
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
|
44 |
-
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
|
45 |
-
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
|
46 |
-
#res=tok.apply_chat_template(cha)
|
47 |
-
#print(tok.decode(res))
|
48 |
|
|
|
|
|
|
|
49 |
|
50 |
def format_prompt0(message, history):
|
51 |
prompt = "<s>"
|
@@ -56,6 +76,10 @@ def format_prompt0(message, history):
|
|
56 |
return prompt
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
59 |
def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
|
60 |
if zeichenlimit is None: zeichenlimit=1000000000 # :-)
|
61 |
startOfString="<s>" #<s> [INST] U1 [/INST] A1</s> [INST] U2 [/INST] A2</s>
|
@@ -71,8 +95,8 @@ def format_prompt(message, history, system=None, RAGAddon=None, system2=None, ze
|
|
71 |
for user_message, bot_response in history[-historylimit:]:
|
72 |
if user_message is None: user_message = ""
|
73 |
if bot_response is None: bot_response = ""
|
74 |
-
|
75 |
-
if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response)
|
76 |
if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
|
77 |
if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
|
78 |
if message is not None: prompt += template1.format(message=message[:zeichenlimit])
|
@@ -81,8 +105,10 @@ def format_prompt(message, history, system=None, RAGAddon=None, system2=None, ze
|
|
81 |
return startOfString+prompt
|
82 |
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
def convertPDF(pdf_file, allow_ocr=False):
|
87 |
reader = PdfReader(pdf_file)
|
88 |
full_text = ""
|
@@ -100,7 +126,7 @@ def convertPDF(pdf_file, allow_ocr=False):
|
|
100 |
return full_text.strip(), page_count, page_list
|
101 |
# Check if there are any images
|
102 |
image_count = sum(len(page.images) for page in reader.pages)
|
103 |
-
# If there are images and not much content, perform OCR on the document
|
104 |
if allow_ocr:
|
105 |
print(f"{image_count} Images")
|
106 |
if image_count > 0 and len(full_text) < 1000:
|
@@ -124,16 +150,24 @@ def convertPDF(pdf_file, allow_ocr=False):
|
|
124 |
}
|
125 |
return page_list, full_text, metadata
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
def split_with_overlap(text,chunk_size=3500, overlap=700):
|
128 |
chunks=[]
|
129 |
step=max(1,chunk_size-overlap)
|
130 |
for i in range(0,len(text),step):
|
131 |
end=min(i+chunk_size,len(text))
|
132 |
-
#chunk = text[i:i+chunk_size]
|
133 |
chunks.append(text[i:end])
|
134 |
return chunks
|
135 |
|
136 |
|
|
|
|
|
|
|
|
|
137 |
def add_doc(path, session):
|
138 |
print("def add_doc!")
|
139 |
print(path)
|
@@ -148,9 +182,8 @@ def add_doc(path, session):
|
|
148 |
anhang=True
|
149 |
else:
|
150 |
gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
|
151 |
-
client = chromadb.PersistentClient(path=
|
152 |
print(str(client.list_collections()))
|
153 |
-
#global collection
|
154 |
print(str(session))
|
155 |
dbName="DB_"+str(session)
|
156 |
if(not "name="+dbName in str(client.list_collections())):
|
@@ -184,15 +217,14 @@ def add_doc(path, session):
|
|
184 |
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
|
185 |
return(collection)
|
186 |
|
187 |
-
|
188 |
#split_with_overlap("test me if you can",2,1)
|
189 |
-
from datetime import date
|
190 |
-
databases=[(date.today(),"0")] # list of all databases
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
196 |
print("def multimodal response!")
|
197 |
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
|
198 |
inferenceClient = InferenceClient(model=myModel, token=hfToken)
|
@@ -213,7 +245,7 @@ def multimodalResponse(message, history, dropdown, hfToken, request: gr.Request)
|
|
213 |
collection=add_doc(message["files"][0], session)
|
214 |
else: # otherwise, you still want to get the collection with the session-based db
|
215 |
collection=add_doc(message["text"], session)
|
216 |
-
client = chromadb.PersistentClient(path=
|
217 |
print(str(client.list_collections()))
|
218 |
x=collection.get(include=[])["ids"]
|
219 |
context=collection.query(query_texts=[query], n_results=1)
|
@@ -238,15 +270,27 @@ def multimodalResponse(message, history, dropdown, hfToken, request: gr.Request)
|
|
238 |
#formatted_prompt = format_prompt0(system+"\n"+query, history)
|
239 |
formatted_prompt = format_prompt(query, history,system=system)
|
240 |
print(formatted_prompt)
|
241 |
-
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
242 |
output = ""
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
yield output
|
246 |
-
|
|
|
|
|
247 |
yield output
|
248 |
|
249 |
-
|
|
|
|
|
|
|
|
|
250 |
title="Frag dein PDF",
|
251 |
multimodal=True,
|
252 |
additional_inputs=[
|
|
|
2 |
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
|
3 |
# Author: Andreas Fischer
|
4 |
# Date: October 10th, 2024
|
5 |
+
# Last update: October 14th, 2024
|
6 |
##########################################################################################
|
7 |
|
8 |
import os
|
9 |
+
|
|
|
|
|
|
|
|
|
10 |
import torch
|
11 |
+
from transformers import AutoTokenizer, AutoModel # chromaDB
|
12 |
+
from datetime import datetime, date #add_doc,
|
13 |
+
import chromadb #chromaDB
|
14 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB
|
15 |
+
from chromadb.utils import embedding_functions #chromaDB
|
16 |
+
import ocrmypdf #convertPDF
|
17 |
+
from pypdf import PdfReader #convertPDF
|
18 |
+
import re #format_prompt
|
19 |
+
import gradio as gr # multimodal_response
|
20 |
+
from huggingface_hub import InferenceClient #multimodal_response
|
21 |
+
|
22 |
+
|
23 |
+
#---------------------------------------------------
|
24 |
+
# Specify models for text generation and embeddings
|
25 |
+
#---------------------------------------------------
|
26 |
+
|
27 |
+
myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
|
28 |
+
#mod="mistralai/Mixtral-8x7b-instruct-v0.1"
|
29 |
+
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
|
30 |
+
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
|
31 |
+
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
|
32 |
+
#res=tok.apply_chat_template(cha)
|
33 |
+
#print(tok.decode(res))
|
34 |
+
|
35 |
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
|
36 |
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
|
37 |
+
device='cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
|
38 |
jina.to(device) #cuda:0
|
39 |
print(device)
|
40 |
|
41 |
+
|
42 |
+
#-----------------
|
43 |
+
# ChromaDB-client
|
44 |
+
#-----------------
|
45 |
+
|
46 |
class JinaEmbeddingFunction(EmbeddingFunction):
|
47 |
def __call__(self, input: Documents) -> Embeddings:
|
48 |
embeddings = jina.encode(input) #max_length=2048
|
49 |
return(embeddings.tolist())
|
50 |
|
51 |
+
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/"
|
52 |
onPrem = True if(os.path.exists(dbPath)) else False
|
53 |
+
if(onPrem==False): dbPath="/home/user/app/db/"
|
54 |
|
|
|
55 |
print(dbPath)
|
56 |
+
client = chromadb.PersistentClient(path=dbPath)
|
|
|
57 |
print(client.heartbeat())
|
58 |
print(client.get_version())
|
59 |
print(client.list_collections())
|
60 |
+
|
61 |
jina_ef=JinaEmbeddingFunction()
|
62 |
embeddingModel=jina_ef
|
63 |
+
databases=[(date.today(),"0")] # start a list of databases
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
#---------------------------------------------------------------------
|
67 |
+
# Function for formatting single message according to prompt template
|
68 |
+
#---------------------------------------------------------------------
|
69 |
|
70 |
def format_prompt0(message, history):
|
71 |
prompt = "<s>"
|
|
|
76 |
return prompt
|
77 |
|
78 |
|
79 |
+
#-------------------------------------------------------------------------
|
80 |
+
# Function for formatting multiturn-dialogue according to prompt template
|
81 |
+
#-------------------------------------------------------------------------
|
82 |
+
|
83 |
def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
|
84 |
if zeichenlimit is None: zeichenlimit=1000000000 # :-)
|
85 |
startOfString="<s>" #<s> [INST] U1 [/INST] A1</s> [INST] U2 [/INST] A2</s>
|
|
|
95 |
for user_message, bot_response in history[-historylimit:]:
|
96 |
if user_message is None: user_message = ""
|
97 |
if bot_response is None: bot_response = ""
|
98 |
+
bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
|
99 |
+
if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering)
|
100 |
if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
|
101 |
if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
|
102 |
if message is not None: prompt += template1.format(message=message[:zeichenlimit])
|
|
|
105 |
return startOfString+prompt
|
106 |
|
107 |
|
108 |
+
#--------------------------------------------
|
109 |
+
# Function for converting pdf-files to text
|
110 |
+
#--------------------------------------------
|
111 |
+
|
112 |
def convertPDF(pdf_file, allow_ocr=False):
|
113 |
reader = PdfReader(pdf_file)
|
114 |
full_text = ""
|
|
|
126 |
return full_text.strip(), page_count, page_list
|
127 |
# Check if there are any images
|
128 |
image_count = sum(len(page.images) for page in reader.pages)
|
129 |
+
# If there are images and not much content, you may want to perform OCR on the document
|
130 |
if allow_ocr:
|
131 |
print(f"{image_count} Images")
|
132 |
if image_count > 0 and len(full_text) < 1000:
|
|
|
150 |
}
|
151 |
return page_list, full_text, metadata
|
152 |
|
153 |
+
|
154 |
+
#------------------------------------------
|
155 |
+
# Function for splitting text with overlap
|
156 |
+
#------------------------------------------
|
157 |
+
|
158 |
def split_with_overlap(text,chunk_size=3500, overlap=700):
|
159 |
chunks=[]
|
160 |
step=max(1,chunk_size-overlap)
|
161 |
for i in range(0,len(text),step):
|
162 |
end=min(i+chunk_size,len(text))
|
|
|
163 |
chunks.append(text[i:end])
|
164 |
return chunks
|
165 |
|
166 |
|
167 |
+
#---------------------------------------------------------------
|
168 |
+
# Function for adding docs to ChromaDB and/or return collection
|
169 |
+
#---------------------------------------------------------------
|
170 |
+
|
171 |
def add_doc(path, session):
|
172 |
print("def add_doc!")
|
173 |
print(path)
|
|
|
182 |
anhang=True
|
183 |
else:
|
184 |
gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
|
185 |
+
client = chromadb.PersistentClient(path=dbPath)
|
186 |
print(str(client.list_collections()))
|
|
|
187 |
print(str(session))
|
188 |
dbName="DB_"+str(session)
|
189 |
if(not "name="+dbName in str(client.list_collections())):
|
|
|
217 |
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
|
218 |
return(collection)
|
219 |
|
|
|
220 |
#split_with_overlap("test me if you can",2,1)
|
|
|
|
|
221 |
|
222 |
+
|
223 |
+
#--------------------------------------------------------
|
224 |
+
# Function for response to user queries and pot. addenda
|
225 |
+
#--------------------------------------------------------
|
226 |
+
|
227 |
+
def multimodal_response(message, history, dropdown, hfToken, request: gr.Request):
|
228 |
print("def multimodal response!")
|
229 |
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
|
230 |
inferenceClient = InferenceClient(model=myModel, token=hfToken)
|
|
|
245 |
collection=add_doc(message["files"][0], session)
|
246 |
else: # otherwise, you still want to get the collection with the session-based db
|
247 |
collection=add_doc(message["text"], session)
|
248 |
+
client = chromadb.PersistentClient(path=dbPath)
|
249 |
print(str(client.list_collections()))
|
250 |
x=collection.get(include=[])["ids"]
|
251 |
context=collection.query(query_texts=[query], n_results=1)
|
|
|
270 |
#formatted_prompt = format_prompt0(system+"\n"+query, history)
|
271 |
formatted_prompt = format_prompt(query, history,system=system)
|
272 |
print(formatted_prompt)
|
|
|
273 |
output = ""
|
274 |
+
try:
|
275 |
+
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
276 |
+
for response in stream:
|
277 |
+
output += response.token.text
|
278 |
+
yield output
|
279 |
+
except Exception as e:
|
280 |
+
output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
|
281 |
+
if(len(context)>0):
|
282 |
+
output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
|
283 |
yield output
|
284 |
+
print(str(e))
|
285 |
+
if(len(context)>0):
|
286 |
+
output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>"
|
287 |
yield output
|
288 |
|
289 |
+
#------------------------------
|
290 |
+
# Launch Gradio-ChatInterface
|
291 |
+
#------------------------------
|
292 |
+
|
293 |
+
i=gr.ChatInterface(multimodal_response,
|
294 |
title="Frag dein PDF",
|
295 |
multimodal=True,
|
296 |
additional_inputs=[
|