AFischer1985 commited on
Commit
b7f29b3
1 Parent(s): a380202

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +184 -87
run.py CHANGED
@@ -1,70 +1,51 @@
1
- #########################################################################################
2
- # Title: Gradio Interface to LLM-chatbot with RAG-funcionality and ChromaDB on HF-Hub
3
  # Author: Andreas Fischer
4
- # Date: December 29th, 2023
5
- # Last update: December 31th, 2023
6
  ##########################################################################################
7
 
8
-
9
- # Chroma-DB
10
- #-----------
11
  import os
12
  import chromadb
13
- dbPath="/home/af/Schreibtisch/gradio/Chroma/db"
14
- if(os.path.exists(dbPath)==False):
15
- dbPath="/home/user/app/db"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  print(dbPath)
17
- #client = chromadb.Client()
18
  path=dbPath
19
  client = chromadb.PersistentClient(path=path)
20
  print(client.heartbeat())
21
  print(client.get_version())
22
  print(client.list_collections())
23
- from chromadb.utils import embedding_functions
24
- default_ef = embedding_functions.DefaultEmbeddingFunction()
25
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
26
- #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
27
- print(str(client.list_collections()))
28
-
29
- global collection
30
- if("name=ChromaDB1" in str(client.list_collections())):
31
- print("ChromaDB1 found!")
32
- collection = client.get_collection(name="ChromaDB1", embedding_function=sentence_transformer_ef)
33
- else:
34
- print("ChromaDB1 created!")
35
- collection = client.create_collection(
36
- "ChromaDB1",
37
- embedding_function=sentence_transformer_ef,
38
- metadata={"hnsw:space": "cosine"})
39
-
40
- collection.add(
41
- documents=["The meaning of life is to love.", "This is a sentence", "This is a sentence too"],
42
- metadatas=[{"source": "notion"}, {"source": "google-docs"}, {"source": "google-docs"}],
43
- ids=["doc1", "doc2", "doc3"],
44
- )
45
-
46
- print("Database ready!")
47
- print(collection.count())
48
 
49
 
50
- # Model
51
- #-------
52
-
53
  from huggingface_hub import InferenceClient
54
  import gradio as gr
55
-
56
- client = InferenceClient(
57
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
58
  #"mistralai/Mistral-7B-Instruct-v0.1"
59
  )
60
-
61
-
62
- # Gradio-GUI
63
- #------------
64
-
65
- import gradio as gr
66
- import json
67
-
68
  def format_prompt(message, history):
69
  prompt = "<s>"
70
  #for user_prompt, bot_response in history:
@@ -73,45 +54,161 @@ def format_prompt(message, history):
73
  prompt += f"[INST] {message} [/INST]"
74
  return prompt
75
 
76
- def response(
77
- prompt, history, temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
78
- ):
79
- temperature = float(temperature)
80
- if temperature < 1e-2: temperature = 1e-2
81
- top_p = float(top_p)
82
- generate_kwargs = dict(
83
- temperature=temperature,
84
- max_new_tokens=max_new_tokens,
85
- top_p=top_p,
86
- repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  do_sample=True,
88
  seed=42,
89
- )
90
- addon=""
91
- results=collection.query(
92
- query_texts=[prompt],
93
- n_results=2,
94
- #where={"source": "google-docs"}
95
- #where_document={"$contains":"search_string"}
96
- )
97
- dists=["<small>(relevance: "+str(round((1-d)*100)/100)+";" for d in results['distances'][0]]
98
- sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
99
- results=results['documents'][0]
100
- combination = zip(results,dists,sources)
101
- combination = [' '.join(triplets) for triplets in combination]
102
- print(combination)
103
- if(len(results)>1):
104
- addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
105
- system="Du bist ein KI-basiertes Assistenzsystem."+addon+"\n\nUser-Anliegen:"
106
- #body={"prompt":system+"### Instruktion:\n"+message+"\n\n### Antwort:","max_tokens":500, "echo":"False","stream":"True"} #e.g. SauerkrautLM
107
- formatted_prompt = format_prompt(system+"\n"+prompt, history)
108
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
109
- output = ""
110
- for response in stream:
111
- output += response.token.text
112
- yield output
113
- output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
114
  yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- gr.ChatInterface(response, chatbot=gr.Chatbot(render_markdown=True),title="German RAG-Interface to the Hugging Face Hub").queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
