|
|
|
|
|
import re |
|
|
from typing import TYPE_CHECKING, List, Tuple, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from .base import BaseAgentTemplate |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from swift.llm.infer import Function |
|
|
from swift.llm.template import Prompt |
|
|
|
|
|
|
|
|
class HermesAgentTemplate(BaseAgentTemplate): |
|
|
|
|
|
def get_toolcall(self, response: str) -> List['Function']: |
|
|
from swift.llm.infer import Function |
|
|
res_list = re.findall(r'<tool_call>(.+?)</tool_call>', response, re.DOTALL) |
|
|
functions = [] |
|
|
for res in res_list: |
|
|
res = self._parse_json(res) |
|
|
if isinstance(res, dict) and 'name' in res and 'arguments' in res: |
|
|
functions.append(Function(name=res['name'], arguments=res['arguments'])) |
|
|
if len(functions) == 0: |
|
|
|
|
|
return super().get_toolcall(response) |
|
|
return functions |
|
|
|
|
|
def _format_tool_responses( |
|
|
self, |
|
|
assistant_content: str, |
|
|
tool_messages, |
|
|
) -> Tuple[str, 'Prompt']: |
|
|
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content |
|
|
if with_action: |
|
|
return super()._format_tool_responses(assistant_content, tool_messages) |
|
|
if hasattr(self, 'template_meta'): |
|
|
prompt = self.template_meta.prompt |
|
|
chat_sep = self.template_meta.chat_sep |
|
|
else: |
|
|
prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'] |
|
|
chat_sep = ['<|im_end|>\n'] |
|
|
res = chat_sep.copy() |
|
|
res_tool = [] |
|
|
for tool_message in tool_messages: |
|
|
tool_content = tool_message['content'] |
|
|
res_tool.append(f'<tool_response>\n{tool_content}\n</tool_response>') |
|
|
total_tool = '\n'.join(res_tool) |
|
|
for context in prompt: |
|
|
if isinstance(context, str): |
|
|
context = context.replace('{{QUERY}}', total_tool) |
|
|
res.append(context) |
|
|
return assistant_content, res |
|
|
|
|
|
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: |
|
|
tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools] |
|
|
return f"""{system} |
|
|
|
|
|
# Tools |
|
|
|
|
|
You may call one or more functions to assist with the user query. |
|
|
|
|
|
You are provided with function signatures within <tools></tools> XML tags: |
|
|
<tools> |
|
|
""" + '\n'.join(tool_descs) + """ |
|
|
</tools> |
|
|
|
|
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|
|
<tool_call> |
|
|
{"name": <function-name>, "arguments": <args-json-object>} |
|
|
</tool_call>""" |
|
|
|
|
|
def _format_tool_calls(self, tool_call_messages): |
|
|
tool_calls = [] |
|
|
for message in tool_call_messages: |
|
|
tool_call = self._parse_tool_call(message['content']) |
|
|
tool_calls.append(f'<tool_call>\n{json.dumps(tool_call, ensure_ascii=False)}\n</tool_call>') |
|
|
return '\n'.join(tool_calls) |
|
|
|