Chris4K commited on
Commit
9effd9a
·
verified ·
1 Parent(s): 3f11a95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -41
app.py CHANGED
@@ -88,7 +88,7 @@ def index_urls_in_rag(urls=[]):
88
  # Create a vector store for storing embeddings of documents
89
  vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
90
 
91
- print(urls)
92
 
93
 
94
  for url in urls:
@@ -99,7 +99,7 @@ def index_urls_in_rag(urls=[]):
99
  # Split the document into chunks
100
  text_splitter = RecursiveCharacterTextSplitter()
101
  document_chunks = text_splitter.split_documents(document)
102
-
103
  # Index document chunks into the vector store
104
  vector_store.add_documents(document_chunks)
105
 
@@ -149,50 +149,25 @@ def get_response(message, history=[]):
149
  last_part = parts[-1]
150
  return last_part#[-1]['generation']['content']
151
 
152
-
153
- def get_conversational_rag_chain(retriever_chain):
154
-
155
- llm = load_model()
156
-
157
- prompt = ChatPromptTemplate.from_messages([
158
- ("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Du beantwortest basierend auf dem Context. Benutze nur den Inhalt des Context. Füge wenn möglich die Quelle hinzu. Antworte mit: Ich bin mir nicht sicher. Wenn die Antwort nicht aus dem Context hervorgeht. Antworte auf Deutsch, bitte? CONTEXT:\n\n{context}"),
159
- MessagesPlaceholder(variable_name="chat_history"),
160
- ("user", "{input}"),
161
- ])
162
-
163
- stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
164
-
165
- return create_retrieval_chain(retriever_chain, stuff_documents_chain)
166
-
167
-
168
- def get_response(message, history=[]):
169
-
170
- retriever_chain = index_urls_in_rag()
171
- conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
172
-
173
- response = conversation_rag_chain.invoke({
174
- "chat_history": history,
175
- "input": message + " Assistant: ",
176
- "chat_message": message + " Assistant: "
177
- })
178
- #print("get_response " +response)
179
- res = response['answer']
180
- parts = res.split(" Assistant: ")
181
- last_part = parts[-1]
182
- return last_part#[-1]['generation']['content']
183
-
184
 
185
  # Index URLs on app startup
186
  @app.on_event("startup")
187
  async def startup():
188
- domain_url = 'https://www.bofrost.de/faq/service-infos-fuer-neukunden.html'
189
- links = get_all_links_from_domain(domain_url)
190
- print(links)
191
- retriever_chain = index_urls_in_rag(links)
192
- retriever_chain.invoke("Was ist bofrost*")
193
- get_response("Was kosten Schoko Osterhasen?")
194
 
195
  # Define API endpoint to receive queries and provide responses
196
  @app.post("/generate/")
197
  def generate(user_input):
198
- return get_response(user_input, [])
 
 
 
 
 
 
 
 
88
  # Create a vector store for storing embeddings of documents
89
  vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
90
 
91
+ print("Embedding " +urls)
92
 
93
 
94
  for url in urls:
 
99
  # Split the document into chunks
100
  text_splitter = RecursiveCharacterTextSplitter()
101
  document_chunks = text_splitter.split_documents(document)
102
+ print(document_chunks)
103
  # Index document chunks into the vector store
104
  vector_store.add_documents(document_chunks)
105
 
 
149
  last_part = parts[-1]
150
  return last_part#[-1]['generation']['content']
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # Index URLs on app startup
154
  @app.on_event("startup")
155
  async def startup():
156
+ # domain_url = 'https://www.bofrost.de/faq/service-infos-fuer-neukunden.html'
157
+ # links = get_all_links_from_domain(domain_url)
158
+ # print(links)
159
+ # retriever_chain = index_urls_in_rag(links)
160
+ # retriever_chain.invoke("Was ist bofrost*")
161
+ # get_response("Was kosten Schoko Osterhasen?")
162
 
163
  # Define API endpoint to receive queries and provide responses
164
  @app.post("/generate/")
165
  def generate(user_input):
166
+ return get_response(user_input, [])
167
+
168
+ # Define API endpoint to receive queries and provide responses
169
+ @app.post("/update/")
170
+ def generate(index_url):
171
+ retriever_chain = index_urls_in_rag(domain_url)
172
+ retriever_chain.invoke("Was ist bofrost*")
173
+ get_response("Was kosten Schoko Osterhasen?")