| """ |
| state.py — Typed run/session/node state for durable execution. |
| |
| RunState captures everything needed to resume a run after interruption: |
| - Current node/step position |
| - Accumulated state data |
| - Pending actions |
| - Memory snapshot references |
| - Tool call idempotency keys |
| """ |
| from __future__ import annotations |
|
|
| import time |
| import uuid |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any |
|
|
|
|
| class RunStatus(str, Enum): |
| """Status of a run.""" |
| PENDING = "pending" |
| RUNNING = "running" |
| PAUSED = "paused" |
| COMPLETED = "completed" |
| FAILED = "failed" |
| CANCELLED = "cancelled" |
|
|
|
|
| @dataclass |
| class NodeState: |
| """State of a single node in a Flow/Graph.""" |
| node_id: str |
| status: str = "pending" |
| input_data: dict[str, Any] = field(default_factory=dict) |
| output_data: dict[str, Any] = field(default_factory=dict) |
| started_at: float = 0.0 |
| finished_at: float = 0.0 |
| error: str | None = None |
| attempts: int = 0 |
| idempotency_key: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) |
|
|
|
|
| @dataclass |
| class RunState: |
| """ |
| Complete state of an execution run — serializable for checkpointing. |
| |
| Contains everything needed to resume from any point. |
| """ |
| run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) |
| session_id: str = "" |
| status: RunStatus = RunStatus.PENDING |
| purpose: str = "" |
|
|
| |
| current_node: str = "" |
| current_step: int = 0 |
| max_steps: int = 20 |
|
|
| |
| data: dict[str, Any] = field(default_factory=dict) |
| node_states: dict[str, NodeState] = field(default_factory=dict) |
|
|
| |
| completed_nodes: list[str] = field(default_factory=list) |
| action_history: list[dict[str, Any]] = field(default_factory=list) |
|
|
| |
| started_at: float = field(default_factory=time.time) |
| updated_at: float = field(default_factory=time.time) |
| finished_at: float = 0.0 |
|
|
| |
| tool_results_cache: dict[str, Any] = field(default_factory=dict) |
|
|
| |
| heuristic_count_at_checkpoint: int = 0 |
| experience_count_at_checkpoint: int = 0 |
|
|
| |
| metadata: dict[str, Any] = field(default_factory=dict) |
| version: int = 1 |
|
|
| def mark_node_started(self, node_id: str) -> None: |
| if node_id not in self.node_states: |
| self.node_states[node_id] = NodeState(node_id=node_id) |
| self.node_states[node_id].status = "running" |
| self.node_states[node_id].started_at = time.time() |
| self.node_states[node_id].attempts += 1 |
| self.current_node = node_id |
| self.updated_at = time.time() |
|
|
| def mark_node_completed(self, node_id: str, output: dict[str, Any] | None = None) -> None: |
| if node_id in self.node_states: |
| self.node_states[node_id].status = "completed" |
| self.node_states[node_id].finished_at = time.time() |
| if output: |
| self.node_states[node_id].output_data = output |
| self.completed_nodes.append(node_id) |
| self.updated_at = time.time() |
|
|
| def mark_node_failed(self, node_id: str, error: str) -> None: |
| if node_id in self.node_states: |
| self.node_states[node_id].status = "failed" |
| self.node_states[node_id].error = error |
| self.node_states[node_id].finished_at = time.time() |
| self.updated_at = time.time() |
|
|
| def get_idempotency_key(self, tool_name: str, args_hash: str) -> str: |
| """Generate a stable idempotency key for a tool call.""" |
| return f"{self.run_id}:{self.current_step}:{tool_name}:{args_hash}" |
|
|
| def has_cached_result(self, key: str) -> bool: |
| return key in self.tool_results_cache |
|
|
| def get_cached_result(self, key: str) -> Any: |
| return self.tool_results_cache.get(key) |
|
|
| def cache_result(self, key: str, result: Any) -> None: |
| self.tool_results_cache[key] = result |
|
|
| def to_dict(self) -> dict[str, Any]: |
| """Serialize for checkpointing.""" |
| return { |
| "run_id": self.run_id, |
| "session_id": self.session_id, |
| "status": self.status.value, |
| "purpose": self.purpose, |
| "current_node": self.current_node, |
| "current_step": self.current_step, |
| "max_steps": self.max_steps, |
| "data": self.data, |
| "node_states": {k: { |
| "node_id": v.node_id, "status": v.status, |
| "input_data": v.input_data, "output_data": v.output_data, |
| "attempts": v.attempts, "error": v.error, |
| "idempotency_key": v.idempotency_key, |
| } for k, v in self.node_states.items()}, |
| "completed_nodes": self.completed_nodes, |
| "action_history": self.action_history[-50:], |
| "started_at": self.started_at, |
| "updated_at": self.updated_at, |
| "finished_at": self.finished_at, |
| "tool_results_cache": self.tool_results_cache, |
| "heuristic_count_at_checkpoint": self.heuristic_count_at_checkpoint, |
| "experience_count_at_checkpoint": self.experience_count_at_checkpoint, |
| "metadata": self.metadata, |
| "version": self.version, |
| } |
|
|
| @classmethod |
| def from_dict(cls, d: dict[str, Any]) -> "RunState": |
| """Deserialize from checkpoint.""" |
| state = cls( |
| run_id=d.get("run_id", ""), |
| session_id=d.get("session_id", ""), |
| status=RunStatus(d.get("status", "pending")), |
| purpose=d.get("purpose", ""), |
| current_node=d.get("current_node", ""), |
| current_step=d.get("current_step", 0), |
| max_steps=d.get("max_steps", 20), |
| data=d.get("data", {}), |
| completed_nodes=d.get("completed_nodes", []), |
| action_history=d.get("action_history", []), |
| started_at=d.get("started_at", 0), |
| updated_at=d.get("updated_at", 0), |
| finished_at=d.get("finished_at", 0), |
| tool_results_cache=d.get("tool_results_cache", {}), |
| heuristic_count_at_checkpoint=d.get("heuristic_count_at_checkpoint", 0), |
| experience_count_at_checkpoint=d.get("experience_count_at_checkpoint", 0), |
| metadata=d.get("metadata", {}), |
| version=d.get("version", 1), |
| ) |
| for k, v in d.get("node_states", {}).items(): |
| state.node_states[k] = NodeState( |
| node_id=v.get("node_id", k), |
| status=v.get("status", "pending"), |
| input_data=v.get("input_data", {}), |
| output_data=v.get("output_data", {}), |
| attempts=v.get("attempts", 0), |
| error=v.get("error"), |
| idempotency_key=v.get("idempotency_key", ""), |
| ) |
| return state |
|
|