shank commited on
Commit
dc8001b
Β·
1 Parent(s): 4bac574

Fix eval device selection with CUDA-safe fallback

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +36 -2
training/train_grpo.py CHANGED
@@ -253,6 +253,40 @@ model = FastLanguageModel.get_peft_model(
253
  )
254
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # ── Reward function ───────────────────────────────────────────────────────────
257
  calculator = DebugRewardCalculator()
258
 
@@ -305,7 +339,7 @@ def run_baseline(n: int = 20) -> dict:
305
  solved = 0
306
  for bug in bugs:
307
  prompt = bug_to_prompt(bug)
308
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
309
  with torch.no_grad():
310
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
311
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
@@ -379,7 +413,7 @@ post_rewards = []
379
  post_solved = 0
380
  for bug in bugs:
381
  prompt = bug_to_prompt(bug)
382
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
383
  with torch.no_grad():
384
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
385
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
253
  )
254
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
255
 
256
+ # ── Runtime device selection ──────────────────────────────────────────────────
257
+ def _select_runtime_device(model) -> str:
258
+ """
259
+ Pick the safest generation device without forcing CUDA init on broken drivers.
260
+ """
261
+ def _cuda_usable() -> bool:
262
+ try:
263
+ if not torch.cuda.is_available():
264
+ return False
265
+ # Force lightweight CUDA init probe.
266
+ _ = torch.zeros(1, device="cuda")
267
+ return True
268
+ except Exception as e:
269
+ print(f"WARNING: CUDA initialization failed ({e}). Falling back to CPU.")
270
+ return False
271
+
272
+ # Prefer model's current device when available.
273
+ try:
274
+ model_device = str(next(model.parameters()).device)
275
+ if model_device.startswith("cuda") and not _cuda_usable():
276
+ return "cpu"
277
+ return model_device
278
+ except Exception:
279
+ pass
280
+
281
+ # Fallback to torch capability checks.
282
+ if _cuda_usable():
283
+ return "cuda"
284
+ return "cpu"
285
+
286
+
287
+ RUNTIME_DEVICE = _select_runtime_device(model)
288
+ print(f"Using generation/training runtime device: {RUNTIME_DEVICE}")
289
+
290
  # ── Reward function ───────────────────────────────────────────────────────────
291
  calculator = DebugRewardCalculator()
292
 
 
339
  solved = 0
340
  for bug in bugs:
341
  prompt = bug_to_prompt(bug)
342
+ inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
343
  with torch.no_grad():
344
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
345
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
413
  post_solved = 0
414
  for bug in bugs:
415
  prompt = bug_to_prompt(bug)
416
+ inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
417
  with torch.no_grad():
418
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
419
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)