|
from __future__ import annotations |
|
import inspect |
|
from typing import Any, Dict |
|
|
|
from langchain.callbacks.manager import ( |
|
CallbackManagerForChainRun, |
|
) |
|
|
|
from typing import Any, Dict |
|
|
|
from langchain.callbacks.manager import ( |
|
CallbackManagerForChainRun, |
|
) |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
|
|
|
|
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): |
|
fallback_answer: str = "No sources available to answer this question." |
|
|
|
def _call(self, inputs, run_manager=None): |
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
|
accepts_run_manager = ( |
|
"run_manager" in inspect.signature(self._get_docs).parameters |
|
) |
|
if accepts_run_manager: |
|
docs = self._get_docs(inputs, run_manager=_run_manager) |
|
else: |
|
docs = self._get_docs(inputs) |
|
|
|
if len(docs) == 0: |
|
answer = self.fallback_answer |
|
sources = [] |
|
else: |
|
answer = self.combine_documents_chain.run( |
|
input_documents=docs, callbacks=_run_manager.get_child(), **inputs |
|
) |
|
answer, sources = self._split_sources(answer) |
|
|
|
result: Dict[str, Any] = { |
|
self.answer_key: answer, |
|
self.sources_answer_key: sources, |
|
} |
|
if self.return_source_documents: |
|
result["source_documents"] = docs |
|
return result |
|
|