ghostexec / scripts /train_sft_then_grpo.py
modelbuilderhq's picture
Upload folder using huggingface_hub
d669b0f verified
from __future__ import annotations
import argparse
import json
import random
import re
from pathlib import Path
from typing import Any
import requests
from transformers import TrainerCallback
LEGAL_ACTION_TYPES = [
"reply_email",
"archive_email",
"reschedule_meeting",
"cancel_meeting",
"complete_task",
"delegate_task",
"send_message",
"do_nothing",
]
MODEL_PRESETS: dict[str, str] = {
# Fast iteration winner preset: small, strong instruction following, QLoRA-friendly.
"small_iter_fast": "unsloth/Qwen2.5-3B-Instruct",
# Existing baseline used in this repo.
"balanced_3b": "unsloth/Llama-3.2-3B-Instruct",
# Larger option when compute budget is stable.
"bigger_4b": "unsloth/Qwen3-4B-Instruct-2507",
}
TRAINING_PRESETS: dict[str, dict[str, float | int | str]] = {
"hackathon_turbo": {
"max_sft_steps": 80,
"max_grpo_steps": 180,
"env_reward_scale": 1.00,
"local_reward_scale": 0.45,
"complexity_curriculum": "easy_to_full",
"curriculum_ramp_ratio": 0.65,
"sft_samples": 180,
# Optimizer / schedule knobs (stability-first for iterative winning runs)
"sft_lr": 1.2e-5,
"sft_grad_accum": 8,
"grpo_lr": 3.0e-6,
"grpo_grad_accum": 8,
"grpo_beta": 0.08,
"reward_ema_decay": 0.35,
},
# Quicker loop for smoke iterations on weaker hardware.
"quick_smoke": {
"max_sft_steps": 30,
"max_grpo_steps": 80,
"env_reward_scale": 0.95,
"local_reward_scale": 0.35,
"complexity_curriculum": "easy_to_full",
"curriculum_ramp_ratio": 0.50,
"sft_samples": 90,
"sft_lr": 1.5e-5,
"sft_grad_accum": 4,
"grpo_lr": 4.0e-6,
"grpo_grad_accum": 4,
"grpo_beta": 0.06,
"reward_ema_decay": 0.25,
},
}
def _extract_briefing(reset_payload: dict[str, Any]) -> str:
obs = reset_payload.get("observation", reset_payload)
if isinstance(obs, dict):
return str(obs.get("echoed_message", "")).strip()
return ""
def _legal_action_heuristic(briefing: str) -> dict[str, Any]:
# Minimal heuristic used only for SFT warm-start data generation.
# Keeps the action schema valid and non-idle-biased.
lower = briefing.lower()
if "e01" in lower:
return {
"action_type": "reply_email",
"email_id": "e01",
"message_body": "Acknowledged. Sharing a concise update shortly.",
}
if "m02" in lower:
return {
"action_type": "reschedule_meeting",
"meeting_id": "m02",
"new_time": "2026-04-21T18:00:00",
"reason": "Resolve overlap with higher priority commitments.",
}
if "t06" in lower:
return {"action_type": "complete_task", "task_id": "t06"}
return {"action_type": random.choice(LEGAL_ACTION_TYPES)}
def generate_sft_jsonl_from_env(
env_url: str,
out_jsonl: Path,
samples: int = 120,
task_id: str = "phase2_core",
) -> None:
out_jsonl.parent.mkdir(parents=True, exist_ok=True)
rows: list[dict[str, str]] = []
for _ in range(samples):
r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": task_id}, timeout=30)
r.raise_for_status()
payload = r.json()
briefing = _extract_briefing(payload)
if not briefing:
continue
action = _legal_action_heuristic(briefing)
prompt = (
"You are Ghostexec AI Chief-of-Staff.\n"
"Output one valid GhostexecAction JSON only.\n\n"
f"{briefing}"
)
rows.append({"prompt": prompt, "completion": json.dumps(action, ensure_ascii=True)})
with out_jsonl.open("w", encoding="utf-8") as fh:
for row in rows:
fh.write(json.dumps(row, ensure_ascii=True) + "\n")
print(f"Wrote {len(rows)} SFT rows to {out_jsonl}")
def run_sft_then_grpo(
model_name: str,
env_url: str,
sft_jsonl: Path,
out_dir: Path,
env_reward_scale: float,
local_reward_scale: float,
max_sft_steps: int,
max_grpo_steps: int,
complexity_curriculum: str,
curriculum_ramp_ratio: float,
*,
sft_lr: float,
sft_grad_accum: int,
grpo_lr: float,
grpo_grad_accum: int,
grpo_beta: float,
reward_ema_decay: float,
) -> None:
try:
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
except Exception as exc: # pragma: no cover
raise RuntimeError(
"Missing training deps. Install unsloth, trl, datasets, transformers before running."
) from exc
out_dir.mkdir(parents=True, exist_ok=True)
def _trainable_lora_sum_abs(model) -> float:
total = 0.0
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if "lora" not in n.lower():
continue
total += float(p.detach().abs().sum().item())
return total
policy, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
policy = FastLanguageModel.get_peft_model(
policy,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0.0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)
ds = load_dataset("json", data_files=str(sft_jsonl), split="train")
sft_cfg = SFTConfig(
output_dir=str(out_dir / "sft"),
max_steps=max_sft_steps,
per_device_train_batch_size=1,
gradient_accumulation_steps=sft_grad_accum,
learning_rate=sft_lr,
lr_scheduler_type="cosine",
warmup_ratio=0.06,
max_grad_norm=1.0,
adam_beta1=0.9,
adam_beta2=0.95,
logging_steps=5,
save_steps=max(10, max_sft_steps),
report_to=[],
)
sft_trainer = SFTTrainer(
model=policy,
tokenizer=tokenizer,
train_dataset=ds,
args=sft_cfg,
dataset_text_field="prompt",
formatting_func=lambda ex: [f"{p}\n\n{c}" for p, c in zip(ex["prompt"], ex["completion"])],
)
sft_before = _trainable_lora_sum_abs(policy)
sft_trainer.train()
sft_after = _trainable_lora_sum_abs(sft_trainer.model)
sft_delta = abs(sft_after - sft_before)
print(f"SFT LoRA delta(abs-sum): {sft_delta:.6f}")
if sft_delta <= 1e-6:
raise RuntimeError("SFT appears not to have updated LoRA weights (delta too small).")
sft_dir = out_dir / "sft_adapter"
sft_trainer.model.save_pretrained(sft_dir)
tokenizer.save_pretrained(sft_dir)
print(f"SFT complete. Adapter saved: {sft_dir}")
def _extract_json(text: str) -> dict[str, Any] | None:
m = re.search(r"\{.*\}", text, flags=re.S)
if not m:
return None
try:
obj = json.loads(m.group(0))
except Exception:
return None
return obj if isinstance(obj, dict) else None
def _env_step_reward_from_completion(text: str) -> float:
payload = _extract_json(text)
if payload is None:
return -0.25
payload.setdefault("action_type", "do_nothing")
try:
r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": "phase2_core"}, timeout=30)
r.raise_for_status()
s = requests.post(
f"{env_url.rstrip('/')}/step",
json={"action": payload},
timeout=30,
)
s.raise_for_status()
raw = s.json()
except Exception:
return 0.0
rew = raw.get("reward")
if rew is None and isinstance(raw.get("observation"), dict):
rew = raw["observation"].get("reward", 0.0)
try:
return float(rew)
except Exception:
return 0.0
progress = {"step": 0, "total": max(1, max_grpo_steps)}
reward_ema_state = {"env": None}
class _ProgressCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
progress["step"] = int(getattr(state, "global_step", progress["step"]))
return control
def _progress_frac() -> float:
return min(1.0, progress["step"] / progress["total"])
def _curriculum_phase_weight() -> float:
frac = _progress_frac()
ramp = max(0.05, min(1.0, curriculum_ramp_ratio))
if complexity_curriculum == "off":
return 1.0
# easy_to_full: start with strong scaffold guidance, then smoothly
# transition to full env-dominant optimization.
if frac >= ramp:
return 0.0
return max(0.0, 1.0 - (frac / ramp))
def _annealed_local_scale() -> float:
frac = _progress_frac()
base = local_reward_scale * (1.20 - 0.70 * frac)
return base * (1.0 + 0.70 * _curriculum_phase_weight())
def _annealed_env_scale() -> float:
w = _curriculum_phase_weight()
# Slightly downweight env reward in early easy phase to reduce variance,
# then recover to full strength by the end of ramp.
return env_reward_scale * (1.0 - 0.30 * w)
def env_reward(completions, **_):
scale = _annealed_env_scale()
raw = [scale * _env_step_reward_from_completion(str(c)) for c in completions]
if reward_ema_decay <= 0.0:
return raw
batch_mean = sum(raw) / max(len(raw), 1)
prev = reward_ema_state["env"]
d = max(0.0, min(1.0, reward_ema_decay))
if prev is None:
smoothed_mean = batch_mean
else:
smoothed_mean = (1.0 - d) * prev + d * batch_mean
reward_ema_state["env"] = smoothed_mean
delta = smoothed_mean - batch_mean
return [r + delta for r in raw]
def format_reward(completions, **_):
scale = _annealed_local_scale()
outs: list[float] = []
for c in completions:
txt = str(c).strip()
obj = _extract_json(txt)
if obj is None:
outs.append(-0.20 * scale)
continue
if obj.get("action_type") not in LEGAL_ACTION_TYPES:
outs.append(-0.20 * scale)
continue
# Encourage concise, parseable schema-correct JSON.
length_pen = -0.04 * scale if len(txt) > 500 else 0.0
outs.append(0.12 * scale + length_pen)
return outs
def semantic_action_reward(completions, prompts=None, **_):
scale = _annealed_local_scale()
outs: list[float] = []
for i, c in enumerate(completions):
obj = _extract_json(str(c))
if obj is None:
outs.append(-0.10 * scale)
continue
at = str(obj.get("action_type", ""))
ptxt = str(prompts[i] if prompts and i < len(prompts) else "").lower()
bonus = 0.0
if "critical" in ptxt and at == "reply_email":
bonus += 0.08
if "clash" in ptxt and at in ("reschedule_meeting", "cancel_meeting"):
bonus += 0.08
if ("overdue" in ptxt or "due soon" in ptxt) and at in ("complete_task", "delegate_task"):
bonus += 0.08
outs.append(scale * bonus)
return outs
def anti_idle_reward(completions, **_):
scale = _annealed_local_scale()
outs = []
for c in completions:
txt = str(c).lower()
outs.append((-0.20 if "do_nothing" in txt else 0.02) * scale)
return outs
grpo_cfg = GRPOConfig(
output_dir=str(out_dir / "grpo"),
learning_rate=grpo_lr,
per_device_train_batch_size=1,
gradient_accumulation_steps=grpo_grad_accum,
max_steps=max_grpo_steps,
logging_steps=5,
num_generations=2,
beta=grpo_beta,
lr_scheduler_type="cosine",
warmup_ratio=0.06,
max_grad_norm=1.0,
adam_beta1=0.9,
adam_beta2=0.95,
report_to=[],
)
grpo_trainer = GRPOTrainer(
model=sft_trainer.model,
processing_class=tokenizer,
reward_funcs=[env_reward, format_reward, semantic_action_reward, anti_idle_reward],
train_dataset=ds,
args=grpo_cfg,
callbacks=[_ProgressCallback()],
)
grpo_before = _trainable_lora_sum_abs(sft_trainer.model)
grpo_trainer.train()
progress["step"] = progress["total"]
grpo_after = _trainable_lora_sum_abs(grpo_trainer.model)
grpo_delta = abs(grpo_after - grpo_before)
print(f"GRPO LoRA delta(abs-sum): {grpo_delta:.6f}")
if grpo_delta <= 1e-6:
raise RuntimeError("GRPO appears not to have updated LoRA weights (delta too small).")
final_dir = out_dir / "grpo_adapter"
grpo_trainer.model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"GRPO complete. Adapter saved: {final_dir}")
def main() -> None:
parser = argparse.ArgumentParser(description="Run SFT warmup before GRPO.")
parser.add_argument(
"--model-name",
default="",
help="Optional explicit model id. If omitted, --model-preset is used.",
)
parser.add_argument(
"--model-preset",
choices=sorted(MODEL_PRESETS.keys()),
default="small_iter_fast",
help="Recommended compute-aware preset. small_iter_fast is best for iteration speed.",
)
parser.add_argument(
"--training-preset",
choices=sorted(TRAINING_PRESETS.keys()),
default="hackathon_turbo",
help="Compute-aware run preset. hackathon_turbo is best default for iterative winning loops.",
)
parser.add_argument("--env-url", default="http://127.0.0.1:8000")
parser.add_argument("--sft-jsonl", type=Path, default=Path("outputs/sft_from_env.jsonl"))
parser.add_argument("--out-dir", type=Path, default=Path("outputs/train_runs/sft_then_grpo"))
parser.add_argument("--generate-sft-from-env", action="store_true")
parser.add_argument("--sft-samples", type=int, default=120)
parser.add_argument("--max-sft-steps", type=int, default=60)
parser.add_argument("--max-grpo-steps", type=int, default=120)
parser.add_argument("--env-reward-scale", type=float, default=1.0)
parser.add_argument("--local-reward-scale", type=float, default=0.35)
parser.add_argument(
"--complexity-curriculum",
choices=["off", "easy_to_full"],
default="easy_to_full",
help="Reward curriculum: easy_to_full starts with stronger local scaffold and anneals to env-dominant.",
)
parser.add_argument(
"--curriculum-ramp-ratio",
type=float,
default=0.60,
help="Fraction of GRPO steps used to ramp from easy scaffold to full env weighting.",
)
parser.add_argument(
"--reward-ema-decay",
type=float,
default=-1.0,
help="EMA decay in [0,1] for env reward smoothing; -1 uses training preset default.",
)
args = parser.parse_args()
model_name = args.model_name.strip() or MODEL_PRESETS[args.model_preset]
p = TRAINING_PRESETS[args.training_preset]
max_sft_steps = int(p["max_sft_steps"])
max_grpo_steps = int(p["max_grpo_steps"])
env_reward_scale = float(p["env_reward_scale"])
local_reward_scale = float(p["local_reward_scale"])
complexity_curriculum = str(p["complexity_curriculum"])
curriculum_ramp_ratio = float(p["curriculum_ramp_ratio"])
sft_samples = int(p["sft_samples"])
sft_lr = float(p["sft_lr"])
sft_grad_accum = int(p["sft_grad_accum"])
grpo_lr = float(p["grpo_lr"])
grpo_grad_accum = int(p["grpo_grad_accum"])
grpo_beta = float(p["grpo_beta"])
reward_ema_decay = float(p["reward_ema_decay"])
if args.max_sft_steps != 60:
max_sft_steps = args.max_sft_steps
if args.max_grpo_steps != 120:
max_grpo_steps = args.max_grpo_steps
if args.env_reward_scale != 1.0:
env_reward_scale = args.env_reward_scale
if args.local_reward_scale != 0.35:
local_reward_scale = args.local_reward_scale
if args.complexity_curriculum != "easy_to_full":
complexity_curriculum = args.complexity_curriculum
if args.curriculum_ramp_ratio != 0.60:
curriculum_ramp_ratio = args.curriculum_ramp_ratio
if args.sft_samples != 120:
sft_samples = args.sft_samples
if args.reward_ema_decay >= 0.0:
reward_ema_decay = float(args.reward_ema_decay)
print(f"Model preset: {args.model_preset} -> {model_name}")
print(
"Training preset:"
f" {args.training_preset} -> sft={max_sft_steps}, grpo={max_grpo_steps},"
f" env_scale={env_reward_scale}, local_scale={local_reward_scale},"
f" curriculum={complexity_curriculum}, ramp={curriculum_ramp_ratio}"
)
if args.generate_sft_from_env or not args.sft_jsonl.exists():
generate_sft_jsonl_from_env(
env_url=args.env_url,
out_jsonl=args.sft_jsonl,
samples=sft_samples,
task_id="phase2_core",
)
run_sft_then_grpo(
model_name=model_name,
env_url=args.env_url,
sft_jsonl=args.sft_jsonl,
out_dir=args.out_dir,
env_reward_scale=env_reward_scale,
local_reward_scale=local_reward_scale,
max_sft_steps=max_sft_steps,
max_grpo_steps=max_grpo_steps,
complexity_curriculum=complexity_curriculum,
curriculum_ramp_ratio=curriculum_ramp_ratio,
sft_lr=sft_lr,
sft_grad_accum=sft_grad_accum,
grpo_lr=grpo_lr,
grpo_grad_accum=grpo_grad_accum,
grpo_beta=grpo_beta,
reward_ema_decay=reward_ema_decay,
)
if __name__ == "__main__":
main()