shank commited on
Commit
b8172c5
Β·
1 Parent(s): 77156dd

Cuda returns false fixed

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +24 -3
training/train_grpo.py CHANGED
@@ -41,8 +41,29 @@ args = parser.parse_args()
41
  # requirements.txt only has torch (too large to install at runtime).
42
  # Everything else is installed here, after gradio is already up.
43
  # NOTE: mergekit intentionally excluded β€” conflicts with accelerate/peft/trl.
44
- # NOTE: torch excluded β€” installed at Docker build time via requirements.txt.
45
  if not args.test_local:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  _TRAIN_DEPS = [
47
  "wandb==0.18.7",
48
  "datasets==3.0.2",
@@ -425,7 +446,7 @@ def run_baseline(n: int = 20) -> dict:
425
  prompt = bug_to_prompt(bug)
426
  inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
427
  with torch.no_grad():
428
- out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=False)
429
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
430
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
431
  rewards.append(r[0])
@@ -499,7 +520,7 @@ for bug in bugs:
499
  prompt = bug_to_prompt(bug)
500
  inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
501
  with torch.no_grad():
502
- out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=False)
503
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
504
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
505
  post_rewards.append(r[0])
 
41
  # requirements.txt only has torch (too large to install at runtime).
42
  # Everything else is installed here, after gradio is already up.
43
  # NOTE: mergekit intentionally excluded β€” conflicts with accelerate/peft/trl.
 
44
  if not args.test_local:
45
+ # ── Ensure CUDA-enabled torch is present before anything else imports it ──
46
+ # The default PyPI torch wheel is CPU-only. We must install from the
47
+ # PyTorch CUDA index so that torch.cuda.is_available() returns True and
48
+ # device_map="auto" maps the model to GPU, not RAM.
49
+ import importlib.util, importlib
50
+ _needs_cuda_torch = True
51
+ if importlib.util.find_spec("torch") is not None:
52
+ import torch as _t
53
+ if _t.cuda.is_available():
54
+ _needs_cuda_torch = False
55
+ del _t
56
+ if _needs_cuda_torch:
57
+ print("Installing CUDA-enabled torch (cu121)...", flush=True)
58
+ _r = os.system(
59
+ f"{sys.executable} -m pip install -q --no-cache-dir "
60
+ "torch --index-url https://download.pytorch.org/whl/cu121"
61
+ )
62
+ if _r != 0:
63
+ print("ERROR: CUDA torch install failed.", flush=True)
64
+ sys.exit(1)
65
+ print("CUDA torch installed.", flush=True)
66
+
67
  _TRAIN_DEPS = [
68
  "wandb==0.18.7",
69
  "datasets==3.0.2",
 
446
  prompt = bug_to_prompt(bug)
447
  inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
448
  with torch.no_grad():
449
+ out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
450
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
451
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
452
  rewards.append(r[0])
 
520
  prompt = bug_to_prompt(bug)
521
  inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
522
  with torch.no_grad():
523
+ out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
524
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
525
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
526
  post_rewards.append(r[0])