NeerajCodz commited on
Commit
ab65628
·
1 Parent(s): ff3e1be

feat: add core RL environment models (observation, action, reward, env)

Browse files
backend/app/core/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core module - RL environment, observations, actions, and rewards."""
2
+
3
+ from app.core.action import Action, ActionType
4
+ from app.core.env import WebScraperEnv
5
+ from app.core.episode import Episode, EpisodeStatus
6
+ from app.core.observation import Observation
7
+ from app.core.reward import RewardEngine, RewardBreakdown
8
+
9
+ __all__ = [
10
+ "Action",
11
+ "ActionType",
12
+ "WebScraperEnv",
13
+ "Episode",
14
+ "EpisodeStatus",
15
+ "Observation",
16
+ "RewardEngine",
17
+ "RewardBreakdown",
18
+ ]
backend/app/core/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (627 Bytes). View file
 
backend/app/core/__pycache__/action.cpython-314.pyc ADDED
Binary file (20.3 kB). View file
 
backend/app/core/__pycache__/env.cpython-314.pyc ADDED
Binary file (22.6 kB). View file
 
backend/app/core/__pycache__/episode.cpython-314.pyc ADDED
Binary file (16.1 kB). View file
 
backend/app/core/__pycache__/observation.cpython-314.pyc ADDED
Binary file (12.9 kB). View file
 
backend/app/core/__pycache__/reward.cpython-314.pyc ADDED
Binary file (19.3 kB). View file
 
backend/app/core/action.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Action model for the RL environment."""
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
7
+
8
+
9
+ class ActionType(str, Enum):
10
+ """All possible action types in the environment."""
11
+
12
+ # Navigation actions
13
+ NAVIGATE = "navigate"
14
+ GO_BACK = "go_back"
15
+ GO_FORWARD = "go_forward"
16
+ REFRESH = "refresh"
17
+
18
+ # Interaction actions
19
+ CLICK = "click"
20
+ FILL = "fill"
21
+ SELECT = "select"
22
+ SCROLL = "scroll"
23
+ HOVER = "hover"
24
+
25
+ # Extraction actions
26
+ EXTRACT_FIELD = "extract_field"
27
+ EXTRACT_TABLE = "extract_table"
28
+ EXTRACT_LIST = "extract_list"
29
+
30
+ # Search actions
31
+ SEARCH_PAGE = "search_page"
32
+ SEARCH_ENGINE = "search_engine"
33
+
34
+ # Verification actions
35
+ VERIFY_FACT = "verify_fact"
36
+ VERIFY_FIELD = "verify_field"
37
+
38
+ # Memory actions
39
+ STORE_MEMORY = "store_memory"
40
+ RECALL_MEMORY = "recall_memory"
41
+
42
+ # Tool actions
43
+ MCP_TOOL_CALL = "mcp_tool_call"
44
+
45
+ # Planning actions
46
+ CREATE_PLAN = "create_plan"
47
+ UPDATE_PLAN = "update_plan"
48
+
49
+ # Communication actions
50
+ SEND_MESSAGE = "send_message"
51
+
52
+ # Control actions
53
+ WAIT = "wait"
54
+ DONE = "done"
55
+ FAIL = "fail"
56
+
57
+
58
+ class NavigateParams(BaseModel):
59
+ """Parameters for navigation actions."""
60
+
61
+ url: str
62
+ wait_for: str | None = None
63
+ timeout_ms: int = 30000
64
+
65
+
66
+ class ClickParams(BaseModel):
67
+ """Parameters for click actions."""
68
+
69
+ selector: str
70
+ button: str = "left"
71
+ click_count: int = 1
72
+ wait_after_ms: int = 500
73
+
74
+
75
+ class FillParams(BaseModel):
76
+ """Parameters for form fill actions."""
77
+
78
+ selector: str
79
+ value: str
80
+ clear_first: bool = True
81
+
82
+
83
+ class SelectParams(BaseModel):
84
+ """Parameters for select dropdown actions."""
85
+
86
+ selector: str
87
+ value: str | None = None
88
+ label: str | None = None
89
+ index: int | None = None
90
+
91
+
92
+ class ScrollParams(BaseModel):
93
+ """Parameters for scroll actions."""
94
+
95
+ direction: str = "down"
96
+ amount: int | str = "page"
97
+ selector: str | None = None
98
+
99
+
100
+ class ExtractFieldParams(BaseModel):
101
+ """Parameters for field extraction actions."""
102
+
103
+ field_name: str
104
+ selector: str | None = None
105
+ extraction_method: str = "text"
106
+ attribute: str | None = None
107
+ regex_pattern: str | None = None
108
+ post_process: str | None = None
109
+
110
+
111
+ class ExtractTableParams(BaseModel):
112
+ """Parameters for table extraction actions."""
113
+
114
+ table_selector: str
115
+ headers: list[str] | None = None
116
+ row_selector: str | None = None
117
+ cell_selectors: dict[str, str] | None = None
118
+
119
+
120
+ class ExtractListParams(BaseModel):
121
+ """Parameters for list extraction actions."""
122
+
123
+ container_selector: str
124
+ item_selector: str
125
+ field_selectors: dict[str, str]
126
+
127
+
128
+ class SearchPageParams(BaseModel):
129
+ """Parameters for searching within the current page."""
130
+
131
+ query: str
132
+ search_type: str = "text"
133
+
134
+
135
+ class SearchEngineParams(BaseModel):
136
+ """Parameters for search engine queries."""
137
+
138
+ query: str
139
+ engine: str = "google"
140
+ num_results: int = 10
141
+
142
+
143
+ class VerifyFactParams(BaseModel):
144
+ """Parameters for fact verification."""
145
+
146
+ claim: str
147
+ sources: list[str] | None = None
148
+ confidence_threshold: float = 0.8
149
+
150
+
151
+ class VerifyFieldParams(BaseModel):
152
+ """Parameters for field verification."""
153
+
154
+ field_name: str
155
+ expected_type: str | None = None
156
+ expected_format: str | None = None
157
+ validation_rules: list[str] = Field(default_factory=list)
158
+
159
+
160
+ class MemoryParams(BaseModel):
161
+ """Parameters for memory operations."""
162
+
163
+ key: str
164
+ value: Any | None = None
165
+ memory_type: str = "working"
166
+ ttl_seconds: int | None = None
167
+
168
+
169
+ class MCPToolCallParams(BaseModel):
170
+ """Parameters for MCP tool calls."""
171
+
172
+ tool_name: str
173
+ arguments: dict[str, Any] = Field(default_factory=dict)
174
+
175
+
176
+ class PlanParams(BaseModel):
177
+ """Parameters for planning actions."""
178
+
179
+ plan_description: str | None = None
180
+ steps: list[dict[str, Any]] | None = None
181
+
182
+
183
+ class MessageParams(BaseModel):
184
+ """Parameters for inter-agent messages."""
185
+
186
+ target_agent: str
187
+ message_type: str
188
+ content: dict[str, Any] = Field(default_factory=dict)
189
+
190
+
191
+ class WaitParams(BaseModel):
192
+ """Parameters for wait actions."""
193
+
194
+ duration_ms: int = 1000
195
+ wait_for_selector: str | None = None
196
+ wait_for_navigation: bool = False
197
+
198
+
199
+ class DoneParams(BaseModel):
200
+ """Parameters for completion."""
201
+
202
+ success: bool = True
203
+ message: str | None = None
204
+ final_result: dict[str, Any] | None = None
205
+
206
+
207
+ class Action(BaseModel):
208
+ """
209
+ Represents an action to be taken in the environment.
210
+
211
+ An action consists of:
212
+ - action_type: The type of action
213
+ - parameters: Action-specific parameters
214
+ - reasoning: Why this action was chosen
215
+ - confidence: How confident the agent is
216
+ """
217
+
218
+ action_type: ActionType = Field(..., description="Type of action to execute")
219
+ parameters: dict[str, Any] = Field(
220
+ default_factory=dict,
221
+ description="Action-specific parameters",
222
+ )
223
+ reasoning: str | None = Field(
224
+ default=None,
225
+ description="Agent's reasoning for this action",
226
+ )
227
+ confidence: float = Field(
228
+ default=1.0,
229
+ ge=0.0,
230
+ le=1.0,
231
+ description="Confidence in this action (0-1)",
232
+ )
233
+ agent_id: str | None = Field(
234
+ default=None,
235
+ description="ID of the agent that produced this action",
236
+ )
237
+ plan_step: int | None = Field(
238
+ default=None,
239
+ description="Which step of the plan this corresponds to",
240
+ )
241
+
242
+ @field_validator("confidence")
243
+ @classmethod
244
+ def validate_confidence(cls, v: float) -> float:
245
+ """Ensure confidence is between 0 and 1."""
246
+ return max(0.0, min(1.0, v))
247
+
248
+ model_config = ConfigDict(
249
+ json_schema_extra={
250
+ "example": {
251
+ "action_type": "extract_field",
252
+ "parameters": {
253
+ "field_name": "price",
254
+ "selector": ".product-price",
255
+ "extraction_method": "text",
256
+ },
257
+ "reasoning": "The price element is visible with class .product-price",
258
+ "confidence": 0.92,
259
+ }
260
+ }
261
+ )
262
+
263
+ @classmethod
264
+ def navigate(cls, url: str, **kwargs: Any) -> "Action":
265
+ """Create a navigate action."""
266
+ return cls(
267
+ action_type=ActionType.NAVIGATE,
268
+ parameters={"url": url, **kwargs},
269
+ )
270
+
271
+ @classmethod
272
+ def click(cls, selector: str, **kwargs: Any) -> "Action":
273
+ """Create a click action."""
274
+ return cls(
275
+ action_type=ActionType.CLICK,
276
+ parameters={"selector": selector, **kwargs},
277
+ )
278
+
279
+ @classmethod
280
+ def extract_field(
281
+ cls,
282
+ field_name: str,
283
+ selector: str | None = None,
284
+ **kwargs: Any,
285
+ ) -> "Action":
286
+ """Create an extract field action."""
287
+ return cls(
288
+ action_type=ActionType.EXTRACT_FIELD,
289
+ parameters={"field_name": field_name, "selector": selector, **kwargs},
290
+ )
291
+
292
+ @classmethod
293
+ def search_engine(cls, query: str, engine: str = "google", **kwargs: Any) -> "Action":
294
+ """Create a search engine action."""
295
+ return cls(
296
+ action_type=ActionType.SEARCH_ENGINE,
297
+ parameters={"query": query, "engine": engine, **kwargs},
298
+ )
299
+
300
+ @classmethod
301
+ def done(cls, success: bool = True, message: str | None = None) -> "Action":
302
+ """Create a done action."""
303
+ return cls(
304
+ action_type=ActionType.DONE,
305
+ parameters={"success": success, "message": message},
306
+ )
307
+
308
+ @classmethod
309
+ def wait(cls, duration_ms: int = 1000) -> "Action":
310
+ """Create a wait action."""
311
+ return cls(
312
+ action_type=ActionType.WAIT,
313
+ parameters={"duration_ms": duration_ms},
314
+ )
315
+
316
+ @classmethod
317
+ def mcp_tool_call(cls, tool_name: str, **arguments: Any) -> "Action":
318
+ """Create an MCP tool call action."""
319
+ return cls(
320
+ action_type=ActionType.MCP_TOOL_CALL,
321
+ parameters={"tool_name": tool_name, "arguments": arguments},
322
+ )
323
+
324
+ def get_param(self, key: str, default: Any = None) -> Any:
325
+ """Get a parameter value with optional default."""
326
+ return self.parameters.get(key, default)
327
+
328
+ def validate_params(self) -> list[str]:
329
+ """Validate parameters for this action type. Returns list of errors."""
330
+ errors = []
331
+
332
+ required_params = {
333
+ ActionType.NAVIGATE: ["url"],
334
+ ActionType.CLICK: ["selector"],
335
+ ActionType.FILL: ["selector", "value"],
336
+ ActionType.EXTRACT_FIELD: ["field_name"],
337
+ ActionType.SEARCH_ENGINE: ["query"],
338
+ ActionType.MCP_TOOL_CALL: ["tool_name"],
339
+ ActionType.SEND_MESSAGE: ["target_agent", "message_type"],
340
+ }
341
+
342
+ if self.action_type in required_params:
343
+ for param in required_params[self.action_type]:
344
+ if param not in self.parameters or self.parameters[param] is None:
345
+ errors.append(f"Missing required parameter: {param}")
346
+
347
+ return errors
backend/app/core/env.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web scraper RL environment."""
2
+
3
+ import logging
4
+ import time
5
+ from typing import Any
6
+
7
+ from app.config import Settings, get_settings
8
+ from app.core.action import Action, ActionType
9
+ from app.core.episode import Episode, EpisodeManager
10
+ from app.core.observation import (
11
+ AvailableAction,
12
+ ExtractedField,
13
+ MemoryContext,
14
+ Observation,
15
+ TaskContext,
16
+ )
17
+ from app.core.reward import RewardBreakdown, RewardEngine
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class WebScraperEnv:
23
+ """
24
+ Reinforcement Learning environment for web scraping.
25
+
26
+ Follows the Gymnasium API pattern:
27
+ - reset(task_id, seed) -> observation, info
28
+ - step(action) -> observation, reward, terminated, truncated, info
29
+ - get_state() -> state dict
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ episode_id: str,
35
+ settings: Settings | None = None,
36
+ ) -> None:
37
+ """
38
+ Initialize the environment.
39
+
40
+ Args:
41
+ episode_id: Unique identifier for this episode.
42
+ settings: Application settings.
43
+ """
44
+ self.episode_id = episode_id
45
+ self.settings = settings or get_settings()
46
+ self.reward_engine = RewardEngine(settings)
47
+ self.episode_manager = EpisodeManager()
48
+
49
+ # State
50
+ self._episode: Episode | None = None
51
+ self._current_observation: Observation | None = None
52
+ self._task_context: TaskContext | None = None
53
+ self._ground_truth: dict[str, Any] | None = None
54
+
55
+ # Browser state (placeholder - would use Playwright in production)
56
+ self._current_url: str | None = None
57
+ self._page_html: str | None = None
58
+ self._page_title: str | None = None
59
+
60
+ # Extraction state
61
+ self._extracted_fields: list[ExtractedField] = []
62
+ self._navigation_history: list[str] = []
63
+
64
+ # Timing
65
+ self._start_time: float | None = None
66
+
67
+ async def reset(
68
+ self,
69
+ task_id: str,
70
+ seed: int | None = None,
71
+ config: dict[str, Any] | None = None,
72
+ ) -> tuple[Observation, dict[str, Any]]:
73
+ """
74
+ Reset the environment for a new episode.
75
+
76
+ Args:
77
+ task_id: ID of the task to execute.
78
+ seed: Random seed for reproducibility.
79
+ config: Optional episode configuration.
80
+
81
+ Returns:
82
+ Tuple of (initial_observation, info_dict).
83
+ """
84
+ logger.info(f"Resetting environment for task {task_id}")
85
+
86
+ # Reset state
87
+ self.reward_engine.reset()
88
+ self._extracted_fields = []
89
+ self._navigation_history = []
90
+ self._start_time = time.time()
91
+ self._current_url = None
92
+ self._page_html = None
93
+ self._page_title = None
94
+
95
+ # Create episode
96
+ self._episode = self.episode_manager.create_episode(
97
+ episode_id=self.episode_id,
98
+ task_id=task_id,
99
+ max_steps=self.settings.max_steps_per_episode,
100
+ seed=seed,
101
+ config=config or {},
102
+ )
103
+ self._episode.start()
104
+
105
+ # Load task context
106
+ self._task_context = await self._load_task_context(task_id)
107
+
108
+ # Create initial observation
109
+ self._current_observation = self._create_observation()
110
+
111
+ info = {
112
+ "episode_id": self.episode_id,
113
+ "task_id": task_id,
114
+ "max_steps": self._episode.max_steps,
115
+ "target_fields": self._task_context.target_fields if self._task_context else [],
116
+ }
117
+
118
+ return self._current_observation, info
119
+
120
+ async def step(
121
+ self,
122
+ action: Action,
123
+ ) -> tuple[Observation, float, dict[str, float], bool, bool, dict[str, Any]]:
124
+ """
125
+ Execute an action and return the result.
126
+
127
+ Args:
128
+ action: The action to execute.
129
+
130
+ Returns:
131
+ Tuple of (observation, reward, reward_breakdown, terminated, truncated, info).
132
+ """
133
+ if self._episode is None or self._current_observation is None:
134
+ raise RuntimeError("Environment not reset. Call reset() first.")
135
+
136
+ if self._episode.is_terminal:
137
+ raise RuntimeError("Episode has already terminated.")
138
+
139
+ step_start = time.time()
140
+ prev_observation = self._current_observation
141
+
142
+ # Validate action
143
+ errors = action.validate_params()
144
+ if errors:
145
+ logger.warning(f"Invalid action parameters: {errors}")
146
+
147
+ # Execute action
148
+ action_result = await self._execute_action(action)
149
+
150
+ # Update observation
151
+ self._current_observation = self._create_observation()
152
+ if action_result.get("error"):
153
+ self._current_observation.last_action_error = action_result["error"]
154
+ self._current_observation.consecutive_errors = (
155
+ prev_observation.consecutive_errors + 1
156
+ )
157
+ else:
158
+ self._current_observation.consecutive_errors = 0
159
+
160
+ # Compute reward
161
+ reward, breakdown = self.reward_engine.compute_reward(
162
+ action=action,
163
+ prev_observation=prev_observation,
164
+ new_observation=self._current_observation,
165
+ ground_truth=self._ground_truth,
166
+ max_steps=self._episode.max_steps,
167
+ )
168
+
169
+ # Check termination
170
+ terminated = self._check_terminated(action)
171
+ truncated = self._check_truncated()
172
+
173
+ # Update episode
174
+ step_duration = (time.time() - step_start) * 1000
175
+ self._episode.add_step(
176
+ action_type=action.action_type.value,
177
+ action_params=action.parameters,
178
+ action_reasoning=action.reasoning,
179
+ reward=reward,
180
+ reward_breakdown=breakdown.to_dict(),
181
+ observation_summary={
182
+ "url": self._current_observation.current_url,
183
+ "progress": self._current_observation.extraction_progress,
184
+ "fields_extracted": len(self._current_observation.extracted_so_far),
185
+ },
186
+ error=action_result.get("error"),
187
+ duration_ms=step_duration,
188
+ )
189
+
190
+ # Handle terminal states
191
+ if terminated:
192
+ success = action.action_type == ActionType.DONE and action.get_param(
193
+ "success", True
194
+ )
195
+ self._episode.complete(
196
+ success=success,
197
+ extracted_data=self._current_observation.get_extraction_dict(),
198
+ )
199
+
200
+ # Add terminal reward
201
+ terminal_reward, terminal_breakdown = (
202
+ self.reward_engine.compute_terminal_reward(
203
+ self._current_observation,
204
+ success=success,
205
+ ground_truth=self._ground_truth,
206
+ )
207
+ )
208
+ reward += terminal_reward
209
+ breakdown.total += terminal_reward
210
+ elif truncated:
211
+ self._episode.truncate()
212
+
213
+ info = {
214
+ "action_result": action_result,
215
+ "step_duration_ms": step_duration,
216
+ "episode_step": self._episode.current_step,
217
+ }
218
+
219
+ return (
220
+ self._current_observation,
221
+ reward,
222
+ breakdown.to_dict(),
223
+ terminated,
224
+ truncated,
225
+ info,
226
+ )
227
+
228
+ def get_state(self) -> dict[str, Any]:
229
+ """Get the current state of the environment."""
230
+ if self._episode is None:
231
+ return {
232
+ "episode_id": self.episode_id,
233
+ "status": "not_started",
234
+ }
235
+
236
+ return {
237
+ "episode_id": self.episode_id,
238
+ "task_id": self._episode.task_id,
239
+ "step_number": self._episode.current_step,
240
+ "current_url": self._current_url,
241
+ "is_terminal": self._episode.is_terminal,
242
+ "total_reward": self._episode.total_reward,
243
+ "extracted_data": (
244
+ self._current_observation.get_extraction_dict()
245
+ if self._current_observation
246
+ else {}
247
+ ),
248
+ "status": self._episode.status.value,
249
+ }
250
+
251
+ async def _load_task_context(self, task_id: str) -> TaskContext:
252
+ """Load task context from task repository."""
253
+ # In production, this would fetch from database
254
+ from app.api.routes.tasks import TASK_REPOSITORY
255
+
256
+ task = TASK_REPOSITORY.get(task_id)
257
+ if task:
258
+ return TaskContext(
259
+ task_id=task.id,
260
+ task_name=task.name,
261
+ task_type=task.task_type.value,
262
+ target_fields=[f.name for f in task.fields_to_extract],
263
+ required_fields=task.success_criteria.get("required_fields", []),
264
+ hints=task.hints,
265
+ success_criteria=task.success_criteria,
266
+ )
267
+
268
+ # Default context
269
+ return TaskContext(
270
+ task_id=task_id,
271
+ task_name=f"Task {task_id}",
272
+ task_type="unknown",
273
+ target_fields=[],
274
+ required_fields=[],
275
+ )
276
+
277
+ def _create_observation(self) -> Observation:
278
+ """Create an observation from current state."""
279
+ if self._episode is None:
280
+ raise RuntimeError("Episode not initialized")
281
+
282
+ elapsed = time.time() - (self._start_time or time.time())
283
+
284
+ # Get available actions
285
+ available_actions = self._get_available_actions()
286
+
287
+ # Calculate progress
288
+ target_fields = (
289
+ self._task_context.target_fields if self._task_context else []
290
+ )
291
+ extracted_names = {f.field_name for f in self._extracted_fields}
292
+ fields_remaining = [f for f in target_fields if f not in extracted_names]
293
+ progress = (
294
+ len(self._extracted_fields) / len(target_fields)
295
+ if target_fields
296
+ else 0.0
297
+ )
298
+
299
+ return Observation(
300
+ episode_id=self.episode_id,
301
+ task_id=self._episode.task_id,
302
+ step_number=self._episode.current_step,
303
+ elapsed_seconds=elapsed,
304
+ current_url=self._current_url,
305
+ page_title=self._page_title,
306
+ page_html=self._page_html,
307
+ navigation_history=self._navigation_history.copy(),
308
+ can_go_back=len(self._navigation_history) > 1,
309
+ task_context=self._task_context,
310
+ extracted_so_far=self._extracted_fields.copy(),
311
+ extraction_progress=progress,
312
+ fields_remaining=fields_remaining,
313
+ memory_context=MemoryContext(),
314
+ available_actions=available_actions,
315
+ tokens_used=self._episode.tokens_used,
316
+ api_calls_made=self._episode.api_calls,
317
+ )
318
+
319
+ def _get_available_actions(self) -> list[AvailableAction]:
320
+ """Get list of currently available actions."""
321
+ actions = []
322
+
323
+ # Navigation actions
324
+ actions.append(
325
+ AvailableAction(
326
+ action_type="navigate",
327
+ description="Navigate to a URL",
328
+ parameters={"url": "required"},
329
+ )
330
+ )
331
+
332
+ if self._current_url:
333
+ # Page interaction actions
334
+ actions.extend([
335
+ AvailableAction(
336
+ action_type="click",
337
+ description="Click on an element",
338
+ parameters={"selector": "required"},
339
+ ),
340
+ AvailableAction(
341
+ action_type="extract_field",
342
+ description="Extract a field from the page",
343
+ parameters={"field_name": "required", "selector": "optional"},
344
+ ),
345
+ AvailableAction(
346
+ action_type="search_page",
347
+ description="Search within the current page",
348
+ parameters={"query": "required"},
349
+ ),
350
+ ])
351
+
352
+ # Always available
353
+ actions.extend([
354
+ AvailableAction(
355
+ action_type="search_engine",
356
+ description="Perform a web search",
357
+ parameters={"query": "required", "engine": "optional"},
358
+ ),
359
+ AvailableAction(
360
+ action_type="done",
361
+ description="Mark task as complete",
362
+ parameters={"success": "boolean"},
363
+ ),
364
+ ])
365
+
366
+ return actions
367
+
368
+ async def _execute_action(self, action: Action) -> dict[str, Any]:
369
+ """Execute an action and return the result."""
370
+ result: dict[str, Any] = {"success": False}
371
+
372
+ try:
373
+ match action.action_type:
374
+ case ActionType.NAVIGATE:
375
+ result = await self._execute_navigate(action)
376
+ case ActionType.CLICK:
377
+ result = await self._execute_click(action)
378
+ case ActionType.FILL:
379
+ result = await self._execute_fill(action)
380
+ case ActionType.EXTRACT_FIELD:
381
+ result = await self._execute_extract(action)
382
+ case ActionType.SEARCH_ENGINE:
383
+ result = await self._execute_search_engine(action)
384
+ case ActionType.DONE:
385
+ result = {"success": True, "done": True}
386
+ case ActionType.WAIT:
387
+ await self._execute_wait(action)
388
+ result = {"success": True}
389
+ case _:
390
+ result = {
391
+ "success": False,
392
+ "error": f"Action type {action.action_type} not implemented",
393
+ }
394
+ except Exception as e:
395
+ logger.error(f"Action execution failed: {e}")
396
+ result = {"success": False, "error": str(e)}
397
+
398
+ return result
399
+
400
+ async def _execute_navigate(self, action: Action) -> dict[str, Any]:
401
+ """Execute a navigate action."""
402
+ url = action.get_param("url")
403
+ if not url:
404
+ return {"success": False, "error": "URL is required"}
405
+
406
+ # Placeholder - in production would use Playwright
407
+ self._current_url = url
408
+ self._navigation_history.append(url)
409
+ self._page_title = f"Page at {url}"
410
+ self._page_html = f"<html><body><h1>Mock page for {url}</h1></body></html>"
411
+
412
+ return {"success": True, "url": url}
413
+
414
+ async def _execute_click(self, action: Action) -> dict[str, Any]:
415
+ """Execute a click action."""
416
+ selector = action.get_param("selector")
417
+ if not selector:
418
+ return {"success": False, "error": "Selector is required"}
419
+
420
+ # Placeholder
421
+ return {"success": True, "selector": selector, "clicked": True}
422
+
423
+ async def _execute_fill(self, action: Action) -> dict[str, Any]:
424
+ """Execute a fill action."""
425
+ selector = action.get_param("selector")
426
+ value = action.get_param("value")
427
+
428
+ if not selector or value is None:
429
+ return {"success": False, "error": "Selector and value are required"}
430
+
431
+ # Placeholder
432
+ return {"success": True, "selector": selector, "filled": True}
433
+
434
+ async def _execute_extract(self, action: Action) -> dict[str, Any]:
435
+ """Execute an extract action."""
436
+ field_name = action.get_param("field_name")
437
+ if not field_name:
438
+ return {"success": False, "error": "field_name is required"}
439
+
440
+ # Placeholder - in production would actually extract from page
441
+ extracted_field = ExtractedField(
442
+ field_name=field_name,
443
+ value=f"mock_value_for_{field_name}",
444
+ confidence=0.9,
445
+ source_selector=action.get_param("selector"),
446
+ extraction_step=self._episode.current_step if self._episode else 0,
447
+ )
448
+
449
+ self._extracted_fields.append(extracted_field)
450
+
451
+ return {
452
+ "success": True,
453
+ "field_name": field_name,
454
+ "value": extracted_field.value,
455
+ "confidence": extracted_field.confidence,
456
+ }
457
+
458
+ async def _execute_search_engine(self, action: Action) -> dict[str, Any]:
459
+ """Execute a search engine action."""
460
+ query = action.get_param("query")
461
+ if not query:
462
+ return {"success": False, "error": "Query is required"}
463
+
464
+ engine = action.get_param("engine", "google")
465
+
466
+ # Placeholder
467
+ return {
468
+ "success": True,
469
+ "query": query,
470
+ "engine": engine,
471
+ "results": [
472
+ {"title": f"Result 1 for {query}", "url": "https://example.com/1"},
473
+ {"title": f"Result 2 for {query}", "url": "https://example.com/2"},
474
+ ],
475
+ }
476
+
477
+ async def _execute_wait(self, action: Action) -> None:
478
+ """Execute a wait action."""
479
+ import asyncio
480
+ duration_ms = action.get_param("duration_ms", 1000)
481
+ await asyncio.sleep(duration_ms / 1000)
482
+
483
+ def _check_terminated(self, action: Action) -> bool:
484
+ """Check if the episode should terminate."""
485
+ if action.action_type == ActionType.DONE:
486
+ return True
487
+ if action.action_type == ActionType.FAIL:
488
+ return True
489
+ return False
490
+
491
+ def _check_truncated(self) -> bool:
492
+ """Check if the episode should be truncated."""
493
+ if self._episode is None:
494
+ return False
495
+ if self._episode.current_step >= self._episode.max_steps:
496
+ return True
497
+ return False
backend/app/core/episode.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Episode state machine and management."""
2
+
3
+ from datetime import datetime, timezone
4
+ from enum import Enum
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class EpisodeStatus(str, Enum):
11
+ """Status of an episode."""
12
+
13
+ PENDING = "pending"
14
+ RUNNING = "running"
15
+ COMPLETED = "completed"
16
+ FAILED = "failed"
17
+ TRUNCATED = "truncated"
18
+ CANCELLED = "cancelled"
19
+
20
+
21
+ class EpisodeStep(BaseModel):
22
+ """Record of a single step in the episode."""
23
+
24
+ step_number: int
25
+ timestamp: str
26
+ action_type: str
27
+ action_params: dict[str, Any]
28
+ action_reasoning: str | None = None
29
+ reward: float
30
+ reward_breakdown: dict[str, float]
31
+ observation_summary: dict[str, Any]
32
+ error: str | None = None
33
+ duration_ms: float = 0.0
34
+
35
+
36
+ class Episode(BaseModel):
37
+ """
38
+ Represents a complete episode in the RL environment.
39
+
40
+ An episode is a sequence of steps from reset to termination,
41
+ tracking all actions, rewards, and observations.
42
+ """
43
+
44
+ # Identification
45
+ episode_id: str
46
+ task_id: str
47
+
48
+ # Timing
49
+ created_at: str = Field(
50
+ default_factory=lambda: datetime.now(timezone.utc).isoformat()
51
+ )
52
+ started_at: str | None = None
53
+ ended_at: str | None = None
54
+
55
+ # State
56
+ status: EpisodeStatus = EpisodeStatus.PENDING
57
+ current_step: int = 0
58
+ max_steps: int = 50
59
+
60
+ # Seed for reproducibility
61
+ seed: int | None = None
62
+
63
+ # Configuration
64
+ config: dict[str, Any] = Field(default_factory=dict)
65
+
66
+ # Step history
67
+ steps: list[EpisodeStep] = Field(default_factory=list)
68
+
69
+ # Aggregates
70
+ total_reward: float = 0.0
71
+ tokens_used: int = 0
72
+ api_calls: int = 0
73
+ estimated_cost_usd: float = 0.0
74
+
75
+ # Results
76
+ extracted_data: dict[str, Any] = Field(default_factory=dict)
77
+ final_accuracy: float | None = None
78
+ success: bool | None = None
79
+ failure_reason: str | None = None
80
+
81
+ # Navigation history
82
+ urls_visited: list[str] = Field(default_factory=list)
83
+
84
+ def start(self) -> None:
85
+ """Mark the episode as started."""
86
+ self.status = EpisodeStatus.RUNNING
87
+ self.started_at = datetime.now(timezone.utc).isoformat()
88
+
89
+ def add_step(
90
+ self,
91
+ action_type: str,
92
+ action_params: dict[str, Any],
93
+ reward: float,
94
+ reward_breakdown: dict[str, float],
95
+ observation_summary: dict[str, Any],
96
+ action_reasoning: str | None = None,
97
+ error: str | None = None,
98
+ duration_ms: float = 0.0,
99
+ ) -> EpisodeStep:
100
+ """Add a step to the episode."""
101
+ self.current_step += 1
102
+
103
+ step = EpisodeStep(
104
+ step_number=self.current_step,
105
+ timestamp=datetime.now(timezone.utc).isoformat(),
106
+ action_type=action_type,
107
+ action_params=action_params,
108
+ action_reasoning=action_reasoning,
109
+ reward=reward,
110
+ reward_breakdown=reward_breakdown,
111
+ observation_summary=observation_summary,
112
+ error=error,
113
+ duration_ms=duration_ms,
114
+ )
115
+
116
+ self.steps.append(step)
117
+ self.total_reward += reward
118
+
119
+ return step
120
+
121
+ def complete(
122
+ self,
123
+ success: bool,
124
+ extracted_data: dict[str, Any] | None = None,
125
+ final_accuracy: float | None = None,
126
+ ) -> None:
127
+ """Mark the episode as completed."""
128
+ self.status = EpisodeStatus.COMPLETED
129
+ self.ended_at = datetime.now(timezone.utc).isoformat()
130
+ self.success = success
131
+ if extracted_data:
132
+ self.extracted_data = extracted_data
133
+ self.final_accuracy = final_accuracy
134
+
135
+ def fail(self, reason: str) -> None:
136
+ """Mark the episode as failed."""
137
+ self.status = EpisodeStatus.FAILED
138
+ self.ended_at = datetime.now(timezone.utc).isoformat()
139
+ self.success = False
140
+ self.failure_reason = reason
141
+
142
+ def truncate(self, reason: str = "max_steps_reached") -> None:
143
+ """Mark the episode as truncated (stopped early)."""
144
+ self.status = EpisodeStatus.TRUNCATED
145
+ self.ended_at = datetime.now(timezone.utc).isoformat()
146
+ self.failure_reason = reason
147
+
148
+ def cancel(self) -> None:
149
+ """Mark the episode as cancelled."""
150
+ self.status = EpisodeStatus.CANCELLED
151
+ self.ended_at = datetime.now(timezone.utc).isoformat()
152
+
153
+ @property
154
+ def is_terminal(self) -> bool:
155
+ """Check if the episode has terminated."""
156
+ return self.status in [
157
+ EpisodeStatus.COMPLETED,
158
+ EpisodeStatus.FAILED,
159
+ EpisodeStatus.TRUNCATED,
160
+ EpisodeStatus.CANCELLED,
161
+ ]
162
+
163
+ @property
164
+ def duration_seconds(self) -> float | None:
165
+ """Get episode duration in seconds."""
166
+ if not self.started_at:
167
+ return None
168
+ end = self.ended_at or datetime.now(timezone.utc).isoformat()
169
+ start_dt = datetime.fromisoformat(self.started_at.replace("Z", "+00:00"))
170
+ end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
171
+ return (end_dt - start_dt).total_seconds()
172
+
173
+ @property
174
+ def average_reward(self) -> float:
175
+ """Get average reward per step."""
176
+ if not self.steps:
177
+ return 0.0
178
+ return self.total_reward / len(self.steps)
179
+
180
+ def get_summary(self) -> dict[str, Any]:
181
+ """Get a summary of the episode."""
182
+ return {
183
+ "episode_id": self.episode_id,
184
+ "task_id": self.task_id,
185
+ "status": self.status.value,
186
+ "steps": self.current_step,
187
+ "total_reward": self.total_reward,
188
+ "average_reward": self.average_reward,
189
+ "duration_seconds": self.duration_seconds,
190
+ "tokens_used": self.tokens_used,
191
+ "estimated_cost_usd": self.estimated_cost_usd,
192
+ "success": self.success,
193
+ "fields_extracted": len(self.extracted_data),
194
+ }
195
+
196
+ def get_step_history(
197
+ self,
198
+ start: int = 0,
199
+ end: int | None = None,
200
+ ) -> list[EpisodeStep]:
201
+ """Get a slice of the step history."""
202
+ return self.steps[start:end]
203
+
204
+ def get_action_sequence(self) -> list[str]:
205
+ """Get the sequence of action types taken."""
206
+ return [step.action_type for step in self.steps]
207
+
208
+ def get_reward_history(self) -> list[float]:
209
+ """Get the sequence of rewards received."""
210
+ return [step.reward for step in self.steps]
211
+
212
+
213
+ class EpisodeManager:
214
+ """Manager for episode lifecycle."""
215
+
216
+ def __init__(self) -> None:
217
+ """Initialize the episode manager."""
218
+ self._episodes: dict[str, Episode] = {}
219
+
220
+ def create_episode(
221
+ self,
222
+ episode_id: str,
223
+ task_id: str,
224
+ max_steps: int = 50,
225
+ seed: int | None = None,
226
+ config: dict[str, Any] | None = None,
227
+ ) -> Episode:
228
+ """Create a new episode."""
229
+ episode = Episode(
230
+ episode_id=episode_id,
231
+ task_id=task_id,
232
+ max_steps=max_steps,
233
+ seed=seed,
234
+ config=config or {},
235
+ )
236
+ self._episodes[episode_id] = episode
237
+ return episode
238
+
239
+ def get_episode(self, episode_id: str) -> Episode | None:
240
+ """Get an episode by ID."""
241
+ return self._episodes.get(episode_id)
242
+
243
+ def remove_episode(self, episode_id: str) -> bool:
244
+ """Remove an episode."""
245
+ if episode_id in self._episodes:
246
+ del self._episodes[episode_id]
247
+ return True
248
+ return False
249
+
250
+ def list_episodes(
251
+ self,
252
+ status: EpisodeStatus | None = None,
253
+ task_id: str | None = None,
254
+ ) -> list[Episode]:
255
+ """List episodes with optional filtering."""
256
+ episodes = list(self._episodes.values())
257
+ if status:
258
+ episodes = [e for e in episodes if e.status == status]
259
+ if task_id:
260
+ episodes = [e for e in episodes if e.task_id == task_id]
261
+ return episodes
backend/app/core/observation.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Observation model for the RL environment."""
2
+
3
+ from datetime import datetime, timezone
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+
9
+ class ToolSnapshot(BaseModel):
10
+ """Snapshot of a tool from the registry."""
11
+
12
+ name: str
13
+ description: str
14
+ parameters: list[dict[str, Any]]
15
+ enabled: bool = True
16
+ cost_estimate: float = 0.0
17
+
18
+
19
+ class MemoryContext(BaseModel):
20
+ """Context from memory systems."""
21
+
22
+ short_term: list[dict[str, Any]] = Field(default_factory=list)
23
+ working: list[dict[str, Any]] = Field(default_factory=list)
24
+ long_term_relevant: list[dict[str, Any]] = Field(default_factory=list)
25
+ shared: dict[str, Any] = Field(default_factory=dict)
26
+
27
+
28
+ class PageElement(BaseModel):
29
+ """A significant element on the page."""
30
+
31
+ selector: str
32
+ tag: str
33
+ text: str | None = None
34
+ attributes: dict[str, str] = Field(default_factory=dict)
35
+ is_interactive: bool = False
36
+ is_visible: bool = True
37
+ bounding_box: dict[str, float] | None = None
38
+
39
+
40
+ class ExtractedField(BaseModel):
41
+ """A field that has been extracted."""
42
+
43
+ field_name: str
44
+ value: Any
45
+ confidence: float = 1.0
46
+ source_selector: str | None = None
47
+ extraction_step: int = 0
48
+ verified: bool = False
49
+
50
+
51
+ class AvailableAction(BaseModel):
52
+ """An action that is currently available."""
53
+
54
+ action_type: str
55
+ description: str
56
+ parameters: dict[str, Any] = Field(default_factory=dict)
57
+ estimated_reward: float | None = None
58
+ risk_level: str = "low"
59
+
60
+
61
+ class TaskContext(BaseModel):
62
+ """Context about the current task."""
63
+
64
+ task_id: str
65
+ task_name: str
66
+ task_type: str
67
+ target_fields: list[str]
68
+ required_fields: list[str]
69
+ hints: list[str] = Field(default_factory=list)
70
+ success_criteria: dict[str, Any] = Field(default_factory=dict)
71
+
72
+
73
+ class Observation(BaseModel):
74
+ """
75
+ Complete observation provided to the agent after each step.
76
+
77
+ Contains all information the agent needs to make decisions:
78
+ - Episode and task context
79
+ - Current page state
80
+ - Extracted data so far
81
+ - Memory context
82
+ - Available tools and actions
83
+ """
84
+
85
+ # Episode identification
86
+ episode_id: str = Field(..., description="Unique episode identifier")
87
+ task_id: str = Field(..., description="Task being executed")
88
+ step_number: int = Field(..., description="Current step in the episode")
89
+
90
+ # Timing
91
+ timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
92
+ elapsed_seconds: float = Field(default=0.0, description="Time elapsed in episode")
93
+
94
+ # Page state
95
+ current_url: str | None = Field(default=None, description="Current page URL")
96
+ page_title: str | None = Field(default=None, description="Current page title")
97
+ page_html: str | None = Field(default=None, description="Full HTML of current page")
98
+ page_html_chunked: list[str] = Field(
99
+ default_factory=list,
100
+ description="HTML split into semantic chunks",
101
+ )
102
+ page_text: str | None = Field(default=None, description="Visible text content")
103
+ page_elements: list[PageElement] = Field(
104
+ default_factory=list,
105
+ description="Significant page elements",
106
+ )
107
+
108
+ # Navigation state
109
+ navigation_history: list[str] = Field(
110
+ default_factory=list,
111
+ description="URLs visited in this episode",
112
+ )
113
+ can_go_back: bool = Field(default=False)
114
+ can_go_forward: bool = Field(default=False)
115
+
116
+ # Task context
117
+ task_context: TaskContext | None = Field(
118
+ default=None,
119
+ description="Information about the current task",
120
+ )
121
+
122
+ # Extraction state
123
+ extracted_so_far: list[ExtractedField] = Field(
124
+ default_factory=list,
125
+ description="Fields extracted so far",
126
+ )
127
+ extraction_progress: float = Field(
128
+ default=0.0,
129
+ description="Progress towards task completion (0-1)",
130
+ )
131
+ fields_remaining: list[str] = Field(
132
+ default_factory=list,
133
+ description="Fields still to be extracted",
134
+ )
135
+
136
+ # Memory context
137
+ memory_context: MemoryContext = Field(
138
+ default_factory=MemoryContext,
139
+ description="Relevant memories from all layers",
140
+ )
141
+
142
+ # Tool registry snapshot
143
+ tool_registry_snapshot: list[ToolSnapshot] = Field(
144
+ default_factory=list,
145
+ description="Available tools and their state",
146
+ )
147
+
148
+ # Available actions
149
+ available_actions: list[AvailableAction] = Field(
150
+ default_factory=list,
151
+ description="Actions available in current state",
152
+ )
153
+
154
+ # Agent coordination
155
+ pending_messages: list[dict[str, Any]] = Field(
156
+ default_factory=list,
157
+ description="Messages from other agents",
158
+ )
159
+ active_plan: dict[str, Any] | None = Field(
160
+ default=None,
161
+ description="Current execution plan if any",
162
+ )
163
+ current_plan_step: int | None = Field(
164
+ default=None,
165
+ description="Current step in the plan",
166
+ )
167
+
168
+ # Error state
169
+ last_action_error: str | None = Field(
170
+ default=None,
171
+ description="Error from last action if any",
172
+ )
173
+ consecutive_errors: int = Field(
174
+ default=0,
175
+ description="Number of consecutive action errors",
176
+ )
177
+
178
+ # Cost tracking
179
+ tokens_used: int = Field(default=0, description="LLM tokens used so far")
180
+ api_calls_made: int = Field(default=0, description="API calls made")
181
+ estimated_cost_usd: float = Field(default=0.0, description="Estimated cost so far")
182
+
183
+ # Hints and guidance
184
+ system_hints: list[str] = Field(
185
+ default_factory=list,
186
+ description="Hints from the environment or previous steps",
187
+ )
188
+
189
+ model_config = ConfigDict(
190
+ json_schema_extra={
191
+ "example": {
192
+ "episode_id": "ep_abc123",
193
+ "task_id": "task_001",
194
+ "step_number": 5,
195
+ "current_url": "https://example.com/product/123",
196
+ "page_title": "Product Details - Example Store",
197
+ "extracted_so_far": [
198
+ {
199
+ "field_name": "product_name",
200
+ "value": "Example Product",
201
+ "confidence": 0.95,
202
+ }
203
+ ],
204
+ "extraction_progress": 0.33,
205
+ "fields_remaining": ["price", "description"],
206
+ }
207
+ }
208
+ )
209
+
210
+ def get_extraction_dict(self) -> dict[str, Any]:
211
+ """Get extracted fields as a dictionary."""
212
+ return {field.field_name: field.value for field in self.extracted_so_far}
213
+
214
+ def is_field_extracted(self, field_name: str) -> bool:
215
+ """Check if a field has been extracted."""
216
+ return any(f.field_name == field_name for f in self.extracted_so_far)
217
+
218
+ def get_context_summary(self) -> str:
219
+ """Get a summary of the current context for LLM prompts."""
220
+ parts = [
221
+ f"Step {self.step_number}",
222
+ f"URL: {self.current_url or 'None'}",
223
+ f"Progress: {self.extraction_progress:.0%}",
224
+ f"Extracted: {len(self.extracted_so_far)}/{len(self.extracted_so_far) + len(self.fields_remaining)} fields",
225
+ ]
226
+ if self.last_action_error:
227
+ parts.append(f"Last error: {self.last_action_error}")
228
+ return " | ".join(parts)
backend/app/core/reward.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward computation engine with component breakdown."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ from app.config import Settings, get_settings
7
+ from app.core.action import Action, ActionType
8
+ from app.core.observation import Observation
9
+
10
+
11
+ @dataclass
12
+ class RewardBreakdown:
13
+ """Detailed breakdown of reward components."""
14
+
15
+ # Core components
16
+ accuracy: float = 0.0
17
+ efficiency: float = 0.0
18
+ cost: float = 0.0
19
+ completeness: float = 0.0
20
+
21
+ # Bonus/penalty components
22
+ progress_bonus: float = 0.0
23
+ error_penalty: float = 0.0
24
+ time_penalty: float = 0.0
25
+ redundancy_penalty: float = 0.0
26
+ exploration_bonus: float = 0.0
27
+ verification_bonus: float = 0.0
28
+
29
+ # Metadata
30
+ total: float = 0.0
31
+ components: dict[str, float] = field(default_factory=dict)
32
+
33
+ def compute_total(self, weights: dict[str, float]) -> float:
34
+ """Compute total reward with weights."""
35
+ self.total = (
36
+ self.accuracy * weights.get("accuracy", 0.4)
37
+ + self.efficiency * weights.get("efficiency", 0.2)
38
+ + self.cost * weights.get("cost", 0.2)
39
+ + self.completeness * weights.get("completeness", 0.2)
40
+ + self.progress_bonus
41
+ + self.exploration_bonus
42
+ + self.verification_bonus
43
+ - self.error_penalty
44
+ - self.time_penalty
45
+ - self.redundancy_penalty
46
+ )
47
+
48
+ self.components = {
49
+ "accuracy": self.accuracy,
50
+ "efficiency": self.efficiency,
51
+ "cost": self.cost,
52
+ "completeness": self.completeness,
53
+ "progress_bonus": self.progress_bonus,
54
+ "error_penalty": self.error_penalty,
55
+ "time_penalty": self.time_penalty,
56
+ "redundancy_penalty": self.redundancy_penalty,
57
+ "exploration_bonus": self.exploration_bonus,
58
+ "verification_bonus": self.verification_bonus,
59
+ }
60
+
61
+ return self.total
62
+
63
+ def to_dict(self) -> dict[str, float]:
64
+ """Convert to dictionary."""
65
+ return {
66
+ "total": self.total,
67
+ **self.components,
68
+ }
69
+
70
+
71
+ class RewardEngine:
72
+ """
73
+ Computes rewards for actions in the web scraping environment.
74
+
75
+ Reward components:
76
+ - Accuracy: How correct extracted data is
77
+ - Efficiency: Steps taken vs optimal
78
+ - Cost: API/compute costs
79
+ - Completeness: Progress towards task completion
80
+
81
+ Plus bonuses/penalties for:
82
+ - Progress: Making progress towards goal
83
+ - Errors: Failed actions or invalid extractions
84
+ - Time: Taking too long
85
+ - Redundancy: Repeating unsuccessful actions
86
+ - Exploration: Discovering new information
87
+ - Verification: Validating extracted data
88
+ """
89
+
90
+ def __init__(self, settings: Settings | None = None) -> None:
91
+ """Initialize the reward engine."""
92
+ self.settings = settings or get_settings()
93
+ self.weights = {
94
+ "accuracy": self.settings.reward_accuracy_weight,
95
+ "efficiency": self.settings.reward_efficiency_weight,
96
+ "cost": self.settings.reward_cost_weight,
97
+ "completeness": self.settings.reward_completeness_weight,
98
+ }
99
+
100
+ # Tracking for penalties
101
+ self._action_history: list[Action] = []
102
+ self._extraction_attempts: dict[str, int] = {}
103
+ self._url_visits: dict[str, int] = {}
104
+
105
+ def reset(self) -> None:
106
+ """Reset tracking state for a new episode."""
107
+ self._action_history.clear()
108
+ self._extraction_attempts.clear()
109
+ self._url_visits.clear()
110
+
111
+ def compute_reward(
112
+ self,
113
+ action: Action,
114
+ prev_observation: Observation,
115
+ new_observation: Observation,
116
+ ground_truth: dict[str, Any] | None = None,
117
+ max_steps: int = 50,
118
+ ) -> tuple[float, RewardBreakdown]:
119
+ """
120
+ Compute reward for an action.
121
+
122
+ Args:
123
+ action: The action that was taken.
124
+ prev_observation: Observation before the action.
125
+ new_observation: Observation after the action.
126
+ ground_truth: Optional ground truth data for accuracy calculation.
127
+ max_steps: Maximum steps allowed in episode.
128
+
129
+ Returns:
130
+ Tuple of (total_reward, breakdown).
131
+ """
132
+ breakdown = RewardBreakdown()
133
+
134
+ # Track action
135
+ self._action_history.append(action)
136
+
137
+ # Compute accuracy component
138
+ breakdown.accuracy = self._compute_accuracy(
139
+ action, new_observation, ground_truth
140
+ )
141
+
142
+ # Compute efficiency component
143
+ breakdown.efficiency = self._compute_efficiency(
144
+ new_observation.step_number, max_steps
145
+ )
146
+
147
+ # Compute cost component
148
+ breakdown.cost = self._compute_cost_reward(new_observation)
149
+
150
+ # Compute completeness component
151
+ breakdown.completeness = self._compute_completeness(
152
+ prev_observation, new_observation
153
+ )
154
+
155
+ # Compute bonuses
156
+ breakdown.progress_bonus = self._compute_progress_bonus(
157
+ prev_observation, new_observation
158
+ )
159
+ breakdown.exploration_bonus = self._compute_exploration_bonus(
160
+ action, new_observation
161
+ )
162
+ breakdown.verification_bonus = self._compute_verification_bonus(
163
+ action, new_observation
164
+ )
165
+
166
+ # Compute penalties
167
+ breakdown.error_penalty = self._compute_error_penalty(new_observation)
168
+ breakdown.time_penalty = self._compute_time_penalty(new_observation, max_steps)
169
+ breakdown.redundancy_penalty = self._compute_redundancy_penalty(action)
170
+
171
+ # Compute total
172
+ total = breakdown.compute_total(self.weights)
173
+
174
+ return total, breakdown
175
+
176
+ def _compute_accuracy(
177
+ self,
178
+ action: Action,
179
+ observation: Observation,
180
+ ground_truth: dict[str, Any] | None,
181
+ ) -> float:
182
+ """Compute accuracy reward component."""
183
+ if ground_truth is None:
184
+ # Without ground truth, use confidence scores
185
+ if observation.extracted_so_far:
186
+ avg_confidence = sum(
187
+ f.confidence for f in observation.extracted_so_far
188
+ ) / len(observation.extracted_so_far)
189
+ return avg_confidence
190
+ return 0.5 # Neutral
191
+
192
+ # With ground truth, compute actual accuracy
193
+ extracted = observation.get_extraction_dict()
194
+ if not extracted:
195
+ return 0.0
196
+
197
+ correct = 0
198
+ total = 0
199
+ for field_name, expected_value in ground_truth.items():
200
+ if field_name in extracted:
201
+ total += 1
202
+ actual_value = extracted[field_name]
203
+ if self._values_match(actual_value, expected_value):
204
+ correct += 1
205
+
206
+ if total == 0:
207
+ return 0.0
208
+
209
+ return correct / total
210
+
211
+ def _values_match(self, actual: Any, expected: Any) -> bool:
212
+ """Check if extracted value matches expected value."""
213
+ if actual == expected:
214
+ return True
215
+
216
+ # Fuzzy matching for strings
217
+ if isinstance(actual, str) and isinstance(expected, str):
218
+ actual_clean = actual.strip().lower()
219
+ expected_clean = expected.strip().lower()
220
+ if actual_clean == expected_clean:
221
+ return True
222
+ # Partial match
223
+ if expected_clean in actual_clean or actual_clean in expected_clean:
224
+ return True
225
+
226
+ # Numeric comparison with tolerance
227
+ if isinstance(actual, (int, float)) and isinstance(expected, (int, float)):
228
+ tolerance = abs(expected) * 0.01 if expected != 0 else 0.01
229
+ return abs(actual - expected) <= tolerance
230
+
231
+ return False
232
+
233
+ def _compute_efficiency(self, current_step: int, max_steps: int) -> float:
234
+ """Compute efficiency based on steps taken."""
235
+ # Higher reward for completing tasks in fewer steps
236
+ remaining_ratio = (max_steps - current_step) / max_steps
237
+ return max(0.0, remaining_ratio)
238
+
239
+ def _compute_cost_reward(self, observation: Observation) -> float:
240
+ """Compute reward based on cost efficiency."""
241
+ # Penalize high token usage and API calls
242
+ max_expected_tokens = 10000
243
+ max_expected_calls = 50
244
+
245
+ token_efficiency = 1.0 - min(
246
+ observation.tokens_used / max_expected_tokens, 1.0
247
+ )
248
+ call_efficiency = 1.0 - min(
249
+ observation.api_calls_made / max_expected_calls, 1.0
250
+ )
251
+
252
+ return (token_efficiency + call_efficiency) / 2
253
+
254
+ def _compute_completeness(
255
+ self,
256
+ prev_observation: Observation,
257
+ new_observation: Observation,
258
+ ) -> float:
259
+ """Compute completeness based on extraction progress."""
260
+ return new_observation.extraction_progress
261
+
262
+ def _compute_progress_bonus(
263
+ self,
264
+ prev_observation: Observation,
265
+ new_observation: Observation,
266
+ ) -> float:
267
+ """Bonus for making progress."""
268
+ progress_delta = (
269
+ new_observation.extraction_progress - prev_observation.extraction_progress
270
+ )
271
+
272
+ # Bonus for new extractions
273
+ new_extractions = len(new_observation.extracted_so_far) - len(
274
+ prev_observation.extracted_so_far
275
+ )
276
+
277
+ bonus = 0.0
278
+ if progress_delta > 0:
279
+ bonus += progress_delta * 0.5
280
+ if new_extractions > 0:
281
+ bonus += new_extractions * 0.1
282
+
283
+ return bonus
284
+
285
+ def _compute_exploration_bonus(
286
+ self,
287
+ action: Action,
288
+ observation: Observation,
289
+ ) -> float:
290
+ """Bonus for exploring new pages."""
291
+ bonus = 0.0
292
+
293
+ if action.action_type == ActionType.NAVIGATE:
294
+ url = action.get_param("url", "")
295
+ if url and url not in self._url_visits:
296
+ bonus += 0.05
297
+ self._url_visits[url] = self._url_visits.get(url, 0) + 1
298
+
299
+ return bonus
300
+
301
+ def _compute_verification_bonus(
302
+ self,
303
+ action: Action,
304
+ observation: Observation,
305
+ ) -> float:
306
+ """Bonus for verification actions."""
307
+ if action.action_type in [ActionType.VERIFY_FACT, ActionType.VERIFY_FIELD]:
308
+ return 0.05
309
+ return 0.0
310
+
311
+ def _compute_error_penalty(self, observation: Observation) -> float:
312
+ """Penalty for errors."""
313
+ if observation.last_action_error:
314
+ base_penalty = 0.1
315
+ consecutive_penalty = observation.consecutive_errors * 0.05
316
+ return base_penalty + consecutive_penalty
317
+ return 0.0
318
+
319
+ def _compute_time_penalty(
320
+ self,
321
+ observation: Observation,
322
+ max_steps: int,
323
+ ) -> float:
324
+ """Penalty for taking too long."""
325
+ step_ratio = observation.step_number / max_steps
326
+ if step_ratio > 0.8:
327
+ return (step_ratio - 0.8) * 0.5
328
+ return 0.0
329
+
330
+ def _compute_redundancy_penalty(self, action: Action) -> float:
331
+ """Penalty for redundant actions."""
332
+ if len(self._action_history) < 2:
333
+ return 0.0
334
+
335
+ # Check for repeated extract attempts on same field
336
+ if action.action_type == ActionType.EXTRACT_FIELD:
337
+ field = action.get_param("field_name", "")
338
+ attempts = self._extraction_attempts.get(field, 0)
339
+ self._extraction_attempts[field] = attempts + 1
340
+ if attempts > 0:
341
+ return min(attempts * 0.05, 0.2)
342
+
343
+ # Check for repeated navigation to same URL
344
+ if action.action_type == ActionType.NAVIGATE:
345
+ url = action.get_param("url", "")
346
+ visits = self._url_visits.get(url, 0)
347
+ if visits > 1:
348
+ return min((visits - 1) * 0.03, 0.15)
349
+
350
+ return 0.0
351
+
352
+ def compute_terminal_reward(
353
+ self,
354
+ observation: Observation,
355
+ success: bool,
356
+ ground_truth: dict[str, Any] | None = None,
357
+ ) -> tuple[float, RewardBreakdown]:
358
+ """
359
+ Compute final reward at episode termination.
360
+
361
+ Args:
362
+ observation: Final observation.
363
+ success: Whether the task was completed successfully.
364
+ ground_truth: Optional ground truth for accuracy.
365
+
366
+ Returns:
367
+ Tuple of (total_reward, breakdown).
368
+ """
369
+ breakdown = RewardBreakdown()
370
+
371
+ if success:
372
+ # Big bonus for successful completion
373
+ breakdown.completeness = 1.0
374
+ breakdown.progress_bonus = 0.5
375
+
376
+ # Compute final accuracy
377
+ if ground_truth:
378
+ extracted = observation.get_extraction_dict()
379
+ correct = sum(
380
+ 1 for k, v in ground_truth.items()
381
+ if k in extracted and self._values_match(extracted[k], v)
382
+ )
383
+ total = len(ground_truth)
384
+ breakdown.accuracy = correct / total if total > 0 else 1.0
385
+ else:
386
+ breakdown.accuracy = observation.extraction_progress
387
+
388
+ # Efficiency bonus for fast completion
389
+ breakdown.efficiency = 1.0 - (
390
+ observation.step_number / self.settings.max_steps_per_episode
391
+ )
392
+ else:
393
+ # Partial credit for progress made
394
+ breakdown.completeness = observation.extraction_progress * 0.5
395
+ breakdown.error_penalty = 0.3
396
+
397
+ total = breakdown.compute_total(self.weights)
398
+ return total, breakdown