Spaces:
Runtime error
Runtime error
"""Map-reduce chain. | |
Splits up a document, sends the smaller parts to the LLM with one prompt, | |
then combines the results with another one. | |
""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Mapping, Optional | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import Extra | |
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks | |
from langchain.chains import ReduceDocumentsChain | |
from langchain.chains.base import Chain | |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | |
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain | |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
from langchain.chains.llm import LLMChain | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import TextSplitter | |
class MapReduceChain(Chain): | |
"""Map-reduce chain.""" | |
combine_documents_chain: BaseCombineDocumentsChain | |
"""Chain to use to combine documents.""" | |
text_splitter: TextSplitter | |
"""Text splitter to use.""" | |
input_key: str = "input_text" #: :meta private: | |
output_key: str = "output_text" #: :meta private: | |
def from_params( | |
cls, | |
llm: BaseLanguageModel, | |
prompt: BasePromptTemplate, | |
text_splitter: TextSplitter, | |
callbacks: Callbacks = None, | |
combine_chain_kwargs: Optional[Mapping[str, Any]] = None, | |
reduce_chain_kwargs: Optional[Mapping[str, Any]] = None, | |
**kwargs: Any, | |
) -> MapReduceChain: | |
"""Construct a map-reduce chain that uses the chain for map and reduce.""" | |
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) | |
stuff_chain = StuffDocumentsChain( | |
llm_chain=llm_chain, | |
callbacks=callbacks, | |
**(reduce_chain_kwargs if reduce_chain_kwargs else {}), | |
) | |
reduce_documents_chain = ReduceDocumentsChain( | |
combine_documents_chain=stuff_chain | |
) | |
combine_documents_chain = MapReduceDocumentsChain( | |
llm_chain=llm_chain, | |
reduce_documents_chain=reduce_documents_chain, | |
callbacks=callbacks, | |
**(combine_chain_kwargs if combine_chain_kwargs else {}), | |
) | |
return cls( | |
combine_documents_chain=combine_documents_chain, | |
text_splitter=text_splitter, | |
callbacks=callbacks, | |
**kwargs, | |
) | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def input_keys(self) -> List[str]: | |
"""Expect input key. | |
:meta private: | |
""" | |
return [self.input_key] | |
def output_keys(self) -> List[str]: | |
"""Return output key. | |
:meta private: | |
""" | |
return [self.output_key] | |
def _call( | |
self, | |
inputs: Dict[str, str], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
# Split the larger text into smaller chunks. | |
doc_text = inputs.pop(self.input_key) | |
texts = self.text_splitter.split_text(doc_text) | |
docs = [Document(page_content=text) for text in texts] | |
_inputs: Dict[str, Any] = { | |
**inputs, | |
self.combine_documents_chain.input_key: docs, | |
} | |
outputs = self.combine_documents_chain.run( | |
_inputs, callbacks=_run_manager.get_child() | |
) | |
return {self.output_key: outputs} | |