SiraH commited on
Commit
98eb843
1 Parent(s): 0820ec4

change to retrieval method

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -181,11 +181,9 @@ def load_llama2_llamaCpp():
181
 
182
  def set_custom_prompt():
183
  custom_prompt_template = """ Use the following pieces of information to answer the user's question.
184
- If you don't know the answer, please just say that you don't know the answer, don't try to make up
185
- an answer.
186
 
187
  Context : {context}
188
- chat_history : {chat_history}
189
  Question : {question}
190
 
191
  Only returns the helpful answer below and nothing else.
@@ -193,7 +191,7 @@ def set_custom_prompt():
193
  """
194
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
195
  'question',
196
- 'chat_history'])
197
  return prompt
198
 
199
 
@@ -228,10 +226,10 @@ def main():
228
 
229
  llm = load_llama2_llamaCpp()
230
  qa_prompt = set_custom_prompt()
231
- memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history")
232
  #memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
233
- doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt)
234
- question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
235
  embeddings = load_embeddings()
236
 
237
 
@@ -243,15 +241,27 @@ def main():
243
  for page in pdf_reader.pages:
244
  text += page.extract_text()
245
  db = FAISS.from_texts(text, embeddings)
246
- qa_chain = ConversationalRetrievalChain(
247
- retriever =db.as_retriever(search_kwargs={'k':2}),
248
- question_generator=question_generator,
249
- #condense_question_prompt=CONDENSE_QUESTION_PROMPT,
250
- combine_docs_chain=doc_chain,
251
- return_source_documents=True,
 
 
 
 
252
  memory = memory,
253
- #get_chat_history=lambda h :h
254
- )
 
 
 
 
 
 
 
 
255
 
256
  for message in st.session_state.messages:
257
  with st.chat_message(message["role"]):
@@ -267,20 +277,20 @@ def main():
267
 
268
  start = time.time()
269
 
270
- response = qa_chain({'question': query})
271
 
272
  # url_list = set([i.metadata['source'] for i in response['source_documents']])
273
  #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}")
274
 
275
  with st.chat_message("assistant"):
276
- st.markdown(response['answer'])
277
 
278
  end = time.time()
279
  st.write("Respone time:",int(end-start),"sec")
280
  print(response)
281
 
282
  # Add assistant response to chat history
283
- st.session_state.messages.append({"role": "assistant", "content": response['answer']})
284
 
285
  # with st.expander("See the related documents"):
286
  # for count, url in enumerate(url_list):
 
181
 
182
  def set_custom_prompt():
183
  custom_prompt_template = """ Use the following pieces of information to answer the user's question.
184
+ If you don't know the answer, don't try to make up an answer.
 
185
 
186
  Context : {context}
 
187
  Question : {question}
188
 
189
  Only returns the helpful answer below and nothing else.
 
191
  """
192
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
193
  'question',
194
+ ])
195
  return prompt
196
 
197
 
 
226
 
227
  llm = load_llama2_llamaCpp()
228
  qa_prompt = set_custom_prompt()
229
+ #memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history")
230
  #memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
231
+ #doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt)
232
+ #question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
233
  embeddings = load_embeddings()
234
 
235
 
 
241
  for page in pdf_reader.pages:
242
  text += page.extract_text()
243
  db = FAISS.from_texts(text, embeddings)
244
+
245
+ memory = ConversationBufferMemory(memory_key="chat_history",
246
+ return_messages=True,
247
+ input_key="query",
248
+ output_key="result")
249
+ qa_chain = RetrievalQA.from_chain_type(
250
+ llm = llm,
251
+ chain_type = "stuff",
252
+ retriever = db.as_retriever(search_kwargs = {'k':3}),
253
+ return_source_documents = True,
254
  memory = memory,
255
+ chain_type_kwargs = {"prompt":qa_prompt})
256
+ # qa_chain = ConversationalRetrievalChain(
257
+ # retriever =db.as_retriever(search_kwargs={'k':2}),
258
+ # question_generator=question_generator,
259
+ # #condense_question_prompt=CONDENSE_QUESTION_PROMPT,
260
+ # combine_docs_chain=doc_chain,
261
+ # return_source_documents=True,
262
+ # memory = memory,
263
+ # #get_chat_history=lambda h :h
264
+ # )
265
 
266
  for message in st.session_state.messages:
267
  with st.chat_message(message["role"]):
 
277
 
278
  start = time.time()
279
 
280
+ response = qa_chain({'query': query})
281
 
282
  # url_list = set([i.metadata['source'] for i in response['source_documents']])
283
  #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}")
284
 
285
  with st.chat_message("assistant"):
286
+ st.markdown(response['result'])
287
 
288
  end = time.time()
289
  st.write("Respone time:",int(end-start),"sec")
290
  print(response)
291
 
292
  # Add assistant response to chat history
293
+ st.session_state.messages.append({"role": "assistant", "content": response['result']})
294
 
295
  # with st.expander("See the related documents"):
296
  # for count, url in enumerate(url_list):