Spaces:
Running
Running
File size: 7,673 Bytes
f0f6e5c 7b40088 f0f6e5c 719511c f0f6e5c 7b40088 f0f6e5c 719511c f0f6e5c 7b40088 f0f6e5c 7b40088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import base64
from functools import cached_property
from typing import Any, Literal, Optional, Self
from proxy_lite.browser.browser import BrowserSession
from proxy_lite.environments.environment_base import (
Action,
BaseEnvironment,
BaseEnvironmentConfig,
Environments,
Observation,
State,
)
from proxy_lite.tools import BrowserTool, Tool, ToolExecutionResponse
# Import logger from proxy_lite.logger, or if it's already available via BaseEnvironment
from proxy_lite.logger import logger # Assuming you want to use the same logger
@Environments.register_environment_config("webbrowser")
class WebBrowserEnvironmentConfig(BaseEnvironmentConfig):
name: Literal["webbrowser"] = "webbrowser"
homepage: str = "https://google.com"
annotate_image: bool = True
screenshot_delay: float = 1.0 # seconds
include_html: bool = True
include_poi_text: bool = True
record_pois: bool = True
viewport_width: int = 1280
viewport_height: int = 720
browserbase_timeout: int = 7200
headless: bool = True
keep_original_image: bool = False
no_pois_in_image: bool = False
@Environments.register_environment("webbrowser")
class WebBrowserEnvironment(BaseEnvironment):
config: WebBrowserEnvironmentConfig
browser: Optional[BrowserSession] = None
cancelled_last_action: bool = False
class Config:
arbitrary_types_allowed = True
async def __aenter__(self) -> Self:
# Initialize the BrowserSession
self.browser = self.browser_session(
viewport_width=self.config.viewport_width,
viewport_height=self.config.viewport_height,
headless=self.config.headless,
)
await self.browser.__aenter__()
# Initialize other resources if necessary
if self.cookies:
await self.browser.context.add_cookies(self.cookies)
self.logger.info("π [bold blue]Browser session started.[/]")
return self
async def __aexit__(self, exc_type, exc_value, traceback):
# Clean up the BrowserSession
await self.browser.__aexit__(exc_type, exc_value, traceback)
@property
def info_for_user(self) -> str:
return "This is a web browser environment. You can navigate the web, search the web, and perform actions on the web." # noqa: E501
@cached_property
def tools(self) -> list[Tool]:
return [BrowserTool(session=self.browser)]
@cached_property
def browser_session(self) -> type[BrowserSession]:
return BrowserSession
@property
def cookies(self) -> list[dict]:
return []
async def initialise(self) -> Observation:
self.logger.debug(f"DEBUG: Initialising WebBrowserEnvironment. Homepage: {self.config.homepage}")
try:
await self.browser.goto(self.config.homepage)
self.logger.debug(f"DEBUG: Browser navigated to homepage. Current URL: {self.browser.current_url}")
except Exception as e:
self.logger.error(f"ERROR: Failed to navigate to homepage {self.config.homepage}: {e}")
raise # Re-raise to propagate the error
original_img, annotated_img = await self.browser.screenshot(
delay=self.config.screenshot_delay,
)
if self.config.no_pois_in_image:
base64_image = base64.b64encode(original_img).decode("utf-8")
else:
base64_image = base64.b64encode(annotated_img).decode("utf-8")
html_content = await self.browser.current_page.content() if self.config.include_html else None
info = {"url": self.browser.current_url}
if self.config.record_pois:
info["pois"] = self.browser.pois
if self.config.keep_original_image:
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
self.logger.debug(f"DEBUG: Initial observation captured. URL: {self.browser.current_url}")
return Observation(
state=State(
text=f"URL: {self.browser.current_url}"
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
image=base64_image,
html=html_content,
),
terminated=False,
reward=None,
info=info,
)
async def should_perform_action(self) -> bool:
# if cancelled last action, run the action without updating POIs
if self.cancelled_last_action:
self.cancelled_last_action = False
return True
# check for page changes
old_points = [tuple(point) for point in self.browser.poi_centroids]
await self.browser.update_poi()
new_points = [tuple(point) for point in self.browser.poi_centroids]
page_changed_mid_action = old_points != new_points
# record if the last action was cancelled
if page_changed_mid_action:
self.cancelled_last_action = True
return False
return True
async def execute_action(self, action: Action) -> Observation:
responses = []
cancelled_tools_flag = False
if await self.should_perform_action():
for tool_call in action.tool_calls:
# Perform the chosen action
try:
tool_response: ToolExecutionResponse = await self.execute_tool(
tool_call,
)
tool_response.id = tool_call.id
responses.append(tool_response)
except Exception as e: # noqa: PERF203
self.logger.warning("π An error occurred taking action: %s", str(e), exc_info=False)
tool_response = ToolExecutionResponse(content=str(e), id=tool_call.id)
responses.append(tool_response)
else:
self.logger.warning("π Page changed since last observation, cancelling action.")
self.cancelled_last_action = True
for tool_call in action.tool_calls:
tool_response = ToolExecutionResponse(
content="The page changed before the action could be executed, instead of being ran it was cancelled.", # noqa: E501
id=tool_call.id,
)
responses.append(tool_response)
cancelled_tools_flag = True
original_img, annotated_img = await self.browser.screenshot(
delay=self.config.screenshot_delay,
)
base64_image = base64.b64encode(annotated_img).decode("utf-8")
info = {"url": self.browser.current_url, "cancelled_tools": cancelled_tools_flag}
if self.config.record_pois:
info["pois"] = self.browser.pois
if self.config.keep_original_image:
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
html_content = await self.browser.current_page.content() if self.config.include_html else None
return Observation(
state=State(
text=f"URL: {self.browser.current_url}"
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
image=base64_image,
html=html_content,
tool_responses=responses,
),
terminated=False,
reward=None,
info=info,
)
async def observe(self) -> Observation:
return await self.browser.observe()
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
return {}
async def get_info(self) -> dict[str, Any]:
info = {}
return info |