Spaces:
Runtime error
Runtime error
"""Chain that combines documents by stuffing into context.""" | |
from typing import Any, Dict, List, Optional, Tuple | |
from langchain_core.prompts import BasePromptTemplate, format_document | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_core.pydantic_v1 import Extra, Field, root_validator | |
from langchain.callbacks.manager import Callbacks | |
from langchain.chains.combine_documents.base import ( | |
BaseCombineDocumentsChain, | |
) | |
from langchain.chains.llm import LLMChain | |
from langchain.docstore.document import Document | |
def _get_default_document_prompt() -> PromptTemplate: | |
return PromptTemplate(input_variables=["page_content"], template="{page_content}") | |
class StuffDocumentsChain(BaseCombineDocumentsChain): | |
"""Chain that combines documents by stuffing into context. | |
This chain takes a list of documents and first combines them into a single string. | |
It does this by formatting each document into a string with the `document_prompt` | |
and then joining them together with `document_separator`. It then adds that new | |
string to the inputs with the variable name set by `document_variable_name`. | |
Those inputs are then passed to the `llm_chain`. | |
Example: | |
.. code-block:: python | |
from langchain.chains import StuffDocumentsChain, LLMChain | |
from langchain_core.prompts import PromptTemplate | |
from langchain.llms import OpenAI | |
# This controls how each document will be formatted. Specifically, | |
# it will be passed to `format_document` - see that function for more | |
# details. | |
document_prompt = PromptTemplate( | |
input_variables=["page_content"], | |
template="{page_content}" | |
) | |
document_variable_name = "context" | |
llm = OpenAI() | |
# The prompt here should take as an input variable the | |
# `document_variable_name` | |
prompt = PromptTemplate.from_template( | |
"Summarize this content: {context}" | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
chain = StuffDocumentsChain( | |
llm_chain=llm_chain, | |
document_prompt=document_prompt, | |
document_variable_name=document_variable_name | |
) | |
""" | |
llm_chain: LLMChain | |
"""LLM chain which is called with the formatted document string, | |
along with any other inputs.""" | |
document_prompt: BasePromptTemplate = Field( | |
default_factory=_get_default_document_prompt | |
) | |
"""Prompt to use to format each document, gets passed to `format_document`.""" | |
document_variable_name: str | |
"""The variable name in the llm_chain to put the documents in. | |
If only one variable in the llm_chain, this need not be provided.""" | |
document_separator: str = "\n\n" | |
"""The string with which to join the formatted documents""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def get_default_document_variable_name(cls, values: Dict) -> Dict: | |
"""Get default document variable name, if not provided. | |
If only one variable is present in the llm_chain.prompt, | |
we can infer that the formatted documents should be passed in | |
with this variable name. | |
""" | |
llm_chain_variables = values["llm_chain"].prompt.input_variables | |
if "document_variable_name" not in values: | |
if len(llm_chain_variables) == 1: | |
values["document_variable_name"] = llm_chain_variables[0] | |
else: | |
raise ValueError( | |
"document_variable_name must be provided if there are " | |
"multiple llm_chain_variables" | |
) | |
else: | |
if values["document_variable_name"] not in llm_chain_variables: | |
raise ValueError( | |
f"document_variable_name {values['document_variable_name']} was " | |
f"not found in llm_chain input_variables: {llm_chain_variables}" | |
) | |
return values | |
def input_keys(self) -> List[str]: | |
extra_keys = [ | |
k for k in self.llm_chain.input_keys if k != self.document_variable_name | |
] | |
return super().input_keys + extra_keys | |
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: | |
"""Construct inputs from kwargs and docs. | |
Format and the join all the documents together into one input with name | |
`self.document_variable_name`. The pluck any additional variables | |
from **kwargs. | |
Args: | |
docs: List of documents to format and then join into single input | |
**kwargs: additional inputs to chain, will pluck any other required | |
arguments from here. | |
Returns: | |
dictionary of inputs to LLMChain | |
""" | |
# Format each document according to the prompt | |
doc_strings = [format_document(doc, self.document_prompt) for doc in docs] | |
# Join the documents together to put them in the prompt. | |
inputs = { | |
k: v | |
for k, v in kwargs.items() | |
if k in self.llm_chain.prompt.input_variables | |
} | |
inputs[self.document_variable_name] = self.document_separator.join(doc_strings) | |
return inputs | |
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: | |
"""Return the prompt length given the documents passed in. | |
This can be used by a caller to determine whether passing in a list | |
of documents would exceed a certain prompt length. This useful when | |
trying to ensure that the size of a prompt remains below a certain | |
context limit. | |
Args: | |
docs: List[Document], a list of documents to use to calculate the | |
total prompt length. | |
Returns: | |
Returns None if the method does not depend on the prompt length, | |
otherwise the length of the prompt in tokens. | |
""" | |
inputs = self._get_inputs(docs, **kwargs) | |
prompt = self.llm_chain.prompt.format(**inputs) | |
return self.llm_chain._get_num_tokens(prompt) | |
def combine_docs( | |
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any | |
) -> Tuple[str, dict]: | |
"""Stuff all documents into one prompt and pass to LLM. | |
Args: | |
docs: List of documents to join together into one variable | |
callbacks: Optional callbacks to pass along | |
**kwargs: additional parameters to use to get inputs to LLMChain. | |
Returns: | |
The first element returned is the single string output. The second | |
element returned is a dictionary of other keys to return. | |
""" | |
inputs = self._get_inputs(docs, **kwargs) | |
# Call predict on the LLM. | |
return self.llm_chain.predict(callbacks=callbacks, **inputs), {} | |
async def acombine_docs( | |
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any | |
) -> Tuple[str, dict]: | |
"""Async stuff all documents into one prompt and pass to LLM. | |
Args: | |
docs: List of documents to join together into one variable | |
callbacks: Optional callbacks to pass along | |
**kwargs: additional parameters to use to get inputs to LLMChain. | |
Returns: | |
The first element returned is the single string output. The second | |
element returned is a dictionary of other keys to return. | |
""" | |
inputs = self._get_inputs(docs, **kwargs) | |
# Call predict on the LLM. | |
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} | |
def _chain_type(self) -> str: | |
return "stuff_documents_chain" | |