"""Chain for chatting with a vector database.""" from __future__ import annotations from abc import abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Extra, Field from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseRetriever, Document from langchain.vectorstores.base import VectorStore def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: buffer = "" for human_s, ai_s in chat_history: human = "Human: " + human_s ai = "Assistant: " + ai_s buffer += "\n" + "\n".join([human, ai]) return buffer class BaseConversationalRetrievalChain(Chain, BaseModel): """Chain for chatting with an index.""" combine_docs_chain: BaseCombineDocumentsChain question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None """Return the source documents.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @property def input_keys(self) -> List[str]: """Input keys.""" return ["question", "chat_history"] @property def output_keys(self) -> List[str]: """Return the output keys. :meta private: """ _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] return _output_keys @abstractmethod def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: """Get docs.""" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: new_question = self.question_generator.run( question=question, chat_history=chat_history_str ) else: new_question = question docs = self._get_docs(new_question, inputs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: new_question = await self.question_generator.arun( question=question, chat_history=chat_history_str ) else: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. docs = self._get_docs(new_question, inputs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} def save(self, file_path: Union[Path, str]) -> None: if self.get_chat_history: raise ValueError("Chain not savable when `get_chat_history` is not None.") super().save(file_path) class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): """Chain for chatting with an index.""" retriever: BaseRetriever def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: return self.retriever.get_relevant_texts(question) @classmethod def from_llm( cls, llm: BaseLanguageModel, retriever: BaseRetriever, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" doc_chain = load_qa_chain( llm, chain_type=chain_type, prompt=qa_prompt, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls( retriever=retriever, combine_docs_chain=doc_chain, question_generator=condense_question_chain, **kwargs, ) class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): """Chain for chatting with a vector database.""" vectorstore: VectorStore = Field(alias="vectorstore") top_k_docs_for_context: int = 4 search_kwargs: dict = Field(default_factory=dict) @property def _chain_type(self) -> str: return "chat-vector-db" def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: vectordbkwargs = inputs.get("vectordbkwargs", {}) full_kwargs = {**self.search_kwargs, **vectordbkwargs} return self.vectorstore.similarity_search( question, k=self.top_k_docs_for_context, **full_kwargs ) @classmethod def from_llm( cls, llm: BaseLanguageModel, vectorstore: VectorStore, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" doc_chain = load_qa_chain( llm, chain_type=chain_type, prompt=qa_prompt, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls( vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=condense_question_chain, **kwargs, )