Lagent / lagent /actions /base_action.py
Superkingjcj's picture
Upload 111 files
e679d69 verified
raw
history blame
14.6 kB
import inspect
import logging
import re
from abc import ABCMeta
from copy import deepcopy
from functools import wraps
from typing import Callable, Optional, Type, get_args, get_origin
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
from griffe import Docstring
try:
from griffe import DocstringSectionKind
except ImportError:
from griffe.enumerations import DocstringSectionKind
from ..schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser, ParseError
logging.getLogger('griffe').setLevel(logging.ERROR)
def tool_api(func: Optional[Callable] = None,
*,
explode_return: bool = False,
returns_named_value: bool = False,
**kwargs):
"""Turn functions into tools. It will parse typehints as well as docstrings
to build the tool description and attach it to functions via an attribute
``api_description``.
Examples:
.. code-block:: python
# typehints has higher priority than docstrings
from typing import Annotated
@tool_api
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
'''Add operation
Args:
x (int): a
y (int): b
'''
return a + b
print(add.api_description)
Args:
func (Optional[Callable]): function to decorate. Defaults to ``None``.
explode_return (bool): whether to flatten the dictionary or tuple return
as the ``return_data`` field. When enabled, it is recommended to
annotate the member in docstrings. Defaults to ``False``.
.. code-block:: python
@tool_api(explode_return=True)
def foo(a, b):
'''A simple function
Args:
a (int): a
b (int): b
Returns:
dict: information of inputs
* x: value of a
* y: value of b
'''
return {'x': a, 'y': b}
print(foo.api_description)
returns_named_value (bool): whether to parse ``thing: Description`` in
returns sections as a name and description, rather than a type and
description. When true, type must be wrapped in parentheses:
``(int): Description``. When false, parentheses are optional but
the items cannot be named: ``int: Description``. Defaults to ``False``.
Returns:
Callable: wrapped function or partial decorator
Important:
``return_data`` field will be added to ``api_description`` only
when ``explode_return`` or ``returns_named_value`` is enabled.
"""
def _detect_type(string):
field_type = 'STRING'
if 'list' in string:
field_type = 'Array'
elif 'str' not in string:
if 'float' in string:
field_type = 'FLOAT'
elif 'int' in string:
field_type = 'NUMBER'
elif 'bool' in string:
field_type = 'BOOLEAN'
return field_type
def _explode(desc):
kvs = []
desc = '\nArgs:\n' + '\n'.join([
' ' + item.lstrip(' -+*#.')
for item in desc.split('\n')[1:] if item.strip()
])
docs = Docstring(desc).parse('google')
if not docs:
return kvs
if docs[0].kind is DocstringSectionKind.parameters:
for d in docs[0].value:
d = d.as_dict()
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
kvs.append(d)
return kvs
def _parse_tool(function):
# remove rst syntax
docs = Docstring(
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
'google', returns_named_value=returns_named_value, **kwargs)
desc = dict(
name=function.__name__,
description=docs[0].value
if docs[0].kind is DocstringSectionKind.text else '',
parameters=[],
required=[],
)
args_doc, returns_doc = {}, []
for doc in docs:
if doc.kind is DocstringSectionKind.parameters:
for d in doc.value:
d = d.as_dict()
d['type'] = _detect_type(d.pop('annotation').lower())
args_doc[d['name']] = d
if doc.kind is DocstringSectionKind.returns:
for d in doc.value:
d = d.as_dict()
if not d['name']:
d.pop('name')
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
returns_doc.append(d)
sig = inspect.signature(function)
for name, param in sig.parameters.items():
if name == 'self':
continue
parameter = dict(
name=param.name,
type='STRING',
description=args_doc.get(param.name,
{}).get('description', ''))
annotation = param.annotation
if annotation is inspect.Signature.empty:
parameter['type'] = args_doc.get(param.name,
{}).get('type', 'STRING')
else:
if get_origin(annotation) is Annotated:
annotation, info = get_args(annotation)
if info:
parameter['description'] = info
while get_origin(annotation):
annotation = get_args(annotation)
parameter['type'] = _detect_type(str(annotation))
desc['parameters'].append(parameter)
if param.default is inspect.Signature.empty:
desc['required'].append(param.name)
return_data = []
if explode_return:
return_data = _explode(returns_doc[0]['description'])
elif returns_named_value:
return_data = returns_doc
if return_data:
desc['return_data'] = return_data
return desc
if callable(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
return await func(self, *args, **kwargs)
else:
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
def decorate(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
return await func(self, *args, **kwargs)
else:
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
return decorate
class ToolMeta(ABCMeta):
"""Metaclass of tools."""
def __new__(mcs, name, base, attrs):
is_toolkit, tool_desc = True, dict(
name=name,
description=Docstring(attrs.get('__doc__',
'')).parse('google')[0].value)
for key, value in attrs.items():
if callable(value) and hasattr(value, 'api_description'):
api_desc = getattr(value, 'api_description')
if key == 'run':
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
is_toolkit = False
else:
tool_desc.setdefault('api_list', []).append(api_desc)
if not is_toolkit and 'api_list' in tool_desc:
raise KeyError('`run` and other tool APIs can not be implemented '
'at the same time')
if is_toolkit and 'api_list' not in tool_desc:
is_toolkit = False
if callable(attrs.get('run')):
run_api = tool_api(attrs['run'])
api_desc = run_api.api_description
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
attrs['run'] = run_api
else:
tool_desc['parameters'], tool_desc['required'] = [], []
attrs['_is_toolkit'] = is_toolkit
attrs['__tool_description__'] = tool_desc
return super().__new__(mcs, name, base, attrs)
class BaseAction(metaclass=ToolMeta):
"""Base class for all actions.
Args:
description (:class:`Optional[dict]`): The description of the action.
Defaults to ``None``.
parser (:class:`Type[BaseParser]`): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
Examples:
* simple tool
.. code-block:: python
class Bold(BaseAction):
'''Make text bold'''
def run(self, text: str):
'''
Args:
text (str): input text
Returns:
str: bold text
'''
return '**' + text + '**'
action = Bold()
* toolkit with multiple APIs
.. code-block:: python
class Calculator(BaseAction):
'''Calculator'''
@tool_api
def add(self, a, b):
'''Add operation
Args:
a (int): augend
b (int): addend
Returns:
int: sum
'''
return a + b
@tool_api
def sub(self, a, b):
'''Subtraction operation
Args:
a (int): minuend
b (int): subtrahend
Returns:
int: difference
'''
return a - b
action = Calculator()
"""
def __init__(
self,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
):
self._description = deepcopy(description or self.__tool_description__)
self._name = self._description['name']
self._parser = parser(self)
def __call__(self, inputs: str, name='run') -> ActionReturn:
fallback_args = {'inputs': inputs, 'name': name}
if not hasattr(self, name):
return ActionReturn(
fallback_args,
type=self.name,
errmsg=f'invalid API: {name}',
state=ActionStatusCode.API_ERROR)
try:
inputs = self._parser.parse_inputs(inputs, name)
except ParseError as exc:
return ActionReturn(
fallback_args,
type=self.name,
errmsg=exc.err_msg,
state=ActionStatusCode.ARGS_ERROR)
try:
outputs = getattr(self, name)(**inputs)
except Exception as exc:
return ActionReturn(
inputs,
type=self.name,
errmsg=str(exc),
state=ActionStatusCode.API_ERROR)
if isinstance(outputs, ActionReturn):
action_return = outputs
if not action_return.args:
action_return.args = inputs
if not action_return.type:
action_return.type = self.name
else:
result = self._parser.parse_outputs(outputs)
action_return = ActionReturn(inputs, type=self.name, result=result)
return action_return
@property
def name(self):
return self._name
@property
def is_toolkit(self):
return self._is_toolkit
@property
def description(self) -> dict:
"""Description of the tool."""
return self._description
def __repr__(self):
return f'{self.description}'
__str__ = __repr__
class AsyncActionMixin:
async def __call__(self, inputs: str, name='run') -> ActionReturn:
fallback_args = {'inputs': inputs, 'name': name}
if not hasattr(self, name):
return ActionReturn(
fallback_args,
type=self.name,
errmsg=f'invalid API: {name}',
state=ActionStatusCode.API_ERROR)
try:
inputs = self._parser.parse_inputs(inputs, name)
except ParseError as exc:
return ActionReturn(
fallback_args,
type=self.name,
errmsg=exc.err_msg,
state=ActionStatusCode.ARGS_ERROR)
try:
outputs = await getattr(self, name)(**inputs)
except Exception as exc:
return ActionReturn(
inputs,
type=self.name,
errmsg=str(exc),
state=ActionStatusCode.API_ERROR)
if isinstance(outputs, ActionReturn):
action_return = outputs
if not action_return.args:
action_return.args = inputs
if not action_return.type:
action_return.type = self.name
else:
result = self._parser.parse_outputs(outputs)
action_return = ActionReturn(inputs, type=self.name, result=result)
return action_return