AFischer1985 commited on
Commit
0ad705b
1 Parent(s): e00a35e

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +85 -41
run.py CHANGED
@@ -2,50 +2,70 @@
2
  # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
3
  # Author: Andreas Fischer
4
  # Date: October 10th, 2024
5
- # Last update: October 12th, 2024
6
  ##########################################################################################
7
 
8
  import os
9
- import chromadb
10
- from datetime import datetime
11
- from chromadb import Documents, EmbeddingFunction, Embeddings
12
- from chromadb.utils import embedding_functions
13
- from transformers import AutoTokenizer, AutoModel
14
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
16
  #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
17
- device='cuda' if torch.cuda.is_available() else 'cpu'
18
- #device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
19
  jina.to(device) #cuda:0
20
  print(device)
21
 
 
 
 
 
 
22
  class JinaEmbeddingFunction(EmbeddingFunction):
23
  def __call__(self, input: Documents) -> Embeddings:
24
  embeddings = jina.encode(input) #max_length=2048
25
  return(embeddings.tolist())
26
 
27
- dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db"
28
  onPrem = True if(os.path.exists(dbPath)) else False
29
- if(onPrem==False): dbPath="/home/user/app/db"
30
 
31
- #onPrem=True # uncomment to override automatic detection
32
  print(dbPath)
33
- path=dbPath
34
- client = chromadb.PersistentClient(path=path)
35
  print(client.heartbeat())
36
  print(client.get_version())
37
  print(client.list_collections())
 
38
  jina_ef=JinaEmbeddingFunction()
39
  embeddingModel=jina_ef
 
40
 
41
- myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
42
- #mod="mistralai/Mixtral-8x7b-instruct-v0.1"
43
- #tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
44
- #cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
45
- #cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
46
- #res=tok.apply_chat_template(cha)
47
- #print(tok.decode(res))
48
 
 
 
 
49
 
50
  def format_prompt0(message, history):
51
  prompt = "<s>"
@@ -56,6 +76,10 @@ def format_prompt0(message, history):
56
  return prompt
57
 
58
 
 
 
 
 
59
  def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
60
  if zeichenlimit is None: zeichenlimit=1000000000 # :-)
61
  startOfString="<s>" #<s> [INST] U1 [/INST] A1</s> [INST] U2 [/INST] A2</s>
@@ -71,8 +95,8 @@ def format_prompt(message, history, system=None, RAGAddon=None, system2=None, ze
71
  for user_message, bot_response in history[-historylimit:]:
72
  if user_message is None: user_message = ""
73
  if bot_response is None: bot_response = ""
74
- #bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
75
- if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering)
76
  if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
77
  if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
78
  if message is not None: prompt += template1.format(message=message[:zeichenlimit])
@@ -81,8 +105,10 @@ def format_prompt(message, history, system=None, RAGAddon=None, system2=None, ze
81
  return startOfString+prompt
82
 
83
 
84
- from pypdf import PdfReader
85
- import ocrmypdf
 
 
86
  def convertPDF(pdf_file, allow_ocr=False):
87
  reader = PdfReader(pdf_file)
88
  full_text = ""
@@ -100,7 +126,7 @@ def convertPDF(pdf_file, allow_ocr=False):
100
  return full_text.strip(), page_count, page_list
101
  # Check if there are any images
102
  image_count = sum(len(page.images) for page in reader.pages)
103
- # If there are images and not much content, perform OCR on the document
104
  if allow_ocr:
105
  print(f"{image_count} Images")
106
  if image_count > 0 and len(full_text) < 1000:
@@ -124,16 +150,24 @@ def convertPDF(pdf_file, allow_ocr=False):
124
  }
125
  return page_list, full_text, metadata
126
 
 
 
 
 
 
127
  def split_with_overlap(text,chunk_size=3500, overlap=700):
128
  chunks=[]
129
  step=max(1,chunk_size-overlap)
130
  for i in range(0,len(text),step):
131
  end=min(i+chunk_size,len(text))
132
- #chunk = text[i:i+chunk_size]
133
  chunks.append(text[i:end])
134
  return chunks
135
 
136
 
 
 
 
 
137
  def add_doc(path, session):
138
  print("def add_doc!")
139
  print(path)
@@ -148,9 +182,8 @@ def add_doc(path, session):
148
  anhang=True
149
  else:
150
  gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
