InfiAgent / src /infiagent /agent /react /async_react_agent.py
g3eIL's picture
Upload 80 files
77320e4 verified
import re
import time
from typing import Union, List, Dict
from werkzeug.datastructures import FileStorage
from .. import BaseAgent
from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException
from ...schemas import (
AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse,
BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile
)
from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool
from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \
OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \
TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
CODE_BLOCK_START_TAG = '```python'
CODE_BLOCK_TAG = '```'
logger = get_logger()
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
CODE_BLOCK_START_TAG = '```python'
CODE_BLOCK_TAG = '```'
STOP_WORD = ['Observation:']
logger = get_logger()
class AsyncReactAgent(BaseAgent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._name = self._name or "AsyncReactAgent"
self._type = AgentType.react
self.__intermediate_steps: List[BaseAgentResponse] = []
@property
def intermediate_steps(self):
return self.__intermediate_steps
def run(self, *args, **kwargs):
pass
async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]):
sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME)
if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)):
raise InternalErrorException("SandBox client is not ready for agent, please check init logic.")
return await sandbox_plugin.sync_to_sandbox(file)
async def async_run(self, agent_req: AgentRequest):
instruction = '\n'.join(message.content for message in agent_req.messages)
async for response in self._chat(instruction, is_cn=agent_req.is_cn):
yield response
async def _chat(self, instruction: str, is_cn=False, max_iterations=10,
max_single_step_iterations=3):
current_iteration = 0
for _ in range(max_iterations):
current_iteration += 1
llm_response = await self._single_round_thought(instruction,
max_llm_iteration=max_single_step_iterations,
is_cn=is_cn)
logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n"
.format(current_iteration, max_iterations, llm_response.raw_output,
llm_response.formatted_output))
yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output)
if isinstance(llm_response, AgentFinish):
logger.info("Find final answer, stop iteration.")
break
self.intermediate_steps.append(llm_response)
action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration,
max_iterations, is_cn)
logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n"
.format(current_iteration, max_iterations, action_response.raw_output,
action_response.formatted_output))
self.intermediate_steps.append(action_response)
yield self.create_agent_response(action_response.formatted_output,
cur_output_files,
action_response.raw_output)
logger.info(f"Finished iteration in {current_iteration}.")
# TODO update logic to not be sandbox specific, sandbox related logic should be handled in sandbox client
async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False):
try:
response.tool = 'python_code_sandbox'
action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input)
logger.info(
f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n"
f"{action_response.output_text}")
if "STDERR" in action_response.output_text:
formatted_output = self._process_sandbox_output(action_response.output_text)
else:
formatted_output = action_response.output_text
formatted_output = replace_latex_format(formatted_output)
observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN
formatted_output = f"{observation_prefix}\n{formatted_output}\n"
action_observation = AgentObservation(tool=response.tool,
formatted_output=formatted_output,
raw_output=action_response.output_text)
cur_output_files = self._get_output_files(action_response)
return action_observation, cur_output_files
except Exception as e:
logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. "
f"Error: {str(e)}", exc_info=True)
# TODO: We hard code here as we only have one tool
raise SandboxException("Error occurred while running the tool") from e
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps)
tool_description = self._get_plugin_description()
tool_names = ", ".join(list(self.plugins_map.keys()))
if self.prompt_template is None:
raise InternalErrorException("Agent prompt is none, please check init process")
return self.prompt_template.format(
instruction=instruction,
agent_scratchpad=agent_scratchpad,
tool_description=tool_description,
tool_names=tool_names
)
async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \
Union[AgentAction, AgentFinish]:
llm_iteration_count = 0
llm_response = None
while llm_iteration_count <= max_llm_iteration:
llm_iteration_count += 1
try:
llm_response = await self._get_llm_response(instruction)
action_response = self._parse_output(llm_response.content, is_cn)
return action_response
except Exception as e:
logger.error("LLM iteration {} out of {} failed. Error: {}".
format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True)
if llm_iteration_count > max_llm_iteration:
logger.error("LLM iteration {} exceed max retry {}. Aborting".
format(llm_iteration_count, max_llm_iteration))
return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN,
raw_output=str(llm_response))
async def _get_llm_response(self, instruction: str):
prompt = self._compose_prompt(instruction)
logger.info("Send prompt to LLM:\n{}".format(prompt))
response = await self.llm.async_completion(prompt)
if response.state == "error":
raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content)))
logger.info("Got response from llm, raw response content: \n{}".format(response.content))
return response
def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]:
for stop_word in STOP_WORD:
if stop_word in llm_output:
llm_output = llm_output.split(stop_word)[0].rstrip()
break
# Check for Final Answer, if it is final, then just return
for indicator in FINAL_ANSWER_INDICATORS:
if indicator in llm_output:
# got final answer and remove the indicator
parts = llm_output.split(indicator)
# formatted_output = ''.join(parts[:-1]).strip()
formatted_output = ''.join(parts).strip()
formatted_output = replace_latex_format(formatted_output)
return AgentFinish(raw_output=llm_output, formatted_output=formatted_output)
# Updated regex pattern for capturing the expected input format
ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL)
# Find action, context, and action input, build action response
if action_match:
context = action_match.group(1).strip()
action_tool_description = action_match.group(2).strip()
action_input = action_match.group(3).strip()
# Format code
# TODO: currently we only have one plugin which is sandbox, update to support multiple tools
format_code_block = self._format_code_block(action_input)
prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN
formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block)
formatted_output = replace_latex_format(formatted_output)
return AgentAction(tool=action_tool_description,
tool_input=format_code_block,
formatted_output=formatted_output,
raw_output=llm_output)
# Not final answer and not action, raise exception
if not re.search(r"Action\s*:", llm_output, re.DOTALL):
raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`")
elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL):
raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`")
else:
raise LLMException(f"Unrecognized LLM output format: `{llm_output}`")
def _format_code_block(self, tool_input):
stripped_tool_input = tool_input.strip()
if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'):
stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \
'\n'
formatted_code = stripped_tool_input
elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith(
CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n'
else:
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n'
return formatted_code.encode("utf-8").decode("utf-8")
def _process_sandbox_output(self, output: str):
"""Function to process the result containing STDERR."""
if len(output) <= 1000:
return output
logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}".
format(output))
rows = output.split("\n")
# Get the first 500 characters, respecting line boundaries
top_segment = []
length = 0
for sub_p in rows:
if length + len(sub_p) > 500:
break
top_segment.append(sub_p)
length += len(sub_p)
# Get the last 500 characters, respecting line boundaries
bottom_segment = []
length = 0
for sub_p in reversed(rows):
if length + len(sub_p) > 500:
break
bottom_segment.insert(0, sub_p)
length += len(sub_p)
# Combine the segments with "......" in between
timed_output = "\n".join(top_segment + ["......"] + bottom_segment)
return timed_output
def _get_output_files(self, tool_response) -> list[MediaFile]:
output_files = []
if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput):
raw_output = tool_response.raw_output
if raw_output.code == 0 and not raw_output.data.is_partial:
result_data = raw_output.data.result
# TODO confirm if we still need output and format
if len(result_data.new_generated_files) > 0:
output_files.extend([MediaFile(tos_path=file.download_link) for file in
result_data.new_generated_files])
if len(result_data.code_output_result) > 0:
output_files.extend(
[MediaFile(tos_path=image.content) for image in result_data.code_output_result
if image.type == 'image'])
return output_files
def _replace_csv_path(self, input_string):
# Search for the pattern and replace it
pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)'
replacement = "pd.read_csv('/path/to/your/dataset')"
updated_string = re.sub(pattern, replacement, input_string)
return updated_string
@staticmethod
def create_agent_response(formatted_output, output_files, raw_output):
return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output)