shank commited on
Commit ·
73f957d
1
Parent(s): 8b16369
Optimize for Kaggle P100: float16, batch=1, grad_accum=8, num_gen=4, max_completion=256, lora_r=8
Browse files- requirements_kaggle.txt +9 -0
- training/train_grpo.py +16 -16
requirements_kaggle.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kaggle P100 — torch is pre-installed, skip it
|
| 2 |
+
# pip install -r requirements_kaggle.txt
|
| 3 |
+
wandb==0.18.7
|
| 4 |
+
datasets==3.0.2
|
| 5 |
+
transformers==4.46.3
|
| 6 |
+
accelerate==1.0.1
|
| 7 |
+
trl==0.14.0
|
| 8 |
+
bitsandbytes==0.43.3
|
| 9 |
+
peft==0.13.2
|
training/train_grpo.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
AgentDebuggerEnv — GRPO Training Script
|
| 3 |
-
Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via
|
| 4 |
Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
|
| 5 |
-
GPU:
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
# Local reward sanity-check (no GPU, no model loading):
|
|
@@ -262,7 +262,7 @@ print(f"Loading {MODEL_NAME}...")
|
|
| 262 |
bnb_config = BitsAndBytesConfig(
|
| 263 |
load_in_4bit=True,
|
| 264 |
bnb_4bit_quant_type="nf4",
|
| 265 |
-
bnb_4bit_compute_dtype=torch.
|
| 266 |
bnb_4bit_use_double_quant=True,
|
| 267 |
)
|
| 268 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
@@ -274,12 +274,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 274 |
quantization_config=bnb_config,
|
| 275 |
device_map="auto",
|
| 276 |
trust_remote_code=True,
|
| 277 |
-
torch_dtype=torch.
|
| 278 |
)
|
| 279 |
model.config.use_cache = False
|
| 280 |
|
| 281 |
lora_config = LoraConfig(
|
| 282 |
-
r=
|
| 283 |
lora_alpha=16,
|
| 284 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 285 |
"gate_proj", "up_proj", "down_proj"],
|
|
@@ -380,7 +380,7 @@ def run_baseline(n: int = 20) -> dict:
|
|
| 380 |
prompt = bug_to_prompt(bug)
|
| 381 |
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 382 |
with torch.no_grad():
|
| 383 |
-
out = model.generate(**inputs, max_new_tokens=
|
| 384 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 385 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 386 |
rewards.append(r[0])
|
|
@@ -407,16 +407,16 @@ def make_dataset(step: int) -> Dataset:
|
|
| 407 |
config = GRPOConfig(
|
| 408 |
output_dir=CHECKPOINT_DIR,
|
| 409 |
max_steps=MAX_STEPS,
|
| 410 |
-
per_device_train_batch_size=
|
| 411 |
-
gradient_accumulation_steps=
|
| 412 |
-
learning_rate=2e-5,
|
| 413 |
lr_scheduler_type="cosine",
|
| 414 |
-
warmup_steps=
|
| 415 |
-
num_generations=
|
| 416 |
-
max_completion_length=
|
| 417 |
-
temperature=0.9,
|
| 418 |
-
logging_steps=5
|
| 419 |
-
save_steps=50 if args.test else
|
| 420 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 421 |
)
|
| 422 |
|
|
@@ -454,7 +454,7 @@ for bug in bugs:
|
|
| 454 |
prompt = bug_to_prompt(bug)
|
| 455 |
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 456 |
with torch.no_grad():
|
| 457 |
-
out = model.generate(**inputs, max_new_tokens=
|
| 458 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 459 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 460 |
post_rewards.append(r[0])
|
|
|
|
| 1 |
"""
|
| 2 |
AgentDebuggerEnv — GRPO Training Script
|
| 3 |
+
Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via bitsandbytes)
|
| 4 |
Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
|
| 5 |
+
GPU: Kaggle P100 (16GB) — float16 only, no bfloat16
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
# Local reward sanity-check (no GPU, no model loading):
|
|
|
|
| 262 |
bnb_config = BitsAndBytesConfig(
|
| 263 |
load_in_4bit=True,
|
| 264 |
bnb_4bit_quant_type="nf4",
|
| 265 |
+
bnb_4bit_compute_dtype=torch.float16, # P100 has no bfloat16 hardware support
|
| 266 |
bnb_4bit_use_double_quant=True,
|
| 267 |
)
|
| 268 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
| 274 |
quantization_config=bnb_config,
|
| 275 |
device_map="auto",
|
| 276 |
trust_remote_code=True,
|
| 277 |
+
torch_dtype=torch.float16, # P100 has no bfloat16 hardware support
|
| 278 |
)
|
| 279 |
model.config.use_cache = False
|
| 280 |
|
| 281 |
lora_config = LoraConfig(
|
| 282 |
+
r=8, # P100: 16GB VRAM, halved from r=16
|
| 283 |
lora_alpha=16,
|
| 284 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 285 |
"gate_proj", "up_proj", "down_proj"],
|
|
|
|
| 380 |
prompt = bug_to_prompt(bug)
|
| 381 |
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 382 |
with torch.no_grad():
|
| 383 |
+
out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=False)
|
| 384 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 385 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 386 |
rewards.append(r[0])
|
|
|
|
| 407 |
config = GRPOConfig(
|
| 408 |
output_dir=CHECKPOINT_DIR,
|
| 409 |
max_steps=MAX_STEPS,
|
| 410 |
+
per_device_train_batch_size=1, # P100 16GB: must be 1
|
| 411 |
+
gradient_accumulation_steps=8, # effective batch = 8 (compensates for batch=1)
|
| 412 |
+
learning_rate=2e-5,
|
| 413 |
lr_scheduler_type="cosine",
|
| 414 |
+
warmup_steps=10 if args.test else 30,
|
| 415 |
+
num_generations=4, # P100: halved from 8 to fit in 16GB
|
| 416 |
+
max_completion_length=256, # P100: halved from 512 to fit in 16GB
|
| 417 |
+
temperature=0.9,
|
| 418 |
+
logging_steps=5,
|
| 419 |
+
save_steps=50 if args.test else 50,
|
| 420 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 421 |
)
|
| 422 |
|
|
|
|
| 454 |
prompt = bug_to_prompt(bug)
|
| 455 |
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 456 |
with torch.no_grad():
|
| 457 |
+
out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=False)
|
| 458 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 459 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 460 |
post_rewards.append(r[0])
|