151
- client = chromadb.PersistentClient(path="output/general_knowledge")
152
  print(str(client.list_collections()))
153
- #global collection
154
  print(str(session))
155
  dbName="DB_"+str(session)
156
  if(not "name="+dbName in str(client.list_collections())):
@@ -184,15 +217,14 @@ def add_doc(path, session):
184
  print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
185
  return(collection)
186
 
187
-
188
  #split_with_overlap("test me if you can",2,1)
189
- from datetime import date
190
- databases=[(date.today(),"0")] # list of all databases
191
 
192
- from huggingface_hub import InferenceClient
193
- import gradio as gr
194
- import re
195
- def multimodalResponse(message, history, dropdown, hfToken, request: gr.Request):
 
 
196
  print("def multimodal response!")
197
  if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
198
  inferenceClient = InferenceClient(model=myModel, token=hfToken)
@@ -213,7 +245,7 @@ def multimodalResponse(message, history, dropdown, hfToken, request: gr.Request)
213
  collection=add_doc(message["files"][0], session)
214
  else: # otherwise, you still want to get the collection with the session-based db
215
  collection=add_doc(message["text"], session)
216
- client = chromadb.PersistentClient(path="output/general_knowledge")
217
  print(str(client.list_collections()))
218
  x=collection.get(include=[])["ids"]
219
  context=collection.query(query_texts=[query], n_results=1)
@@ -238,15 +270,27 @@ def multimodalResponse(message, history, dropdown, hfToken, request: gr.Request)
238
  #formatted_prompt = format_prompt0(system+"\n"+query, history)
239
  formatted_prompt = format_prompt(query, history,system=system)
240
  print(formatted_prompt)
241
- stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
242
  output = ""
243
- for response in stream:
244
- output += response.token.text
 
 
 
 
 
 
 
245
  yield output
246
- #output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br>"+str(context)+"</details>"
 
 
247
  yield output
248
 
249
- i=gr.ChatInterface(multimodalResponse,
 
 
 
 
250
  title="Frag dein PDF",
251
  multimodal=True,
252
  additional_inputs=[
 
2
  # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
3
  # Author: Andreas Fischer
4
  # Date: October 10th, 2024
5
+ # Last update: October 14th, 2024
6
  ##########################################################################################
7
 
8
  import os
9
+
 
 
 
 
10
  import torch
11
+ from transformers import AutoTokenizer, AutoModel # chromaDB
12
+ from datetime import datetime, date #add_doc,
13
+ import chromadb #chromaDB
14
+ from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB
15
+ from chromadb.utils import embedding_functions #chromaDB
16
+ import ocrmypdf #convertPDF
17
+ from pypdf import PdfReader #convertPDF
18
+ import re #format_prompt
19
+ import gradio as gr # multimodal_response
20
+ from huggingface_hub import InferenceClient #multimodal_response
21
+
22
+
23
+ #---------------------------------------------------
24
+ # Specify models for text generation and embeddings
25
+ #---------------------------------------------------
26
+
27
+ myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
28
+ #mod="mistralai/Mixtral-8x7b-instruct-v0.1"
29
+ #tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
30
+ #cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
31
+ #cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
32
+ #res=tok.apply_chat_template(cha)
33
+ #print(tok.decode(res))
34
+
35
  jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
36
  #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
37
+ device='cuda:0' if torch.cuda.is_available() else 'cpu'
 
38
  jina.to(device) #cuda:0
39
  print(device)
40
 
41
+
42
+ #-----------------
43
+ # ChromaDB-client
44
+ #-----------------
45
+
46
  class JinaEmbeddingFunction(EmbeddingFunction):
47
  def __call__(self, input: Documents) -> Embeddings:
48
  embeddings = jina.encode(input) #max_length=2048
49
  return(embeddings.tolist())
50
 
51
+ dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/"
52
  onPrem = True if(os.path.exists(dbPath)) else False
53
+ if(onPrem==False): dbPath="/home/user/app/db/"
54
 
 
55
  print(dbPath)
56
+ client = chromadb.PersistentClient(path=dbPath)
 
57
  print(client.heartbeat())
58
  print(client.get_version())
59
  print(client.list_collections())
60
+
61
  jina_ef=JinaEmbeddingFunction()
62
  embeddingModel=jina_ef
63
+ databases=[(date.today(),"0")] # start a list of databases
64
 
 
 
 
 
 
 
 
65
 
66
+ #---------------------------------------------------------------------
67
+ # Function for formatting single message according to prompt template
68
+ #---------------------------------------------------------------------
69
 
70
  def format_prompt0(message, history):
71
  prompt = "<s>"
 
76
  return prompt
77
 
78
 
79
+ #-------------------------------------------------------------------------
80
+ # Function for formatting multiturn-dialogue according to prompt template
81
+ #-------------------------------------------------------------------------
82
+
83
  def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
84
  if zeichenlimit is None: zeichenlimit=1000000000 # :-)
85
  startOfString="<s>" #<s> [INST] U1 [/INST] A1</s> [INST] U2 [/INST] A2</s>
 
95
  for user_message, bot_response in history[-historylimit:]:
96
  if user_message is None: user_message = ""
97
  if bot_response is None: bot_response = ""
98
+ bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
99
+ if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering)
100
  if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
101
  if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
102
  if message is not None: prompt += template1.format(message=message[:zeichenlimit])
 
105
  return startOfString+prompt
106
 
107
 
108
+ #--------------------------------------------
109
+ # Function for converting pdf-files to text
110
+ #--------------------------------------------
111
+
112
  def convertPDF(pdf_file, allow_ocr=False):
113
  reader = PdfReader(pdf_file)
114
  full_text = ""
 
126
  return full_text.strip(), page_count, page_list
127
  # Check if there are any images
128
  image_count = sum(len(page.images) for page in reader.pages)
129
+ # If there are images and not much content, you may want to perform OCR on the document
130
  if allow_ocr:
131
  print(f"{image_count} Images")
132
  if image_count > 0 and len(full_text) < 1000:
 
150
  }
