Spaces:
Runtime error
Runtime error
"""Chain that runs an arbitrary python function.""" | |
import functools | |
import logging | |
from typing import Any, Awaitable, Callable, Dict, List, Optional | |
from langchain_core.pydantic_v1 import Field | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.base import Chain | |
logger = logging.getLogger(__name__) | |
class TransformChain(Chain): | |
"""Chain that transforms the chain output. | |
Example: | |
.. code-block:: python | |
from langchain.chains import TransformChain | |
transform_chain = TransformChain(input_variables=["text"], | |
output_variables["entities"], transform=func()) | |
""" | |
input_variables: List[str] | |
"""The keys expected by the transform's input dictionary.""" | |
output_variables: List[str] | |
"""The keys returned by the transform's output dictionary.""" | |
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform") | |
"""The transform function.""" | |
atransform_cb: Optional[ | |
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]] | |
] = Field(None, alias="atransform") | |
"""The async coroutine transform function.""" | |
def _log_once(msg: str) -> None: | |
"""Log a message once. | |
:meta private: | |
""" | |
logger.warning(msg) | |
def input_keys(self) -> List[str]: | |
"""Expect input keys. | |
:meta private: | |
""" | |
return self.input_variables | |
def output_keys(self) -> List[str]: | |
"""Return output keys. | |
:meta private: | |
""" | |
return self.output_variables | |
def _call( | |
self, | |
inputs: Dict[str, str], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
return self.transform_cb(inputs) | |
async def _acall( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
if self.atransform_cb is not None: | |
return await self.atransform_cb(inputs) | |
else: | |
self._log_once( | |
"TransformChain's atransform is not provided, falling" | |
" back to synchronous transform" | |
) | |
return self.transform_cb(inputs) | |