stabilizer-forge / models.py
ronitraj's picture
Upload folder using huggingface_hub
b1100bc verified
"""Data models for the StabilizerForge environment.
Action: a single Clifford gate (H, S, CX) or FINALIZE.
Observation: target stabilizers + circuit-so-far + current match fraction + bookkeeping.
"""
from __future__ import annotations
from typing import Literal
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
class StabilizerAction(Action):
"""One Clifford gate, or FINALIZE to end the episode."""
op: Literal["H", "S", "CX", "FINALIZE"] = Field(
..., description="Gate to apply, or FINALIZE to end the episode."
)
qubits: list[int] = Field(
default_factory=list,
description="Target qubit indices. 1 for H/S, 2 for CX (control, target). Empty for FINALIZE.",
)
class StabilizerObservation(Observation):
"""Full per-step view of the episode."""
task_id: str = Field(default="", description="Active task identifier.")
target_stabilizers: list[str] = Field(
default_factory=list,
description="Target stabilizer generators as Pauli strings (e.g., 'XZZXI').",
)
n_qubits: int = Field(default=0, description="Number of physical qubits.")
connectivity_edges: list[list[int]] | None = Field(
default=None,
description="Adjacency edge list. None means all-to-all.",
)
gates_so_far: list[str] = Field(
default_factory=list,
description="Gates applied this episode, as Stim instruction strings.",
)
current_circuit: str = Field(
default="",
description="Concatenated Stim text of all gates emitted so far.",
)
current_match: list[bool] = Field(
default_factory=list,
description="Per-stabilizer preservation under current circuit.",
)
match_fraction: float = Field(
default=0.0, description="Fraction of target stabilizers preserved (0..1)."
)
gates_emitted: int = Field(default=0, description="Number of valid gates applied.")
cnot_count: int = Field(default=0, description="Number of CX gates applied.")
nonadj_cnot_count: int = Field(
default=0,
description="Number of CX gates applied across non-adjacent qubits.",
)
gate_budget: int = Field(default=0, description="Hard cap on total gates.")
gate_budget_remaining: int = Field(default=0)
benchmark_optimum: int = Field(
default=0, description="Reference encoding's gate count (volume-style)."
)
benchmark_optimum_2q: int = Field(
default=0, description="Reference encoding's two-qubit gate count."
)
format_violations: int = Field(default=0)
consecutive_violations: int = Field(default=0)
last_action_valid: bool = Field(default=True)
last_action_error: str = Field(default="")
step_count: int = Field(default=0)
finalized: bool = Field(default=False)