|
|
import re |
|
|
import time |
|
|
from typing import Union, List, Dict |
|
|
|
|
|
from werkzeug.datastructures import FileStorage |
|
|
|
|
|
from .. import BaseAgent |
|
|
from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException |
|
|
from ...schemas import ( |
|
|
AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse, |
|
|
BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile |
|
|
) |
|
|
from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool |
|
|
from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \ |
|
|
OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \ |
|
|
TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN |
|
|
|
|
|
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
|
|
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
|
|
CODE_BLOCK_START_TAG = '```python' |
|
|
CODE_BLOCK_TAG = '```' |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
|
|
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
|
|
CODE_BLOCK_START_TAG = '```python' |
|
|
CODE_BLOCK_TAG = '```' |
|
|
STOP_WORD = ['Observation:'] |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class AsyncReactAgent(BaseAgent): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self._name = self._name or "AsyncReactAgent" |
|
|
self._type = AgentType.react |
|
|
self.__intermediate_steps: List[BaseAgentResponse] = [] |
|
|
|
|
|
@property |
|
|
def intermediate_steps(self): |
|
|
return self.__intermediate_steps |
|
|
|
|
|
def run(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]): |
|
|
sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME) |
|
|
if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)): |
|
|
raise InternalErrorException("SandBox client is not ready for agent, please check init logic.") |
|
|
return await sandbox_plugin.sync_to_sandbox(file) |
|
|
|
|
|
async def async_run(self, agent_req: AgentRequest): |
|
|
instruction = '\n'.join(message.content for message in agent_req.messages) |
|
|
async for response in self._chat(instruction, is_cn=agent_req.is_cn): |
|
|
yield response |
|
|
|
|
|
async def _chat(self, instruction: str, is_cn=False, max_iterations=10, |
|
|
max_single_step_iterations=3): |
|
|
current_iteration = 0 |
|
|
|
|
|
for _ in range(max_iterations): |
|
|
current_iteration += 1 |
|
|
llm_response = await self._single_round_thought(instruction, |
|
|
max_llm_iteration=max_single_step_iterations, |
|
|
is_cn=is_cn) |
|
|
logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n" |
|
|
.format(current_iteration, max_iterations, llm_response.raw_output, |
|
|
llm_response.formatted_output)) |
|
|
yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output) |
|
|
|
|
|
if isinstance(llm_response, AgentFinish): |
|
|
logger.info("Find final answer, stop iteration.") |
|
|
break |
|
|
|
|
|
self.intermediate_steps.append(llm_response) |
|
|
action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration, |
|
|
max_iterations, is_cn) |
|
|
logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n" |
|
|
.format(current_iteration, max_iterations, action_response.raw_output, |
|
|
action_response.formatted_output)) |
|
|
self.intermediate_steps.append(action_response) |
|
|
|
|
|
yield self.create_agent_response(action_response.formatted_output, |
|
|
cur_output_files, |
|
|
action_response.raw_output) |
|
|
|
|
|
logger.info(f"Finished iteration in {current_iteration}.") |
|
|
|
|
|
|
|
|
async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False): |
|
|
try: |
|
|
response.tool = 'python_code_sandbox' |
|
|
action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input) |
|
|
logger.info( |
|
|
f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n" |
|
|
f"{action_response.output_text}") |
|
|
|
|
|
if "STDERR" in action_response.output_text: |
|
|
formatted_output = self._process_sandbox_output(action_response.output_text) |
|
|
else: |
|
|
formatted_output = action_response.output_text |
|
|
|
|
|
formatted_output = replace_latex_format(formatted_output) |
|
|
observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN |
|
|
formatted_output = f"{observation_prefix}\n{formatted_output}\n" |
|
|
|
|
|
action_observation = AgentObservation(tool=response.tool, |
|
|
formatted_output=formatted_output, |
|
|
raw_output=action_response.output_text) |
|
|
cur_output_files = self._get_output_files(action_response) |
|
|
return action_observation, cur_output_files |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. " |
|
|
f"Error: {str(e)}", exc_info=True) |
|
|
|
|
|
raise SandboxException("Error occurred while running the tool") from e |
|
|
|
|
|
def _compose_prompt(self, instruction) -> str: |
|
|
""" |
|
|
Compose the prompt from template, worker description, examples and instruction. |
|
|
""" |
|
|
agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps) |
|
|
tool_description = self._get_plugin_description() |
|
|
tool_names = ", ".join(list(self.plugins_map.keys())) |
|
|
if self.prompt_template is None: |
|
|
raise InternalErrorException("Agent prompt is none, please check init process") |
|
|
|
|
|
return self.prompt_template.format( |
|
|
instruction=instruction, |
|
|
agent_scratchpad=agent_scratchpad, |
|
|
tool_description=tool_description, |
|
|
tool_names=tool_names |
|
|
) |
|
|
|
|
|
async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \ |
|
|
Union[AgentAction, AgentFinish]: |
|
|
|
|
|
llm_iteration_count = 0 |
|
|
|
|
|
llm_response = None |
|
|
while llm_iteration_count <= max_llm_iteration: |
|
|
llm_iteration_count += 1 |
|
|
try: |
|
|
llm_response = await self._get_llm_response(instruction) |
|
|
action_response = self._parse_output(llm_response.content, is_cn) |
|
|
|
|
|
return action_response |
|
|
except Exception as e: |
|
|
logger.error("LLM iteration {} out of {} failed. Error: {}". |
|
|
format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True) |
|
|
|
|
|
if llm_iteration_count > max_llm_iteration: |
|
|
logger.error("LLM iteration {} exceed max retry {}. Aborting". |
|
|
format(llm_iteration_count, max_llm_iteration)) |
|
|
return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN, |
|
|
raw_output=str(llm_response)) |
|
|
|
|
|
async def _get_llm_response(self, instruction: str): |
|
|
prompt = self._compose_prompt(instruction) |
|
|
logger.info("Send prompt to LLM:\n{}".format(prompt)) |
|
|
response = await self.llm.async_completion(prompt) |
|
|
if response.state == "error": |
|
|
raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content))) |
|
|
|
|
|
logger.info("Got response from llm, raw response content: \n{}".format(response.content)) |
|
|
return response |
|
|
|
|
|
def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]: |
|
|
|
|
|
for stop_word in STOP_WORD: |
|
|
if stop_word in llm_output: |
|
|
llm_output = llm_output.split(stop_word)[0].rstrip() |
|
|
break |
|
|
|
|
|
|
|
|
for indicator in FINAL_ANSWER_INDICATORS: |
|
|
if indicator in llm_output: |
|
|
|
|
|
parts = llm_output.split(indicator) |
|
|
|
|
|
formatted_output = ''.join(parts).strip() |
|
|
formatted_output = replace_latex_format(formatted_output) |
|
|
return AgentFinish(raw_output=llm_output, formatted_output=formatted_output) |
|
|
|
|
|
|
|
|
ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
|
|
ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
|
|
|
|
|
action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL) |
|
|
|
|
|
|
|
|
if action_match: |
|
|
context = action_match.group(1).strip() |
|
|
action_tool_description = action_match.group(2).strip() |
|
|
action_input = action_match.group(3).strip() |
|
|
|
|
|
|
|
|
|
|
|
format_code_block = self._format_code_block(action_input) |
|
|
|
|
|
prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN |
|
|
formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block) |
|
|
formatted_output = replace_latex_format(formatted_output) |
|
|
|
|
|
return AgentAction(tool=action_tool_description, |
|
|
tool_input=format_code_block, |
|
|
formatted_output=formatted_output, |
|
|
raw_output=llm_output) |
|
|
|
|
|
|
|
|
if not re.search(r"Action\s*:", llm_output, re.DOTALL): |
|
|
raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`") |
|
|
elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL): |
|
|
raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`") |
|
|
else: |
|
|
raise LLMException(f"Unrecognized LLM output format: `{llm_output}`") |
|
|
|
|
|
def _format_code_block(self, tool_input): |
|
|
stripped_tool_input = tool_input.strip() |
|
|
|
|
|
if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
|
|
if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'): |
|
|
stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \ |
|
|
'\n' |
|
|
formatted_code = stripped_tool_input |
|
|
elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith( |
|
|
CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
|
|
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n' |
|
|
else: |
|
|
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n' |
|
|
|
|
|
return formatted_code.encode("utf-8").decode("utf-8") |
|
|
|
|
|
def _process_sandbox_output(self, output: str): |
|
|
"""Function to process the result containing STDERR.""" |
|
|
if len(output) <= 1000: |
|
|
return output |
|
|
|
|
|
logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}". |
|
|
format(output)) |
|
|
rows = output.split("\n") |
|
|
|
|
|
top_segment = [] |
|
|
length = 0 |
|
|
for sub_p in rows: |
|
|
if length + len(sub_p) > 500: |
|
|
break |
|
|
top_segment.append(sub_p) |
|
|
length += len(sub_p) |
|
|
|
|
|
|
|
|
bottom_segment = [] |
|
|
length = 0 |
|
|
for sub_p in reversed(rows): |
|
|
if length + len(sub_p) > 500: |
|
|
break |
|
|
bottom_segment.insert(0, sub_p) |
|
|
length += len(sub_p) |
|
|
|
|
|
|
|
|
timed_output = "\n".join(top_segment + ["......"] + bottom_segment) |
|
|
|
|
|
return timed_output |
|
|
|
|
|
def _get_output_files(self, tool_response) -> list[MediaFile]: |
|
|
output_files = [] |
|
|
|
|
|
if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput): |
|
|
raw_output = tool_response.raw_output |
|
|
|
|
|
if raw_output.code == 0 and not raw_output.data.is_partial: |
|
|
result_data = raw_output.data.result |
|
|
|
|
|
|
|
|
if len(result_data.new_generated_files) > 0: |
|
|
output_files.extend([MediaFile(tos_path=file.download_link) for file in |
|
|
result_data.new_generated_files]) |
|
|
|
|
|
if len(result_data.code_output_result) > 0: |
|
|
output_files.extend( |
|
|
[MediaFile(tos_path=image.content) for image in result_data.code_output_result |
|
|
if image.type == 'image']) |
|
|
|
|
|
return output_files |
|
|
|
|
|
def _replace_csv_path(self, input_string): |
|
|
|
|
|
pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)' |
|
|
replacement = "pd.read_csv('/path/to/your/dataset')" |
|
|
updated_string = re.sub(pattern, replacement, input_string) |
|
|
return updated_string |
|
|
|
|
|
@staticmethod |
|
|
def create_agent_response(formatted_output, output_files, raw_output): |
|
|
return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output) |
|
|
|
|
|
|