Marlin Lee commited on
Commit
02c5c77
·
1 Parent(s): 3529287

Fall back to CPU for backbone+SAE inference when no CUDA available

Browse files
Files changed (1) hide show
  1. scripts/explorer/inference.py +1 -4
scripts/explorer/inference.py CHANGED
@@ -47,16 +47,13 @@ def get_gpu_runner():
47
  return _gpu_runner
48
  if not args.sae_path:
49
  return None
50
- if not torch.cuda.is_available():
51
- print("[GPU runner] No CUDA device — on-the-fly inference disabled.")
52
- return None
53
 
54
  src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
55
  sys.path.insert(0, src_dir)
56
  from backbone_runners import load_batched_backbone
57
  from precompute_utils import load_sae, extract_tokens
58
 
59
- dev = torch.device("cuda:0")
60
  d_model = _all_datasets[0]['d_model'] # SAE output dim fixed to primary dataset
61
  print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...")
62
  fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev)
 
47
  return _gpu_runner
48
  if not args.sae_path:
49
  return None
 
 
 
50
 
51
  src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
52
  sys.path.insert(0, src_dir)
53
  from backbone_runners import load_batched_backbone
54
  from precompute_utils import load_sae, extract_tokens
55
 
56
+ dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
57
  d_model = _all_datasets[0]['d_model'] # SAE output dim fixed to primary dataset
58
  print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...")
59
  fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev)