Chris4K commited on
Commit
c97201a
·
verified ·
1 Parent(s): 2520078

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -277
app.py CHANGED
@@ -1,165 +1,19 @@
1
- #####################################
2
- ##
3
- #####################################
4
-
5
  from langchain_community.llms import HuggingFaceHub
6
-
7
-
8
- ###### other models:
9
- # "Trelis/Llama-2-7b-chat-hf-sharded-bf16"
10
- # "bn22/Mistral-7B-Instruct-v0.1-sharded"
11
- # "HuggingFaceH4/zephyr-7b-beta"
12
-
13
- # function for loading 4-bit quantized model
14
- def load_model( ):
15
-
16
- model = HuggingFaceHub(
17
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
18
- model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
19
- )
20
-
21
- return model
22
-
23
- ##################################################
24
- ## vs chat
25
- ##################################################
26
- import torch
27
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
28
-
29
- from langchain_core.messages import AIMessage, HumanMessage
30
- from langchain_community.document_loaders import WebBaseLoader
31
- from langchain.text_splitter import RecursiveCharacterTextSplitter
32
- from langchain_community.vectorstores import Chroma
33
-
34
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
35
- from langchain.vectorstores.faiss import FAISS
36
-
37
-
38
- from dotenv import load_dotenv
39
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
40
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
41
  from langchain.chains.combine_documents import create_stuff_documents_chain
