feather-a10g-large-runtime / overlay /scripts /htm_gpu_micro_canary.py
Jackoatmon's picture
Update Feather a10g-large training runtime image
fc102e3 verified
#!/usr/bin/env python3
"""Standalone GPU HTM micro-canary for HYDRA/Feather.
This intentionally bypasses the full language-model forward path and exercises
only the HTMLayer CUDA path that failed in the H200 optimal-strict canary. It
prints JSON lines so HF job logs can be parsed mechanically.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
import traceback
from pathlib import Path
from typing import Any
import torch
def ensure_repo_on_path() -> None:
"""Make overlay package imports work from both /app/scripts and repo-root runs."""
candidates = [
Path('/workspace/feather'),
Path(__file__).resolve().parents[1] if len(Path(__file__).resolve().parents) > 1 else None,
]
for candidate in candidates:
if candidate and (candidate / 'subsystems' / 'htm.py').exists():
candidate_s = str(candidate)
if candidate_s not in sys.path:
sys.path.insert(0, candidate_s)
return
def build_htm_env(mode: str) -> dict[str, str]:
"""Return env overrides for the requested HTM diagnostic mode."""
if mode not in {"batched-fused", "fused", "cuda"}:
raise ValueError(f"unknown mode: {mode}")
return {
"HYDRA_FORCE_HTM_CPU": "0",
"HYDRA_HTM_FUSED": "1" if mode in {"batched-fused", "fused"} else "0",
"HYDRA_HTM_BATCHED_FUSED": "1" if mode == "batched-fused" else "0",
# Strict only for batched-fused: the goal is to catch missing batched
# entrypoints loudly. The other modes are deliberate diagnostic bisection
# modes and should be allowed to exercise narrower paths.
"HYDRA_STRICT_OPTIMAL_COMPONENTS": "1" if mode == "batched-fused" else "0",
}
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--mode", choices=["batched-fused", "fused", "cuda"], default="batched-fused")
parser.add_argument("--batch", type=int, default=int(os.environ.get("HYDRA_BATCH_SIZE", "4")))
parser.add_argument("--seq", type=int, default=int(os.environ.get("HYDRA_HTM_MICRO_SEQ", os.environ.get("HYDRA_MAX_SEQ_LEN", "512"))))
parser.add_argument("--input-bits", type=int, default=int(os.environ.get("HYDRA_HTM_INPUT_BITS", "16384")))
parser.add_argument("--n-columns", type=int, default=int(os.environ.get("HYDRA_HTM_COLUMNS", "2048")))
parser.add_argument("--cells-per-column", type=int, default=int(os.environ.get("HYDRA_HTM_CELLS_PER_COLUMN", "32")))
parser.add_argument("--active-bits", type=int, default=int(os.environ.get("HYDRA_HTM_ACTIVE_BITS", "256")))
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--learn", action="store_true")
parser.add_argument("--sync-each", action="store_true", help="use HTMLayer.forward instead of forward_async/forward_await")
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args(argv)
def emit(event: str, **payload: Any) -> None:
print(json.dumps({"event": event, **payload}, sort_keys=True), flush=True)
def make_sparse_sdr(*, batch: int, seq: int, input_bits: int, active_bits: int, device: str, seed: int):
import torch
if active_bits <= 0 or active_bits > input_bits:
raise ValueError("active_bits must be in [1, input_bits]")
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
sdr = torch.zeros((batch, seq, input_bits), dtype=torch.uint8, device="cpu")
for b in range(batch):
for t in range(seq):
idx = torch.randperm(input_bits, generator=gen)[:active_bits]
sdr[b, t, idx] = 1
return sdr.to(device, non_blocking=False)
def _plan_payload(args: argparse.Namespace, env: dict[str, str]) -> dict[str, Any]:
return {
"mode": args.mode,
"shape": {"batch": args.batch, "seq": args.seq, "input_bits": args.input_bits},
"htm": {"n_columns": args.n_columns, "cells_per_column": args.cells_per_column, "active_bits": args.active_bits},
"learn": bool(args.learn),
"sync_each": bool(args.sync_each),
"env": env,
}
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
env = build_htm_env(args.mode)
os.environ.update(env)
emit("plan", **_plan_payload(args, env))
if args.dry_run:
return 0
import torch
ensure_repo_on_path()
from subsystems.htm import HTMLayer
emit(
"cuda_state",
torch_cuda_available=torch.cuda.is_available(),
device_count=torch.cuda.device_count() if torch.cuda.is_available() else 0,
device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for HTM GPU micro-canary")
device = "cuda"
sdr = make_sparse_sdr(
batch=args.batch,
seq=args.seq,
input_bits=args.input_bits,
active_bits=args.active_bits,
device=device,
seed=args.seed,
)
emit("sdr_ready", dtype=str(sdr.dtype), shape=list(sdr.shape), active_total=int(sdr.sum().item()))
layer = HTMLayer(
input_bits=args.input_bits,
n_columns=args.n_columns,
cells_per_column=args.cells_per_column,
batch_size=args.batch,
seed=args.seed,
learn=args.learn,
use_gpu=True,
reset_each_forward=True,
).to(device)
if args.learn:
layer.train()
else:
layer.eval()
emit("layer_ready", use_gpu=bool(getattr(layer, "_use_gpu", False)), region_count=len(getattr(layer, "_regions", [])))
start = time.perf_counter()
if args.sync_each:
out = layer(sdr)
else:
handle = layer.forward_async(sdr)
emit("forward_submitted", handle_keys=sorted(handle.keys()))
out = layer.forward_await(handle)
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1000.0
emit("success", elapsed_ms=round(elapsed_ms, 3), output_shape=list(out.shape), output_dtype=str(out.dtype))
return 0
if __name__ == "__main__":
raise SystemExit(main())