|
"""Base classes for chain routing.""" |
|
|
|
from __future__ import annotations |
|
|
|
from abc import ABC |
|
from typing import Any, Dict, List, Mapping, NamedTuple, Optional |
|
|
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForChainRun, |
|
CallbackManagerForChainRun, |
|
Callbacks, |
|
) |
|
from pydantic import ConfigDict |
|
|
|
from langchain.chains.base import Chain |
|
|
|
|
|
class Route(NamedTuple): |
|
destination: Optional[str] |
|
next_inputs: Dict[str, Any] |
|
|
|
|
|
class RouterChain(Chain, ABC): |
|
"""Chain that outputs the name of a destination chain and the inputs to it.""" |
|
|
|
@property |
|
def output_keys(self) -> List[str]: |
|
return ["destination", "next_inputs"] |
|
|
|
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route: |
|
""" |
|
Route inputs to a destination chain. |
|
|
|
Args: |
|
inputs: inputs to the chain |
|
callbacks: callbacks to use for the chain |
|
|
|
Returns: |
|
a Route object |
|
""" |
|
result = self(inputs, callbacks=callbacks) |
|
return Route(result["destination"], result["next_inputs"]) |
|
|
|
async def aroute( |
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None |
|
) -> Route: |
|
result = await self.acall(inputs, callbacks=callbacks) |
|
return Route(result["destination"], result["next_inputs"]) |
|
|
|
|
|
class MultiRouteChain(Chain): |
|
"""Use a single chain to route an input to one of multiple candidate chains.""" |
|
|
|
router_chain: RouterChain |
|
"""Chain that routes inputs to destination chains.""" |
|
destination_chains: Mapping[str, Chain] |
|
"""Chains that return final answer to inputs.""" |
|
default_chain: Chain |
|
"""Default chain to use when none of the destination chains are suitable.""" |
|
silent_errors: bool = False |
|
"""If True, use default_chain when an invalid destination name is provided. |
|
Defaults to False.""" |
|
|
|
model_config = ConfigDict( |
|
arbitrary_types_allowed=True, |
|
extra="forbid", |
|
) |
|
|
|
@property |
|
def input_keys(self) -> List[str]: |
|
"""Will be whatever keys the router chain prompt expects. |
|
|
|
:meta private: |
|
""" |
|
return self.router_chain.input_keys |
|
|
|
@property |
|
def output_keys(self) -> List[str]: |
|
"""Will always return text key. |
|
|
|
:meta private: |
|
""" |
|
return [] |
|
|
|
def _call( |
|
self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
) -> Dict[str, Any]: |
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
|
callbacks = _run_manager.get_child() |
|
route = self.router_chain.route(inputs, callbacks=callbacks) |
|
|
|
_run_manager.on_text( |
|
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose |
|
) |
|
if not route.destination: |
|
return self.default_chain(route.next_inputs, callbacks=callbacks) |
|
elif route.destination in self.destination_chains: |
|
return self.destination_chains[route.destination]( |
|
route.next_inputs, callbacks=callbacks |
|
) |
|
elif self.silent_errors: |
|
return self.default_chain(route.next_inputs, callbacks=callbacks) |
|
else: |
|
raise ValueError( |
|
f"Received invalid destination chain name '{route.destination}'" |
|
) |
|
|
|
async def _acall( |
|
self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
|
) -> Dict[str, Any]: |
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() |
|
callbacks = _run_manager.get_child() |
|
route = await self.router_chain.aroute(inputs, callbacks=callbacks) |
|
|
|
await _run_manager.on_text( |
|
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose |
|
) |
|
if not route.destination: |
|
return await self.default_chain.acall( |
|
route.next_inputs, callbacks=callbacks |
|
) |
|
elif route.destination in self.destination_chains: |
|
return await self.destination_chains[route.destination].acall( |
|
route.next_inputs, callbacks=callbacks |
|
) |
|
elif self.silent_errors: |
|
return await self.default_chain.acall( |
|
route.next_inputs, callbacks=callbacks |
|
) |
|
else: |
|
raise ValueError( |
|
f"Received invalid destination chain name '{route.destination}'" |
|
) |
|
|