Spaces:
Runtime error
Runtime error
File size: 4,574 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""Base classes for chain routing."""
from __future__ import annotations
from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from langchain_core.pydantic_v1 import Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
class RouterChain(Chain, ABC):
"""Chain that outputs the name of a destination chain and the inputs to it."""
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
"""
Route inputs to a destination chain.
Args:
inputs: inputs to the chain
callbacks: callbacks to use for the chain
Returns:
a Route object
"""
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Route:
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
class MultiRouteChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: RouterChain
"""Chain that routes inputs to destination chains."""
destination_chains: Mapping[str, Chain]
"""Chains that return final answer to inputs."""
default_chain: Chain
"""Default chain to use when none of the destination chains are suitable."""
silent_errors: bool = False
"""If True, use default_chain when an invalid destination name is provided.
Defaults to False."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return []
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = self.router_chain.route(inputs, callbacks=callbacks)
_run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
)
if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks)
elif route.destination in self.destination_chains:
return self.destination_chains[route.destination](
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
return self.default_chain(route.next_inputs, callbacks=callbacks)
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
await _run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
)
if not route.destination:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
elif route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall(
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)
|