File size: 1,159 Bytes
3ac9dae 4ce9985 3ac9dae 4ce9985 3ac9dae 4ce9985 3ac9dae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.base import Chain
from langchain.vectorstores.base import VectorStore
from app_modules.llm_inference import LLMInference
class QAChain(LLMInference):
vectorstore: VectorStore
def __init__(self, vectorstore, llm_loader, doc_id_to_vectorstore_mapping=None):
super().__init__(llm_loader)
self.vectorstore = vectorstore
self.doc_id_to_vectorstore_mapping = doc_id_to_vectorstore_mapping
def get_chain(self, inputs) -> Chain:
return self.create_chain(inputs)
def create_chain(self, inputs) -> Chain:
vectorstore = self.vectorstore
if "chat_id" in inputs:
if inputs["chat_id"] in self.doc_id_to_vectorstore_mapping:
vectorstore = self.doc_id_to_vectorstore_mapping[inputs["chat_id"]]
qa = ConversationalRetrievalChain.from_llm(
self.llm_loader.llm,
vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
max_tokens_limit=self.llm_loader.max_tokens_limit,
return_source_documents=True,
)
return qa
|