Spaces:
Running
Running
import json | |
from enum import IntEnum | |
# import re | |
from typing import Any, Callable, List, Optional | |
from lagent.prompts.parsers import StrParser | |
from lagent.utils import create_object, load_class_from_string | |
def default_plugin_validate(plugin: str): | |
plugin = plugin.strip() | |
if not (plugin.startswith('{') and plugin.endswith("}")): | |
raise json.decoder.JSONDecodeError | |
return json.loads(plugin) | |
class ToolStatusCode(IntEnum): | |
NO_TOOL = 0 | |
VALID_TOOL = 1 | |
PARSING_ERROR = -1 | |
class ToolParser(StrParser): | |
def __init__(self, | |
tool_type: str, | |
template: str = '', | |
begin: str = '<tool>\n', | |
end: str = '</tool>\n', | |
validate: Callable[[str], Any] = None, | |
**kwargs): | |
super().__init__(template, begin=begin, end=end, **kwargs) | |
self.template = template | |
self.tool_type = tool_type | |
# self.pattern = re.compile( | |
# '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), | |
# re.DOTALL) | |
self.validate = load_class_from_string(validate) if isinstance( | |
validate, str) else validate | |
def parse_response(self, data: str) -> dict: | |
if self.format_field['begin'] not in data: | |
return dict( | |
tool_type=None, | |
thought=data, | |
action=None, | |
status=ToolStatusCode.NO_TOOL) | |
thought, action, *_ = data.split(self.format_field["begin"]) | |
action = action.split(self.format_field['end'])[0] | |
status = ToolStatusCode.VALID_TOOL | |
if self.validate: | |
try: | |
action = self.validate(action) | |
except Exception: | |
status = ToolStatusCode.PARSING_ERROR | |
return dict( | |
tool_type=self.tool_type, | |
thought=thought, | |
action=action, | |
status=status) | |
def format_response(self, parsed: dict) -> str: | |
if parsed['action'] is None: | |
return parsed['thought'] | |
assert parsed['tool_type'] == self.tool_type | |
if isinstance(parsed['action'], dict): | |
action = json.dumps(parsed['action'], ensure_ascii=False) | |
else: | |
action = str(parsed['action']) | |
return parsed['thought'] + self.format_field[ | |
'begin'] + action + self.format_field['end'] | |
class InterpreterParser(ToolParser): | |
def __init__(self, | |
tool_type: str = 'interpreter', | |
template: str = '', | |
begin: str = '<|action_start|><|interpreter|>\n', | |
end: str = '<|action_end|>\n', | |
validate: Callable[[str], Any] = None, | |
**kwargs): | |
super().__init__(tool_type, template, begin, end, validate, **kwargs) | |
class PluginParser(ToolParser): | |
def __init__(self, | |
tool_type: str = 'plugin', | |
template: str = '', | |
begin: str = '<|action_start|><|plugin|>\n', | |
end: str = '<|action_end|>\n', | |
validate: Callable[[str], Any] = default_plugin_validate, | |
**kwargs): | |
super().__init__(tool_type, template, begin, end, validate, **kwargs) | |
class MixedToolParser(StrParser): | |
def __init__(self, | |
tool_type: Optional[str] = None, | |
template='', | |
parsers: List[ToolParser] = None, | |
**format_field): | |
self.parsers = {} | |
self.tool_type = tool_type | |
for parser in parsers or []: | |
parser = create_object(parser) | |
self.parsers[parser.tool_type] = parser | |
super().__init__(template, **format_field) | |
def format_instruction(self) -> List[dict]: | |
inst = [] | |
content = super().format_instruction() | |
if content.strip(): | |
msg = dict(role='system', content=content) | |
if self.tool_type: | |
msg['name'] = self.tool_type | |
inst.append(msg) | |
for name, parser in self.parsers.items(): | |
content = parser.format_instruction() | |
if content.strip(): | |
inst.append(dict(role='system', content=content, name=name)) | |
return inst | |
def parse_response(self, data: str) -> dict: | |
res = dict( | |
tool_type=None, | |
thought=data, | |
action=None, | |
status=ToolStatusCode.NO_TOOL) | |
for name, parser in self.parsers.items(): | |
res = parser.parse_response(data) | |
if res['tool_type'] == name: | |
break | |
return res | |
def format_response(self, parsed: dict) -> str: | |
if parsed['action'] is None: | |
return parsed['thought'] | |
assert parsed['tool_type'] in self.parsers | |
return self.parsers[parsed['tool_type']].format_response(parsed) | |