Spaces:
Paused
Paused
| from __future__ import annotations | |
| # unsloth must be imported before trl / transformers / peft so its monkey-patches | |
| # take effect. we attempt it here at module load time so the import order is | |
| # always correct regardless of which backend is ultimately selected. | |
| try: | |
| import unsloth # noqa: F401 | |
| except ImportError: | |
| pass | |
| import argparse | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| def _silence_noisy_warnings() -> None: | |
| """suppress benign hf / torch generation warnings so the kaggle log is readable. | |
| each filter targets a specific message we have confirmed is either a | |
| false positive (we already configure the thing the warning complains | |
| about) or is an upstream deprecation we cannot act on from here. | |
| - ``max_new_tokens`` vs ``max_length``: trl's internal generate call | |
| inherits the base model's default ``max_length=32768`` but our | |
| ``max_new_tokens=384`` correctly takes precedence, as documented | |
| - right-padding detected: our tokenizer is configured with | |
| ``padding_side='left'`` (see ``_load_model_and_tokenizer``); trl | |
| also re-fixes padding per batch | |
| - ``AttentionMaskConverter`` / ``attention_mask_utils`` deprecation: | |
| transformers v5.10 internal migration, unrelated to our code | |
| """ | |
| os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") | |
| warnings.filterwarnings("ignore", message=r".*max_new_tokens.*max_length.*") | |
| warnings.filterwarnings("ignore", message=r".*right-padding was detected.*") | |
| warnings.filterwarnings("ignore", message=r".*AttentionMaskConverter.*") | |
| warnings.filterwarnings("ignore", message=r".*attention_mask_utils.*") | |
| warnings.filterwarnings("ignore", category=FutureWarning, module=r"transformers(\..*)?") | |
| try: | |
| from transformers.utils import logging as _hf_logging # type: ignore | |
| _hf_logging.set_verbosity_error() | |
| except Exception: # noqa: BLE001 | |
| pass | |
| _silence_noisy_warnings() | |
| def _parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter) | |
| parser.add_argument( | |
| "--env-urls", | |
| nargs="+", | |
| required=True, | |
| help="one or more openenv sysadmin server base urls. hosted hf spaces work directly", | |
| ) | |
| parser.add_argument( | |
| "--env-api-key", | |
| default=os.environ.get("OPENENV_API_KEY", ""), | |
| help="bearer token required by the sysadmin-env server (set OPENENV_API_KEY env var or pass directly)", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=os.environ.get("HPC_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct"), | |
| help=( | |
| "hf hub id. defaults to Qwen/Qwen2.5-Coder-7B-Instruct (the kaggle a100 profile). " | |
| "use Qwen/Qwen2.5-Coder-3B-Instruct for t4 colab" | |
| ), | |
| ) | |
| parser.add_argument("--output-dir", default="./runs/hpc_openenv_gemma") | |
| parser.add_argument("--group-size", type=int, default=8) | |
| # bumped from 16: scenarios like hpc_pid_stale / hpc_nfs_stale routinely | |
| # take 10+ turns to even surface a useful observation, and a small | |
| # instruct model spends several turns getting the format right. with | |
| # the old 16 ceiling most rollouts truncated before the health signal | |
| # moved. keep --max-turns a cli override. | |
| parser.add_argument("--max-turns", type=int, default=24) | |
| parser.add_argument("--max-seq-length", type=int, default=4096) | |
| parser.add_argument("--num-train-steps", type=int, default=200) | |
| parser.add_argument("--learning-rate", type=float, default=2e-5) | |
| parser.add_argument("--lora-r", type=int, default=16) | |
| parser.add_argument("--lora-alpha", type=int, default=32) | |
| parser.add_argument("--temperature", type=float, default=1.0) | |
| parser.add_argument("--top-p", type=float, default=0.95) | |
| parser.add_argument("--max-new-tokens", type=int, default=384) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument( | |
| "--scenarios", | |
| default="hpc_outage,hpc_munge,hpc_pid_stale,hpc_gpu_ecc,hpc_nfs_stale,hpc_ood_apache", | |
| ) | |
| parser.add_argument("--logging-steps", type=int, default=5) | |
| parser.add_argument("--save-steps", type=int, default=50) | |
| parser.add_argument("--report-to", default="tensorboard") | |
| parser.add_argument("--wandb-project", default=os.environ.get("WANDB_PROJECT")) | |
| parser.add_argument("--hub-repo", default=os.environ.get("HF_HUB_REPO")) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="skip heavy deps and run a single random-policy rollout through the remote servers", | |
| ) | |
| parser.add_argument( | |
| "--backend", | |
| choices=["unsloth", "transformers"], | |
| default="unsloth", | |
| help="model loader. unsloth (default) for colab/single gpu, transformers for vertex/hf jobs", | |
| ) | |
| parser.add_argument( | |
| "--curriculum", | |
| action="store_true", | |
| help=( | |
| "enable curriculum sampling. early grpo steps only sample the " | |
| "easiest scenario bucket (hpc_pid_stale, hpc_gpu_ecc, " | |
| "hpc_ood_apache) and new buckets are introduced as training " | |
| "progresses. addresses the judge guide section on avoiding " | |
| "zero-reward starts" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--save-adapter-only", | |
| action="store_true", | |
| help=( | |
| "save only the lora adapter weights and skip the risky " | |
| "upcast-then-merge path. matches the unsloth qlora save warning " | |
| "from section 16 of the judge guide" | |
| ), | |
| ) | |
| return parser.parse_args() | |
| def _resolve_scenarios(raw: str) -> list[str]: | |
| names = [part.strip() for part in raw.split(",") if part.strip()] | |
| if not names: | |
| raise ValueError("at least one scenario id must be provided") | |
| return names | |
| def _random_policy(rng: random.Random): | |
| pool = [ | |
| "sinfo", | |
| "squeue", | |
| "ssh compute-01", | |
| "cat /etc/sysconfig/network-scripts/route-eth0", | |
| "printf 'default via 10.0.0.1 dev eth0\\n10.0.0.0/24 dev eth0 proto kernel scope link src 10.0.0.11\\n' > /etc/sysconfig/network-scripts/route-eth0", | |
| "systemctl restart slurmd", | |
| "chmod 0400 /etc/munge/munge.key", | |
| "systemctl restart munge", | |
| "rm /var/run/slurmd.pid", | |
| "nvidia-smi", | |
| "nvidia-smi -r -i 0", | |
| "umount -l /mnt/shared", | |
| "mount /mnt/shared", | |
| "apachectl configtest", | |
| "apachectl graceful", | |
| "exit", | |
| "curl -I http://localhost:8080/", | |
| "curl -I http://localhost:8081/", | |
| ] | |
| def generate(batches): | |
| return [f"<bash>{rng.choice(pool)}</bash>" for _ in batches] | |
| return generate | |
| def _env_factory(env_urls: list[str], scenarios: list[str], api_key: str | None = None): | |
| from training.remote_env import HttpEnterpriseHPCEnv | |
| from training.remote_env import RemoteEndpointPool | |
| pool = RemoteEndpointPool(env_urls, api_key=api_key or None) | |
| active_scenarios = list(scenarios) | |
| def make_env(): | |
| return HttpEnterpriseHPCEnv( | |
| env_urls=env_urls, scenario_pool=active_scenarios, pool=pool | |
| ) | |
| def set_scenarios(new_scenarios: list[str]) -> None: | |
| active_scenarios[:] = new_scenarios | |
| return make_env, pool, set_scenarios | |
| # curriculum buckets ordered from lowest to highest expected difficulty. the | |
| # guide section 6 ("keep the task simple at first") and section 14 | |
| # ("curriculum") both argue for this so the policy sees non-zero reward | |
| # quickly. | |
| CURRICULUM_BUCKETS: list[list[str]] = [ | |
| ["hpc_pid_stale", "hpc_gpu_ecc", "hpc_ood_apache"], | |
| ["hpc_nfs_stale"], | |
| ["hpc_outage", "hpc_munge"], | |
| ] | |
| def _curriculum_scenarios(step: int, total_steps: int, full_pool: list[str]) -> list[str]: | |
| if total_steps <= 0: | |
| return full_pool | |
| progress = min(1.0, step / max(1, total_steps)) | |
| # split training into three thirds; each unlocks the next bucket | |
| if progress < 0.34: | |
| unlocked = CURRICULUM_BUCKETS[0] | |
| elif progress < 0.67: | |
| unlocked = CURRICULUM_BUCKETS[0] + CURRICULUM_BUCKETS[1] | |
| else: | |
| unlocked = [s for bucket in CURRICULUM_BUCKETS for s in bucket] | |
| filtered = [s for s in unlocked if s in full_pool] | |
| return filtered or full_pool | |
| def _dry_run(args: argparse.Namespace) -> int: | |
| from training.logger import RewardLogger | |
| from training.rollout import run_interactive_group | |
| from training.rollout import summarize_group | |
| scenarios = _resolve_scenarios(args.scenarios) | |
| rng = random.Random(args.seed) | |
| make_env, pool, _set_scenarios = _env_factory(args.env_urls, scenarios, api_key=args.env_api_key or None) | |
| logger = RewardLogger(args.output_dir, run_name="dry_run", hub_repo=args.hub_repo, wandb_project=args.wandb_project) | |
| try: | |
| records = run_interactive_group( | |
| group_size=args.group_size, | |
| generate_fn=_random_policy(rng), | |
| env_factory=make_env, | |
| max_turns=args.max_turns, | |
| seed_start=args.seed, | |
| ) | |
| logger.log(step=0, records=records) | |
| print(f"dry_run summary {summarize_group(records)}") | |
| finally: | |
| logger.close() | |
| pool.close() | |
| return 0 | |
| def _load_model_and_tokenizer(args: argparse.Namespace): | |
| if args.backend == "unsloth": | |
| try: | |
| from unsloth import FastLanguageModel # type: ignore | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model, | |
| max_seq_length=args.max_seq_length, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=args.lora_r, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=args.seed, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer, "unsloth" | |
| except Exception as _ue: # noqa: BLE001 | |
| # Unsloth raises RuntimeError/AssertionError on CUDA/version mismatch, not just ImportError | |
| print(f"unsloth unavailable ({_ue.__class__.__name__}: {_ue}) — falling back to transformers backend", file=sys.stderr) | |
| import torch # type: ignore | |
| from peft import LoraConfig # type: ignore | |
| from peft import get_peft_model # type: ignore | |
| from transformers import AutoModelForCausalLM # type: ignore | |
| from transformers import AutoTokenizer # type: ignore | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, padding_side="left") | |
| try: | |
| from transformers import AutoModelForMultimodalLM # type: ignore | |
| model = AutoModelForMultimodalLM.from_pretrained( | |
| args.model, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, | |
| device_map="auto", | |
| ) | |
| except Exception: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, | |
| device_map="auto", | |
| ) | |
| _lora_kwargs: dict = dict( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.0, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| # multimodal models (eg Gemma4) wrap vision-encoder linears in non-standard | |
| # classes (Gemma4ClippableLinear) that older PEFT can't inject into. Qwen2.5-Coder | |
| # is text-only so this branch is a no-op for it, but we keep the guard so the | |
| # script still works when pointed at a vision model like gemma-4-e4b-it. | |
| _vision_substrings = ("vision_tower", "multi_modal_projector", "image_newline", "patch_embedding") | |
| _has_vision = any( | |
| any(s in name for s in _vision_substrings) for name, _ in model.named_modules() | |
| ) | |
| if _has_vision: | |
| import inspect as _inspect # noqa: PLC0415 | |
| if "exclude_modules" in _inspect.signature(LoraConfig.__init__).parameters: | |
| _lora_kwargs["exclude_modules"] = list(_vision_substrings) | |
| else: | |
| # Older PEFT: filter target_modules to only nn.Linear instances, | |
| # which excludes wrapped Gemma4ClippableLinear in the vision tower. | |
| import torch.nn as _nn # noqa: PLC0415 | |
| _suffixes = set(_lora_kwargs["target_modules"]) | |
| _safe_targets: set[str] = set() | |
| for _name, _mod in model.named_modules(): | |
| if type(_mod) is _nn.Linear: | |
| for _sfx in _suffixes: | |
| if _name.endswith(f".{_sfx}"): | |
| _safe_targets.add(_sfx) | |
| _lora_kwargs["target_modules"] = sorted(_safe_targets) or list(_suffixes) | |
| lora = LoraConfig(**_lora_kwargs) | |
| model = get_peft_model(model, lora) | |
| return model, tokenizer, "transformers" | |
| def _train(args: argparse.Namespace) -> int: | |
| try: | |
| from datasets import Dataset # type: ignore | |
| from trl import GRPOConfig # type: ignore | |
| from trl import GRPOTrainer # type: ignore | |
| except ImportError as exc: | |
| print(f"trl or datasets missing install them first {exc}", file=sys.stderr) | |
| return 2 | |
| import torch # type: ignore | |
| from training.agent_prompt import SYSTEM_PROMPT | |
| from training.agent_prompt import USER_PROMPT | |
| from training.logger import RewardLogger | |
| from training.rollout import run_interactive_group | |
| from training.rollout import summarize_group | |
| scenarios = _resolve_scenarios(args.scenarios) | |
| make_env, pool, set_scenarios = _env_factory(args.env_urls, scenarios, api_key=args.env_api_key or None) | |
| print(f"train load model {args.model} backend {args.backend}") | |
| model, tokenizer, backend = _load_model_and_tokenizer(args) | |
| prompt_text = tokenizer.apply_chat_template( | |
| [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": USER_PROMPT}, | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| dataset = Dataset.from_dict({"prompt": [prompt_text] * max(args.num_train_steps, 32)}) | |
| def generate_fn(batch_messages): | |
| texts = [ | |
| tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
| for m in batch_messages | |
| ] | |
| inputs = tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=args.max_seq_length, | |
| ).to(model.device) | |
| with torch.inference_mode(): | |
| out = model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| max_new_tokens=args.max_new_tokens, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| new_tokens = out[:, inputs["input_ids"].shape[1]:] | |
| return tokenizer.batch_decode(new_tokens, skip_special_tokens=True) | |
| logger = RewardLogger( | |
| args.output_dir, | |
| run_name="hpc_openenv_gemma", | |
| hub_repo=args.hub_repo, | |
| wandb_project=args.wandb_project, | |
| ) | |
| step_counter = {"n": 0} | |
| from training.reward_functions import make_reward_functions | |
| def _runner(group_size: int, _seed: int | None, completions: list[str] | None = None): | |
| if args.curriculum: | |
| set_scenarios( | |
| _curriculum_scenarios( | |
| step_counter["n"], args.num_train_steps, scenarios | |
| ) | |
| ) | |
| return run_interactive_group( | |
| group_size=group_size, | |
| generate_fn=generate_fn, | |
| env_factory=make_env, | |
| max_turns=args.max_turns, | |
| seed_start=random.randrange(1 << 30), | |
| initial_completions=completions, | |
| ) | |
| def _on_rollout(records, wall_seconds): | |
| step_counter["n"] += 1 | |
| summary = summarize_group(records) | |
| logger.log(step=step_counter["n"], records=records) | |
| print( | |
| f"grpo group summary {summary} rollout_seconds {wall_seconds:.2f}" | |
| ) | |
| reward_funcs, _cache = make_reward_functions( | |
| runner=_runner, | |
| max_turns=args.max_turns, | |
| on_rollout=_on_rollout, | |
| ) | |
| training_args = GRPOConfig( | |
| output_dir=args.output_dir, | |
| learning_rate=args.learning_rate, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=1, | |
| num_generations=args.group_size, | |
| max_prompt_length=args.max_seq_length // 2, | |
| max_completion_length=args.max_new_tokens, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| max_steps=args.num_train_steps, | |
| bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, | |
| fp16=(not torch.cuda.is_bf16_supported()) if torch.cuda.is_available() else False, | |
| report_to=args.report_to, | |
| seed=args.seed, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=reward_funcs, | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| try: | |
| print(f"train start backend {backend} steps {args.num_train_steps} group {args.group_size}") | |
| started = time.time() | |
| trainer.train() | |
| print(f"train done elapsed {time.time() - started:.1f}s") | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| _save_trained_model(trainer, tokenizer, args) | |
| finally: | |
| logger.close() | |
| pool.close() | |
| return 0 | |
| def _save_trained_model(trainer, tokenizer, args: argparse.Namespace) -> None: | |
| """save the trained model. by default we only persist the lora adapter, | |
| following the judge guide section 16 warning about upcasting a 4-bit | |
| model to 16-bit and merging the adapter naively.""" | |
| out = Path(args.output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| try: | |
| model = trainer.model | |
| if args.save_adapter_only and hasattr(model, "save_pretrained"): | |
| adapter_dir = out / "lora_adapter" | |
| model.save_pretrained(str(adapter_dir)) | |
| tokenizer.save_pretrained(str(adapter_dir)) | |
| print(f"save adapter only wrote {adapter_dir}") | |
| return | |
| trainer.save_model(str(out)) | |
| tokenizer.save_pretrained(str(out)) | |
| print(f"save full model wrote {out}") | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"save failed {type(exc).__name__} {exc}") | |
| def main() -> int: | |
| args = _parse_args() | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| if args.dry_run: | |
| return _dry_run(args) | |
| return _train(args) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |