Spaces:
Running
Running
AFischer1985
commited on
Commit
•
b7f29b3
1
Parent(s):
a380202
Update run.py
Browse files
run.py
CHANGED
@@ -1,70 +1,51 @@
|
|
1 |
-
|
2 |
-
# Title: Gradio Interface to LLM-chatbot with RAG-funcionality and ChromaDB
|
3 |
# Author: Andreas Fischer
|
4 |
-
# Date:
|
5 |
-
# Last update:
|
6 |
##########################################################################################
|
7 |
|
8 |
-
|
9 |
-
# Chroma-DB
|
10 |
-
#-----------
|
11 |
import os
|
12 |
import chromadb
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
print(dbPath)
|
17 |
-
#client = chromadb.Client()
|
18 |
path=dbPath
|
19 |
client = chromadb.PersistentClient(path=path)
|
20 |
print(client.heartbeat())
|
21 |
print(client.get_version())
|
22 |
print(client.list_collections())
|
23 |
-
|
24 |
-
|
25 |
-
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
|
26 |
-
#instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
|
27 |
-
print(str(client.list_collections()))
|
28 |
-
|
29 |
-
global collection
|
30 |
-
if("name=ChromaDB1" in str(client.list_collections())):
|
31 |
-
print("ChromaDB1 found!")
|
32 |
-
collection = client.get_collection(name="ChromaDB1", embedding_function=sentence_transformer_ef)
|
33 |
-
else:
|
34 |
-
print("ChromaDB1 created!")
|
35 |
-
collection = client.create_collection(
|
36 |
-
"ChromaDB1",
|
37 |
-
embedding_function=sentence_transformer_ef,
|
38 |
-
metadata={"hnsw:space": "cosine"})
|
39 |
-
|
40 |
-
collection.add(
|
41 |
-
documents=["The meaning of life is to love.", "This is a sentence", "This is a sentence too"],
|
42 |
-
metadatas=[{"source": "notion"}, {"source": "google-docs"}, {"source": "google-docs"}],
|
43 |
-
ids=["doc1", "doc2", "doc3"],
|
44 |
-
)
|
45 |
-
|
46 |
-
print("Database ready!")
|
47 |
-
print(collection.count())
|
48 |
|
49 |
|
50 |
-
# Model
|
51 |
-
#-------
|
52 |
-
|
53 |
from huggingface_hub import InferenceClient
|
54 |
import gradio as gr
|
55 |
-
|
56 |
-
|
57 |
"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
58 |
#"mistralai/Mistral-7B-Instruct-v0.1"
|
59 |
)
|
60 |
-
|
61 |
-
|
62 |
-
# Gradio-GUI
|
63 |
-
#------------
|
64 |
-
|
65 |
-
import gradio as gr
|
66 |
-
import json
|
67 |
-
|
68 |
def format_prompt(message, history):
|
69 |
prompt = "<s>"
|
70 |
#for user_prompt, bot_response in history:
|
@@ -73,45 +54,161 @@ def format_prompt(message, history):
|
|
73 |
prompt += f"[INST] {message} [/INST]"
|
74 |
return prompt
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
do_sample=True,
|
88 |
seed=42,
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
combination = [' '.join(triplets) for triplets in combination]
|
102 |
-
print(combination)
|
103 |
-
if(len(results)>1):
|
104 |
-
addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
|
105 |
-
system="Du bist ein KI-basiertes Assistenzsystem."+addon+"\n\nUser-Anliegen:"
|
106 |
-
#body={"prompt":system+"### Instruktion:\n"+message+"\n\n### Antwort:","max_tokens":500, "echo":"False","stream":"True"} #e.g. SauerkrautLM
|
107 |
-
formatted_prompt = format_prompt(system+"\n"+prompt, history)
|
108 |
-
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
109 |
-
output = ""
|
110 |
-
for response in stream:
|
111 |
-
output += response.token.text
|
112 |
-
yield output
|
113 |
-
output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
|
114 |
yield output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
gr.ChatInterface(response, chatbot=gr.Chatbot(render_markdown=True),title="German RAG-Interface to the Hugging Face Hub").queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
|
117 |
-
print("Interface up and running!")
|
|
|
1 |
+
###########################################################################################
|
2 |
+
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
|
3 |
# Author: Andreas Fischer
|
4 |
+
# Date: October 10th, 2024
|
5 |
+
# Last update: October 10th, 2024
|
6 |
##########################################################################################
|
7 |
|
|
|
|
|
|
|
8 |
import os
|
9 |
import chromadb
|
10 |
+
from datetime import datetime
|
11 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
12 |
+
from chromadb.utils import embedding_functions
|
13 |
+
from transformers import AutoTokenizer, AutoModel
|
14 |
+
import torch
|
15 |
+
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
|
16 |
+
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
|
17 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
#device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
jina.to(device) #cuda:0
|
20 |
+
print(device)
|
21 |
+
|
22 |
+
class JinaEmbeddingFunction(EmbeddingFunction):
|
23 |
+
def __call__(self, input: Documents) -> Embeddings:
|
24 |
+
embeddings = jina.encode(input) #max_length=2048
|
25 |
+
return(embeddings.tolist())
|
26 |
+
|
27 |
+
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db"
|
28 |
+
onPrem = True if(os.path.exists(dbPath)) else False
|
29 |
+
if(onPrem==False): dbPath="/home/user/app/db"
|
30 |
+
|
31 |
+
#onPrem=True # uncomment to override automatic detection
|
32 |
print(dbPath)
|
|
|
33 |
path=dbPath
|
34 |
client = chromadb.PersistentClient(path=path)
|
35 |
print(client.heartbeat())
|
36 |
print(client.get_version())
|
37 |
print(client.list_collections())
|
38 |
+
jina_ef=JinaEmbeddingFunction()
|
39 |
+
embeddingModel=jina_ef
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
|
|
|
|
|
|
42 |
from huggingface_hub import InferenceClient
|
43 |
import gradio as gr
|
44 |
+
import json
|
45 |
+
inferenceClient = InferenceClient(
|
46 |
"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
47 |
#"mistralai/Mistral-7B-Instruct-v0.1"
|
48 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def format_prompt(message, history):
|
50 |
prompt = "<s>"
|
51 |
#for user_prompt, bot_response in history:
|
|
|
54 |
prompt += f"[INST] {message} [/INST]"
|
55 |
return prompt
|
56 |
|
57 |
+
|
58 |
+
|
59 |
+
from pypdf import PdfReader
|
60 |
+
import ocrmypdf
|
61 |
+
def convertPDF(pdf_file, allow_ocr=False):
|
62 |
+
reader = PdfReader(pdf_file)
|
63 |
+
full_text = ""
|
64 |
+
page_list = []
|
65 |
+
def extract_text_from_pdf(reader):
|
66 |
+
full_text = ""
|
67 |
+
page_list = []
|
68 |
+
page_count = 1
|
69 |
+
for idx, page in enumerate(reader.pages):
|
70 |
+
text = page.extract_text()
|
71 |
+
if len(text) > 0:
|
72 |
+
page_list.append(text)
|
73 |
+
#full_text += f"---- Page {idx} ----\n" + text + "\n\n"
|
74 |
+
page_count += 1
|
75 |
+
return full_text.strip(), page_count, page_list
|
76 |
+
# Check if there are any images
|
77 |
+
image_count = sum(len(page.images) for page in reader.pages)
|
78 |
+
# If there are images and not much content, perform OCR on the document
|
79 |
+
if allow_ocr:
|
80 |
+
print(f"{image_count} Images")
|
81 |
+
if image_count > 0 and len(full_text) < 1000:
|
82 |
+
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
83 |
+
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
84 |
+
reader = PdfReader(out_pdf_file)
|
85 |
+
# Extract text:
|
86 |
+
full_text, page_count, page_list = extract_text_from_pdf(reader)
|
87 |
+
l = len(page_list)
|
88 |
+
print(f"{l} Pages")
|
89 |
+
# Extract metadata
|
90 |
+
metadata = {
|
91 |
+
"author": reader.metadata.author,
|
92 |
+
"creator": reader.metadata.creator,
|
93 |
+
"producer": reader.metadata.producer,
|
94 |
+
"subject": reader.metadata.subject,
|
95 |
+
"title": reader.metadata.title,
|
96 |
+
"image_count": image_count,
|
97 |
+
"page_count": page_count,
|
98 |
+
"char_count": len(full_text),
|
99 |
+
}
|
100 |
+
return page_list, full_text, metadata
|
101 |
+
|
102 |
+
def split_with_overlap(text,chunk_size=3500, overlap=700):
|
103 |
+
chunks=[]
|
104 |
+
step=max(1,chunk_size-overlap)
|
105 |
+
for i in range(0,len(text),step):
|
106 |
+
end=min(i+chunk_size,len(text))
|
107 |
+
#chunk = text[i:i+chunk_size]
|
108 |
+
chunks.append(text[i:end])
|
109 |
+
return chunks
|
110 |
+
|
111 |
+
def add_doc(path):
|
112 |
+
print("def add_doc!")
|
113 |
+
print(path)
|
114 |
+
if(str.lower(path).endswith(".pdf")):
|
115 |
+
doc=convertPDF(path)
|
116 |
+
doc="\n\n".join(doc[0])
|
117 |
+
gr.Info("PDF uploaded, start Indexing!")
|
118 |
+
else:
|
119 |
+
gr.Info("Error: Only pdfs are accepted!")
|
120 |
+
client = chromadb.PersistentClient(path="output/general_knowledge")
|
121 |
+
print(str(client.list_collections()))
|
122 |
+
#global collection
|
123 |
+
dbName="test"
|
124 |
+
if("name="+dbName in str(client.list_collections())):
|
125 |
+
client.delete_collection(name=dbName)
|
126 |
+
collection = client.create_collection(
|
127 |
+
dbName,
|
128 |
+
embedding_function=embeddingModel,
|
129 |
+
metadata={"hnsw:space": "cosine"})
|
130 |
+
corpus=split_with_overlap(doc,3500,700)
|
131 |
+
print(len(corpus))
|
132 |
+
then = datetime.now()
|
133 |
+
x=collection.get(include=[])["ids"]
|
134 |
+
print(len(x))
|
135 |
+
if(len(x)==0):
|
136 |
+
chunkSize=40000
|
137 |
+
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
|
138 |
+
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
|
139 |
+
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
|
140 |
+
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
|
141 |
+
textIDs=[str(id) for id in ids[0:len(batch)]]
|
142 |
+
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID
|
143 |
+
collection.add(documents=batch, ids=ids,
|
144 |
+
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids,
|
145 |
+
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))
|
146 |
+
now = datetime.now()
|
147 |
+
gr.Info(f"Indexing complete!")
|
148 |
+
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
|
149 |
+
return(collection)
|
150 |
+
|
151 |
+
#split_with_overlap("test me if you can",2,1)
|
152 |
+
|
153 |
+
import gradio as gr
|
154 |
+
import re
|
155 |
+
def multimodalResponse(message,history,headerPattern,sentenceWiseSplitting):
|
156 |
+
print("def multimodal response!")
|
157 |
+
length=str(len(history))
|
158 |
+
query=message["text"]
|
159 |
+
if(len(message["files"])>0): # is there at least one file attached?
|
160 |
+
collection=add_doc(message["files"][0])
|
161 |
+
client = chromadb.PersistentClient(path="output/general_knowledge")
|
162 |
+
print(str(client.list_collections()))
|
163 |
+
x=collection.get(include=[])["ids"]
|
164 |
+
context=collection.query(query_texts=[query], n_results=1)
|
165 |
+
print(str(context))
|
166 |
+
#context=["<context "+str(i+1)+">\n"+c+"\n</context "+str(i+1)+">" for i, c in enumerate(retrievedTexts)]
|
167 |
+
#context="\n\n".join(context)
|
168 |
+
#return context
|
169 |
+
if temperature < 1e-2: temperature = 1e-2
|
170 |
+
top_p = float(top_p)
|
171 |
+
generate_kwargs = dict(
|
172 |
+
temperature=float(0.9),
|
173 |
+
max_new_tokens=5000,
|
174 |
+
top_p=0.95,
|
175 |
+
repetition_penalty=1.0,
|
176 |
do_sample=True,
|
177 |
seed=42,
|
178 |
+
)
|
179 |
+
system="Given the following conversation, relevant context, and a follow up question, "+\
|
180 |
+
"reply with an answer to the current question the user is asking. "+\
|
181 |
+
"Return only your response to the question given the above information "+\
|
182 |
+
"following the users instructions as needed.\n\nContext:"+\
|
183 |
+
str(context)
|
184 |
+
print(system)
|
185 |
+
formatted_prompt = format_prompt(system+"\n"+prompt, history)
|
186 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
187 |
+
output = ""
|
188 |
+
for response in stream:
|
189 |
+
output += response.token.text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
yield output
|
191 |
+
#output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
|
192 |
+
yield output
|
193 |
+
|
194 |
+
i=gr.ChatInterface(multimodalResponse,
|
195 |
+
title="pdfChatbot",
|
196 |
+
multimodal=True,
|
197 |
+
additional_inputs=[
|
198 |
+
gr.Dropdown(
|
199 |
+
info="select retrieval version",
|
200 |
+
choices=["1","2","3"],
|
201 |
+
value=["1"],
|
202 |
+
label="Retrieval Version")])
|
203 |
+
i.launch() #allowed_paths=["."])
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
|
|
|
|