Spaces:
Runtime error
Runtime error
| from operator import itemgetter | |
| from typing import Any, Callable, List, Mapping, Optional, Union | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.runnables import RouterRunnable, Runnable | |
| from langchain_core.runnables.base import RunnableBindingBase | |
| from typing_extensions import TypedDict | |
| from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser | |
| class OpenAIFunction(TypedDict): | |
| """A function description for ChatOpenAI""" | |
| name: str | |
| """The name of the function.""" | |
| description: str | |
| """The description of the function.""" | |
| parameters: dict | |
| """The parameters to the function.""" | |
| class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): | |
| """A runnable that routes to the selected function.""" | |
| functions: Optional[List[OpenAIFunction]] | |
| def __init__( | |
| self, | |
| runnables: Mapping[ | |
| str, | |
| Union[ | |
| Runnable[dict, Any], | |
| Callable[[dict], Any], | |
| ], | |
| ], | |
| functions: Optional[List[OpenAIFunction]] = None, | |
| ): | |
| if functions is not None: | |
| assert len(functions) == len(runnables) | |
| assert all(func["name"] in runnables for func in functions) | |
| router = ( | |
| JsonOutputFunctionsParser(args_only=False) | |
| | {"key": itemgetter("name"), "input": itemgetter("arguments")} | |
| | RouterRunnable(runnables) | |
| ) | |
| super().__init__(bound=router, kwargs={}, functions=functions) | |