Spaces:
Runtime error
Runtime error
| """Combining documents by mapping a chain over them first, then reranking results.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast | |
| from pydantic import BaseModel, Extra, root_validator | |
| from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.docstore.document import Document | |
| from langchain.output_parsers.regex import RegexParser | |
| class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): | |
| """Combining documents by mapping a chain over them, then reranking results.""" | |
| llm_chain: LLMChain | |
| """Chain to apply to each document individually.""" | |
| document_variable_name: str | |
| """The variable name in the llm_chain to put the documents in. | |
| If only one variable in the llm_chain, this need not be provided.""" | |
| rank_key: str | |
| """Key in output of llm_chain to rank on.""" | |
| answer_key: str | |
| """Key in output of llm_chain to return as answer.""" | |
| metadata_keys: Optional[List[str]] = None | |
| return_intermediate_steps: bool = False | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def output_keys(self) -> List[str]: | |
| """Expect input key. | |
| :meta private: | |
| """ | |
| _output_keys = super().output_keys | |
| if self.return_intermediate_steps: | |
| _output_keys = _output_keys + ["intermediate_steps"] | |
| if self.metadata_keys is not None: | |
| _output_keys += self.metadata_keys | |
| return _output_keys | |
| def validate_llm_output(cls, values: Dict) -> Dict: | |
| """Validate that the combine chain outputs a dictionary.""" | |
| output_parser = values["llm_chain"].prompt.output_parser | |
| if not isinstance(output_parser, RegexParser): | |
| raise ValueError( | |
| "Output parser of llm_chain should be a RegexParser," | |
| f" got {output_parser}" | |
| ) | |
| output_keys = output_parser.output_keys | |
| if values["rank_key"] not in output_keys: | |
| raise ValueError( | |
| f"Got {values['rank_key']} as key to rank on, but did not find " | |
| f"it in the llm_chain output keys ({output_keys})" | |
| ) | |
| if values["answer_key"] not in output_keys: | |
| raise ValueError( | |
| f"Got {values['answer_key']} as key to return, but did not find " | |
| f"it in the llm_chain output keys ({output_keys})" | |
| ) | |
| return values | |
| def get_default_document_variable_name(cls, values: Dict) -> Dict: | |
| """Get default document variable name, if not provided.""" | |
| if "document_variable_name" not in values: | |
| llm_chain_variables = values["llm_chain"].prompt.input_variables | |
| if len(llm_chain_variables) == 1: | |
| values["document_variable_name"] = llm_chain_variables[0] | |
| else: | |
| raise ValueError( | |
| "document_variable_name must be provided if there are " | |
| "multiple llm_chain input_variables" | |
| ) | |
| else: | |
| llm_chain_variables = values["llm_chain"].prompt.input_variables | |
| if values["document_variable_name"] not in llm_chain_variables: | |
| raise ValueError( | |
| f"document_variable_name {values['document_variable_name']} was " | |
| f"not found in llm_chain input_variables: {llm_chain_variables}" | |
| ) | |
| return values | |
| def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: | |
| """Combine documents in a map rerank manner. | |
| Combine by mapping first chain over all documents, then reranking the results. | |
| """ | |
| results = self.llm_chain.apply_and_parse( | |
| # FYI - this is parallelized and so it is fast. | |
| [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] | |
| ) | |
| return self._process_results(docs, results) | |
| async def acombine_docs( | |
| self, docs: List[Document], **kwargs: Any | |
| ) -> Tuple[str, dict]: | |
| """Combine documents in a map rerank manner. | |
| Combine by mapping first chain over all documents, then reranking the results. | |
| """ | |
| results = await self.llm_chain.aapply_and_parse( | |
| # FYI - this is parallelized and so it is fast. | |
| [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] | |
| ) | |
| return self._process_results(docs, results) | |
| def _process_results( | |
| self, | |
| docs: List[Document], | |
| results: Sequence[Union[str, List[str], Dict[str, str]]], | |
| ) -> Tuple[str, dict]: | |
| typed_results = cast(List[dict], results) | |
| sorted_res = sorted( | |
| zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) | |
| ) | |
| output, document = sorted_res[0] | |
| extra_info = {} | |
| if self.metadata_keys is not None: | |
| for key in self.metadata_keys: | |
| extra_info[key] = document.metadata[key] | |
| if self.return_intermediate_steps: | |
| extra_info["intermediate_steps"] = results | |
| return output[self.answer_key], extra_info | |
| def _chain_type(self) -> str: | |
| return "map_rerank_documents_chain" | |