Spaces:
Runtime error
Runtime error
Update qa_engine/qa_engine.py
Browse files- qa_engine/qa_engine.py +30 -0
qa_engine/qa_engine.py
CHANGED
|
@@ -227,6 +227,33 @@ class QAEngine():
|
|
| 227 |
self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
|
| 228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
| 232 |
"""
|
|
@@ -271,6 +298,9 @@ class QAEngine():
|
|
| 271 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 272 |
|
| 273 |
logger.info('Running LLM chain')
|
|
|
|
|
|
|
|
|
|
| 274 |
answer = self.llm_chain.run(question=question, context=context)
|
| 275 |
response.set_answer(answer)
|
| 276 |
logger.info('Received answer')
|
|
|
|
| 227 |
self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model)
|
| 228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 229 |
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def _preprocess_question(question: str) -> str:
|
| 233 |
+
if question[-1] != '?':
|
| 234 |
+
question += '?'
|
| 235 |
+
return question
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _postprocess_answer(answer: str) -> str:
|
| 240 |
+
'''
|
| 241 |
+
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
| 242 |
+
'''
|
| 243 |
+
REMOVE_SEQUENCES = [
|
| 244 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
| 245 |
+
]
|
| 246 |
+
STOP_SEQUENCES = [
|
| 247 |
+
'\nUser:', '\nYou:'
|
| 248 |
+
]
|
| 249 |
+
for seq in REMOVE_SEQUENCES:
|
| 250 |
+
answer = answer.replace(seq, '')
|
| 251 |
+
for seq in STOP_SEQUENCES:
|
| 252 |
+
if seq in answer:
|
| 253 |
+
answer = answer[:answer.index(seq)]
|
| 254 |
+
answer = answer.strip()
|
| 255 |
+
return answer
|
| 256 |
+
|
| 257 |
|
| 258 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
| 259 |
"""
|
|
|
|
| 298 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 299 |
|
| 300 |
logger.info('Running LLM chain')
|
| 301 |
+
question_processed = QAEngine._preprocess_question(question)
|
| 302 |
+
answer = self.llm_chain.run(question=question_processed, context=context)
|
| 303 |
+
answer = QAEngine._postprocess_answer(answer)
|
| 304 |
answer = self.llm_chain.run(question=question, context=context)
|
| 305 |
response.set_answer(answer)
|
| 306 |
logger.info('Received answer')
|