Spaces:
Runtime error
Runtime error
"""Chain for chatting with a vector database.""" | |
from __future__ import annotations | |
from abc import abstractmethod | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from pydantic import BaseModel, Extra, Field | |
from langchain.chains.base import Chain | |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.schema import BaseLanguageModel, BaseRetriever, Document | |
from langchain.vectorstores.base import VectorStore | |
def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: | |
buffer = "" | |
for human_s, ai_s in chat_history: | |
human = "Human: " + human_s | |
ai = "Assistant: " + ai_s | |
buffer += "\n" + "\n".join([human, ai]) | |
return buffer | |
class BaseConversationalRetrievalChain(Chain, BaseModel): | |
"""Chain for chatting with an index.""" | |
combine_docs_chain: BaseCombineDocumentsChain | |
question_generator: LLMChain | |
output_key: str = "answer" | |
return_source_documents: bool = False | |
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None | |
"""Return the source documents.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
allow_population_by_field_name = True | |
def input_keys(self) -> List[str]: | |
"""Input keys.""" | |
return ["question", "chat_history"] | |
def output_keys(self) -> List[str]: | |
"""Return the output keys. | |
:meta private: | |
""" | |
_output_keys = [self.output_key] | |
if self.return_source_documents: | |
_output_keys = _output_keys + ["source_documents"] | |
return _output_keys | |
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: | |
"""Get docs.""" | |
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
question = inputs["question"] | |
get_chat_history = self.get_chat_history or _get_chat_history | |
chat_history_str = get_chat_history(inputs["chat_history"]) | |
if chat_history_str: | |
new_question = self.question_generator.run( | |
question=question, chat_history=chat_history_str | |
) | |
else: | |
new_question = question | |
docs = self._get_docs(new_question, inputs) | |
new_inputs = inputs.copy() | |
new_inputs["question"] = new_question | |
new_inputs["chat_history"] = chat_history_str | |
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs) | |
if self.return_source_documents: | |
return {self.output_key: answer, "source_documents": docs} | |
else: | |
return {self.output_key: answer} | |
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
question = inputs["question"] | |
get_chat_history = self.get_chat_history or _get_chat_history | |
chat_history_str = get_chat_history(inputs["chat_history"]) | |
if chat_history_str: | |
new_question = await self.question_generator.arun( | |
question=question, chat_history=chat_history_str | |
) | |
else: | |
new_question = question | |
# TODO: This blocks the event loop, but it's not clear how to avoid it. | |
docs = self._get_docs(new_question, inputs) | |
new_inputs = inputs.copy() | |
new_inputs["question"] = new_question | |
new_inputs["chat_history"] = chat_history_str | |
answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs) | |
if self.return_source_documents: | |
return {self.output_key: answer, "source_documents": docs} | |
else: | |
return {self.output_key: answer} | |
def save(self, file_path: Union[Path, str]) -> None: | |
if self.get_chat_history: | |
raise ValueError("Chain not savable when `get_chat_history` is not None.") | |
super().save(file_path) | |
class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): | |
"""Chain for chatting with an index.""" | |
retriever: BaseRetriever | |
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: | |
return self.retriever.get_relevant_texts(question) | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
retriever: BaseRetriever, | |
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, | |
qa_prompt: Optional[BasePromptTemplate] = None, | |
chain_type: str = "stuff", | |
**kwargs: Any, | |
) -> BaseConversationalRetrievalChain: | |
"""Load chain from LLM.""" | |
doc_chain = load_qa_chain( | |
llm, | |
chain_type=chain_type, | |
prompt=qa_prompt, | |
) | |
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) | |
return cls( | |
retriever=retriever, | |
combine_docs_chain=doc_chain, | |
question_generator=condense_question_chain, | |
**kwargs, | |
) | |
class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): | |
"""Chain for chatting with a vector database.""" | |
vectorstore: VectorStore = Field(alias="vectorstore") | |
top_k_docs_for_context: int = 4 | |
search_kwargs: dict = Field(default_factory=dict) | |
def _chain_type(self) -> str: | |
return "chat-vector-db" | |
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: | |
vectordbkwargs = inputs.get("vectordbkwargs", {}) | |
full_kwargs = {**self.search_kwargs, **vectordbkwargs} | |
return self.vectorstore.similarity_search( | |
question, k=self.top_k_docs_for_context, **full_kwargs | |
) | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
vectorstore: VectorStore, | |
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, | |
qa_prompt: Optional[BasePromptTemplate] = None, | |
chain_type: str = "stuff", | |
**kwargs: Any, | |
) -> BaseConversationalRetrievalChain: | |
"""Load chain from LLM.""" | |
doc_chain = load_qa_chain( | |
llm, | |
chain_type=chain_type, | |
prompt=qa_prompt, | |
) | |
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) | |
return cls( | |
vectorstore=vectorstore, | |
combine_docs_chain=doc_chain, | |
question_generator=condense_question_chain, | |
**kwargs, | |
) | |