|
|
|
|
|
|
|
import re |
|
import time |
|
import traceback |
|
import json |
|
from typing import Dict, Any, Optional, List, Union, Tuple |
|
from dataclasses import dataclass, field |
|
|
|
from langchain_core.messages import HumanMessage, BaseMessage, AIMessage, ToolMessage |
|
from pydantic import ValidationError |
|
|
|
from aworld.core.agent.base import AgentFactory, AgentResult |
|
from aworld.agents.llm_agent import Agent |
|
from examples.browsers.prompts import SystemPrompt |
|
from examples.browsers.utils import convert_input_messages, extract_json_from_model_output, estimate_messages_tokens |
|
from examples.browsers.common import AgentState, AgentStepInfo, AgentHistory, PolicyMetadata, AgentBrain |
|
from aworld.config.conf import AgentConfig, ConfigDict |
|
from aworld.core.common import Observation, ActionModel, ToolActionInfo, ActionResult |
|
from aworld.logs.util import logger |
|
from examples.browsers.prompts import AgentMessagePrompt |
|
from examples.tools.tool_action import BrowserAction |
|
|
|
|
|
@dataclass |
|
class Trajectory: |
|
"""A class to store agent history records, including all observations, info and AgentResult""" |
|
history: List[tuple[List[BaseMessage], Observation, Dict[str, Any], AIMessage, AgentResult]] = field( |
|
default_factory=list) |
|
|
|
def add_step(self, input_messages: List[BaseMessage], observation: Observation, info: Dict[str, Any], |
|
output_message: AIMessage, agent_result: AgentResult): |
|
"""Add a step to the history""" |
|
self.history.append((input_messages, observation, info, output_message, agent_result)) |
|
|
|
def get_history(self) -> List[tuple[List[BaseMessage], Observation, Dict[str, Any], AIMessage, AgentResult]]: |
|
"""Get the complete history""" |
|
return self.history |
|
|
|
def save_history(self, file_path: str): |
|
his_li = [] |
|
for input_messages, observation, info, output_message, agent_result in self.get_history(): |
|
llm_input = [{"type": input_message.type, "content": input_message.content} for input_message in |
|
input_messages] |
|
llm_output = output_message.content |
|
his_li.append({"llm_input": llm_input, "llm_output": llm_output}) |
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
json.dump(his_li, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
@AgentFactory.register(name='browser_agent', desc="browser agent") |
|
class BrowserAgent(Agent): |
|
def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs): |
|
super(BrowserAgent, self).__init__(conf, **kwargs) |
|
self.state = AgentState() |
|
self.settings = self.conf |
|
provider = self.conf.llm_config.llm_provider if self.conf.llm_config.llm_provider else self.conf.llm_provider |
|
if self.conf.llm_config.llm_provider: |
|
self.conf.llm_config.llm_provider = "chat" + provider |
|
else: |
|
self.conf.llm_provider = "chat" + provider |
|
|
|
self.save_file_path = self.conf.save_file_path |
|
self.available_actions = self._build_action_prompt() |
|
|
|
|
|
self.trajectory = Trajectory() |
|
self._init = False |
|
|
|
def reset(self, options: Dict[str, Any]): |
|
super(BrowserAgent, self).reset(options) |
|
|
|
|
|
self.trajectory = Trajectory() |
|
|
|
|
|
|
|
|
|
self._init = True |
|
|
|
def _build_action_prompt(self) -> str: |
|
def _prompt(info: ToolActionInfo) -> str: |
|
s = f'{info.desc}: \n' |
|
s += '{' + str(info.name) + ': ' |
|
if info.input_params: |
|
s += str({k: {"title": k, "type": v.type} for k, v in info.input_params.items()}) |
|
s += '}' |
|
return s |
|
|
|
val = "\n".join([_prompt(v.value) for k, v in BrowserAction.__members__.items()]) |
|
return val |
|
|
|
def _log_message_sequence(self, input_messages: List[BaseMessage]) -> None: |
|
"""Log the sequence of messages for debugging purposes""" |
|
logger.info(f"[agent] 🔍 Invoking LLM with {len(input_messages)} messages") |
|
logger.info("[agent] 📝 Messages sequence:") |
|
for i, msg in enumerate(input_messages): |
|
prefix = msg.type |
|
logger.info(f"[agent] Message {i + 1}: {prefix} ===================================") |
|
if isinstance(msg.content, list): |
|
for item in msg.content: |
|
if item.get('type') == 'text': |
|
logger.info(f"[agent] Text content: {item.get('text')}") |
|
elif item.get('type') == 'image_url': |
|
|
|
image_url = item.get('image_url', {}).get('url', '') |
|
if image_url.startswith('data:image'): |
|
logger.info(f"[agent] Image: [Base64 image data]") |
|
else: |
|
logger.info(f"[agent] Image URL: {image_url[:30]}...") |
|
else: |
|
content = str(msg.content) |
|
chunk_size = 500 |
|
for j in range(0, len(content), chunk_size): |
|
chunk = content[j:j + chunk_size] |
|
if j == 0: |
|
logger.info(f"[agent] Content: {chunk}") |
|
else: |
|
logger.info(f"[agent] Content (continued): {chunk}") |
|
|
|
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls: |
|
for tool_call in msg.tool_calls: |
|
logger.info(f"[agent] Tool call: {tool_call.get('name')} - ID: {tool_call.get('id')}") |
|
args = str(tool_call.get('args', {}))[:1000] |
|
logger.info(f"[agent] Tool args: {args}...") |
|
|
|
def save_process(self, file_path: str): |
|
self.trajectory.save_history(file_path) |
|
|
|
def policy(self, |
|
observation: Observation, |
|
info: Dict[str, Any] = None, **kwargs) -> Union[List[ActionModel], None]: |
|
start_time = time.time() |
|
|
|
if self._init is False: |
|
self.reset({"task": observation.content}) |
|
|
|
self._finished = False |
|
|
|
self.state.last_result = observation.action_result |
|
|
|
if self.conf.max_steps <= self.state.n_steps: |
|
logger.info('Last step finishing up') |
|
|
|
logger.info(f'[agent] step {self.state.n_steps}') |
|
|
|
|
|
input_messages = self.build_messages_from_trajectory_and_observation(observation=observation) |
|
|
|
|
|
|
|
|
|
tokens = self._estimate_tokens_for_messages(input_messages) |
|
|
|
llm_result = None |
|
output_message = None |
|
try: |
|
|
|
self._log_message_sequence(input_messages) |
|
|
|
output_message, llm_result = self._do_policy(input_messages) |
|
|
|
if not llm_result: |
|
logger.error("[agent] ❌ Failed to parse LLM response") |
|
return [ActionModel(tool_name=Tools.BROWSER.value, action_name="stop")] |
|
|
|
self.state.n_steps += 1 |
|
|
|
|
|
|
|
|
|
if self.state.stopped or self.state.paused: |
|
logger.info('Browser gent paused after getting state') |
|
return [ActionModel(tool_name=Tools.BROWSER.value, action_name="stop")] |
|
|
|
tool_action = llm_result.actions |
|
|
|
|
|
self.trajectory.add_step(input_messages, observation, info, output_message, llm_result) |
|
|
|
except Exception as e: |
|
logger.warning(traceback.format_exc()) |
|
|
|
|
|
logger.error(f"[agent] ❌ Error parsing LLM response: {str(e)}") |
|
|
|
|
|
error_result = AgentResult( |
|
current_state=AgentBrain( |
|
evaluation_previous_goal="Failed due to error", |
|
memory=f"Error occurred: {str(e)}", |
|
thought="Recover from error", |
|
next_goal="Recover from error" |
|
), |
|
actions=[] |
|
) |
|
|
|
|
|
self.trajectory.add_step(input_messages, observation, info, output_message, error_result) |
|
|
|
raise RuntimeError("Browser agent encountered exception while making the policy.", e) |
|
finally: |
|
if llm_result: |
|
|
|
metadata = PolicyMetadata( |
|
number=self.state.n_steps, |
|
start_time=start_time, |
|
end_time=time.time(), |
|
input_tokens=tokens, |
|
) |
|
self._make_history_item(llm_result, observation, observation.action_result, metadata) |
|
else: |
|
logger.warning("no result to record!") |
|
|
|
return tool_action |
|
|
|
def _do_policy(self, input_messages: list[BaseMessage]) -> Tuple[AIMessage, AgentResult]: |
|
THINK_TAGS = re.compile(r'<think>.*?</think>', re.DOTALL) |
|
|
|
def _remove_think_tags(text: str) -> str: |
|
"""Remove think tags from text""" |
|
return re.sub(THINK_TAGS, '', text) |
|
|
|
input_messages = self._convert_input_messages(input_messages) |
|
output_message = None |
|
try: |
|
|
|
output_message = self.llm.invoke(input_messages) |
|
|
|
if not output_message or not output_message.content: |
|
logger.warning("[agent] LLM returned empty response") |
|
return output_message, AgentResult( |
|
current_state=AgentBrain(evaluation_previous_goal="", memory="", thought="", next_goal=""), |
|
actions=[ActionModel(agent_name=self.id(), tool_name='browser', action_name="stop")]) |
|
except: |
|
logger.error(f"[agent] Response content: {output_message}") |
|
raise RuntimeError('call llm fail, please check llm conf and network.') |
|
|
|
if self.model_name == 'deepseek-reasoner': |
|
output_message.content = _remove_think_tags(output_message.content) |
|
try: |
|
|
|
max_retries = self.settings.get('max_llm_json_retries', 3) |
|
retry_count = 0 |
|
json_parse_error = None |
|
|
|
while retry_count < max_retries: |
|
try: |
|
parsed_json = extract_json_from_model_output(output_message.content) |
|
|
|
json_parse_error = None |
|
break |
|
except ValueError as e: |
|
|
|
json_parse_error = e |
|
retry_count += 1 |
|
logger.warning(f"[agent] Failed to parse JSON (attempt {retry_count}/{max_retries}): {str(e)}") |
|
|
|
if retry_count < max_retries: |
|
|
|
format_reminder = HumanMessage( |
|
content="Your responses must be always JSON with the specified format. Make sure your response includes a 'current_state' object with 'evaluation_previous_goal', 'memory', and 'next_goal' fields, and an 'action' array with the actions to perform. Do not include any explanatory text, only return the raw JSON.") |
|
retry_messages = input_messages.copy() |
|
retry_messages.append(format_reminder) |
|
|
|
|
|
logger.info( |
|
f"[agent] Retrying LLM invocation ({retry_count}/{max_retries}) with format reminder") |
|
output_message = self.llm.invoke(retry_messages) |
|
|
|
|
|
if not output_message or not output_message.content: |
|
logger.warning( |
|
f"[agent] LLM returned empty response on retry attempt {retry_count}/{max_retries}") |
|
|
|
continue |
|
|
|
if self.model_name == 'deepseek-reasoner': |
|
output_message.content = _remove_think_tags(output_message.content) |
|
|
|
|
|
if json_parse_error: |
|
logger.error(f"[agent] ❌ All {max_retries} attempts to parse JSON failed") |
|
raise json_parse_error |
|
|
|
logger.info((f"llm response: {parsed_json}")) |
|
try: |
|
agent_brain = AgentBrain(**parsed_json['current_state']) |
|
except: |
|
agent_brain = None |
|
actions = parsed_json.get('action') |
|
result = [] |
|
if not actions: |
|
actions = parsed_json.get("actions") |
|
if not actions: |
|
logger.warning("agent not policy an action.") |
|
self._finished = True |
|
return output_message, AgentResult(current_state=agent_brain, |
|
actions=[ActionModel(tool_name='browser', |
|
agent_name=self.id(), |
|
action_name="done")]) |
|
|
|
for action in actions: |
|
if "action_name" in action: |
|
action_name = action['action_name'] |
|
browser_action = BrowserAction.get_value_by_name(action_name) |
|
if not browser_action: |
|
logger.warning(f"Unsupported action: {action_name}") |
|
if action_name == "done": |
|
self._finished = True |
|
action_model = ActionModel(agent_name=self.id(), |
|
tool_name='browser', |
|
action_name=action_name, |
|
params=action.get('params', {})) |
|
result.append(action_model) |
|
else: |
|
for k, v in action.items(): |
|
browser_action = BrowserAction.get_value_by_name(k) |
|
if not browser_action: |
|
logger.warning(f"Unsupported action: {k}") |
|
|
|
action_model = ActionModel(agent_name=self.id(), tool_name='browser', action_name=k, params=v) |
|
result.append(action_model) |
|
if k == "done": |
|
self._finished = True |
|
return output_message, AgentResult(current_state=agent_brain, actions=result) |
|
except (ValueError, ValidationError) as e: |
|
logger.warning(f'Failed to parse model output: {output_message} {str(e)}') |
|
raise ValueError('Could not parse response.') |
|
|
|
def _convert_input_messages(self, input_messages: list[BaseMessage]) -> list[BaseMessage]: |
|
"""Convert input messages to the correct format""" |
|
if self.model_name == 'deepseek-reasoner' or self.model_name.startswith('deepseek-r1'): |
|
return convert_input_messages(input_messages, self.model_name) |
|
else: |
|
return input_messages |
|
|
|
def _make_history_item(self, |
|
model_output: AgentResult | None, |
|
state: Observation, |
|
result: list[ActionResult], |
|
metadata: Optional[PolicyMetadata] = None) -> None: |
|
content = "" |
|
if hasattr(state, 'dom_tree') and state.dom_tree is not None: |
|
if hasattr(state.dom_tree, 'element_tree'): |
|
content = state.dom_tree.element_tree.__repr__() |
|
else: |
|
content = str(state.dom_tree) |
|
|
|
history_item = AgentHistory(model_output=model_output, |
|
result=state.action_result, |
|
metadata=metadata, |
|
content=content, |
|
base64_img=state.image if hasattr(state, 'image') else None) |
|
|
|
self.state.history.history.append(history_item) |
|
|
|
def _process_action_result(self, action_result, messages, tool_call=None): |
|
"""Helper method to process an action result and add appropriate messages""" |
|
if action_result.content is not None: |
|
messages.append(HumanMessage(content='Action result: ' + action_result.content)) |
|
elif action_result.error is not None: |
|
|
|
messages.append(HumanMessage(content='Action result: ' + action_result.error)) |
|
if tool_call is not None: |
|
logger.warning(f"Action {tool_call} failed: {action_result.error}") |
|
else: |
|
logger.warning(f"Action failed: {action_result.error}") |
|
|
|
if action_result.success is True: |
|
error_msg = f"Invalid result: success=True but error message exists: {action_result.error}" |
|
logger.error(error_msg) |
|
raise ValueError(error_msg) |
|
return action_result.error is not None |
|
|
|
def build_messages_from_trajectory_and_observation(self, observation: Optional[Observation] = None) -> List[ |
|
BaseMessage]: |
|
""" |
|
Build complete message history from trajectory and current observation |
|
|
|
Args: |
|
observation: Current observation object, if None current observation won't be added |
|
""" |
|
messages = [] |
|
|
|
system_message = SystemPrompt( |
|
max_actions_per_step=self.settings.get('max_actions_per_step') |
|
).get_system_message() |
|
if isinstance(system_message, tuple): |
|
system_message = system_message[0] |
|
messages.append(system_message) |
|
|
|
tool_calling_method = self.settings.get("tool_calling_method") |
|
llm_provider = self.conf.llm_provider if self.conf.llm_provider else self.conf.llm_config.llm_provider |
|
|
|
if tool_calling_method == 'raw' or (tool_calling_method == 'auto' and ( |
|
llm_provider == 'deepseek-reasoner' or llm_provider.startswith('deepseek-r1'))): |
|
message_context = f'\n\nAvailable actions: {self.available_actions}' |
|
else: |
|
message_context = None |
|
|
|
|
|
if message_context: |
|
context_message = HumanMessage(content='Context for the task' + message_context) |
|
messages.append(context_message) |
|
|
|
|
|
task_message = HumanMessage( |
|
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.' |
|
) |
|
messages.append(task_message) |
|
|
|
|
|
placeholder_message = HumanMessage(content='Example output:') |
|
messages.append(placeholder_message) |
|
|
|
|
|
tool_calls = [ |
|
{ |
|
'name': 'AgentOutput', |
|
'args': { |
|
'current_state': { |
|
'evaluation_previous_goal': 'Success - I opend the first page', |
|
'memory': 'Starting with the new task. I have completed 1/10 steps', |
|
'thought': 'From the current page I can get information about all the companies.', |
|
'next_goal': 'Click on company a', |
|
}, |
|
'action': [{'click_element': {'index': 0}}], |
|
}, |
|
'id': '1', |
|
'type': 'tool_call', |
|
} |
|
] |
|
example_tool_call = AIMessage( |
|
content='', |
|
tool_calls=tool_calls, |
|
) |
|
messages.append(example_tool_call) |
|
|
|
|
|
messages.append(ToolMessage(content='Browser started', tool_call_id='1')) |
|
|
|
|
|
messages.append(HumanMessage(content='[Your task history memory starts here]')) |
|
|
|
|
|
if self.settings.get('available_file_paths'): |
|
filepaths_msg = HumanMessage( |
|
content=f'Here are file paths you can use: {self.settings.get("available_file_paths")}') |
|
messages.append(filepaths_msg) |
|
previous_action_entries = [] |
|
|
|
for input_msgs, obs, info, output_msg, llm_result in self.trajectory.get_history(): |
|
|
|
has_error = False |
|
if obs.action_result is not None: |
|
|
|
if len(previous_action_entries) == 0: |
|
|
|
logger.info( |
|
f"History item with action_result count ({len(obs.action_result)}) with empty previous actions - skipping count check") |
|
elif len(previous_action_entries) == len(obs.action_result): |
|
for i, one_action_result in enumerate(obs.action_result): |
|
has_error = self._process_action_result(one_action_result, messages, |
|
previous_action_entries[i]) or has_error |
|
else: |
|
|
|
error_msg = f"Action results count ({len(obs.action_result)}) doesn't match action entries count ({len(previous_action_entries)})" |
|
logger.error(error_msg) |
|
has_error = True |
|
|
|
|
|
|
|
if llm_result: |
|
|
|
output_data = llm_result.model_dump(mode='json', exclude_unset=True) |
|
action_entries = [{action.action_name: action.params} for action in llm_result.actions] |
|
output_data["action"] = action_entries |
|
if "actions" in output_data: |
|
del output_data["actions"] |
|
|
|
|
|
tool_id = 1 if len(self.trajectory.get_history()) == 0 else len(self.trajectory.get_history()) + 1 |
|
tool_calls = [ |
|
{ |
|
'name': 'AgentOutput', |
|
'args': output_data, |
|
'id': str(tool_id), |
|
'type': 'tool_call', |
|
} |
|
] |
|
previous_action_entries = action_entries |
|
ai_message = AIMessage( |
|
content='', |
|
tool_calls=tool_calls, |
|
) |
|
messages.append(ai_message) |
|
|
|
|
|
messages.append(ToolMessage(content='', tool_call_id=str(tool_id))) |
|
|
|
|
|
if observation: |
|
|
|
has_error = False |
|
if hasattr(observation, 'action_result') and observation.action_result is not None: |
|
|
|
if len(previous_action_entries) == 0: |
|
|
|
logger.info( |
|
f"Current observation with action_result count ({len(observation.action_result)}) with empty previous actions - skipping count check") |
|
elif len(previous_action_entries) == len(observation.action_result): |
|
for i, one_action_result in enumerate(observation.action_result): |
|
has_error = self._process_action_result(one_action_result, messages, |
|
previous_action_entries[i]) or has_error |
|
else: |
|
|
|
error_msg = f"Action results count ({len(observation.action_result)}) doesn't match action entries count ({len(previous_action_entries)})" |
|
logger.error(error_msg) |
|
has_error = True |
|
|
|
|
|
if has_error and observation.content: |
|
messages.append(HumanMessage(content=observation.content)) |
|
|
|
elif not has_error: |
|
step_info = AgentStepInfo(number=self.state.n_steps, max_steps=self.conf.max_steps) |
|
if hasattr(observation, 'dom_tree') and observation.dom_tree: |
|
state_message = AgentMessagePrompt( |
|
observation, |
|
self.state.last_result, |
|
include_attributes=self.settings.get('include_attributes'), |
|
step_info=step_info, |
|
).get_user_message(self.settings.get('use_vision')) |
|
messages.append(state_message) |
|
elif observation.content: |
|
messages.append(HumanMessage(content=observation.content)) |
|
|
|
|
|
|
|
if self.conf.max_steps <= self.state.n_steps: |
|
last_step_message = f""" |
|
Now comes your last step. Use only the "done" action now. No other actions - so here your action sequence must have length 1. |
|
\nIf the task is not yet fully finished as requested by the user, set success in "done" to false! E.g. if not all steps are fully completed. |
|
\nIf the task is fully finished, set success in "done" to true. |
|
\nInclude everything you found out for the ultimate task in the done text. |
|
""" |
|
messages.append(HumanMessage(content=[{'type': 'text', 'text': last_step_message}])) |
|
|
|
return messages |
|
|
|
def _estimate_tokens_for_messages(self, messages: List[BaseMessage]) -> int: |
|
"""Roughly estimate token count for message list""" |
|
|
|
|
|
return estimate_messages_tokens( |
|
messages, |
|
image_tokens=self.settings.get('image_tokens', 800), |
|
estimated_characters_per_token=self.settings.get('estimated_characters_per_token', 3) |
|
) |
|
|