151
  return page_list, full_text, metadata
152
 
153
+
154
+ #------------------------------------------
155
+ # Function for splitting text with overlap
156
+ #------------------------------------------
157
+
158
  def split_with_overlap(text,chunk_size=3500, overlap=700):
159
  chunks=[]
160
  step=max(1,chunk_size-overlap)
161
  for i in range(0,len(text),step):
162
  end=min(i+chunk_size,len(text))
 
163
  chunks.append(text[i:end])
164
  return chunks
165
 
166
 
167
+ #---------------------------------------------------------------
168
+ # Function for adding docs to ChromaDB and/or return collection
169
+ #---------------------------------------------------------------
170
+
171
  def add_doc(path, session):
172
  print("def add_doc!")
173
  print(path)
 
182
  anhang=True
183
  else:
184
  gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
185
+ client = chromadb.PersistentClient(path=dbPath)
186
  print(str(client.list_collections()))
 
187
  print(str(session))
188
  dbName="DB_"+str(session)
189
  if(not "name="+dbName in str(client.list_collections())):
 
217
  print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
218
  return(collection)
219
 
 
220
  #split_with_overlap("test me if you can",2,1)
 
 
221
 
222
+
223
+ #--------------------------------------------------------
224
+ # Function for response to user queries and pot. addenda
225
+ #--------------------------------------------------------
226
+
227
+ def multimodal_response(message, history, dropdown, hfToken, request: gr.Request):
228
  print("def multimodal response!")
229
  if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
230
  inferenceClient = InferenceClient(model=myModel, token=hfToken)
 
245
  collection=add_doc(message["files"][0], session)
246
  else: # otherwise, you still want to get the collection with the session-based db
247
  collection=add_doc(message["text"], session)
248
+ client = chromadb.PersistentClient(path=dbPath)
249
  print(str(client.list_collections()))
250
  x=collection.get(include=[])["ids"]
251
  context=collection.query(query_texts=[query], n_results=1)
 
270
  #formatted_prompt = format_prompt0(system+"\n"+query, history)
271
  formatted_prompt = format_prompt(query, history,system=system)
272
  print(formatted_prompt)
 
273
  output = ""
274
+ try:
275
+ stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
276
+ for response in stream:
277
+ output += response.token.text
278
+ yield output
279
+ except Exception as e:
280
+ output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
281
+ if(len(context)>0):
282
+ output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
283
  yield output
284
+ print(str(e))
285
+ if(len(context)>0):
286
+ output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>"
287
  yield output
288
 
289
+ #------------------------------
290
+ # Launch Gradio-ChatInterface
291
+ #------------------------------
292
+
293
+ i=gr.ChatInterface(multimodal_response,
294
  title="Frag dein PDF",
295
  multimodal=True,
296
  additional_inputs=[