Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import re | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| from transformers import TrainerCallback | |
| LEGAL_ACTION_TYPES = [ | |
| "reply_email", | |
| "archive_email", | |
| "reschedule_meeting", | |
| "cancel_meeting", | |
| "complete_task", | |
| "delegate_task", | |
| "send_message", | |
| "do_nothing", | |
| ] | |
| MODEL_PRESETS: dict[str, str] = { | |
| # Fast iteration winner preset: small, strong instruction following, QLoRA-friendly. | |
| "small_iter_fast": "unsloth/Qwen2.5-3B-Instruct", | |
| # Existing baseline used in this repo. | |
| "balanced_3b": "unsloth/Llama-3.2-3B-Instruct", | |
| # Larger option when compute budget is stable. | |
| "bigger_4b": "unsloth/Qwen3-4B-Instruct-2507", | |
| } | |
| TRAINING_PRESETS: dict[str, dict[str, float | int | str]] = { | |
| "hackathon_turbo": { | |
| "max_sft_steps": 80, | |
| "max_grpo_steps": 180, | |
| "env_reward_scale": 1.00, | |
| "local_reward_scale": 0.45, | |
| "complexity_curriculum": "easy_to_full", | |
| "curriculum_ramp_ratio": 0.65, | |
| "sft_samples": 180, | |
| # Optimizer / schedule knobs (stability-first for iterative winning runs) | |
| "sft_lr": 1.2e-5, | |
| "sft_grad_accum": 8, | |
| "grpo_lr": 3.0e-6, | |
| "grpo_grad_accum": 8, | |
| "grpo_beta": 0.08, | |
| "reward_ema_decay": 0.35, | |
| }, | |
| # Quicker loop for smoke iterations on weaker hardware. | |
| "quick_smoke": { | |
| "max_sft_steps": 30, | |
| "max_grpo_steps": 80, | |
| "env_reward_scale": 0.95, | |
| "local_reward_scale": 0.35, | |
| "complexity_curriculum": "easy_to_full", | |
| "curriculum_ramp_ratio": 0.50, | |
| "sft_samples": 90, | |
| "sft_lr": 1.5e-5, | |
| "sft_grad_accum": 4, | |
| "grpo_lr": 4.0e-6, | |
| "grpo_grad_accum": 4, | |
| "grpo_beta": 0.06, | |
| "reward_ema_decay": 0.25, | |
| }, | |
| } | |
| def _extract_briefing(reset_payload: dict[str, Any]) -> str: | |
| obs = reset_payload.get("observation", reset_payload) | |
| if isinstance(obs, dict): | |
| return str(obs.get("echoed_message", "")).strip() | |
| return "" | |
| def _legal_action_heuristic(briefing: str) -> dict[str, Any]: | |
| # Minimal heuristic used only for SFT warm-start data generation. | |
| # Keeps the action schema valid and non-idle-biased. | |
| lower = briefing.lower() | |
| if "e01" in lower: | |
| return { | |
| "action_type": "reply_email", | |
| "email_id": "e01", | |
| "message_body": "Acknowledged. Sharing a concise update shortly.", | |
| } | |
| if "m02" in lower: | |
| return { | |
| "action_type": "reschedule_meeting", | |
| "meeting_id": "m02", | |
| "new_time": "2026-04-21T18:00:00", | |
| "reason": "Resolve overlap with higher priority commitments.", | |
| } | |
| if "t06" in lower: | |
| return {"action_type": "complete_task", "task_id": "t06"} | |
| return {"action_type": random.choice(LEGAL_ACTION_TYPES)} | |
| def generate_sft_jsonl_from_env( | |
| env_url: str, | |
| out_jsonl: Path, | |
| samples: int = 120, | |
| task_id: str = "phase2_core", | |
| ) -> None: | |
| out_jsonl.parent.mkdir(parents=True, exist_ok=True) | |
| rows: list[dict[str, str]] = [] | |
| for _ in range(samples): | |
| r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": task_id}, timeout=30) | |
| r.raise_for_status() | |
| payload = r.json() | |
| briefing = _extract_briefing(payload) | |
| if not briefing: | |
| continue | |
| action = _legal_action_heuristic(briefing) | |
| prompt = ( | |
| "You are Ghostexec AI Chief-of-Staff.\n" | |
| "Output one valid GhostexecAction JSON only.\n\n" | |
| f"{briefing}" | |
| ) | |
| rows.append({"prompt": prompt, "completion": json.dumps(action, ensure_ascii=True)}) | |
| with out_jsonl.open("w", encoding="utf-8") as fh: | |
| for row in rows: | |
| fh.write(json.dumps(row, ensure_ascii=True) + "\n") | |
| print(f"Wrote {len(rows)} SFT rows to {out_jsonl}") | |
| def run_sft_then_grpo( | |
| model_name: str, | |
| env_url: str, | |
| sft_jsonl: Path, | |
| out_dir: Path, | |
| env_reward_scale: float, | |
| local_reward_scale: float, | |
| max_sft_steps: int, | |
| max_grpo_steps: int, | |
| complexity_curriculum: str, | |
| curriculum_ramp_ratio: float, | |
| *, | |
| sft_lr: float, | |
| sft_grad_accum: int, | |
| grpo_lr: float, | |
| grpo_grad_accum: int, | |
| grpo_beta: float, | |
| reward_ema_decay: float, | |
| ) -> None: | |
| try: | |
| from datasets import load_dataset | |
| from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer | |
| from unsloth import FastLanguageModel | |
| except Exception as exc: # pragma: no cover | |
| raise RuntimeError( | |
| "Missing training deps. Install unsloth, trl, datasets, transformers before running." | |
| ) from exc | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| def _trainable_lora_sum_abs(model) -> float: | |
| total = 0.0 | |
| for n, p in model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if "lora" not in n.lower(): | |
| continue | |
| total += float(p.detach().abs().sum().item()) | |
| return total | |
| policy, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=2048, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| policy = FastLanguageModel.get_peft_model( | |
| policy, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0.0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| ds = load_dataset("json", data_files=str(sft_jsonl), split="train") | |
| sft_cfg = SFTConfig( | |
| output_dir=str(out_dir / "sft"), | |
| max_steps=max_sft_steps, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=sft_grad_accum, | |
| learning_rate=sft_lr, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.06, | |
| max_grad_norm=1.0, | |
| adam_beta1=0.9, | |
| adam_beta2=0.95, | |
| logging_steps=5, | |
| save_steps=max(10, max_sft_steps), | |
| report_to=[], | |
| ) | |
| sft_trainer = SFTTrainer( | |
| model=policy, | |
| tokenizer=tokenizer, | |
| train_dataset=ds, | |
| args=sft_cfg, | |
| dataset_text_field="prompt", | |
| formatting_func=lambda ex: [f"{p}\n\n{c}" for p, c in zip(ex["prompt"], ex["completion"])], | |
| ) | |
| sft_before = _trainable_lora_sum_abs(policy) | |
| sft_trainer.train() | |
| sft_after = _trainable_lora_sum_abs(sft_trainer.model) | |
| sft_delta = abs(sft_after - sft_before) | |
| print(f"SFT LoRA delta(abs-sum): {sft_delta:.6f}") | |
| if sft_delta <= 1e-6: | |
| raise RuntimeError("SFT appears not to have updated LoRA weights (delta too small).") | |
| sft_dir = out_dir / "sft_adapter" | |
| sft_trainer.model.save_pretrained(sft_dir) | |
| tokenizer.save_pretrained(sft_dir) | |
| print(f"SFT complete. Adapter saved: {sft_dir}") | |
| def _extract_json(text: str) -> dict[str, Any] | None: | |
| m = re.search(r"\{.*\}", text, flags=re.S) | |
| if not m: | |
| return None | |
| try: | |
| obj = json.loads(m.group(0)) | |
| except Exception: | |
| return None | |
| return obj if isinstance(obj, dict) else None | |
| def _env_step_reward_from_completion(text: str) -> float: | |
| payload = _extract_json(text) | |
| if payload is None: | |
| return -0.25 | |
| payload.setdefault("action_type", "do_nothing") | |
| try: | |
| r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": "phase2_core"}, timeout=30) | |
| r.raise_for_status() | |
| s = requests.post( | |
| f"{env_url.rstrip('/')}/step", | |
| json={"action": payload}, | |
| timeout=30, | |
| ) | |
| s.raise_for_status() | |
| raw = s.json() | |
| except Exception: | |
| return 0.0 | |
| rew = raw.get("reward") | |
| if rew is None and isinstance(raw.get("observation"), dict): | |
| rew = raw["observation"].get("reward", 0.0) | |
| try: | |
| return float(rew) | |
| except Exception: | |
| return 0.0 | |
| progress = {"step": 0, "total": max(1, max_grpo_steps)} | |
| reward_ema_state = {"env": None} | |
| class _ProgressCallback(TrainerCallback): | |
| def on_step_end(self, args, state, control, **kwargs): # type: ignore[override] | |
| progress["step"] = int(getattr(state, "global_step", progress["step"])) | |
| return control | |
| def _progress_frac() -> float: | |
| return min(1.0, progress["step"] / progress["total"]) | |
| def _curriculum_phase_weight() -> float: | |
| frac = _progress_frac() | |
| ramp = max(0.05, min(1.0, curriculum_ramp_ratio)) | |
| if complexity_curriculum == "off": | |
| return 1.0 | |
| # easy_to_full: start with strong scaffold guidance, then smoothly | |
| # transition to full env-dominant optimization. | |
| if frac >= ramp: | |
| return 0.0 | |
| return max(0.0, 1.0 - (frac / ramp)) | |
| def _annealed_local_scale() -> float: | |
| frac = _progress_frac() | |
| base = local_reward_scale * (1.20 - 0.70 * frac) | |
| return base * (1.0 + 0.70 * _curriculum_phase_weight()) | |
| def _annealed_env_scale() -> float: | |
| w = _curriculum_phase_weight() | |
| # Slightly downweight env reward in early easy phase to reduce variance, | |
| # then recover to full strength by the end of ramp. | |
| return env_reward_scale * (1.0 - 0.30 * w) | |
| def env_reward(completions, **_): | |
| scale = _annealed_env_scale() | |
| raw = [scale * _env_step_reward_from_completion(str(c)) for c in completions] | |
| if reward_ema_decay <= 0.0: | |
| return raw | |
| batch_mean = sum(raw) / max(len(raw), 1) | |
| prev = reward_ema_state["env"] | |
| d = max(0.0, min(1.0, reward_ema_decay)) | |
| if prev is None: | |
| smoothed_mean = batch_mean | |
| else: | |
| smoothed_mean = (1.0 - d) * prev + d * batch_mean | |
| reward_ema_state["env"] = smoothed_mean | |
| delta = smoothed_mean - batch_mean | |
| return [r + delta for r in raw] | |
| def format_reward(completions, **_): | |
| scale = _annealed_local_scale() | |
| outs: list[float] = [] | |
| for c in completions: | |
| txt = str(c).strip() | |
| obj = _extract_json(txt) | |
| if obj is None: | |
| outs.append(-0.20 * scale) | |
| continue | |
| if obj.get("action_type") not in LEGAL_ACTION_TYPES: | |
| outs.append(-0.20 * scale) | |
| continue | |
| # Encourage concise, parseable schema-correct JSON. | |
| length_pen = -0.04 * scale if len(txt) > 500 else 0.0 | |
| outs.append(0.12 * scale + length_pen) | |
| return outs | |
| def semantic_action_reward(completions, prompts=None, **_): | |
| scale = _annealed_local_scale() | |
| outs: list[float] = [] | |
| for i, c in enumerate(completions): | |
| obj = _extract_json(str(c)) | |
| if obj is None: | |
| outs.append(-0.10 * scale) | |
| continue | |
| at = str(obj.get("action_type", "")) | |
| ptxt = str(prompts[i] if prompts and i < len(prompts) else "").lower() | |
| bonus = 0.0 | |
| if "critical" in ptxt and at == "reply_email": | |
| bonus += 0.08 | |
| if "clash" in ptxt and at in ("reschedule_meeting", "cancel_meeting"): | |
| bonus += 0.08 | |
| if ("overdue" in ptxt or "due soon" in ptxt) and at in ("complete_task", "delegate_task"): | |
| bonus += 0.08 | |
| outs.append(scale * bonus) | |
| return outs | |
| def anti_idle_reward(completions, **_): | |
| scale = _annealed_local_scale() | |
| outs = [] | |
| for c in completions: | |
| txt = str(c).lower() | |
| outs.append((-0.20 if "do_nothing" in txt else 0.02) * scale) | |
| return outs | |
| grpo_cfg = GRPOConfig( | |
| output_dir=str(out_dir / "grpo"), | |
| learning_rate=grpo_lr, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=grpo_grad_accum, | |
| max_steps=max_grpo_steps, | |
| logging_steps=5, | |
| num_generations=2, | |
| beta=grpo_beta, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.06, | |
| max_grad_norm=1.0, | |
| adam_beta1=0.9, | |
| adam_beta2=0.95, | |
| report_to=[], | |
| ) | |
| grpo_trainer = GRPOTrainer( | |
| model=sft_trainer.model, | |
| processing_class=tokenizer, | |
| reward_funcs=[env_reward, format_reward, semantic_action_reward, anti_idle_reward], | |
| train_dataset=ds, | |
| args=grpo_cfg, | |
| callbacks=[_ProgressCallback()], | |
| ) | |
| grpo_before = _trainable_lora_sum_abs(sft_trainer.model) | |
| grpo_trainer.train() | |
| progress["step"] = progress["total"] | |
| grpo_after = _trainable_lora_sum_abs(grpo_trainer.model) | |
| grpo_delta = abs(grpo_after - grpo_before) | |
| print(f"GRPO LoRA delta(abs-sum): {grpo_delta:.6f}") | |
| if grpo_delta <= 1e-6: | |
| raise RuntimeError("GRPO appears not to have updated LoRA weights (delta too small).") | |
| final_dir = out_dir / "grpo_adapter" | |
| grpo_trainer.model.save_pretrained(final_dir) | |
| tokenizer.save_pretrained(final_dir) | |
| print(f"GRPO complete. Adapter saved: {final_dir}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run SFT warmup before GRPO.") | |
| parser.add_argument( | |
| "--model-name", | |
| default="", | |
| help="Optional explicit model id. If omitted, --model-preset is used.", | |
| ) | |
| parser.add_argument( | |
| "--model-preset", | |
| choices=sorted(MODEL_PRESETS.keys()), | |
| default="small_iter_fast", | |
| help="Recommended compute-aware preset. small_iter_fast is best for iteration speed.", | |
| ) | |
| parser.add_argument( | |
| "--training-preset", | |
| choices=sorted(TRAINING_PRESETS.keys()), | |
| default="hackathon_turbo", | |
| help="Compute-aware run preset. hackathon_turbo is best default for iterative winning loops.", | |
| ) | |
| parser.add_argument("--env-url", default="http://127.0.0.1:8000") | |
| parser.add_argument("--sft-jsonl", type=Path, default=Path("outputs/sft_from_env.jsonl")) | |
| parser.add_argument("--out-dir", type=Path, default=Path("outputs/train_runs/sft_then_grpo")) | |
| parser.add_argument("--generate-sft-from-env", action="store_true") | |
| parser.add_argument("--sft-samples", type=int, default=120) | |
| parser.add_argument("--max-sft-steps", type=int, default=60) | |
| parser.add_argument("--max-grpo-steps", type=int, default=120) | |
| parser.add_argument("--env-reward-scale", type=float, default=1.0) | |
| parser.add_argument("--local-reward-scale", type=float, default=0.35) | |
| parser.add_argument( | |
| "--complexity-curriculum", | |
| choices=["off", "easy_to_full"], | |
| default="easy_to_full", | |
| help="Reward curriculum: easy_to_full starts with stronger local scaffold and anneals to env-dominant.", | |
| ) | |
| parser.add_argument( | |
| "--curriculum-ramp-ratio", | |
| type=float, | |
| default=0.60, | |
| help="Fraction of GRPO steps used to ramp from easy scaffold to full env weighting.", | |
| ) | |
| parser.add_argument( | |
| "--reward-ema-decay", | |
| type=float, | |
| default=-1.0, | |
| help="EMA decay in [0,1] for env reward smoothing; -1 uses training preset default.", | |
| ) | |
| args = parser.parse_args() | |
| model_name = args.model_name.strip() or MODEL_PRESETS[args.model_preset] | |
| p = TRAINING_PRESETS[args.training_preset] | |
| max_sft_steps = int(p["max_sft_steps"]) | |
| max_grpo_steps = int(p["max_grpo_steps"]) | |
| env_reward_scale = float(p["env_reward_scale"]) | |
| local_reward_scale = float(p["local_reward_scale"]) | |
| complexity_curriculum = str(p["complexity_curriculum"]) | |
| curriculum_ramp_ratio = float(p["curriculum_ramp_ratio"]) | |
| sft_samples = int(p["sft_samples"]) | |
| sft_lr = float(p["sft_lr"]) | |
| sft_grad_accum = int(p["sft_grad_accum"]) | |
| grpo_lr = float(p["grpo_lr"]) | |
| grpo_grad_accum = int(p["grpo_grad_accum"]) | |
| grpo_beta = float(p["grpo_beta"]) | |
| reward_ema_decay = float(p["reward_ema_decay"]) | |
| if args.max_sft_steps != 60: | |
| max_sft_steps = args.max_sft_steps | |
| if args.max_grpo_steps != 120: | |
| max_grpo_steps = args.max_grpo_steps | |
| if args.env_reward_scale != 1.0: | |
| env_reward_scale = args.env_reward_scale | |
| if args.local_reward_scale != 0.35: | |
| local_reward_scale = args.local_reward_scale | |
| if args.complexity_curriculum != "easy_to_full": | |
| complexity_curriculum = args.complexity_curriculum | |
| if args.curriculum_ramp_ratio != 0.60: | |
| curriculum_ramp_ratio = args.curriculum_ramp_ratio | |
| if args.sft_samples != 120: | |
| sft_samples = args.sft_samples | |
| if args.reward_ema_decay >= 0.0: | |
| reward_ema_decay = float(args.reward_ema_decay) | |
| print(f"Model preset: {args.model_preset} -> {model_name}") | |
| print( | |
| "Training preset:" | |
| f" {args.training_preset} -> sft={max_sft_steps}, grpo={max_grpo_steps}," | |
| f" env_scale={env_reward_scale}, local_scale={local_reward_scale}," | |
| f" curriculum={complexity_curriculum}, ramp={curriculum_ramp_ratio}" | |
| ) | |
| if args.generate_sft_from_env or not args.sft_jsonl.exists(): | |
| generate_sft_jsonl_from_env( | |
| env_url=args.env_url, | |
| out_jsonl=args.sft_jsonl, | |
| samples=sft_samples, | |
| task_id="phase2_core", | |
| ) | |
| run_sft_then_grpo( | |
| model_name=model_name, | |
| env_url=args.env_url, | |
| sft_jsonl=args.sft_jsonl, | |
| out_dir=args.out_dir, | |
| env_reward_scale=env_reward_scale, | |
| local_reward_scale=local_reward_scale, | |
| max_sft_steps=max_sft_steps, | |
| max_grpo_steps=max_grpo_steps, | |
| complexity_curriculum=complexity_curriculum, | |
| curriculum_ramp_ratio=curriculum_ramp_ratio, | |
| sft_lr=sft_lr, | |
| sft_grad_accum=sft_grad_accum, | |
| grpo_lr=grpo_lr, | |
| grpo_grad_accum=grpo_grad_accum, | |
| grpo_beta=grpo_beta, | |
| reward_ema_decay=reward_ema_decay, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |