File size: 4,574 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Base classes for chain routing."""
from __future__ import annotations

from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from langchain_core.pydantic_v1 import Extra

from langchain.callbacks.manager import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
    Callbacks,
)
from langchain.chains.base import Chain


class Route(NamedTuple):
    destination: Optional[str]
    next_inputs: Dict[str, Any]


class RouterChain(Chain, ABC):
    """Chain that outputs the name of a destination chain and the inputs to it."""

    @property
    def output_keys(self) -> List[str]:
        return ["destination", "next_inputs"]

    def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
        """
        Route inputs to a destination chain.

        Args:
            inputs: inputs to the chain
            callbacks: callbacks to use for the chain

        Returns:
            a Route object
        """
        result = self(inputs, callbacks=callbacks)
        return Route(result["destination"], result["next_inputs"])

    async def aroute(
        self, inputs: Dict[str, Any], callbacks: Callbacks = None
    ) -> Route:
        result = await self.acall(inputs, callbacks=callbacks)
        return Route(result["destination"], result["next_inputs"])


class MultiRouteChain(Chain):
    """Use a single chain to route an input to one of multiple candidate chains."""

    router_chain: RouterChain
    """Chain that routes inputs to destination chains."""
    destination_chains: Mapping[str, Chain]
    """Chains that return final answer to inputs."""
    default_chain: Chain
    """Default chain to use when none of the destination chains are suitable."""
    silent_errors: bool = False
    """If True, use default_chain when an invalid destination name is provided. 
    Defaults to False."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the router chain prompt expects.

        :meta private:
        """
        return self.router_chain.input_keys

    @property
    def output_keys(self) -> List[str]:
        """Will always return text key.

        :meta private:
        """
        return []

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        route = self.router_chain.route(inputs, callbacks=callbacks)

        _run_manager.on_text(
            str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
        )
        if not route.destination:
            return self.default_chain(route.next_inputs, callbacks=callbacks)
        elif route.destination in self.destination_chains:
            return self.destination_chains[route.destination](
                route.next_inputs, callbacks=callbacks
            )
        elif self.silent_errors:
            return self.default_chain(route.next_inputs, callbacks=callbacks)
        else:
            raise ValueError(
                f"Received invalid destination chain name '{route.destination}'"
            )

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        route = await self.router_chain.aroute(inputs, callbacks=callbacks)

        await _run_manager.on_text(
            str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
        )
        if not route.destination:
            return await self.default_chain.acall(
                route.next_inputs, callbacks=callbacks
            )
        elif route.destination in self.destination_chains:
            return await self.destination_chains[route.destination].acall(
                route.next_inputs, callbacks=callbacks
            )
        elif self.silent_errors:
            return await self.default_chain.acall(
                route.next_inputs, callbacks=callbacks
            )
        else:
            raise ValueError(
                f"Received invalid destination chain name '{route.destination}'"
            )