Spaces:
Running
Running
Commit ·
ab65628
1
Parent(s): ff3e1be
feat: add core RL environment models (observation, action, reward, env)
Browse files- backend/app/core/__init__.py +18 -0
- backend/app/core/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/app/core/__pycache__/action.cpython-314.pyc +0 -0
- backend/app/core/__pycache__/env.cpython-314.pyc +0 -0
- backend/app/core/__pycache__/episode.cpython-314.pyc +0 -0
- backend/app/core/__pycache__/observation.cpython-314.pyc +0 -0
- backend/app/core/__pycache__/reward.cpython-314.pyc +0 -0
- backend/app/core/action.py +347 -0
- backend/app/core/env.py +497 -0
- backend/app/core/episode.py +261 -0
- backend/app/core/observation.py +228 -0
- backend/app/core/reward.py +398 -0
backend/app/core/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core module - RL environment, observations, actions, and rewards."""
|
| 2 |
+
|
| 3 |
+
from app.core.action import Action, ActionType
|
| 4 |
+
from app.core.env import WebScraperEnv
|
| 5 |
+
from app.core.episode import Episode, EpisodeStatus
|
| 6 |
+
from app.core.observation import Observation
|
| 7 |
+
from app.core.reward import RewardEngine, RewardBreakdown
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"Action",
|
| 11 |
+
"ActionType",
|
| 12 |
+
"WebScraperEnv",
|
| 13 |
+
"Episode",
|
| 14 |
+
"EpisodeStatus",
|
| 15 |
+
"Observation",
|
| 16 |
+
"RewardEngine",
|
| 17 |
+
"RewardBreakdown",
|
| 18 |
+
]
|
backend/app/core/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (627 Bytes). View file
|
|
|
backend/app/core/__pycache__/action.cpython-314.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
backend/app/core/__pycache__/env.cpython-314.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
backend/app/core/__pycache__/episode.cpython-314.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
backend/app/core/__pycache__/observation.cpython-314.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
backend/app/core/__pycache__/reward.cpython-314.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
backend/app/core/action.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Action model for the RL environment."""
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ActionType(str, Enum):
|
| 10 |
+
"""All possible action types in the environment."""
|
| 11 |
+
|
| 12 |
+
# Navigation actions
|
| 13 |
+
NAVIGATE = "navigate"
|
| 14 |
+
GO_BACK = "go_back"
|
| 15 |
+
GO_FORWARD = "go_forward"
|
| 16 |
+
REFRESH = "refresh"
|
| 17 |
+
|
| 18 |
+
# Interaction actions
|
| 19 |
+
CLICK = "click"
|
| 20 |
+
FILL = "fill"
|
| 21 |
+
SELECT = "select"
|
| 22 |
+
SCROLL = "scroll"
|
| 23 |
+
HOVER = "hover"
|
| 24 |
+
|
| 25 |
+
# Extraction actions
|
| 26 |
+
EXTRACT_FIELD = "extract_field"
|
| 27 |
+
EXTRACT_TABLE = "extract_table"
|
| 28 |
+
EXTRACT_LIST = "extract_list"
|
| 29 |
+
|
| 30 |
+
# Search actions
|
| 31 |
+
SEARCH_PAGE = "search_page"
|
| 32 |
+
SEARCH_ENGINE = "search_engine"
|
| 33 |
+
|
| 34 |
+
# Verification actions
|
| 35 |
+
VERIFY_FACT = "verify_fact"
|
| 36 |
+
VERIFY_FIELD = "verify_field"
|
| 37 |
+
|
| 38 |
+
# Memory actions
|
| 39 |
+
STORE_MEMORY = "store_memory"
|
| 40 |
+
RECALL_MEMORY = "recall_memory"
|
| 41 |
+
|
| 42 |
+
# Tool actions
|
| 43 |
+
MCP_TOOL_CALL = "mcp_tool_call"
|
| 44 |
+
|
| 45 |
+
# Planning actions
|
| 46 |
+
CREATE_PLAN = "create_plan"
|
| 47 |
+
UPDATE_PLAN = "update_plan"
|
| 48 |
+
|
| 49 |
+
# Communication actions
|
| 50 |
+
SEND_MESSAGE = "send_message"
|
| 51 |
+
|
| 52 |
+
# Control actions
|
| 53 |
+
WAIT = "wait"
|
| 54 |
+
DONE = "done"
|
| 55 |
+
FAIL = "fail"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class NavigateParams(BaseModel):
|
| 59 |
+
"""Parameters for navigation actions."""
|
| 60 |
+
|
| 61 |
+
url: str
|
| 62 |
+
wait_for: str | None = None
|
| 63 |
+
timeout_ms: int = 30000
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ClickParams(BaseModel):
|
| 67 |
+
"""Parameters for click actions."""
|
| 68 |
+
|
| 69 |
+
selector: str
|
| 70 |
+
button: str = "left"
|
| 71 |
+
click_count: int = 1
|
| 72 |
+
wait_after_ms: int = 500
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FillParams(BaseModel):
|
| 76 |
+
"""Parameters for form fill actions."""
|
| 77 |
+
|
| 78 |
+
selector: str
|
| 79 |
+
value: str
|
| 80 |
+
clear_first: bool = True
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SelectParams(BaseModel):
|
| 84 |
+
"""Parameters for select dropdown actions."""
|
| 85 |
+
|
| 86 |
+
selector: str
|
| 87 |
+
value: str | None = None
|
| 88 |
+
label: str | None = None
|
| 89 |
+
index: int | None = None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ScrollParams(BaseModel):
|
| 93 |
+
"""Parameters for scroll actions."""
|
| 94 |
+
|
| 95 |
+
direction: str = "down"
|
| 96 |
+
amount: int | str = "page"
|
| 97 |
+
selector: str | None = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ExtractFieldParams(BaseModel):
|
| 101 |
+
"""Parameters for field extraction actions."""
|
| 102 |
+
|
| 103 |
+
field_name: str
|
| 104 |
+
selector: str | None = None
|
| 105 |
+
extraction_method: str = "text"
|
| 106 |
+
attribute: str | None = None
|
| 107 |
+
regex_pattern: str | None = None
|
| 108 |
+
post_process: str | None = None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ExtractTableParams(BaseModel):
|
| 112 |
+
"""Parameters for table extraction actions."""
|
| 113 |
+
|
| 114 |
+
table_selector: str
|
| 115 |
+
headers: list[str] | None = None
|
| 116 |
+
row_selector: str | None = None
|
| 117 |
+
cell_selectors: dict[str, str] | None = None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ExtractListParams(BaseModel):
|
| 121 |
+
"""Parameters for list extraction actions."""
|
| 122 |
+
|
| 123 |
+
container_selector: str
|
| 124 |
+
item_selector: str
|
| 125 |
+
field_selectors: dict[str, str]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SearchPageParams(BaseModel):
|
| 129 |
+
"""Parameters for searching within the current page."""
|
| 130 |
+
|
| 131 |
+
query: str
|
| 132 |
+
search_type: str = "text"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SearchEngineParams(BaseModel):
|
| 136 |
+
"""Parameters for search engine queries."""
|
| 137 |
+
|
| 138 |
+
query: str
|
| 139 |
+
engine: str = "google"
|
| 140 |
+
num_results: int = 10
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class VerifyFactParams(BaseModel):
|
| 144 |
+
"""Parameters for fact verification."""
|
| 145 |
+
|
| 146 |
+
claim: str
|
| 147 |
+
sources: list[str] | None = None
|
| 148 |
+
confidence_threshold: float = 0.8
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class VerifyFieldParams(BaseModel):
|
| 152 |
+
"""Parameters for field verification."""
|
| 153 |
+
|
| 154 |
+
field_name: str
|
| 155 |
+
expected_type: str | None = None
|
| 156 |
+
expected_format: str | None = None
|
| 157 |
+
validation_rules: list[str] = Field(default_factory=list)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MemoryParams(BaseModel):
|
| 161 |
+
"""Parameters for memory operations."""
|
| 162 |
+
|
| 163 |
+
key: str
|
| 164 |
+
value: Any | None = None
|
| 165 |
+
memory_type: str = "working"
|
| 166 |
+
ttl_seconds: int | None = None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class MCPToolCallParams(BaseModel):
|
| 170 |
+
"""Parameters for MCP tool calls."""
|
| 171 |
+
|
| 172 |
+
tool_name: str
|
| 173 |
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class PlanParams(BaseModel):
|
| 177 |
+
"""Parameters for planning actions."""
|
| 178 |
+
|
| 179 |
+
plan_description: str | None = None
|
| 180 |
+
steps: list[dict[str, Any]] | None = None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class MessageParams(BaseModel):
|
| 184 |
+
"""Parameters for inter-agent messages."""
|
| 185 |
+
|
| 186 |
+
target_agent: str
|
| 187 |
+
message_type: str
|
| 188 |
+
content: dict[str, Any] = Field(default_factory=dict)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class WaitParams(BaseModel):
|
| 192 |
+
"""Parameters for wait actions."""
|
| 193 |
+
|
| 194 |
+
duration_ms: int = 1000
|
| 195 |
+
wait_for_selector: str | None = None
|
| 196 |
+
wait_for_navigation: bool = False
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class DoneParams(BaseModel):
|
| 200 |
+
"""Parameters for completion."""
|
| 201 |
+
|
| 202 |
+
success: bool = True
|
| 203 |
+
message: str | None = None
|
| 204 |
+
final_result: dict[str, Any] | None = None
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Action(BaseModel):
|
| 208 |
+
"""
|
| 209 |
+
Represents an action to be taken in the environment.
|
| 210 |
+
|
| 211 |
+
An action consists of:
|
| 212 |
+
- action_type: The type of action
|
| 213 |
+
- parameters: Action-specific parameters
|
| 214 |
+
- reasoning: Why this action was chosen
|
| 215 |
+
- confidence: How confident the agent is
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
action_type: ActionType = Field(..., description="Type of action to execute")
|
| 219 |
+
parameters: dict[str, Any] = Field(
|
| 220 |
+
default_factory=dict,
|
| 221 |
+
description="Action-specific parameters",
|
| 222 |
+
)
|
| 223 |
+
reasoning: str | None = Field(
|
| 224 |
+
default=None,
|
| 225 |
+
description="Agent's reasoning for this action",
|
| 226 |
+
)
|
| 227 |
+
confidence: float = Field(
|
| 228 |
+
default=1.0,
|
| 229 |
+
ge=0.0,
|
| 230 |
+
le=1.0,
|
| 231 |
+
description="Confidence in this action (0-1)",
|
| 232 |
+
)
|
| 233 |
+
agent_id: str | None = Field(
|
| 234 |
+
default=None,
|
| 235 |
+
description="ID of the agent that produced this action",
|
| 236 |
+
)
|
| 237 |
+
plan_step: int | None = Field(
|
| 238 |
+
default=None,
|
| 239 |
+
description="Which step of the plan this corresponds to",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
@field_validator("confidence")
|
| 243 |
+
@classmethod
|
| 244 |
+
def validate_confidence(cls, v: float) -> float:
|
| 245 |
+
"""Ensure confidence is between 0 and 1."""
|
| 246 |
+
return max(0.0, min(1.0, v))
|
| 247 |
+
|
| 248 |
+
model_config = ConfigDict(
|
| 249 |
+
json_schema_extra={
|
| 250 |
+
"example": {
|
| 251 |
+
"action_type": "extract_field",
|
| 252 |
+
"parameters": {
|
| 253 |
+
"field_name": "price",
|
| 254 |
+
"selector": ".product-price",
|
| 255 |
+
"extraction_method": "text",
|
| 256 |
+
},
|
| 257 |
+
"reasoning": "The price element is visible with class .product-price",
|
| 258 |
+
"confidence": 0.92,
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def navigate(cls, url: str, **kwargs: Any) -> "Action":
|
| 265 |
+
"""Create a navigate action."""
|
| 266 |
+
return cls(
|
| 267 |
+
action_type=ActionType.NAVIGATE,
|
| 268 |
+
parameters={"url": url, **kwargs},
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
def click(cls, selector: str, **kwargs: Any) -> "Action":
|
| 273 |
+
"""Create a click action."""
|
| 274 |
+
return cls(
|
| 275 |
+
action_type=ActionType.CLICK,
|
| 276 |
+
parameters={"selector": selector, **kwargs},
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
@classmethod
|
| 280 |
+
def extract_field(
|
| 281 |
+
cls,
|
| 282 |
+
field_name: str,
|
| 283 |
+
selector: str | None = None,
|
| 284 |
+
**kwargs: Any,
|
| 285 |
+
) -> "Action":
|
| 286 |
+
"""Create an extract field action."""
|
| 287 |
+
return cls(
|
| 288 |
+
action_type=ActionType.EXTRACT_FIELD,
|
| 289 |
+
parameters={"field_name": field_name, "selector": selector, **kwargs},
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def search_engine(cls, query: str, engine: str = "google", **kwargs: Any) -> "Action":
|
| 294 |
+
"""Create a search engine action."""
|
| 295 |
+
return cls(
|
| 296 |
+
action_type=ActionType.SEARCH_ENGINE,
|
| 297 |
+
parameters={"query": query, "engine": engine, **kwargs},
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
@classmethod
|
| 301 |
+
def done(cls, success: bool = True, message: str | None = None) -> "Action":
|
| 302 |
+
"""Create a done action."""
|
| 303 |
+
return cls(
|
| 304 |
+
action_type=ActionType.DONE,
|
| 305 |
+
parameters={"success": success, "message": message},
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def wait(cls, duration_ms: int = 1000) -> "Action":
|
| 310 |
+
"""Create a wait action."""
|
| 311 |
+
return cls(
|
| 312 |
+
action_type=ActionType.WAIT,
|
| 313 |
+
parameters={"duration_ms": duration_ms},
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
@classmethod
|
| 317 |
+
def mcp_tool_call(cls, tool_name: str, **arguments: Any) -> "Action":
|
| 318 |
+
"""Create an MCP tool call action."""
|
| 319 |
+
return cls(
|
| 320 |
+
action_type=ActionType.MCP_TOOL_CALL,
|
| 321 |
+
parameters={"tool_name": tool_name, "arguments": arguments},
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def get_param(self, key: str, default: Any = None) -> Any:
|
| 325 |
+
"""Get a parameter value with optional default."""
|
| 326 |
+
return self.parameters.get(key, default)
|
| 327 |
+
|
| 328 |
+
def validate_params(self) -> list[str]:
|
| 329 |
+
"""Validate parameters for this action type. Returns list of errors."""
|
| 330 |
+
errors = []
|
| 331 |
+
|
| 332 |
+
required_params = {
|
| 333 |
+
ActionType.NAVIGATE: ["url"],
|
| 334 |
+
ActionType.CLICK: ["selector"],
|
| 335 |
+
ActionType.FILL: ["selector", "value"],
|
| 336 |
+
ActionType.EXTRACT_FIELD: ["field_name"],
|
| 337 |
+
ActionType.SEARCH_ENGINE: ["query"],
|
| 338 |
+
ActionType.MCP_TOOL_CALL: ["tool_name"],
|
| 339 |
+
ActionType.SEND_MESSAGE: ["target_agent", "message_type"],
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
if self.action_type in required_params:
|
| 343 |
+
for param in required_params[self.action_type]:
|
| 344 |
+
if param not in self.parameters or self.parameters[param] is None:
|
| 345 |
+
errors.append(f"Missing required parameter: {param}")
|
| 346 |
+
|
| 347 |
+
return errors
|
backend/app/core/env.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web scraper RL environment."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from app.config import Settings, get_settings
|
| 8 |
+
from app.core.action import Action, ActionType
|
| 9 |
+
from app.core.episode import Episode, EpisodeManager
|
| 10 |
+
from app.core.observation import (
|
| 11 |
+
AvailableAction,
|
| 12 |
+
ExtractedField,
|
| 13 |
+
MemoryContext,
|
| 14 |
+
Observation,
|
| 15 |
+
TaskContext,
|
| 16 |
+
)
|
| 17 |
+
from app.core.reward import RewardBreakdown, RewardEngine
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class WebScraperEnv:
|
| 23 |
+
"""
|
| 24 |
+
Reinforcement Learning environment for web scraping.
|
| 25 |
+
|
| 26 |
+
Follows the Gymnasium API pattern:
|
| 27 |
+
- reset(task_id, seed) -> observation, info
|
| 28 |
+
- step(action) -> observation, reward, terminated, truncated, info
|
| 29 |
+
- get_state() -> state dict
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
episode_id: str,
|
| 35 |
+
settings: Settings | None = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Initialize the environment.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
episode_id: Unique identifier for this episode.
|
| 42 |
+
settings: Application settings.
|
| 43 |
+
"""
|
| 44 |
+
self.episode_id = episode_id
|
| 45 |
+
self.settings = settings or get_settings()
|
| 46 |
+
self.reward_engine = RewardEngine(settings)
|
| 47 |
+
self.episode_manager = EpisodeManager()
|
| 48 |
+
|
| 49 |
+
# State
|
| 50 |
+
self._episode: Episode | None = None
|
| 51 |
+
self._current_observation: Observation | None = None
|
| 52 |
+
self._task_context: TaskContext | None = None
|
| 53 |
+
self._ground_truth: dict[str, Any] | None = None
|
| 54 |
+
|
| 55 |
+
# Browser state (placeholder - would use Playwright in production)
|
| 56 |
+
self._current_url: str | None = None
|
| 57 |
+
self._page_html: str | None = None
|
| 58 |
+
self._page_title: str | None = None
|
| 59 |
+
|
| 60 |
+
# Extraction state
|
| 61 |
+
self._extracted_fields: list[ExtractedField] = []
|
| 62 |
+
self._navigation_history: list[str] = []
|
| 63 |
+
|
| 64 |
+
# Timing
|
| 65 |
+
self._start_time: float | None = None
|
| 66 |
+
|
| 67 |
+
async def reset(
|
| 68 |
+
self,
|
| 69 |
+
task_id: str,
|
| 70 |
+
seed: int | None = None,
|
| 71 |
+
config: dict[str, Any] | None = None,
|
| 72 |
+
) -> tuple[Observation, dict[str, Any]]:
|
| 73 |
+
"""
|
| 74 |
+
Reset the environment for a new episode.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
task_id: ID of the task to execute.
|
| 78 |
+
seed: Random seed for reproducibility.
|
| 79 |
+
config: Optional episode configuration.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Tuple of (initial_observation, info_dict).
|
| 83 |
+
"""
|
| 84 |
+
logger.info(f"Resetting environment for task {task_id}")
|
| 85 |
+
|
| 86 |
+
# Reset state
|
| 87 |
+
self.reward_engine.reset()
|
| 88 |
+
self._extracted_fields = []
|
| 89 |
+
self._navigation_history = []
|
| 90 |
+
self._start_time = time.time()
|
| 91 |
+
self._current_url = None
|
| 92 |
+
self._page_html = None
|
| 93 |
+
self._page_title = None
|
| 94 |
+
|
| 95 |
+
# Create episode
|
| 96 |
+
self._episode = self.episode_manager.create_episode(
|
| 97 |
+
episode_id=self.episode_id,
|
| 98 |
+
task_id=task_id,
|
| 99 |
+
max_steps=self.settings.max_steps_per_episode,
|
| 100 |
+
seed=seed,
|
| 101 |
+
config=config or {},
|
| 102 |
+
)
|
| 103 |
+
self._episode.start()
|
| 104 |
+
|
| 105 |
+
# Load task context
|
| 106 |
+
self._task_context = await self._load_task_context(task_id)
|
| 107 |
+
|
| 108 |
+
# Create initial observation
|
| 109 |
+
self._current_observation = self._create_observation()
|
| 110 |
+
|
| 111 |
+
info = {
|
| 112 |
+
"episode_id": self.episode_id,
|
| 113 |
+
"task_id": task_id,
|
| 114 |
+
"max_steps": self._episode.max_steps,
|
| 115 |
+
"target_fields": self._task_context.target_fields if self._task_context else [],
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
return self._current_observation, info
|
| 119 |
+
|
| 120 |
+
async def step(
|
| 121 |
+
self,
|
| 122 |
+
action: Action,
|
| 123 |
+
) -> tuple[Observation, float, dict[str, float], bool, bool, dict[str, Any]]:
|
| 124 |
+
"""
|
| 125 |
+
Execute an action and return the result.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
action: The action to execute.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tuple of (observation, reward, reward_breakdown, terminated, truncated, info).
|
| 132 |
+
"""
|
| 133 |
+
if self._episode is None or self._current_observation is None:
|
| 134 |
+
raise RuntimeError("Environment not reset. Call reset() first.")
|
| 135 |
+
|
| 136 |
+
if self._episode.is_terminal:
|
| 137 |
+
raise RuntimeError("Episode has already terminated.")
|
| 138 |
+
|
| 139 |
+
step_start = time.time()
|
| 140 |
+
prev_observation = self._current_observation
|
| 141 |
+
|
| 142 |
+
# Validate action
|
| 143 |
+
errors = action.validate_params()
|
| 144 |
+
if errors:
|
| 145 |
+
logger.warning(f"Invalid action parameters: {errors}")
|
| 146 |
+
|
| 147 |
+
# Execute action
|
| 148 |
+
action_result = await self._execute_action(action)
|
| 149 |
+
|
| 150 |
+
# Update observation
|
| 151 |
+
self._current_observation = self._create_observation()
|
| 152 |
+
if action_result.get("error"):
|
| 153 |
+
self._current_observation.last_action_error = action_result["error"]
|
| 154 |
+
self._current_observation.consecutive_errors = (
|
| 155 |
+
prev_observation.consecutive_errors + 1
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
self._current_observation.consecutive_errors = 0
|
| 159 |
+
|
| 160 |
+
# Compute reward
|
| 161 |
+
reward, breakdown = self.reward_engine.compute_reward(
|
| 162 |
+
action=action,
|
| 163 |
+
prev_observation=prev_observation,
|
| 164 |
+
new_observation=self._current_observation,
|
| 165 |
+
ground_truth=self._ground_truth,
|
| 166 |
+
max_steps=self._episode.max_steps,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Check termination
|
| 170 |
+
terminated = self._check_terminated(action)
|
| 171 |
+
truncated = self._check_truncated()
|
| 172 |
+
|
| 173 |
+
# Update episode
|
| 174 |
+
step_duration = (time.time() - step_start) * 1000
|
| 175 |
+
self._episode.add_step(
|
| 176 |
+
action_type=action.action_type.value,
|
| 177 |
+
action_params=action.parameters,
|
| 178 |
+
action_reasoning=action.reasoning,
|
| 179 |
+
reward=reward,
|
| 180 |
+
reward_breakdown=breakdown.to_dict(),
|
| 181 |
+
observation_summary={
|
| 182 |
+
"url": self._current_observation.current_url,
|
| 183 |
+
"progress": self._current_observation.extraction_progress,
|
| 184 |
+
"fields_extracted": len(self._current_observation.extracted_so_far),
|
| 185 |
+
},
|
| 186 |
+
error=action_result.get("error"),
|
| 187 |
+
duration_ms=step_duration,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Handle terminal states
|
| 191 |
+
if terminated:
|
| 192 |
+
success = action.action_type == ActionType.DONE and action.get_param(
|
| 193 |
+
"success", True
|
| 194 |
+
)
|
| 195 |
+
self._episode.complete(
|
| 196 |
+
success=success,
|
| 197 |
+
extracted_data=self._current_observation.get_extraction_dict(),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Add terminal reward
|
| 201 |
+
terminal_reward, terminal_breakdown = (
|
| 202 |
+
self.reward_engine.compute_terminal_reward(
|
| 203 |
+
self._current_observation,
|
| 204 |
+
success=success,
|
| 205 |
+
ground_truth=self._ground_truth,
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
reward += terminal_reward
|
| 209 |
+
breakdown.total += terminal_reward
|
| 210 |
+
elif truncated:
|
| 211 |
+
self._episode.truncate()
|
| 212 |
+
|
| 213 |
+
info = {
|
| 214 |
+
"action_result": action_result,
|
| 215 |
+
"step_duration_ms": step_duration,
|
| 216 |
+
"episode_step": self._episode.current_step,
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
return (
|
| 220 |
+
self._current_observation,
|
| 221 |
+
reward,
|
| 222 |
+
breakdown.to_dict(),
|
| 223 |
+
terminated,
|
| 224 |
+
truncated,
|
| 225 |
+
info,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def get_state(self) -> dict[str, Any]:
|
| 229 |
+
"""Get the current state of the environment."""
|
| 230 |
+
if self._episode is None:
|
| 231 |
+
return {
|
| 232 |
+
"episode_id": self.episode_id,
|
| 233 |
+
"status": "not_started",
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"episode_id": self.episode_id,
|
| 238 |
+
"task_id": self._episode.task_id,
|
| 239 |
+
"step_number": self._episode.current_step,
|
| 240 |
+
"current_url": self._current_url,
|
| 241 |
+
"is_terminal": self._episode.is_terminal,
|
| 242 |
+
"total_reward": self._episode.total_reward,
|
| 243 |
+
"extracted_data": (
|
| 244 |
+
self._current_observation.get_extraction_dict()
|
| 245 |
+
if self._current_observation
|
| 246 |
+
else {}
|
| 247 |
+
),
|
| 248 |
+
"status": self._episode.status.value,
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
async def _load_task_context(self, task_id: str) -> TaskContext:
|
| 252 |
+
"""Load task context from task repository."""
|
| 253 |
+
# In production, this would fetch from database
|
| 254 |
+
from app.api.routes.tasks import TASK_REPOSITORY
|
| 255 |
+
|
| 256 |
+
task = TASK_REPOSITORY.get(task_id)
|
| 257 |
+
if task:
|
| 258 |
+
return TaskContext(
|
| 259 |
+
task_id=task.id,
|
| 260 |
+
task_name=task.name,
|
| 261 |
+
task_type=task.task_type.value,
|
| 262 |
+
target_fields=[f.name for f in task.fields_to_extract],
|
| 263 |
+
required_fields=task.success_criteria.get("required_fields", []),
|
| 264 |
+
hints=task.hints,
|
| 265 |
+
success_criteria=task.success_criteria,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Default context
|
| 269 |
+
return TaskContext(
|
| 270 |
+
task_id=task_id,
|
| 271 |
+
task_name=f"Task {task_id}",
|
| 272 |
+
task_type="unknown",
|
| 273 |
+
target_fields=[],
|
| 274 |
+
required_fields=[],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def _create_observation(self) -> Observation:
|
| 278 |
+
"""Create an observation from current state."""
|
| 279 |
+
if self._episode is None:
|
| 280 |
+
raise RuntimeError("Episode not initialized")
|
| 281 |
+
|
| 282 |
+
elapsed = time.time() - (self._start_time or time.time())
|
| 283 |
+
|
| 284 |
+
# Get available actions
|
| 285 |
+
available_actions = self._get_available_actions()
|
| 286 |
+
|
| 287 |
+
# Calculate progress
|
| 288 |
+
target_fields = (
|
| 289 |
+
self._task_context.target_fields if self._task_context else []
|
| 290 |
+
)
|
| 291 |
+
extracted_names = {f.field_name for f in self._extracted_fields}
|
| 292 |
+
fields_remaining = [f for f in target_fields if f not in extracted_names]
|
| 293 |
+
progress = (
|
| 294 |
+
len(self._extracted_fields) / len(target_fields)
|
| 295 |
+
if target_fields
|
| 296 |
+
else 0.0
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return Observation(
|
| 300 |
+
episode_id=self.episode_id,
|
| 301 |
+
task_id=self._episode.task_id,
|
| 302 |
+
step_number=self._episode.current_step,
|
| 303 |
+
elapsed_seconds=elapsed,
|
| 304 |
+
current_url=self._current_url,
|
| 305 |
+
page_title=self._page_title,
|
| 306 |
+
page_html=self._page_html,
|
| 307 |
+
navigation_history=self._navigation_history.copy(),
|
| 308 |
+
can_go_back=len(self._navigation_history) > 1,
|
| 309 |
+
task_context=self._task_context,
|
| 310 |
+
extracted_so_far=self._extracted_fields.copy(),
|
| 311 |
+
extraction_progress=progress,
|
| 312 |
+
fields_remaining=fields_remaining,
|
| 313 |
+
memory_context=MemoryContext(),
|
| 314 |
+
available_actions=available_actions,
|
| 315 |
+
tokens_used=self._episode.tokens_used,
|
| 316 |
+
api_calls_made=self._episode.api_calls,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def _get_available_actions(self) -> list[AvailableAction]:
|
| 320 |
+
"""Get list of currently available actions."""
|
| 321 |
+
actions = []
|
| 322 |
+
|
| 323 |
+
# Navigation actions
|
| 324 |
+
actions.append(
|
| 325 |
+
AvailableAction(
|
| 326 |
+
action_type="navigate",
|
| 327 |
+
description="Navigate to a URL",
|
| 328 |
+
parameters={"url": "required"},
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if self._current_url:
|
| 333 |
+
# Page interaction actions
|
| 334 |
+
actions.extend([
|
| 335 |
+
AvailableAction(
|
| 336 |
+
action_type="click",
|
| 337 |
+
description="Click on an element",
|
| 338 |
+
parameters={"selector": "required"},
|
| 339 |
+
),
|
| 340 |
+
AvailableAction(
|
| 341 |
+
action_type="extract_field",
|
| 342 |
+
description="Extract a field from the page",
|
| 343 |
+
parameters={"field_name": "required", "selector": "optional"},
|
| 344 |
+
),
|
| 345 |
+
AvailableAction(
|
| 346 |
+
action_type="search_page",
|
| 347 |
+
description="Search within the current page",
|
| 348 |
+
parameters={"query": "required"},
|
| 349 |
+
),
|
| 350 |
+
])
|
| 351 |
+
|
| 352 |
+
# Always available
|
| 353 |
+
actions.extend([
|
| 354 |
+
AvailableAction(
|
| 355 |
+
action_type="search_engine",
|
| 356 |
+
description="Perform a web search",
|
| 357 |
+
parameters={"query": "required", "engine": "optional"},
|
| 358 |
+
),
|
| 359 |
+
AvailableAction(
|
| 360 |
+
action_type="done",
|
| 361 |
+
description="Mark task as complete",
|
| 362 |
+
parameters={"success": "boolean"},
|
| 363 |
+
),
|
| 364 |
+
])
|
| 365 |
+
|
| 366 |
+
return actions
|
| 367 |
+
|
| 368 |
+
async def _execute_action(self, action: Action) -> dict[str, Any]:
|
| 369 |
+
"""Execute an action and return the result."""
|
| 370 |
+
result: dict[str, Any] = {"success": False}
|
| 371 |
+
|
| 372 |
+
try:
|
| 373 |
+
match action.action_type:
|
| 374 |
+
case ActionType.NAVIGATE:
|
| 375 |
+
result = await self._execute_navigate(action)
|
| 376 |
+
case ActionType.CLICK:
|
| 377 |
+
result = await self._execute_click(action)
|
| 378 |
+
case ActionType.FILL:
|
| 379 |
+
result = await self._execute_fill(action)
|
| 380 |
+
case ActionType.EXTRACT_FIELD:
|
| 381 |
+
result = await self._execute_extract(action)
|
| 382 |
+
case ActionType.SEARCH_ENGINE:
|
| 383 |
+
result = await self._execute_search_engine(action)
|
| 384 |
+
case ActionType.DONE:
|
| 385 |
+
result = {"success": True, "done": True}
|
| 386 |
+
case ActionType.WAIT:
|
| 387 |
+
await self._execute_wait(action)
|
| 388 |
+
result = {"success": True}
|
| 389 |
+
case _:
|
| 390 |
+
result = {
|
| 391 |
+
"success": False,
|
| 392 |
+
"error": f"Action type {action.action_type} not implemented",
|
| 393 |
+
}
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(f"Action execution failed: {e}")
|
| 396 |
+
result = {"success": False, "error": str(e)}
|
| 397 |
+
|
| 398 |
+
return result
|
| 399 |
+
|
| 400 |
+
async def _execute_navigate(self, action: Action) -> dict[str, Any]:
|
| 401 |
+
"""Execute a navigate action."""
|
| 402 |
+
url = action.get_param("url")
|
| 403 |
+
if not url:
|
| 404 |
+
return {"success": False, "error": "URL is required"}
|
| 405 |
+
|
| 406 |
+
# Placeholder - in production would use Playwright
|
| 407 |
+
self._current_url = url
|
| 408 |
+
self._navigation_history.append(url)
|
| 409 |
+
self._page_title = f"Page at {url}"
|
| 410 |
+
self._page_html = f"<html><body><h1>Mock page for {url}</h1></body></html>"
|
| 411 |
+
|
| 412 |
+
return {"success": True, "url": url}
|
| 413 |
+
|
| 414 |
+
async def _execute_click(self, action: Action) -> dict[str, Any]:
|
| 415 |
+
"""Execute a click action."""
|
| 416 |
+
selector = action.get_param("selector")
|
| 417 |
+
if not selector:
|
| 418 |
+
return {"success": False, "error": "Selector is required"}
|
| 419 |
+
|
| 420 |
+
# Placeholder
|
| 421 |
+
return {"success": True, "selector": selector, "clicked": True}
|
| 422 |
+
|
| 423 |
+
async def _execute_fill(self, action: Action) -> dict[str, Any]:
|
| 424 |
+
"""Execute a fill action."""
|
| 425 |
+
selector = action.get_param("selector")
|
| 426 |
+
value = action.get_param("value")
|
| 427 |
+
|
| 428 |
+
if not selector or value is None:
|
| 429 |
+
return {"success": False, "error": "Selector and value are required"}
|
| 430 |
+
|
| 431 |
+
# Placeholder
|
| 432 |
+
return {"success": True, "selector": selector, "filled": True}
|
| 433 |
+
|
| 434 |
+
async def _execute_extract(self, action: Action) -> dict[str, Any]:
|
| 435 |
+
"""Execute an extract action."""
|
| 436 |
+
field_name = action.get_param("field_name")
|
| 437 |
+
if not field_name:
|
| 438 |
+
return {"success": False, "error": "field_name is required"}
|
| 439 |
+
|
| 440 |
+
# Placeholder - in production would actually extract from page
|
| 441 |
+
extracted_field = ExtractedField(
|
| 442 |
+
field_name=field_name,
|
| 443 |
+
value=f"mock_value_for_{field_name}",
|
| 444 |
+
confidence=0.9,
|
| 445 |
+
source_selector=action.get_param("selector"),
|
| 446 |
+
extraction_step=self._episode.current_step if self._episode else 0,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
self._extracted_fields.append(extracted_field)
|
| 450 |
+
|
| 451 |
+
return {
|
| 452 |
+
"success": True,
|
| 453 |
+
"field_name": field_name,
|
| 454 |
+
"value": extracted_field.value,
|
| 455 |
+
"confidence": extracted_field.confidence,
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
async def _execute_search_engine(self, action: Action) -> dict[str, Any]:
|
| 459 |
+
"""Execute a search engine action."""
|
| 460 |
+
query = action.get_param("query")
|
| 461 |
+
if not query:
|
| 462 |
+
return {"success": False, "error": "Query is required"}
|
| 463 |
+
|
| 464 |
+
engine = action.get_param("engine", "google")
|
| 465 |
+
|
| 466 |
+
# Placeholder
|
| 467 |
+
return {
|
| 468 |
+
"success": True,
|
| 469 |
+
"query": query,
|
| 470 |
+
"engine": engine,
|
| 471 |
+
"results": [
|
| 472 |
+
{"title": f"Result 1 for {query}", "url": "https://example.com/1"},
|
| 473 |
+
{"title": f"Result 2 for {query}", "url": "https://example.com/2"},
|
| 474 |
+
],
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
async def _execute_wait(self, action: Action) -> None:
|
| 478 |
+
"""Execute a wait action."""
|
| 479 |
+
import asyncio
|
| 480 |
+
duration_ms = action.get_param("duration_ms", 1000)
|
| 481 |
+
await asyncio.sleep(duration_ms / 1000)
|
| 482 |
+
|
| 483 |
+
def _check_terminated(self, action: Action) -> bool:
|
| 484 |
+
"""Check if the episode should terminate."""
|
| 485 |
+
if action.action_type == ActionType.DONE:
|
| 486 |
+
return True
|
| 487 |
+
if action.action_type == ActionType.FAIL:
|
| 488 |
+
return True
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
def _check_truncated(self) -> bool:
|
| 492 |
+
"""Check if the episode should be truncated."""
|
| 493 |
+
if self._episode is None:
|
| 494 |
+
return False
|
| 495 |
+
if self._episode.current_step >= self._episode.max_steps:
|
| 496 |
+
return True
|
| 497 |
+
return False
|
backend/app/core/episode.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Episode state machine and management."""
|
| 2 |
+
|
| 3 |
+
from datetime import datetime, timezone
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EpisodeStatus(str, Enum):
|
| 11 |
+
"""Status of an episode."""
|
| 12 |
+
|
| 13 |
+
PENDING = "pending"
|
| 14 |
+
RUNNING = "running"
|
| 15 |
+
COMPLETED = "completed"
|
| 16 |
+
FAILED = "failed"
|
| 17 |
+
TRUNCATED = "truncated"
|
| 18 |
+
CANCELLED = "cancelled"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class EpisodeStep(BaseModel):
|
| 22 |
+
"""Record of a single step in the episode."""
|
| 23 |
+
|
| 24 |
+
step_number: int
|
| 25 |
+
timestamp: str
|
| 26 |
+
action_type: str
|
| 27 |
+
action_params: dict[str, Any]
|
| 28 |
+
action_reasoning: str | None = None
|
| 29 |
+
reward: float
|
| 30 |
+
reward_breakdown: dict[str, float]
|
| 31 |
+
observation_summary: dict[str, Any]
|
| 32 |
+
error: str | None = None
|
| 33 |
+
duration_ms: float = 0.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Episode(BaseModel):
|
| 37 |
+
"""
|
| 38 |
+
Represents a complete episode in the RL environment.
|
| 39 |
+
|
| 40 |
+
An episode is a sequence of steps from reset to termination,
|
| 41 |
+
tracking all actions, rewards, and observations.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# Identification
|
| 45 |
+
episode_id: str
|
| 46 |
+
task_id: str
|
| 47 |
+
|
| 48 |
+
# Timing
|
| 49 |
+
created_at: str = Field(
|
| 50 |
+
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
| 51 |
+
)
|
| 52 |
+
started_at: str | None = None
|
| 53 |
+
ended_at: str | None = None
|
| 54 |
+
|
| 55 |
+
# State
|
| 56 |
+
status: EpisodeStatus = EpisodeStatus.PENDING
|
| 57 |
+
current_step: int = 0
|
| 58 |
+
max_steps: int = 50
|
| 59 |
+
|
| 60 |
+
# Seed for reproducibility
|
| 61 |
+
seed: int | None = None
|
| 62 |
+
|
| 63 |
+
# Configuration
|
| 64 |
+
config: dict[str, Any] = Field(default_factory=dict)
|
| 65 |
+
|
| 66 |
+
# Step history
|
| 67 |
+
steps: list[EpisodeStep] = Field(default_factory=list)
|
| 68 |
+
|
| 69 |
+
# Aggregates
|
| 70 |
+
total_reward: float = 0.0
|
| 71 |
+
tokens_used: int = 0
|
| 72 |
+
api_calls: int = 0
|
| 73 |
+
estimated_cost_usd: float = 0.0
|
| 74 |
+
|
| 75 |
+
# Results
|
| 76 |
+
extracted_data: dict[str, Any] = Field(default_factory=dict)
|
| 77 |
+
final_accuracy: float | None = None
|
| 78 |
+
success: bool | None = None
|
| 79 |
+
failure_reason: str | None = None
|
| 80 |
+
|
| 81 |
+
# Navigation history
|
| 82 |
+
urls_visited: list[str] = Field(default_factory=list)
|
| 83 |
+
|
| 84 |
+
def start(self) -> None:
|
| 85 |
+
"""Mark the episode as started."""
|
| 86 |
+
self.status = EpisodeStatus.RUNNING
|
| 87 |
+
self.started_at = datetime.now(timezone.utc).isoformat()
|
| 88 |
+
|
| 89 |
+
def add_step(
|
| 90 |
+
self,
|
| 91 |
+
action_type: str,
|
| 92 |
+
action_params: dict[str, Any],
|
| 93 |
+
reward: float,
|
| 94 |
+
reward_breakdown: dict[str, float],
|
| 95 |
+
observation_summary: dict[str, Any],
|
| 96 |
+
action_reasoning: str | None = None,
|
| 97 |
+
error: str | None = None,
|
| 98 |
+
duration_ms: float = 0.0,
|
| 99 |
+
) -> EpisodeStep:
|
| 100 |
+
"""Add a step to the episode."""
|
| 101 |
+
self.current_step += 1
|
| 102 |
+
|
| 103 |
+
step = EpisodeStep(
|
| 104 |
+
step_number=self.current_step,
|
| 105 |
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
| 106 |
+
action_type=action_type,
|
| 107 |
+
action_params=action_params,
|
| 108 |
+
action_reasoning=action_reasoning,
|
| 109 |
+
reward=reward,
|
| 110 |
+
reward_breakdown=reward_breakdown,
|
| 111 |
+
observation_summary=observation_summary,
|
| 112 |
+
error=error,
|
| 113 |
+
duration_ms=duration_ms,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.steps.append(step)
|
| 117 |
+
self.total_reward += reward
|
| 118 |
+
|
| 119 |
+
return step
|
| 120 |
+
|
| 121 |
+
def complete(
|
| 122 |
+
self,
|
| 123 |
+
success: bool,
|
| 124 |
+
extracted_data: dict[str, Any] | None = None,
|
| 125 |
+
final_accuracy: float | None = None,
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Mark the episode as completed."""
|
| 128 |
+
self.status = EpisodeStatus.COMPLETED
|
| 129 |
+
self.ended_at = datetime.now(timezone.utc).isoformat()
|
| 130 |
+
self.success = success
|
| 131 |
+
if extracted_data:
|
| 132 |
+
self.extracted_data = extracted_data
|
| 133 |
+
self.final_accuracy = final_accuracy
|
| 134 |
+
|
| 135 |
+
def fail(self, reason: str) -> None:
|
| 136 |
+
"""Mark the episode as failed."""
|
| 137 |
+
self.status = EpisodeStatus.FAILED
|
| 138 |
+
self.ended_at = datetime.now(timezone.utc).isoformat()
|
| 139 |
+
self.success = False
|
| 140 |
+
self.failure_reason = reason
|
| 141 |
+
|
| 142 |
+
def truncate(self, reason: str = "max_steps_reached") -> None:
|
| 143 |
+
"""Mark the episode as truncated (stopped early)."""
|
| 144 |
+
self.status = EpisodeStatus.TRUNCATED
|
| 145 |
+
self.ended_at = datetime.now(timezone.utc).isoformat()
|
| 146 |
+
self.failure_reason = reason
|
| 147 |
+
|
| 148 |
+
def cancel(self) -> None:
|
| 149 |
+
"""Mark the episode as cancelled."""
|
| 150 |
+
self.status = EpisodeStatus.CANCELLED
|
| 151 |
+
self.ended_at = datetime.now(timezone.utc).isoformat()
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def is_terminal(self) -> bool:
|
| 155 |
+
"""Check if the episode has terminated."""
|
| 156 |
+
return self.status in [
|
| 157 |
+
EpisodeStatus.COMPLETED,
|
| 158 |
+
EpisodeStatus.FAILED,
|
| 159 |
+
EpisodeStatus.TRUNCATED,
|
| 160 |
+
EpisodeStatus.CANCELLED,
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def duration_seconds(self) -> float | None:
|
| 165 |
+
"""Get episode duration in seconds."""
|
| 166 |
+
if not self.started_at:
|
| 167 |
+
return None
|
| 168 |
+
end = self.ended_at or datetime.now(timezone.utc).isoformat()
|
| 169 |
+
start_dt = datetime.fromisoformat(self.started_at.replace("Z", "+00:00"))
|
| 170 |
+
end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
|
| 171 |
+
return (end_dt - start_dt).total_seconds()
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def average_reward(self) -> float:
|
| 175 |
+
"""Get average reward per step."""
|
| 176 |
+
if not self.steps:
|
| 177 |
+
return 0.0
|
| 178 |
+
return self.total_reward / len(self.steps)
|
| 179 |
+
|
| 180 |
+
def get_summary(self) -> dict[str, Any]:
|
| 181 |
+
"""Get a summary of the episode."""
|
| 182 |
+
return {
|
| 183 |
+
"episode_id": self.episode_id,
|
| 184 |
+
"task_id": self.task_id,
|
| 185 |
+
"status": self.status.value,
|
| 186 |
+
"steps": self.current_step,
|
| 187 |
+
"total_reward": self.total_reward,
|
| 188 |
+
"average_reward": self.average_reward,
|
| 189 |
+
"duration_seconds": self.duration_seconds,
|
| 190 |
+
"tokens_used": self.tokens_used,
|
| 191 |
+
"estimated_cost_usd": self.estimated_cost_usd,
|
| 192 |
+
"success": self.success,
|
| 193 |
+
"fields_extracted": len(self.extracted_data),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
def get_step_history(
|
| 197 |
+
self,
|
| 198 |
+
start: int = 0,
|
| 199 |
+
end: int | None = None,
|
| 200 |
+
) -> list[EpisodeStep]:
|
| 201 |
+
"""Get a slice of the step history."""
|
| 202 |
+
return self.steps[start:end]
|
| 203 |
+
|
| 204 |
+
def get_action_sequence(self) -> list[str]:
|
| 205 |
+
"""Get the sequence of action types taken."""
|
| 206 |
+
return [step.action_type for step in self.steps]
|
| 207 |
+
|
| 208 |
+
def get_reward_history(self) -> list[float]:
|
| 209 |
+
"""Get the sequence of rewards received."""
|
| 210 |
+
return [step.reward for step in self.steps]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class EpisodeManager:
|
| 214 |
+
"""Manager for episode lifecycle."""
|
| 215 |
+
|
| 216 |
+
def __init__(self) -> None:
|
| 217 |
+
"""Initialize the episode manager."""
|
| 218 |
+
self._episodes: dict[str, Episode] = {}
|
| 219 |
+
|
| 220 |
+
def create_episode(
|
| 221 |
+
self,
|
| 222 |
+
episode_id: str,
|
| 223 |
+
task_id: str,
|
| 224 |
+
max_steps: int = 50,
|
| 225 |
+
seed: int | None = None,
|
| 226 |
+
config: dict[str, Any] | None = None,
|
| 227 |
+
) -> Episode:
|
| 228 |
+
"""Create a new episode."""
|
| 229 |
+
episode = Episode(
|
| 230 |
+
episode_id=episode_id,
|
| 231 |
+
task_id=task_id,
|
| 232 |
+
max_steps=max_steps,
|
| 233 |
+
seed=seed,
|
| 234 |
+
config=config or {},
|
| 235 |
+
)
|
| 236 |
+
self._episodes[episode_id] = episode
|
| 237 |
+
return episode
|
| 238 |
+
|
| 239 |
+
def get_episode(self, episode_id: str) -> Episode | None:
|
| 240 |
+
"""Get an episode by ID."""
|
| 241 |
+
return self._episodes.get(episode_id)
|
| 242 |
+
|
| 243 |
+
def remove_episode(self, episode_id: str) -> bool:
|
| 244 |
+
"""Remove an episode."""
|
| 245 |
+
if episode_id in self._episodes:
|
| 246 |
+
del self._episodes[episode_id]
|
| 247 |
+
return True
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
def list_episodes(
|
| 251 |
+
self,
|
| 252 |
+
status: EpisodeStatus | None = None,
|
| 253 |
+
task_id: str | None = None,
|
| 254 |
+
) -> list[Episode]:
|
| 255 |
+
"""List episodes with optional filtering."""
|
| 256 |
+
episodes = list(self._episodes.values())
|
| 257 |
+
if status:
|
| 258 |
+
episodes = [e for e in episodes if e.status == status]
|
| 259 |
+
if task_id:
|
| 260 |
+
episodes = [e for e in episodes if e.task_id == task_id]
|
| 261 |
+
return episodes
|
backend/app/core/observation.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Observation model for the RL environment."""
|
| 2 |
+
|
| 3 |
+
from datetime import datetime, timezone
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ToolSnapshot(BaseModel):
|
| 10 |
+
"""Snapshot of a tool from the registry."""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
description: str
|
| 14 |
+
parameters: list[dict[str, Any]]
|
| 15 |
+
enabled: bool = True
|
| 16 |
+
cost_estimate: float = 0.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MemoryContext(BaseModel):
|
| 20 |
+
"""Context from memory systems."""
|
| 21 |
+
|
| 22 |
+
short_term: list[dict[str, Any]] = Field(default_factory=list)
|
| 23 |
+
working: list[dict[str, Any]] = Field(default_factory=list)
|
| 24 |
+
long_term_relevant: list[dict[str, Any]] = Field(default_factory=list)
|
| 25 |
+
shared: dict[str, Any] = Field(default_factory=dict)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PageElement(BaseModel):
|
| 29 |
+
"""A significant element on the page."""
|
| 30 |
+
|
| 31 |
+
selector: str
|
| 32 |
+
tag: str
|
| 33 |
+
text: str | None = None
|
| 34 |
+
attributes: dict[str, str] = Field(default_factory=dict)
|
| 35 |
+
is_interactive: bool = False
|
| 36 |
+
is_visible: bool = True
|
| 37 |
+
bounding_box: dict[str, float] | None = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ExtractedField(BaseModel):
|
| 41 |
+
"""A field that has been extracted."""
|
| 42 |
+
|
| 43 |
+
field_name: str
|
| 44 |
+
value: Any
|
| 45 |
+
confidence: float = 1.0
|
| 46 |
+
source_selector: str | None = None
|
| 47 |
+
extraction_step: int = 0
|
| 48 |
+
verified: bool = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AvailableAction(BaseModel):
|
| 52 |
+
"""An action that is currently available."""
|
| 53 |
+
|
| 54 |
+
action_type: str
|
| 55 |
+
description: str
|
| 56 |
+
parameters: dict[str, Any] = Field(default_factory=dict)
|
| 57 |
+
estimated_reward: float | None = None
|
| 58 |
+
risk_level: str = "low"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TaskContext(BaseModel):
|
| 62 |
+
"""Context about the current task."""
|
| 63 |
+
|
| 64 |
+
task_id: str
|
| 65 |
+
task_name: str
|
| 66 |
+
task_type: str
|
| 67 |
+
target_fields: list[str]
|
| 68 |
+
required_fields: list[str]
|
| 69 |
+
hints: list[str] = Field(default_factory=list)
|
| 70 |
+
success_criteria: dict[str, Any] = Field(default_factory=dict)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Observation(BaseModel):
|
| 74 |
+
"""
|
| 75 |
+
Complete observation provided to the agent after each step.
|
| 76 |
+
|
| 77 |
+
Contains all information the agent needs to make decisions:
|
| 78 |
+
- Episode and task context
|
| 79 |
+
- Current page state
|
| 80 |
+
- Extracted data so far
|
| 81 |
+
- Memory context
|
| 82 |
+
- Available tools and actions
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Episode identification
|
| 86 |
+
episode_id: str = Field(..., description="Unique episode identifier")
|
| 87 |
+
task_id: str = Field(..., description="Task being executed")
|
| 88 |
+
step_number: int = Field(..., description="Current step in the episode")
|
| 89 |
+
|
| 90 |
+
# Timing
|
| 91 |
+
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
| 92 |
+
elapsed_seconds: float = Field(default=0.0, description="Time elapsed in episode")
|
| 93 |
+
|
| 94 |
+
# Page state
|
| 95 |
+
current_url: str | None = Field(default=None, description="Current page URL")
|
| 96 |
+
page_title: str | None = Field(default=None, description="Current page title")
|
| 97 |
+
page_html: str | None = Field(default=None, description="Full HTML of current page")
|
| 98 |
+
page_html_chunked: list[str] = Field(
|
| 99 |
+
default_factory=list,
|
| 100 |
+
description="HTML split into semantic chunks",
|
| 101 |
+
)
|
| 102 |
+
page_text: str | None = Field(default=None, description="Visible text content")
|
| 103 |
+
page_elements: list[PageElement] = Field(
|
| 104 |
+
default_factory=list,
|
| 105 |
+
description="Significant page elements",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Navigation state
|
| 109 |
+
navigation_history: list[str] = Field(
|
| 110 |
+
default_factory=list,
|
| 111 |
+
description="URLs visited in this episode",
|
| 112 |
+
)
|
| 113 |
+
can_go_back: bool = Field(default=False)
|
| 114 |
+
can_go_forward: bool = Field(default=False)
|
| 115 |
+
|
| 116 |
+
# Task context
|
| 117 |
+
task_context: TaskContext | None = Field(
|
| 118 |
+
default=None,
|
| 119 |
+
description="Information about the current task",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Extraction state
|
| 123 |
+
extracted_so_far: list[ExtractedField] = Field(
|
| 124 |
+
default_factory=list,
|
| 125 |
+
description="Fields extracted so far",
|
| 126 |
+
)
|
| 127 |
+
extraction_progress: float = Field(
|
| 128 |
+
default=0.0,
|
| 129 |
+
description="Progress towards task completion (0-1)",
|
| 130 |
+
)
|
| 131 |
+
fields_remaining: list[str] = Field(
|
| 132 |
+
default_factory=list,
|
| 133 |
+
description="Fields still to be extracted",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Memory context
|
| 137 |
+
memory_context: MemoryContext = Field(
|
| 138 |
+
default_factory=MemoryContext,
|
| 139 |
+
description="Relevant memories from all layers",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Tool registry snapshot
|
| 143 |
+
tool_registry_snapshot: list[ToolSnapshot] = Field(
|
| 144 |
+
default_factory=list,
|
| 145 |
+
description="Available tools and their state",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Available actions
|
| 149 |
+
available_actions: list[AvailableAction] = Field(
|
| 150 |
+
default_factory=list,
|
| 151 |
+
description="Actions available in current state",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Agent coordination
|
| 155 |
+
pending_messages: list[dict[str, Any]] = Field(
|
| 156 |
+
default_factory=list,
|
| 157 |
+
description="Messages from other agents",
|
| 158 |
+
)
|
| 159 |
+
active_plan: dict[str, Any] | None = Field(
|
| 160 |
+
default=None,
|
| 161 |
+
description="Current execution plan if any",
|
| 162 |
+
)
|
| 163 |
+
current_plan_step: int | None = Field(
|
| 164 |
+
default=None,
|
| 165 |
+
description="Current step in the plan",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Error state
|
| 169 |
+
last_action_error: str | None = Field(
|
| 170 |
+
default=None,
|
| 171 |
+
description="Error from last action if any",
|
| 172 |
+
)
|
| 173 |
+
consecutive_errors: int = Field(
|
| 174 |
+
default=0,
|
| 175 |
+
description="Number of consecutive action errors",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Cost tracking
|
| 179 |
+
tokens_used: int = Field(default=0, description="LLM tokens used so far")
|
| 180 |
+
api_calls_made: int = Field(default=0, description="API calls made")
|
| 181 |
+
estimated_cost_usd: float = Field(default=0.0, description="Estimated cost so far")
|
| 182 |
+
|
| 183 |
+
# Hints and guidance
|
| 184 |
+
system_hints: list[str] = Field(
|
| 185 |
+
default_factory=list,
|
| 186 |
+
description="Hints from the environment or previous steps",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
model_config = ConfigDict(
|
| 190 |
+
json_schema_extra={
|
| 191 |
+
"example": {
|
| 192 |
+
"episode_id": "ep_abc123",
|
| 193 |
+
"task_id": "task_001",
|
| 194 |
+
"step_number": 5,
|
| 195 |
+
"current_url": "https://example.com/product/123",
|
| 196 |
+
"page_title": "Product Details - Example Store",
|
| 197 |
+
"extracted_so_far": [
|
| 198 |
+
{
|
| 199 |
+
"field_name": "product_name",
|
| 200 |
+
"value": "Example Product",
|
| 201 |
+
"confidence": 0.95,
|
| 202 |
+
}
|
| 203 |
+
],
|
| 204 |
+
"extraction_progress": 0.33,
|
| 205 |
+
"fields_remaining": ["price", "description"],
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def get_extraction_dict(self) -> dict[str, Any]:
|
| 211 |
+
"""Get extracted fields as a dictionary."""
|
| 212 |
+
return {field.field_name: field.value for field in self.extracted_so_far}
|
| 213 |
+
|
| 214 |
+
def is_field_extracted(self, field_name: str) -> bool:
|
| 215 |
+
"""Check if a field has been extracted."""
|
| 216 |
+
return any(f.field_name == field_name for f in self.extracted_so_far)
|
| 217 |
+
|
| 218 |
+
def get_context_summary(self) -> str:
|
| 219 |
+
"""Get a summary of the current context for LLM prompts."""
|
| 220 |
+
parts = [
|
| 221 |
+
f"Step {self.step_number}",
|
| 222 |
+
f"URL: {self.current_url or 'None'}",
|
| 223 |
+
f"Progress: {self.extraction_progress:.0%}",
|
| 224 |
+
f"Extracted: {len(self.extracted_so_far)}/{len(self.extracted_so_far) + len(self.fields_remaining)} fields",
|
| 225 |
+
]
|
| 226 |
+
if self.last_action_error:
|
| 227 |
+
parts.append(f"Last error: {self.last_action_error}")
|
| 228 |
+
return " | ".join(parts)
|
backend/app/core/reward.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward computation engine with component breakdown."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from app.config import Settings, get_settings
|
| 7 |
+
from app.core.action import Action, ActionType
|
| 8 |
+
from app.core.observation import Observation
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RewardBreakdown:
|
| 13 |
+
"""Detailed breakdown of reward components."""
|
| 14 |
+
|
| 15 |
+
# Core components
|
| 16 |
+
accuracy: float = 0.0
|
| 17 |
+
efficiency: float = 0.0
|
| 18 |
+
cost: float = 0.0
|
| 19 |
+
completeness: float = 0.0
|
| 20 |
+
|
| 21 |
+
# Bonus/penalty components
|
| 22 |
+
progress_bonus: float = 0.0
|
| 23 |
+
error_penalty: float = 0.0
|
| 24 |
+
time_penalty: float = 0.0
|
| 25 |
+
redundancy_penalty: float = 0.0
|
| 26 |
+
exploration_bonus: float = 0.0
|
| 27 |
+
verification_bonus: float = 0.0
|
| 28 |
+
|
| 29 |
+
# Metadata
|
| 30 |
+
total: float = 0.0
|
| 31 |
+
components: dict[str, float] = field(default_factory=dict)
|
| 32 |
+
|
| 33 |
+
def compute_total(self, weights: dict[str, float]) -> float:
|
| 34 |
+
"""Compute total reward with weights."""
|
| 35 |
+
self.total = (
|
| 36 |
+
self.accuracy * weights.get("accuracy", 0.4)
|
| 37 |
+
+ self.efficiency * weights.get("efficiency", 0.2)
|
| 38 |
+
+ self.cost * weights.get("cost", 0.2)
|
| 39 |
+
+ self.completeness * weights.get("completeness", 0.2)
|
| 40 |
+
+ self.progress_bonus
|
| 41 |
+
+ self.exploration_bonus
|
| 42 |
+
+ self.verification_bonus
|
| 43 |
+
- self.error_penalty
|
| 44 |
+
- self.time_penalty
|
| 45 |
+
- self.redundancy_penalty
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.components = {
|
| 49 |
+
"accuracy": self.accuracy,
|
| 50 |
+
"efficiency": self.efficiency,
|
| 51 |
+
"cost": self.cost,
|
| 52 |
+
"completeness": self.completeness,
|
| 53 |
+
"progress_bonus": self.progress_bonus,
|
| 54 |
+
"error_penalty": self.error_penalty,
|
| 55 |
+
"time_penalty": self.time_penalty,
|
| 56 |
+
"redundancy_penalty": self.redundancy_penalty,
|
| 57 |
+
"exploration_bonus": self.exploration_bonus,
|
| 58 |
+
"verification_bonus": self.verification_bonus,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
return self.total
|
| 62 |
+
|
| 63 |
+
def to_dict(self) -> dict[str, float]:
|
| 64 |
+
"""Convert to dictionary."""
|
| 65 |
+
return {
|
| 66 |
+
"total": self.total,
|
| 67 |
+
**self.components,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class RewardEngine:
|
| 72 |
+
"""
|
| 73 |
+
Computes rewards for actions in the web scraping environment.
|
| 74 |
+
|
| 75 |
+
Reward components:
|
| 76 |
+
- Accuracy: How correct extracted data is
|
| 77 |
+
- Efficiency: Steps taken vs optimal
|
| 78 |
+
- Cost: API/compute costs
|
| 79 |
+
- Completeness: Progress towards task completion
|
| 80 |
+
|
| 81 |
+
Plus bonuses/penalties for:
|
| 82 |
+
- Progress: Making progress towards goal
|
| 83 |
+
- Errors: Failed actions or invalid extractions
|
| 84 |
+
- Time: Taking too long
|
| 85 |
+
- Redundancy: Repeating unsuccessful actions
|
| 86 |
+
- Exploration: Discovering new information
|
| 87 |
+
- Verification: Validating extracted data
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, settings: Settings | None = None) -> None:
|
| 91 |
+
"""Initialize the reward engine."""
|
| 92 |
+
self.settings = settings or get_settings()
|
| 93 |
+
self.weights = {
|
| 94 |
+
"accuracy": self.settings.reward_accuracy_weight,
|
| 95 |
+
"efficiency": self.settings.reward_efficiency_weight,
|
| 96 |
+
"cost": self.settings.reward_cost_weight,
|
| 97 |
+
"completeness": self.settings.reward_completeness_weight,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Tracking for penalties
|
| 101 |
+
self._action_history: list[Action] = []
|
| 102 |
+
self._extraction_attempts: dict[str, int] = {}
|
| 103 |
+
self._url_visits: dict[str, int] = {}
|
| 104 |
+
|
| 105 |
+
def reset(self) -> None:
|
| 106 |
+
"""Reset tracking state for a new episode."""
|
| 107 |
+
self._action_history.clear()
|
| 108 |
+
self._extraction_attempts.clear()
|
| 109 |
+
self._url_visits.clear()
|
| 110 |
+
|
| 111 |
+
def compute_reward(
|
| 112 |
+
self,
|
| 113 |
+
action: Action,
|
| 114 |
+
prev_observation: Observation,
|
| 115 |
+
new_observation: Observation,
|
| 116 |
+
ground_truth: dict[str, Any] | None = None,
|
| 117 |
+
max_steps: int = 50,
|
| 118 |
+
) -> tuple[float, RewardBreakdown]:
|
| 119 |
+
"""
|
| 120 |
+
Compute reward for an action.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
action: The action that was taken.
|
| 124 |
+
prev_observation: Observation before the action.
|
| 125 |
+
new_observation: Observation after the action.
|
| 126 |
+
ground_truth: Optional ground truth data for accuracy calculation.
|
| 127 |
+
max_steps: Maximum steps allowed in episode.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Tuple of (total_reward, breakdown).
|
| 131 |
+
"""
|
| 132 |
+
breakdown = RewardBreakdown()
|
| 133 |
+
|
| 134 |
+
# Track action
|
| 135 |
+
self._action_history.append(action)
|
| 136 |
+
|
| 137 |
+
# Compute accuracy component
|
| 138 |
+
breakdown.accuracy = self._compute_accuracy(
|
| 139 |
+
action, new_observation, ground_truth
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Compute efficiency component
|
| 143 |
+
breakdown.efficiency = self._compute_efficiency(
|
| 144 |
+
new_observation.step_number, max_steps
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Compute cost component
|
| 148 |
+
breakdown.cost = self._compute_cost_reward(new_observation)
|
| 149 |
+
|
| 150 |
+
# Compute completeness component
|
| 151 |
+
breakdown.completeness = self._compute_completeness(
|
| 152 |
+
prev_observation, new_observation
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Compute bonuses
|
| 156 |
+
breakdown.progress_bonus = self._compute_progress_bonus(
|
| 157 |
+
prev_observation, new_observation
|
| 158 |
+
)
|
| 159 |
+
breakdown.exploration_bonus = self._compute_exploration_bonus(
|
| 160 |
+
action, new_observation
|
| 161 |
+
)
|
| 162 |
+
breakdown.verification_bonus = self._compute_verification_bonus(
|
| 163 |
+
action, new_observation
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Compute penalties
|
| 167 |
+
breakdown.error_penalty = self._compute_error_penalty(new_observation)
|
| 168 |
+
breakdown.time_penalty = self._compute_time_penalty(new_observation, max_steps)
|
| 169 |
+
breakdown.redundancy_penalty = self._compute_redundancy_penalty(action)
|
| 170 |
+
|
| 171 |
+
# Compute total
|
| 172 |
+
total = breakdown.compute_total(self.weights)
|
| 173 |
+
|
| 174 |
+
return total, breakdown
|
| 175 |
+
|
| 176 |
+
def _compute_accuracy(
|
| 177 |
+
self,
|
| 178 |
+
action: Action,
|
| 179 |
+
observation: Observation,
|
| 180 |
+
ground_truth: dict[str, Any] | None,
|
| 181 |
+
) -> float:
|
| 182 |
+
"""Compute accuracy reward component."""
|
| 183 |
+
if ground_truth is None:
|
| 184 |
+
# Without ground truth, use confidence scores
|
| 185 |
+
if observation.extracted_so_far:
|
| 186 |
+
avg_confidence = sum(
|
| 187 |
+
f.confidence for f in observation.extracted_so_far
|
| 188 |
+
) / len(observation.extracted_so_far)
|
| 189 |
+
return avg_confidence
|
| 190 |
+
return 0.5 # Neutral
|
| 191 |
+
|
| 192 |
+
# With ground truth, compute actual accuracy
|
| 193 |
+
extracted = observation.get_extraction_dict()
|
| 194 |
+
if not extracted:
|
| 195 |
+
return 0.0
|
| 196 |
+
|
| 197 |
+
correct = 0
|
| 198 |
+
total = 0
|
| 199 |
+
for field_name, expected_value in ground_truth.items():
|
| 200 |
+
if field_name in extracted:
|
| 201 |
+
total += 1
|
| 202 |
+
actual_value = extracted[field_name]
|
| 203 |
+
if self._values_match(actual_value, expected_value):
|
| 204 |
+
correct += 1
|
| 205 |
+
|
| 206 |
+
if total == 0:
|
| 207 |
+
return 0.0
|
| 208 |
+
|
| 209 |
+
return correct / total
|
| 210 |
+
|
| 211 |
+
def _values_match(self, actual: Any, expected: Any) -> bool:
|
| 212 |
+
"""Check if extracted value matches expected value."""
|
| 213 |
+
if actual == expected:
|
| 214 |
+
return True
|
| 215 |
+
|
| 216 |
+
# Fuzzy matching for strings
|
| 217 |
+
if isinstance(actual, str) and isinstance(expected, str):
|
| 218 |
+
actual_clean = actual.strip().lower()
|
| 219 |
+
expected_clean = expected.strip().lower()
|
| 220 |
+
if actual_clean == expected_clean:
|
| 221 |
+
return True
|
| 222 |
+
# Partial match
|
| 223 |
+
if expected_clean in actual_clean or actual_clean in expected_clean:
|
| 224 |
+
return True
|
| 225 |
+
|
| 226 |
+
# Numeric comparison with tolerance
|
| 227 |
+
if isinstance(actual, (int, float)) and isinstance(expected, (int, float)):
|
| 228 |
+
tolerance = abs(expected) * 0.01 if expected != 0 else 0.01
|
| 229 |
+
return abs(actual - expected) <= tolerance
|
| 230 |
+
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
def _compute_efficiency(self, current_step: int, max_steps: int) -> float:
|
| 234 |
+
"""Compute efficiency based on steps taken."""
|
| 235 |
+
# Higher reward for completing tasks in fewer steps
|
| 236 |
+
remaining_ratio = (max_steps - current_step) / max_steps
|
| 237 |
+
return max(0.0, remaining_ratio)
|
| 238 |
+
|
| 239 |
+
def _compute_cost_reward(self, observation: Observation) -> float:
|
| 240 |
+
"""Compute reward based on cost efficiency."""
|
| 241 |
+
# Penalize high token usage and API calls
|
| 242 |
+
max_expected_tokens = 10000
|
| 243 |
+
max_expected_calls = 50
|
| 244 |
+
|
| 245 |
+
token_efficiency = 1.0 - min(
|
| 246 |
+
observation.tokens_used / max_expected_tokens, 1.0
|
| 247 |
+
)
|
| 248 |
+
call_efficiency = 1.0 - min(
|
| 249 |
+
observation.api_calls_made / max_expected_calls, 1.0
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return (token_efficiency + call_efficiency) / 2
|
| 253 |
+
|
| 254 |
+
def _compute_completeness(
|
| 255 |
+
self,
|
| 256 |
+
prev_observation: Observation,
|
| 257 |
+
new_observation: Observation,
|
| 258 |
+
) -> float:
|
| 259 |
+
"""Compute completeness based on extraction progress."""
|
| 260 |
+
return new_observation.extraction_progress
|
| 261 |
+
|
| 262 |
+
def _compute_progress_bonus(
|
| 263 |
+
self,
|
| 264 |
+
prev_observation: Observation,
|
| 265 |
+
new_observation: Observation,
|
| 266 |
+
) -> float:
|
| 267 |
+
"""Bonus for making progress."""
|
| 268 |
+
progress_delta = (
|
| 269 |
+
new_observation.extraction_progress - prev_observation.extraction_progress
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Bonus for new extractions
|
| 273 |
+
new_extractions = len(new_observation.extracted_so_far) - len(
|
| 274 |
+
prev_observation.extracted_so_far
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
bonus = 0.0
|
| 278 |
+
if progress_delta > 0:
|
| 279 |
+
bonus += progress_delta * 0.5
|
| 280 |
+
if new_extractions > 0:
|
| 281 |
+
bonus += new_extractions * 0.1
|
| 282 |
+
|
| 283 |
+
return bonus
|
| 284 |
+
|
| 285 |
+
def _compute_exploration_bonus(
|
| 286 |
+
self,
|
| 287 |
+
action: Action,
|
| 288 |
+
observation: Observation,
|
| 289 |
+
) -> float:
|
| 290 |
+
"""Bonus for exploring new pages."""
|
| 291 |
+
bonus = 0.0
|
| 292 |
+
|
| 293 |
+
if action.action_type == ActionType.NAVIGATE:
|
| 294 |
+
url = action.get_param("url", "")
|
| 295 |
+
if url and url not in self._url_visits:
|
| 296 |
+
bonus += 0.05
|
| 297 |
+
self._url_visits[url] = self._url_visits.get(url, 0) + 1
|
| 298 |
+
|
| 299 |
+
return bonus
|
| 300 |
+
|
| 301 |
+
def _compute_verification_bonus(
|
| 302 |
+
self,
|
| 303 |
+
action: Action,
|
| 304 |
+
observation: Observation,
|
| 305 |
+
) -> float:
|
| 306 |
+
"""Bonus for verification actions."""
|
| 307 |
+
if action.action_type in [ActionType.VERIFY_FACT, ActionType.VERIFY_FIELD]:
|
| 308 |
+
return 0.05
|
| 309 |
+
return 0.0
|
| 310 |
+
|
| 311 |
+
def _compute_error_penalty(self, observation: Observation) -> float:
|
| 312 |
+
"""Penalty for errors."""
|
| 313 |
+
if observation.last_action_error:
|
| 314 |
+
base_penalty = 0.1
|
| 315 |
+
consecutive_penalty = observation.consecutive_errors * 0.05
|
| 316 |
+
return base_penalty + consecutive_penalty
|
| 317 |
+
return 0.0
|
| 318 |
+
|
| 319 |
+
def _compute_time_penalty(
|
| 320 |
+
self,
|
| 321 |
+
observation: Observation,
|
| 322 |
+
max_steps: int,
|
| 323 |
+
) -> float:
|
| 324 |
+
"""Penalty for taking too long."""
|
| 325 |
+
step_ratio = observation.step_number / max_steps
|
| 326 |
+
if step_ratio > 0.8:
|
| 327 |
+
return (step_ratio - 0.8) * 0.5
|
| 328 |
+
return 0.0
|
| 329 |
+
|
| 330 |
+
def _compute_redundancy_penalty(self, action: Action) -> float:
|
| 331 |
+
"""Penalty for redundant actions."""
|
| 332 |
+
if len(self._action_history) < 2:
|
| 333 |
+
return 0.0
|
| 334 |
+
|
| 335 |
+
# Check for repeated extract attempts on same field
|
| 336 |
+
if action.action_type == ActionType.EXTRACT_FIELD:
|
| 337 |
+
field = action.get_param("field_name", "")
|
| 338 |
+
attempts = self._extraction_attempts.get(field, 0)
|
| 339 |
+
self._extraction_attempts[field] = attempts + 1
|
| 340 |
+
if attempts > 0:
|
| 341 |
+
return min(attempts * 0.05, 0.2)
|
| 342 |
+
|
| 343 |
+
# Check for repeated navigation to same URL
|
| 344 |
+
if action.action_type == ActionType.NAVIGATE:
|
| 345 |
+
url = action.get_param("url", "")
|
| 346 |
+
visits = self._url_visits.get(url, 0)
|
| 347 |
+
if visits > 1:
|
| 348 |
+
return min((visits - 1) * 0.03, 0.15)
|
| 349 |
+
|
| 350 |
+
return 0.0
|
| 351 |
+
|
| 352 |
+
def compute_terminal_reward(
|
| 353 |
+
self,
|
| 354 |
+
observation: Observation,
|
| 355 |
+
success: bool,
|
| 356 |
+
ground_truth: dict[str, Any] | None = None,
|
| 357 |
+
) -> tuple[float, RewardBreakdown]:
|
| 358 |
+
"""
|
| 359 |
+
Compute final reward at episode termination.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
observation: Final observation.
|
| 363 |
+
success: Whether the task was completed successfully.
|
| 364 |
+
ground_truth: Optional ground truth for accuracy.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Tuple of (total_reward, breakdown).
|
| 368 |
+
"""
|
| 369 |
+
breakdown = RewardBreakdown()
|
| 370 |
+
|
| 371 |
+
if success:
|
| 372 |
+
# Big bonus for successful completion
|
| 373 |
+
breakdown.completeness = 1.0
|
| 374 |
+
breakdown.progress_bonus = 0.5
|
| 375 |
+
|
| 376 |
+
# Compute final accuracy
|
| 377 |
+
if ground_truth:
|
| 378 |
+
extracted = observation.get_extraction_dict()
|
| 379 |
+
correct = sum(
|
| 380 |
+
1 for k, v in ground_truth.items()
|
| 381 |
+
if k in extracted and self._values_match(extracted[k], v)
|
| 382 |
+
)
|
| 383 |
+
total = len(ground_truth)
|
| 384 |
+
breakdown.accuracy = correct / total if total > 0 else 1.0
|
| 385 |
+
else:
|
| 386 |
+
breakdown.accuracy = observation.extraction_progress
|
| 387 |
+
|
| 388 |
+
# Efficiency bonus for fast completion
|
| 389 |
+
breakdown.efficiency = 1.0 - (
|
| 390 |
+
observation.step_number / self.settings.max_steps_per_episode
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
# Partial credit for progress made
|
| 394 |
+
breakdown.completeness = observation.extraction_progress * 0.5
|
| 395 |
+
breakdown.error_penalty = 0.3
|
| 396 |
+
|
| 397 |
+
total = breakdown.compute_total(self.weights)
|
| 398 |
+
return total, breakdown
|