Spaces:
Runtime error
Runtime error
File size: 2,372 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
"""Chain that runs an arbitrary python function."""
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain_core.pydantic_v1 import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
class TransformChain(Chain):
"""Chain that transforms the chain output.
Example:
.. code-block:: python
from langchain.chains import TransformChain
transform_chain = TransformChain(input_variables=["text"],
output_variables["entities"], transform=func())
"""
input_variables: List[str]
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform_cb: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = Field(None, alias="atransform")
"""The async coroutine transform function."""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""Log a message once.
:meta private:
"""
logger.warning(msg)
@property
def input_keys(self) -> List[str]:
"""Expect input keys.
:meta private:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Return output keys.
:meta private:
"""
return self.output_variables
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform_cb(inputs)
|