117
- print("Interface up and running!")
 
1
+ ###########################################################################################
2
+ # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
3
  # Author: Andreas Fischer
4
+ # Date: October 10th, 2024
5
+ # Last update: October 10th, 2024
6
  ##########################################################################################
7
 
 
 
 
8
  import os
9
  import chromadb
10
+ from datetime import datetime
11
+ from chromadb import Documents, EmbeddingFunction, Embeddings
12
+ from chromadb.utils import embedding_functions
13
+ from transformers import AutoTokenizer, AutoModel
14
+ import torch
15
+ jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
16
+ #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
17
+ device='cuda' if torch.cuda.is_available() else 'cpu'
18
+ #device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
19
+ jina.to(device) #cuda:0
20
+ print(device)
21
+
22
+ class JinaEmbeddingFunction(EmbeddingFunction):
23
+ def __call__(self, input: Documents) -> Embeddings:
24
+ embeddings = jina.encode(input) #max_length=2048
25
+ return(embeddings.tolist())
26
+
27
+ dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db"
28
+ onPrem = True if(os.path.exists(dbPath)) else False
29
+ if(onPrem==False): dbPath="/home/user/app/db"
30
+
31
+ #onPrem=True # uncomment to override automatic detection
32
  print(dbPath)
 
33
  path=dbPath
34
  client = chromadb.PersistentClient(path=path)
35
  print(client.heartbeat())
36
  print(client.get_version())
37
  print(client.list_collections())
38
+ jina_ef=JinaEmbeddingFunction()
39
+ embeddingModel=jina_ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
 
 
 
42
  from huggingface_hub import InferenceClient
43
  import gradio as gr
44
+ import json
45
+ inferenceClient = InferenceClient(
46
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
47
  #"mistralai/Mistral-7B-Instruct-v0.1"
48
  )
 
 
 
 
 
 
 
 
49
  def format_prompt(message, history):
50
  prompt = "<s>"
51
  #for user_prompt, bot_response in history:
 
54
  prompt += f"[INST] {message} [/INST]"
55
  return prompt
56
 
