anything-question-answering / climateqa /custom_retrieval_chain.py
LOUIS SANNA
feat(loader)
cc2ce8c
raw
history blame
1.45 kB
from __future__ import annotations
import inspect
from typing import Any, Dict
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from typing import Any, Dict
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.chains import RetrievalQAWithSourcesChain
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
fallback_answer: str = "No sources available to answer this question."
def _call(self, inputs, run_manager=None):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
if len(docs) == 0:
answer = self.fallback_answer
sources = []
else:
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result