lunarflu HF staff commited on
Commit
1d486a2
1 Parent(s): 16cc722

Update qa_engine/qa_engine.py

Browse files
Files changed (1) hide show
  1. qa_engine/qa_engine.py +30 -0
qa_engine/qa_engine.py CHANGED
@@ -227,6 +227,33 @@ class QAEngine():
227
  self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
228
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def get_response(self, question: str, messages_context: str = '') -> Response:
232
  """
@@ -271,6 +298,9 @@ class QAEngine():
271
  response.set_sources(sources=[str(m['source']) for m in metadata])
272
 
273
  logger.info('Running LLM chain')
 
 
 
274
  answer = self.llm_chain.run(question=question, context=context)
275
  response.set_answer(answer)
276
  logger.info('Received answer')
 
227
  self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
228
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
229
 
230
+
231
+ @staticmethod
232
+ def _preprocess_question(question: str) -> str:
233
+ if question[-1] != '?':
234
+ question += '?'
235
+ return question
236
+
237
+
238
+ @staticmethod
239
+ def _postprocess_answer(answer: str) -> str:
240
+ '''
241
+ Preprocess the answer by removing unnecessary sequences and stop sequences.
242
+ '''
243
+ REMOVE_SEQUENCES = [
244
+ 'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
245
+ ]
246
+ STOP_SEQUENCES = [
247
+ '\nUser:', '\nYou:'
248
+ ]
249
+ for seq in REMOVE_SEQUENCES:
250
+ answer = answer.replace(seq, '')
251
+ for seq in STOP_SEQUENCES:
252
+ if seq in answer:
253
+ answer = answer[:answer.index(seq)]
254
+ answer = answer.strip()
255
+ return answer
256
+
257
 
258
  def get_response(self, question: str, messages_context: str = '') -> Response:
259
  """
 
298
  response.set_sources(sources=[str(m['source']) for m in metadata])
299
 
300
  logger.info('Running LLM chain')
301
+ question_processed = QAEngine._preprocess_question(question)
302
+ answer = self.llm_chain.run(question=question_processed, context=context)
303
+ answer = QAEngine._postprocess_answer(answer)
304
  answer = self.llm_chain.run(question=question, context=context)
305
  response.set_answer(answer)
306
  logger.info('Received answer')