Chris4K commited on
Commit
520e637
·
verified ·
1 Parent(s): 7a83584
Files changed (1) hide show
  1. app.py +80 -60
app.py CHANGED
@@ -23,8 +23,6 @@ app.add_middleware(
23
  )
24
 
25
 
26
- # Define retriever_chain as a global variable
27
- retriever_chain = None
28
 
29
  # Function to crawl all URLs from a domain
30
  def get_all_links_from_domain(domain_url):
@@ -47,9 +45,9 @@ def get_links_from_page(url, visited_urls, all_links, base_domain):
47
  return
48
 
49
  visited_urls.add(url)
50
- print("Getting next" + url)
51
  response = requests.get(url)
52
-
53
  if response.status_code == 200:
54
  soup = BeautifulSoup(response.content, 'html.parser')
55
  base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
@@ -59,29 +57,39 @@ def get_links_from_page(url, visited_urls, all_links, base_domain):
59
  href = link.get('href')
60
  absolute_url = urljoin(base_url, href)
61
  parsed_url = urlparse(absolute_url)
62
-
63
- if parsed_url.netloc == base_domain:
64
- all_links.add(absolute_url)
65
- get_links_from_page(absolute_url, visited_urls, all_links, base_domain)
 
66
 
67
  else:
68
  print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
69
 
70
- # Function to load the RAG model
71
- def load_rag_model():
72
- model = HuggingFaceHub(
73
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
74
- model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
75
- )
76
- return model
77
 
78
  # Function to index URLs in RAG
79
- def index_urls_in_rag(urls):
 
 
 
 
 
 
 
 
 
80
  # Create a vector store for storing embeddings of documents
81
- vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db")
82
 
83
- # Load the RAG model
84
- rag_model = load_rag_model()
85
 
86
  for url in urls:
87
  # Get text from the URL
@@ -95,37 +103,56 @@ def index_urls_in_rag(urls):
95
  # Index document chunks into the vector store
96
  vector_store.add_documents(document_chunks)
97
 
 
98
  # Convert vector store to retriever
99
  retriever = vector_store.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Define prompt for RAG model
 
 
 
102
  prompt = ChatPromptTemplate.from_messages([
 
103
  MessagesPlaceholder(variable_name="chat_history"),
104
  ("user", "{input}"),
105
  ])
106
-
107
- # Create history-aware retriever chain
108
- retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt)
109
-
110
- return retriever_chain
111
-
112
- # Index URLs on app startup
113
- @app.on_event("startup")
114
- async def startup():
115
- domain_url = 'https://www.bofrost.de/faq/'
116
- urls = get_all_links_from_domain(domain_url)
117
 
118
- print(urls)
 
 
119
 
120
- # Define API endpoint to receive queries and provide responses
121
- @app.post("/generate/")
122
- def generate(user_input):
123
 
124
- return get_response(user_input, [])
 
 
 
 
 
 
 
 
 
 
 
125
 
 
126
  def get_conversational_rag_chain(retriever_chain):
127
 
128
- llm = load_model(model_name)
129
 
