Spaces:
Runtime error
Runtime error
"""Chain for summarization with self-verification.""" | |
from __future__ import annotations | |
import warnings | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_core.pydantic_v1 import Extra, root_validator | |
from langchain.callbacks.manager import CallbackManagerForChainRun | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.sequential import SequentialChain | |
PROMPTS_DIR = Path(__file__).parent / "prompts" | |
CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file( | |
PROMPTS_DIR / "create_facts.txt", ["summary"] | |
) | |
CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file( | |
PROMPTS_DIR / "check_facts.txt", ["assertions"] | |
) | |
REVISED_SUMMARY_PROMPT = PromptTemplate.from_file( | |
PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"] | |
) | |
ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file( | |
PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"] | |
) | |
def _load_sequential_chain( | |
llm: BaseLanguageModel, | |
create_assertions_prompt: PromptTemplate, | |
check_assertions_prompt: PromptTemplate, | |
revised_summary_prompt: PromptTemplate, | |
are_all_true_prompt: PromptTemplate, | |
verbose: bool = False, | |
) -> SequentialChain: | |
chain = SequentialChain( | |
chains=[ | |
LLMChain( | |
llm=llm, | |
prompt=create_assertions_prompt, | |
output_key="assertions", | |
verbose=verbose, | |
), | |
LLMChain( | |
llm=llm, | |
prompt=check_assertions_prompt, | |
output_key="checked_assertions", | |
verbose=verbose, | |
), | |
LLMChain( | |
llm=llm, | |
prompt=revised_summary_prompt, | |
output_key="revised_summary", | |
verbose=verbose, | |
), | |
LLMChain( | |
llm=llm, | |
output_key="all_true", | |
prompt=are_all_true_prompt, | |
verbose=verbose, | |
), | |
], | |
input_variables=["summary"], | |
output_variables=["all_true", "revised_summary"], | |
verbose=verbose, | |
) | |
return chain | |
class LLMSummarizationCheckerChain(Chain): | |
"""Chain for question-answering with self-verification. | |
Example: | |
.. code-block:: python | |
from langchain.llms import OpenAI | |
from langchain.chains import LLMSummarizationCheckerChain | |
llm = OpenAI(temperature=0.0) | |
checker_chain = LLMSummarizationCheckerChain.from_llm(llm) | |
""" | |
sequential_chain: SequentialChain | |
llm: Optional[BaseLanguageModel] = None | |
"""[Deprecated] LLM wrapper to use.""" | |
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT | |
"""[Deprecated]""" | |
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT | |
"""[Deprecated]""" | |
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT | |
"""[Deprecated]""" | |
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT | |
"""[Deprecated]""" | |
input_key: str = "query" #: :meta private: | |
output_key: str = "result" #: :meta private: | |
max_checks: int = 2 | |
"""Maximum number of times to check the assertions. Default to double-checking.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def raise_deprecation(cls, values: Dict) -> Dict: | |
if "llm" in values: | |
warnings.warn( | |
"Directly instantiating an LLMSummarizationCheckerChain with an llm is " | |
"deprecated. Please instantiate with" | |
" sequential_chain argument or using the from_llm class method." | |
) | |
if "sequential_chain" not in values and values["llm"] is not None: | |
values["sequential_chain"] = _load_sequential_chain( | |
values["llm"], | |
values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT), | |
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), | |
values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT), | |
values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT), | |
verbose=values.get("verbose", False), | |
) | |
return values | |
def input_keys(self) -> List[str]: | |
"""Return the singular input key. | |
:meta private: | |
""" | |
return [self.input_key] | |
def output_keys(self) -> List[str]: | |
"""Return the singular output key. | |
:meta private: | |
""" | |
return [self.output_key] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
all_true = False | |
count = 0 | |
output = None | |
original_input = inputs[self.input_key] | |
chain_input = original_input | |
while not all_true and count < self.max_checks: | |
output = self.sequential_chain( | |
{"summary": chain_input}, callbacks=_run_manager.get_child() | |
) | |
count += 1 | |
if output["all_true"].strip() == "True": | |
break | |
if self.verbose: | |
print(output["revised_summary"]) | |
chain_input = output["revised_summary"] | |
if not output: | |
raise ValueError("No output from chain") | |
return {self.output_key: output["revised_summary"].strip()} | |
def _chain_type(self) -> str: | |
return "llm_summarization_checker_chain" | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT, | |
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, | |
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT, | |
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT, | |
verbose: bool = False, | |
**kwargs: Any, | |
) -> LLMSummarizationCheckerChain: | |
chain = _load_sequential_chain( | |
llm, | |
create_assertions_prompt, | |
check_assertions_prompt, | |
revised_summary_prompt, | |
are_all_true_prompt, | |
verbose=verbose, | |
) | |
return cls(sequential_chain=chain, verbose=verbose, **kwargs) | |