Spaces:
Runtime error
Runtime error
"""Base classes for LLM-powered router chains.""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional, Type, cast | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import BaseOutputParser | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import root_validator | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains import LLMChain | |
from langchain.chains.router.base import RouterChain | |
from langchain.output_parsers.json import parse_and_check_json_markdown | |
class LLMRouterChain(RouterChain): | |
"""A router chain that uses an LLM chain to perform routing.""" | |
llm_chain: LLMChain | |
"""LLM chain used to perform routing""" | |
def validate_prompt(cls, values: dict) -> dict: | |
prompt = values["llm_chain"].prompt | |
if prompt.output_parser is None: | |
raise ValueError( | |
"LLMRouterChain requires base llm_chain prompt to have an output" | |
" parser that converts LLM text output to a dictionary with keys" | |
" 'destination' and 'next_inputs'. Received a prompt with no output" | |
" parser." | |
) | |
return values | |
def input_keys(self) -> List[str]: | |
"""Will be whatever keys the LLM chain prompt expects. | |
:meta private: | |
""" | |
return self.llm_chain.input_keys | |
def _validate_outputs(self, outputs: Dict[str, Any]) -> None: | |
super()._validate_outputs(outputs) | |
if not isinstance(outputs["next_inputs"], dict): | |
raise ValueError | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
callbacks = _run_manager.get_child() | |
output = cast( | |
Dict[str, Any], | |
self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs), | |
) | |
return output | |
async def _acall( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
callbacks = _run_manager.get_child() | |
output = cast( | |
Dict[str, Any], | |
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), | |
) | |
return output | |
def from_llm( | |
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any | |
) -> LLMRouterChain: | |
"""Convenience constructor.""" | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
return cls(llm_chain=llm_chain, **kwargs) | |
class RouterOutputParser(BaseOutputParser[Dict[str, str]]): | |
"""Parser for output of router chain in the multi-prompt chain.""" | |
default_destination: str = "DEFAULT" | |
next_inputs_type: Type = str | |
next_inputs_inner_key: str = "input" | |
def parse(self, text: str) -> Dict[str, Any]: | |
try: | |
expected_keys = ["destination", "next_inputs"] | |
parsed = parse_and_check_json_markdown(text, expected_keys) | |
if not isinstance(parsed["destination"], str): | |
raise ValueError("Expected 'destination' to be a string.") | |
if not isinstance(parsed["next_inputs"], self.next_inputs_type): | |
raise ValueError( | |
f"Expected 'next_inputs' to be {self.next_inputs_type}." | |
) | |
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]} | |
if ( | |
parsed["destination"].strip().lower() | |
== self.default_destination.lower() | |
): | |
parsed["destination"] = None | |
else: | |
parsed["destination"] = parsed["destination"].strip() | |
return parsed | |
except Exception as e: | |
raise OutputParserException( | |
f"Parsing text\n{text}\n raised following error:\n{e}" | |
) | |