Spaces:
Sleeping
Sleeping
| """scripts/validate_env.py - the Section 1.1 environment validation. | |
| Five gates, in order: | |
| 1. Imports succeed (catches install issues). | |
| 2. Stim generates a tiny distance-3 surface code. | |
| 3. PyMatching decodes 100 syndromes. | |
| 4. Logical-error rate at p=0.001 is in the expected range. | |
| 5. ``DecoderEnvironment`` reset+step works end-to-end (proves the wire | |
| contract is intact). | |
| Run with:: | |
| .venv/bin/python -m scripts.validate_env | |
| Exit code is 0 iff every gate passes. The participant guide explicitly | |
| warns: *"if any of these fail on any team member's machine, fix it now - | |
| not at 11pm on Day 1."* | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import time | |
| from typing import Iterable | |
| GATES = [] | |
| def gate(name: str): | |
| def deco(fn): | |
| GATES.append((name, fn)) | |
| return fn | |
| return deco | |
| def _ok(name: str, msg: str = "") -> None: | |
| extra = f" {msg}" if msg else "" | |
| print(f" PASS {name}{extra}") | |
| def _fail(name: str, msg: str) -> None: | |
| print(f" FAIL {name} -- {msg}") | |
| # --------------------------------------------------------------------------- # | |
| def _imports() -> None: | |
| import stim, pymatching, numpy, fastapi, pydantic # noqa: F401 | |
| import qubit_medic | |
| import qubit_medic.config | |
| import qubit_medic.models | |
| import qubit_medic.prompts | |
| import qubit_medic.server.physics | |
| import qubit_medic.server.rewards | |
| import qubit_medic.server.curriculum | |
| import qubit_medic.server.environment | |
| print(f" stim={stim.__version__} pymatching={pymatching.__version__} " | |
| f"qubit_medic={qubit_medic.__version__}") | |
| def _stim_gen() -> None: | |
| from qubit_medic.config import primary_level | |
| from qubit_medic.server.physics import build_circuit, build_dem, extract_layout | |
| c = build_circuit(primary_level()) | |
| dem = build_dem(c) | |
| layout = extract_layout(c) | |
| assert layout.num_data_qubits == 9, f"expected 9 data qubits, got {layout.num_data_qubits}" | |
| assert layout.num_ancilla_qubits == 8 | |
| assert layout.z_observable_support == (1, 3, 5) | |
| print(f" circuit={len(str(c))} chars, DEM={len(str(dem))} chars, " | |
| f"obs_support={layout.z_observable_support}") | |
| def _pm_decoding() -> None: | |
| import pymatching, numpy as np | |
| from qubit_medic.config import primary_level | |
| from qubit_medic.server.physics import build_circuit, build_dem | |
| c = build_circuit(primary_level()) | |
| dem = build_dem(c) | |
| sampler = c.compile_detector_sampler(seed=42) | |
| det, obs = sampler.sample(100, separate_observables=True) | |
| m = pymatching.Matching.from_detector_error_model(dem) | |
| pred = m.decode_batch(det) | |
| err_rate = float(np.mean(np.any(pred != obs, axis=1))) | |
| print(f" logical-error rate (100 shots): {err_rate:.4f}") | |
| def _ler_range() -> None: | |
| """At distance 3, p=0.001, 5000 shots, PyMatching LER should be < 1%.""" | |
| import pymatching, numpy as np | |
| from qubit_medic.config import primary_level | |
| from qubit_medic.server.physics import build_circuit, build_dem | |
| c = build_circuit(primary_level()) | |
| dem = build_dem(c) | |
| sampler = c.compile_detector_sampler(seed=2024) | |
| det, obs = sampler.sample(5000, separate_observables=True) | |
| m = pymatching.Matching.from_detector_error_model(dem) | |
| pred = m.decode_batch(det) | |
| err = float(np.mean(np.any(pred != obs, axis=1))) | |
| expected_lo, expected_hi = 0.0, 0.01 | |
| if not (expected_lo <= err <= expected_hi): | |
| raise AssertionError( | |
| f"PyMatching LER {err:.4f} outside [{expected_lo}, {expected_hi}]" | |
| ) | |
| print(f" PyMatching LER on 5000 shots: {err:.4f} " | |
| f"(expected ~0.001 - 0.01)") | |
| def _env_roundtrip() -> None: | |
| """Reset + step round-trip with three trivial policies.""" | |
| from qubit_medic.client.client import LocalDecoderClient | |
| from qubit_medic.prompts import format_completion | |
| client = LocalDecoderClient() | |
| obs = client.reset(forced_level="L2_target", seed=1) | |
| assert obs.distance == 3 and obs.rounds == 3 | |
| assert obs.curriculum_level == "L2_target" | |
| # All-zeros policy: claim no errors. | |
| result = client.step( | |
| raw_response=format_completion([], []), | |
| episode_id=obs.episode_id, | |
| ) | |
| assert result.done is True | |
| assert "rewards" in result.info | |
| print(f" reset->step round-trip ok; " | |
| f"all-zeros total reward={result.reward:.3f}, " | |
| f"breakdown={result.info['rewards']}") | |
| # Trivial second episode under forced L1. | |
| obs2 = client.reset(forced_level="L1_warmup", seed=2) | |
| assert obs2.distance == 3 and obs2.rounds == 1 | |
| print(f" L1 warmup reset OK; prompt is {len(obs2.prompt)} chars long") | |
| # --------------------------------------------------------------------------- # | |
| def main(argv: Iterable[str] = ()) -> int: | |
| print("Qubit-Medic environment validation") | |
| print("=" * 60) | |
| failures = 0 | |
| started = time.monotonic() | |
| for name, fn in GATES: | |
| try: | |
| fn() | |
| _ok(name) | |
| except Exception as exc: # noqa: BLE001 - we want to keep going | |
| _fail(name, repr(exc)) | |
| failures += 1 | |
| elapsed = time.monotonic() - started | |
| print("=" * 60) | |
| if failures: | |
| print(f"{failures} gate(s) failed in {elapsed:.2f}s") | |
| return 1 | |
| print(f"all {len(GATES)} gates passed in {elapsed:.2f}s") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main(sys.argv[1:])) | |