Spaces:
Paused
Paused
| from __future__ import annotations | |
| import logging | |
| from typing import List, Optional, Type | |
| from browser_use.agent.message_manager.service import MessageManager | |
| from browser_use.agent.message_manager.views import MessageHistory | |
| from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt | |
| from browser_use.agent.views import ActionResult, AgentStepInfo, ActionModel | |
| from browser_use.browser.views import BrowserState | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| BaseMessage, | |
| HumanMessage, | |
| ToolMessage | |
| ) | |
| from langchain_openai import ChatOpenAI | |
| from ..utils.llm import DeepSeekR1ChatOpenAI | |
| from .custom_prompts import CustomAgentMessagePrompt | |
| logger = logging.getLogger(__name__) | |
| class CustomMessageManager(MessageManager): | |
| def __init__( | |
| self, | |
| llm: BaseChatModel, | |
| task: str, | |
| action_descriptions: str, | |
| system_prompt_class: Type[SystemPrompt], | |
| agent_prompt_class: Type[AgentMessagePrompt], | |
| max_input_tokens: int = 128000, | |
| estimated_characters_per_token: int = 3, | |
| image_tokens: int = 800, | |
| include_attributes: list[str] = [], | |
| max_error_length: int = 400, | |
| max_actions_per_step: int = 10, | |
| message_context: Optional[str] = None | |
| ): | |
| super().__init__( | |
| llm=llm, | |
| task=task, | |
| action_descriptions=action_descriptions, | |
| system_prompt_class=system_prompt_class, | |
| max_input_tokens=max_input_tokens, | |
| estimated_characters_per_token=estimated_characters_per_token, | |
| image_tokens=image_tokens, | |
| include_attributes=include_attributes, | |
| max_error_length=max_error_length, | |
| max_actions_per_step=max_actions_per_step, | |
| message_context=message_context | |
| ) | |
| self.agent_prompt_class = agent_prompt_class | |
| # Custom: Move Task info to state_message | |
| self.history = MessageHistory() | |
| self._add_message_with_tokens(self.system_prompt) | |
| if self.message_context: | |
| context_message = HumanMessage(content=self.message_context) | |
| self._add_message_with_tokens(context_message) | |
| def cut_messages(self): | |
| """Get current message list, potentially trimmed to max tokens""" | |
| diff = self.history.total_tokens - self.max_input_tokens | |
| min_message_len = 2 if self.message_context is not None else 1 | |
| while diff > 0 and len(self.history.messages) > min_message_len: | |
| self.history.remove_message(min_message_len) # alway remove the oldest message | |
| diff = self.history.total_tokens - self.max_input_tokens | |
| def add_state_message( | |
| self, | |
| state: BrowserState, | |
| actions: Optional[List[ActionModel]] = None, | |
| result: Optional[List[ActionResult]] = None, | |
| step_info: Optional[AgentStepInfo] = None, | |
| ) -> None: | |
| """Add browser state as human message""" | |
| # otherwise add state message and result to next message (which will not stay in memory) | |
| state_message = self.agent_prompt_class( | |
| state, | |
| actions, | |
| result, | |
| include_attributes=self.include_attributes, | |
| max_error_length=self.max_error_length, | |
| step_info=step_info, | |
| ).get_user_message() | |
| self._add_message_with_tokens(state_message) | |
| def _count_text_tokens(self, text: str) -> int: | |
| if isinstance(self.llm, (ChatOpenAI, ChatAnthropic, DeepSeekR1ChatOpenAI)): | |
| try: | |
| tokens = self.llm.get_num_tokens(text) | |
| except Exception: | |
| tokens = ( | |
| len(text) // self.estimated_characters_per_token | |
| ) # Rough estimate if no tokenizer available | |
| else: | |
| tokens = ( | |
| len(text) // self.estimated_characters_per_token | |
| ) # Rough estimate if no tokenizer available | |
| return tokens | |
| def _remove_state_message_by_index(self, remove_ind=-1) -> None: | |
| """Remove last state message from history""" | |
| i = len(self.history.messages) - 1 | |
| remove_cnt = 0 | |
| while i >= 0: | |
| if isinstance(self.history.messages[i].message, HumanMessage): | |
| remove_cnt += 1 | |
| if remove_cnt == abs(remove_ind): | |
| self.history.remove_message(i) | |
| break | |
| i -= 1 |