42
-
43
-
44
- load_dotenv()
45
-
46
- def get_vectorstore():
47
- '''
48
- FAISS
49
- A FAISS vector store containing the embeddings of the text chunks.
50
- '''
51
- model = "BAAI/bge-base-en-v1.5"
52
- encode_kwargs = {
53
- "normalize_embeddings": True
54
- } # set True to compute cosine similarity
55
- embeddings = HuggingFaceBgeEmbeddings(
56
- model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
57
- )
58
- # load from disk
59
- vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
60
- return vector_store
61
-
62
- def get_vectorstore_from_url(url):
63
- # get the text in document form
64
- loader = WebBaseLoader(url)
65
- document = loader.load()
66
-
67
- # split the document into chunks
68
- text_splitter = RecursiveCharacterTextSplitter()
69
- document_chunks = text_splitter.split_documents(document)
70
- #######
71
- '''
72
- FAISS
73
- A FAISS vector store containing the embeddings of the text chunks.
74
- '''
75
- model = "BAAI/bge-base-en-v1.5"
76
- encode_kwargs = {
77
- "normalize_embeddings": True
78
- } # set True to compute cosine similarity
79
- embeddings = HuggingFaceBgeEmbeddings(
80
- model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
81
- )
82
- # load from disk
83
- #vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
84
-
85
- #vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
86
- vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="/home/user/.cache/chroma_db")
87
-
88
- all_documents = vector_store.get()['documents']
89
- total_records = len(all_documents)
90
- print("Total records in the collection: ", total_records)
91
-
92
- return vector_store
93
-
94
- def get_context_retriever_chain(vector_store):
95
-
96
- llm = load_model( )
97
-
98
- retriever = vector_store.as_retriever()
99
-
100
- prompt = ChatPromptTemplate.from_messages([
101
- MessagesPlaceholder(variable_name="chat_history"),
102
- ("user", "{input}"),
103
- ("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
104
- ])
105
-
106
- retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
107
-
108
- return retriever_chain
109
-
110
- def get_conversational_rag_chain(retriever_chain):
111
-
112
- llm = load_model( )
113
-
114
- prompt = ChatPromptTemplate.from_messages([
115
- ("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Du beantwortest basierend auf dem Context. Benutze nur den Inhalt des Context. Füge wenn möglich die Quelle hinzu. Antworte mit: Ich bin mir nicht sicher. Wenn die Antwort nicht aus dem Context hervorgeht. Antworte auf Deutsch, bitte? CONTEXT:\n\n{context}"),
116
- MessagesPlaceholder(variable_name="chat_history"),
117
- ("user", "{input}"),
118
- ])
119
-
120
- stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
121
-
122
- return create_retrieval_chain(retriever_chain, stuff_documents_chain)
123
-
124
-
125
- ###################
126
-
127
- ###################
128
- import gradio as gr
129
-
130
-
131
- chat_history = [] # Set your chat history here
132
-
133
- # Define your function here
134
- def get_response(user_input):
135
-
136
- vs = get_vectorstore()
137
- chat_history =[]
138
- retriever_chain = get_context_retriever_chain(vs)
139
- conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
140
-
141
- response = conversation_rag_chain.invoke({
142
- "chat_history": chat_history,
143
- "input": user_input
144
- })
145
- #print("get_response " +response)
146
- res = response['answer']
147
- parts = res.split(" Assistant: ")
148
- last_part = parts[-1]
149
- return last_part
150
-
151
-
152
- ###############
153
- #####
154
- #####
155
- #####
156
- ####
157
- from fastapi import FastAPI
158
- from fastapi.middleware.cors import CORSMiddleware
159
 
160
  app = FastAPI()
161
 
162
- # middlewares to allow cross orgin communications
163
  app.add_middleware(
164
  CORSMiddleware,
165
  allow_origins=['*'],
@@ -168,100 +22,20 @@ app.add_middleware(
168
  allow_headers=['*'],
169
  )
170
 
 
 
 
 
 
 
 
 
171
 
172
- @app.post("/generate/")
173
- def generate(user_input, history=[]):
174
- print("----yuhu -----")
175
- return get_response(user_input, history)
176
- ##################
177
-
178
- def history_to_dialog_format(chat_history: list[str]):
179
- dialog = []
180
- if len(chat_history) > 0:
181
- for idx, message in enumerate(chat_history[0]):
182
- role = "user" if idx % 2 == 0 else "assistant"
183
- dialog.append({
184
- "role": role,
185
- "content": message,
186
- })
187
- return dialog
188
-
189
- def get_response(message, history):
190
- dialog = history_to_dialog_format(history)
191
- dialog.append({"role": "user", "content": message})
192
-
193
- # Define the prompt as a ChatPromptValue object
194
- #user_input = ChatPromptValue(user_input)
195
-
196
- # Convert the prompt to a tensor
197
- #input_ids = user_input.tensor
198
-
199
-
200
- #vs = get_vectorstore_from_url(user_url, all_domain)
201
- vs = get_vectorstore()
202
-
203
- history =[]
204
- retriever_chain = get_context_retriever_chain(vs)
205
- conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
206
-
207
- response = conversation_rag_chain.invoke({
208
- "chat_history": history,
209
- "input": message + " Assistant: ",
210
- "chat_message": message + " Assistant: "
211
- })
212
- #print("get_response " +response)
213
- res = response['answer']
214
- parts = res.split(" Assistant: ")
215
- last_part = parts[-1]
216
- return last_part#[-1]['generation']['content']
217
-
218
-
219
-
220
-
221
-
222
-
223
- ######
224
-
225
- ########
226
- import requests
227
- from bs4 import BeautifulSoup
228
- from urllib.parse import urlparse, urljoin
229
-
230
-
231
- def get_links_from_page(url, visited_urls, all_links):
232
- if url in visited_urls:
233
- return
234
-
235
- if len(visited_urls) > 2:
236
- return
237
-
238
- visited_urls.add(url)
239
-
240
- print("Getting next" + url)
241
- response = requests.get(url)
242
-
243
- if response.status_code == 200:
244
- soup = BeautifulSoup(response.content, 'html.parser')
245
- base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
246
- links = soup.find_all('a', href=True)
247
-
248
- for link in links:
249
- href = link.get('href')
250
- absolute_url = urljoin(base_url, href)
251
- all_links.add(absolute_url)
252
- get_links_from_page(absolute_url, visited_urls, all_links)
253
-
254
- else:
255
- print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
256
-
257
-
258
  def get_links_from_page(url, visited_urls, all_links, base_domain):
259
  if url in visited_urls:
260
  return
261
 
262
- if len(visited_urls) > 2:
263
- return
264
-
265
  visited_urls.add(url)
266
  print("Getting next" + url)
267
  response = requests.get(url)
@@ -283,48 +57,60 @@ def get_links_from_page(url, visited_urls, all_links, base_domain):
283
  else:
284
  print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
285
 
 
 
 
 
 
 
 
286
 
287
- def get_all_links_from_domain(domain_url):
288
- visited_urls = set()
289
- domain_links = set()
290
- parsed_initial_url = urlparse(domain_url)
291
- base_domain = parsed_initial_url.netloc
292
- get_links_from_page(domain_url, visited_urls, domain_links, base_domain)
293
- return domain_links
294
-
295
-
296
 
 
 
 
 
297
 
 
 
 
298
 
299
- def simple(text:str):
300
- return text +" hhhmmm "
301
 
302
- fe_app = gr.ChatInterface(
303
- fn=get_response,
304
- #fn=simple,
305
- # inputs=["text"],
306
- # outputs="text",
307
- title="Chat with Websites",
308
- description="Schreibe hier deine Frage rein...",
309
- #allow_flagging=False
310
- retry_btn=None,
311
- undo_btn=None,
312
- clear_btn=None
313
- )
314
 
315
- #fe_app.launch(debug=True, share=True)
 
 
 
 
316
 
 
 
317
 
 
318
 
319
- # load the model asynchronously on startup and save it into memory
320
  @app.on_event("startup")
321
  async def startup():
322
  domain_url = 'https://www.bofrost.de/faq/'
323
- links = get_all_links_from_domain(domain_url)
324
- print("Links from the domain:", links)
325
-
326
- #########
327
- # Assuming visited_urls is a list of URLs
328
- for url in links:
329
- vs = get_vectorstore_from_url(url)
330
- #load_model()
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
 
 
3
  from langchain_community.llms import HuggingFaceHub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
+ from langchain_community.document_loaders import WebBaseLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import Chroma
10
+ from urllib.parse import urlparse, urljoin
11
+ import requests
12
+ from bs4 import BeautifulSoup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
16
+ # Middleware to allow cross-origin communications
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=['*'],
 
22
  allow_headers=['*'],
23
  )
24
 
25
+ # Function to crawl all URLs from a domain
26
+ def get_all_links_from_domain(domain_url):
27
+ visited_urls = set()
28
+ domain_links = set()
29
+ parsed_initial_url = urlparse(domain_url)
30
+ base_domain = parsed_initial_url.netloc
31
+ get_links_from_page(domain_url, visited_urls, domain_links, base_domain)
32
+ return domain_links
33
 
34
+ # Function to crawl links from a page within the same domain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def get_links_from_page(url, visited_urls, all_links, base_domain):
36
  if url in visited_urls:
37
  return
38
 
 
 
 
39
  visited_urls.add(url)
40
  print("Getting next" + url)
41
  response = requests.get(url)
 
57
  else:
58
  print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
59
 
60
+ # Function to load the RAG model
61
+ def load_rag_model():
62
+ model = HuggingFaceHub(
63
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
64
+ model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
65
+ )
66
+ return model
67
 
68
+ # Function to index URLs in RAG
69
+ def index_urls_in_rag(urls):
70
+ # Create a vector store for storing embeddings of documents
71
+ vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db")
72
+
73
+ # Load the RAG model
74
+ rag_model = load_rag_model()
 
 
75
 
76
+ for url in urls:
77
+ # Get text from the URL
78
+ loader = WebBaseLoader(url)
79
+ document = loader.load()
80
 
81
+ # Split the document into chunks
82
+ text_splitter = RecursiveCharacterTextSplitter()
83
+ document_chunks = text_splitter.split_documents(document)
84
 
85
+ # Index document chunks into the vector store
86
+ vector_store.add_documents(document_chunks)
87
 
88
+ # Convert vector store to retriever
89
+ retriever = vector_store.as_retriever()
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Define prompt for RAG model
92
+ prompt = ChatPromptTemplate.from_messages([
93
+ MessagesPlaceholder(variable_name="chat_history"),
94
+ ("user", "{input}"),
95
+ ])
96
 
97
+ # Create history-aware retriever chain
98
+ retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt)
99
 
100
+ return retriever_chain
101
 
102
+ # Index URLs on app startup
103
  @app.on_event("startup")
104
  async def startup():
105
  domain_url = 'https://www.bofrost.de/faq/'
106
+ urls = get_all_links_from_domain(domain_url)
107
+ retriever_chain = index_urls_in_rag(urls)
108
+
109
+ # Define API endpoint to receive queries and provide responses
110
+ @app.post("/generate/")
111
+ def generate(user_input):
112
+ response = retriever_chain.invoke({
113
+ "chat_history": [],
114
+ "input": user_input
115
+ })
116
+ return response['answer']