Rohan03's picture
Sprint 2: runtime/state.py — typed RunState for durable execution
8da4ecb verified
"""
state.py — Typed run/session/node state for durable execution.
RunState captures everything needed to resume a run after interruption:
- Current node/step position
- Accumulated state data
- Pending actions
- Memory snapshot references
- Tool call idempotency keys
"""
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class RunStatus(str, Enum):
"""Status of a run."""
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused" # HITL or checkpoint pause
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class NodeState:
"""State of a single node in a Flow/Graph."""
node_id: str
status: str = "pending" # pending, running, completed, failed, skipped
input_data: dict[str, Any] = field(default_factory=dict)
output_data: dict[str, Any] = field(default_factory=dict)
started_at: float = 0.0
finished_at: float = 0.0
error: str | None = None
attempts: int = 0
idempotency_key: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
@dataclass
class RunState:
"""
Complete state of an execution run — serializable for checkpointing.
Contains everything needed to resume from any point.
"""
run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
session_id: str = ""
status: RunStatus = RunStatus.PENDING
purpose: str = ""
# Position
current_node: str = ""
current_step: int = 0
max_steps: int = 20
# State data
data: dict[str, Any] = field(default_factory=dict)
node_states: dict[str, NodeState] = field(default_factory=dict)
# History
completed_nodes: list[str] = field(default_factory=list)
action_history: list[dict[str, Any]] = field(default_factory=list)
# Timing
started_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
finished_at: float = 0.0
# Idempotency tracking (tool_call_id → result)
tool_results_cache: dict[str, Any] = field(default_factory=dict)
# Memory references (for post-resume sync)
heuristic_count_at_checkpoint: int = 0
experience_count_at_checkpoint: int = 0
# Metadata
metadata: dict[str, Any] = field(default_factory=dict)
version: int = 1 # For state schema migration
def mark_node_started(self, node_id: str) -> None:
if node_id not in self.node_states:
self.node_states[node_id] = NodeState(node_id=node_id)
self.node_states[node_id].status = "running"
self.node_states[node_id].started_at = time.time()
self.node_states[node_id].attempts += 1
self.current_node = node_id
self.updated_at = time.time()
def mark_node_completed(self, node_id: str, output: dict[str, Any] | None = None) -> None:
if node_id in self.node_states:
self.node_states[node_id].status = "completed"
self.node_states[node_id].finished_at = time.time()
if output:
self.node_states[node_id].output_data = output
self.completed_nodes.append(node_id)
self.updated_at = time.time()
def mark_node_failed(self, node_id: str, error: str) -> None:
if node_id in self.node_states:
self.node_states[node_id].status = "failed"
self.node_states[node_id].error = error
self.node_states[node_id].finished_at = time.time()
self.updated_at = time.time()
def get_idempotency_key(self, tool_name: str, args_hash: str) -> str:
"""Generate a stable idempotency key for a tool call."""
return f"{self.run_id}:{self.current_step}:{tool_name}:{args_hash}"
def has_cached_result(self, key: str) -> bool:
return key in self.tool_results_cache
def get_cached_result(self, key: str) -> Any:
return self.tool_results_cache.get(key)
def cache_result(self, key: str, result: Any) -> None:
self.tool_results_cache[key] = result
def to_dict(self) -> dict[str, Any]:
"""Serialize for checkpointing."""
return {
"run_id": self.run_id,
"session_id": self.session_id,
"status": self.status.value,
"purpose": self.purpose,
"current_node": self.current_node,
"current_step": self.current_step,
"max_steps": self.max_steps,
"data": self.data,
"node_states": {k: {
"node_id": v.node_id, "status": v.status,
"input_data": v.input_data, "output_data": v.output_data,
"attempts": v.attempts, "error": v.error,
"idempotency_key": v.idempotency_key,
} for k, v in self.node_states.items()},
"completed_nodes": self.completed_nodes,
"action_history": self.action_history[-50:], # Keep last 50
"started_at": self.started_at,
"updated_at": self.updated_at,
"finished_at": self.finished_at,
"tool_results_cache": self.tool_results_cache,
"heuristic_count_at_checkpoint": self.heuristic_count_at_checkpoint,
"experience_count_at_checkpoint": self.experience_count_at_checkpoint,
"metadata": self.metadata,
"version": self.version,
}
@classmethod
def from_dict(cls, d: dict[str, Any]) -> "RunState":
"""Deserialize from checkpoint."""
state = cls(
run_id=d.get("run_id", ""),
session_id=d.get("session_id", ""),
status=RunStatus(d.get("status", "pending")),
purpose=d.get("purpose", ""),
current_node=d.get("current_node", ""),
current_step=d.get("current_step", 0),
max_steps=d.get("max_steps", 20),
data=d.get("data", {}),
completed_nodes=d.get("completed_nodes", []),
action_history=d.get("action_history", []),
started_at=d.get("started_at", 0),
updated_at=d.get("updated_at", 0),
finished_at=d.get("finished_at", 0),
tool_results_cache=d.get("tool_results_cache", {}),
heuristic_count_at_checkpoint=d.get("heuristic_count_at_checkpoint", 0),
experience_count_at_checkpoint=d.get("experience_count_at_checkpoint", 0),
metadata=d.get("metadata", {}),
version=d.get("version", 1),
)
for k, v in d.get("node_states", {}).items():
state.node_states[k] = NodeState(
node_id=v.get("node_id", k),
status=v.get("status", "pending"),
input_data=v.get("input_data", {}),
output_data=v.get("output_data", {}),
attempts=v.get("attempts", 0),
error=v.get("error"),
idempotency_key=v.get("idempotency_key", ""),
)
return state