| from copy import deepcopy |
|
|
| from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall |
| from .hook import Hook |
|
|
|
|
| class ActionPreprocessor(Hook): |
| """The ActionPreprocessor is a hook that preprocesses the action message |
| and postprocesses the action return message. |
| |
| """ |
|
|
| def before_action(self, executor, message, session_id): |
| assert isinstance(message.formatted, FunctionCall) or ( |
| isinstance(message.formatted, dict) and 'name' in message.content |
| and 'parameters' in message.formatted) or ( |
| 'action' in message.formatted |
| and 'parameters' in message.formatted['action'] |
| and 'name' in message.formatted['action']) |
| if isinstance(message.formatted, dict): |
| name = message.formatted.get('name', |
| message.formatted['action']['name']) |
| parameters = message.formatted.get( |
| 'parameters', message.formatted['action']['parameters']) |
| else: |
| name = message.formatted.name |
| parameters = message.formatted.parameters |
| message.content = dict(name=name, parameters=parameters) |
| return message |
|
|
| def after_action(self, executor, message, session_id): |
| action_return = message.content |
| if isinstance(action_return, ActionReturn): |
| if action_return.state == ActionStatusCode.SUCCESS: |
| response = action_return.format_result() |
| else: |
| response = action_return.errmsg |
| else: |
| response = action_return |
| message.content = response |
| return message |
|
|
|
|
| class InternLMActionProcessor(ActionPreprocessor): |
|
|
| def __init__(self, code_parameter: str = 'command'): |
| self.code_parameter = code_parameter |
|
|
| def before_action(self, executor, message, session_id): |
| message = deepcopy(message) |
| assert isinstance(message.formatted, dict) and set( |
| message.formatted).issuperset( |
| {'tool_type', 'thought', 'action', 'status'}) |
| if isinstance(message.formatted['action'], str): |
| |
| action_name = next(iter(executor.actions)) |
| parameters = {self.code_parameter: message.formatted['action']} |
| if action_name in ['AsyncIPythonInterpreter']: |
| parameters['session_id'] = session_id |
| message.formatted['action'] = dict( |
| name=action_name, parameters=parameters) |
| return super().before_action(executor, message, session_id) |
|
|