130
  prompt = ChatPromptTemplate.from_messages([
131
  ("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Du beantwortest basierend auf dem Context. Benutze nur den Inhalt des Context. Füge wenn möglich die Quelle hinzu. Antworte mit: Ich bin mir nicht sicher. Wenn die Antwort nicht aus dem Context hervorgeht. Antworte auf Deutsch, bitte? CONTEXT:\n\n{context}"),
@@ -138,30 +165,8 @@ def get_conversational_rag_chain(retriever_chain):
138
  return create_retrieval_chain(retriever_chain, stuff_documents_chain)
139
 
140
 
141
- def get_response(message, history=[]):
142
- # dialog = history_to_dialog_format(history)
143
- # dialog.append({"role": "user", "content": message})
144
-
145
- # Define the prompt as a ChatPromptValue object
146
- #user_input = ChatPromptValue(user_input)
147
-
148
- # Convert the prompt to a tensor
149
- #input_ids = user_input.tensor
150
-
151
-
152
- model = "BAAI/bge-base-en-v1.5"
153
- encode_kwargs = {
154
- "normalize_embeddings": True
155
- } # set True to compute cosine similarity
156
- embeddings = HuggingFaceBgeEmbeddings(
157
- model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
158
- )
159
- # load from disk
160
- vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
161
 
162
-
163
- history =[]
164
- retriever_chain = get_context_retriever_chain(vector_store)
165
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
166
 
167
  response = conversation_rag_chain.invoke({
@@ -174,3 +179,18 @@ def get_response(message, history=[]):
174
  parts = res.split(" Assistant: ")
175
  last_part = parts[-1]
176
  return last_part#[-1]['generation']['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
 
25
 
 
 
26
 
27
  # Function to crawl all URLs from a domain
28
  def get_all_links_from_domain(domain_url):
 
45
  return
46
 
47
  visited_urls.add(url)
48
+ print("Getting next " + url)
49
  response = requests.get(url)
50
+ print(response)
51
  if response.status_code == 200:
52
  soup = BeautifulSoup(response.content, 'html.parser')
53
  base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
 
57
  href = link.get('href')
58
  absolute_url = urljoin(base_url, href)
59
  parsed_url = urlparse(absolute_url)
60
+ if absolute_url not in visited_urls:
61
+ if absolute_url.startswith(base_domain):
62
+ print("hrefe " +absolute_url)
63
+ all_links.add(absolute_url)
64
+ get_links_from_page(absolute_url, visited_urls, all_links, base_domain)
65
 
66
  else:
67
  print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
68
 
69
+
70
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
71
+ from langchain.chains import RetrievalQA
72
+
73
+ from langchain.memory import ConversationBufferMemory
74
+
75
+
76
 
77
  # Function to index URLs in RAG
78
+ def index_urls_in_rag(urls=[]):
79
+ # Load the RAG model
80
+ rag_model = "BAAI/bge-base-en-v1.5"
81
+ encode_kwargs = {
82
+ "normalize_embeddings": True
83
+ } # set True to compute cosine similarity
84
+ embeddings = HuggingFaceBgeEmbeddings(
85
+ model_name=rag_model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
86
+ )
87
+
88
  # Create a vector store for storing embeddings of documents
89
+ vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
90
 
91
+ print(urls)
92
+
93
 
94
  for url in urls:
95
  # Get text from the URL
 
103
  # Index document chunks into the vector store
104
  vector_store.add_documents(document_chunks)
105
 
106
+
107
  # Convert vector store to retriever
108
  retriever = vector_store.as_retriever()
109
+
110
+ return retriever
111
+
112
+
113
+
114
+ # Function to load the RAG model
115
+ def load_model():
116
+ model = HuggingFaceHub(
117
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
118
+ model_kwargs={"max_length": 1048, "temperature":0.1, "max_new_tokens":512, "top_p":0.95, "repetition_penalty":1.0},
119
+ )
120
+ return model
121
 
122
+ def get_conversational_rag_chain(retriever_chain):
123
+
124
+ llm = load_model()
125
+
126
  prompt = ChatPromptTemplate.from_messages([
127
+ ("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Antworte mit: 'Ich bin mir nicht sicher. Wollen Sie eine Mitarbeiter sprechen' Wenn die Antwort nicht aus dem Context hervorgeht. Antworte bitte immer auf Deutsch? CONTEXT:\n\n{context}"),
128
  MessagesPlaceholder(variable_name="chat_history"),
129
  ("user", "{input}"),
130
  ])
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
133
+
134
+ return create_retrieval_chain(retriever_chain, stuff_documents_chain)
135
 
136
+
137
+ def get_response(message, history=[]):
 
138
 
139
+ conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
140
+
141
+ response = conversation_rag_chain.invoke({
142
+ "chat_history": history,
143
+ "input": message + " Assistant: ",
144
+ "chat_message": message + " Assistant: "
145
+ })
146
+ #print("get_response " +response)
147
+ res = response['answer']
148
+ parts = res.split(" Assistant: ")
149
+ last_part = parts[-1]
150
+ return last_part#[-1]['generation']['content']
151
 
152
+
153
  def get_conversational_rag_chain(retriever_chain):
154
 
155
+ llm = load_model()
156
 
157
  prompt = ChatPromptTemplate.from_messages([
158
  ("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Du beantwortest basierend auf dem Context. Benutze nur den Inhalt des Context. Füge wenn möglich die Quelle hinzu. Antworte mit: Ich bin mir nicht sicher. Wenn die Antwort nicht aus dem Context hervorgeht. Antworte auf Deutsch, bitte? CONTEXT:\n\n{context}"),
 
165
  return create_retrieval_chain(retriever_chain, stuff_documents_chain)
166
 
167
 
168
+ def get_response(message, history=[]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
 
 
170
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
171
 
172
  response = conversation_rag_chain.invoke({
 
179
  parts = res.split(" Assistant: ")
180
  last_part = parts[-1]
181
  return last_part#[-1]['generation']['content']
182
+
183
+
184
+ # Index URLs on app startup
185
+ @app.on_event("startup")
186
+ async def startup():
187
+ domain_url = 'https://www.bofrost.de/faq/'
188
+ links = get_all_links_from_domain(domain_url)
189
+ retriever_chain = index_urls_in_rag(links)
190
+ retriever_chain.invoke("Was ist bofrost*")
191
+ get_response("Wlautet die URL Schoko Osterhasen?")
192
+
193
+ # Define API endpoint to receive queries and provide responses
194
+ @app.post("/generate/")
195
+ def generate(user_input):
196
+ return get_response(user_input, [])