Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import inspect | |
from typing import Any, Dict, List, Optional | |
from pydantic import Extra | |
from langchain.schema.language_model import BaseLanguageModel | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.base import Chain | |
from langchain.prompts.base import BasePromptTemplate | |
from typing import Any, Dict, List | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain | |
from langchain.docstore.document import Document | |
from langchain.pydantic_v1 import Field | |
from langchain.schema import BaseRetriever | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.chains.router.llm_router import LLMRouterChain | |
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): | |
fallback_answer:str = "No sources available to answer this question." | |
def _call(self,inputs,run_manager=None): | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
accepts_run_manager = ( | |
"run_manager" in inspect.signature(self._get_docs).parameters | |
) | |
if accepts_run_manager: | |
docs = self._get_docs(inputs, run_manager=_run_manager) | |
else: | |
docs = self._get_docs(inputs) # type: ignore[call-arg] | |
if len(docs) == 0: | |
answer = self.fallback_answer | |
sources = [] | |
else: | |
answer = self.combine_documents_chain.run( | |
input_documents=docs, callbacks=_run_manager.get_child(), **inputs | |
) | |
answer, sources = self._split_sources(answer) | |
result: Dict[str, Any] = { | |
self.answer_key: answer, | |
self.sources_answer_key: sources, | |
} | |
if self.return_source_documents: | |
result["source_documents"] = docs | |
return result | |