Spaces:
Running
Running
import inspect | |
from typing import Dict, Any, Optional, List | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain | |
from langchain.docstore.document import Document | |
from logger import logger | |
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): | |
"""QA with source chain for Chat ArXiv app with references | |
This chain will automatically assign reference number to the article, | |
Then parse it back to titles or anything else. | |
""" | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m") | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
accepts_run_manager = ( | |
"run_manager" in inspect.signature(self._get_docs).parameters | |
) | |
if accepts_run_manager: | |
docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager) | |
else: | |
docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg] | |
answer = self.combine_documents_chain.run( | |
input_documents=docs, callbacks=_run_manager.get_child(), **inputs | |
) | |
# parse source with ref_id | |
sources = [] | |
ref_cnt = 1 | |
for d in docs: | |
ref_id = d.metadata['ref_id'] | |
if f"Doc #{ref_id}" in answer: | |
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}") | |
if f"#{ref_id}" in answer: | |
title = d.metadata['title'].replace('\n', '') | |
d.metadata['ref_id'] = ref_cnt | |
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]") | |
sources.append(d) | |
ref_cnt += 1 | |
result: Dict[str, Any] = { | |
self.answer_key: answer, | |
self.sources_answer_key: sources, | |
} | |
if self.return_source_documents: | |
result["source_documents"] = docs | |
return result | |
async def _acall( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
raise NotImplementedError | |
def _chain_type(self) -> str: | |
return "custom_retrieval_qa_with_sources_chain" | |