57
+
58
+
59
+ from pypdf import PdfReader
60
+ import ocrmypdf
61
+ def convertPDF(pdf_file, allow_ocr=False):
62
+ reader = PdfReader(pdf_file)
63
+ full_text = ""
64
+ page_list = []
65
+ def extract_text_from_pdf(reader):
66
+ full_text = ""
67
+ page_list = []
68
+ page_count = 1
69
+ for idx, page in enumerate(reader.pages):
70
+ text = page.extract_text()
71
+ if len(text) > 0:
72
+ page_list.append(text)
73
+ #full_text += f"---- Page {idx} ----\n" + text + "\n\n"
74
+ page_count += 1
75
+ return full_text.strip(), page_count, page_list
76
+ # Check if there are any images
77
+ image_count = sum(len(page.images) for page in reader.pages)
78
+ # If there are images and not much content, perform OCR on the document
79
+ if allow_ocr:
80
+ print(f"{image_count} Images")
81
+ if image_count > 0 and len(full_text) < 1000:
82
+ out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
83
+ ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
84
+ reader = PdfReader(out_pdf_file)
85
+ # Extract text:
86
+ full_text, page_count, page_list = extract_text_from_pdf(reader)
87
+ l = len(page_list)
88
+ print(f"{l} Pages")
89
+ # Extract metadata
90
+ metadata = {
91
+ "author": reader.metadata.author,
92
+ "creator": reader.metadata.creator,
93
+ "producer": reader.metadata.producer,
94
+ "subject": reader.metadata.subject,
95
+ "title": reader.metadata.title,
96
+ "image_count": image_count,
97
+ "page_count": page_count,
98
+ "char_count": len(full_text),
99
+ }
100
+ return page_list, full_text, metadata
101
+
102
+ def split_with_overlap(text,chunk_size=3500, overlap=700):
103
+ chunks=[]
104
+ step=max(1,chunk_size-overlap)
105
+ for i in range(0,len(text),step):
106
+ end=min(i+chunk_size,len(text))
107
+ #chunk = text[i:i+chunk_size]
108
+ chunks.append(text[i:end])
109
+ return chunks
110
+
111
+ def add_doc(path):
112
+ print("def add_doc!")
113
+ print(path)
114
+ if(str.lower(path).endswith(".pdf")):
115
+ doc=convertPDF(path)
116
+ doc="\n\n".join(doc[0])
117
+ gr.Info("PDF uploaded, start Indexing!")
118
+ else:
119
+ gr.Info("Error: Only pdfs are accepted!")
120
+ client = chromadb.PersistentClient(path="output/general_knowledge")
121
+ print(str(client.list_collections()))
122
+ #global collection
123
+ dbName="test"
124
+ if("name="+dbName in str(client.list_collections())):
125
+ client.delete_collection(name=dbName)
126
+ collection = client.create_collection(
127
+ dbName,
128
+ embedding_function=embeddingModel,
129
+ metadata={"hnsw:space": "cosine"})
130
+ corpus=split_with_overlap(doc,3500,700)
131
+ print(len(corpus))
132
+ then = datetime.now()
133
+ x=collection.get(include=[])["ids"]
134
+ print(len(x))
135
+ if(len(x)==0):
136
+ chunkSize=40000
137
+ for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
138
+ print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
139
+ ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
140
+ batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
141
+ textIDs=[str(id) for id in ids[0:len(batch)]]
142
+ ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID
143
+ collection.add(documents=batch, ids=ids,
144
+ metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids,
145
+ print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))
146
+ now = datetime.now()
147
+ gr.Info(f"Indexing complete!")
148
+ print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
149
+ return(collection)
150
+
151
+ #split_with_overlap("test me if you can",2,1)
152
+
153
+ import gradio as gr
154
+ import re
155
+ def multimodalResponse(message,history,headerPattern,sentenceWiseSplitting):
156
+ print("def multimodal response!")
157
+ length=str(len(history))
158
+ query=message["text"]
159
+ if(len(message["files"])>0): # is there at least one file attached?
160
+ collection=add_doc(message["files"][0])
161
+ client = chromadb.PersistentClient(path="output/general_knowledge")
162
+ print(str(client.list_collections()))
163
+ x=collection.get(include=[])["ids"]
164
+ context=collection.query(query_texts=[query], n_results=1)
165
+ print(str(context))
166
+ #context=["<context "+str(i+1)+">\n"+c+"\n</context "+str(i+1)+">" for i, c in enumerate(retrievedTexts)]
167
+ #context="\n\n".join(context)
168
+ #return context
169
+ if temperature < 1e-2: temperature = 1e-2
170
+ top_p = float(top_p)
171
+ generate_kwargs = dict(
172
+ temperature=float(0.9),
173
+ max_new_tokens=5000,
174
+ top_p=0.95,
175
+ repetition_penalty=1.0,
176
  do_sample=True,
177
  seed=42,
178
+ )
179
+ system="Given the following conversation, relevant context, and a follow up question, "+\
180
+ "reply with an answer to the current question the user is asking. "+\
181
+ "Return only your response to the question given the above information "+\
182
+ "following the users instructions as needed.\n\nContext:"+\
183
+ str(context)
184
+ print(system)
185
+ formatted_prompt = format_prompt(system+"\n"+prompt, history)
186
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
187
+ output = ""
188
+ for response in stream:
189
+ output += response.token.text
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  yield output
191
+ #output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
192
+ yield output
193
+
194
+ i=gr.ChatInterface(multimodalResponse,
195
+ title="pdfChatbot",
196
+ multimodal=True,
197
+ additional_inputs=[
198
+ gr.Dropdown(
199
+ info="select retrieval version",
200
+ choices=["1","2","3"],
201
+ value=["1"],
202
+ label="Retrieval Version")])
203
+ i.launch() #allowed_paths=["."])
204
+
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214