HPCOpenenv / training /hpc_openenv_gemma.py
huggingmenfordays's picture
deploy: ccyloopss/HPCOpenenv — with OPENENV_API_KEY auth guard
bc35a94
from __future__ import annotations
# unsloth must be imported before trl / transformers / peft so its monkey-patches
# take effect. we attempt it here at module load time so the import order is
# always correct regardless of which backend is ultimately selected.
try:
import unsloth # noqa: F401
except ImportError:
pass
import argparse
import os
import random
import sys
import time
import warnings
from pathlib import Path
def _silence_noisy_warnings() -> None:
"""suppress benign hf / torch generation warnings so the kaggle log is readable.
each filter targets a specific message we have confirmed is either a
false positive (we already configure the thing the warning complains
about) or is an upstream deprecation we cannot act on from here.
- ``max_new_tokens`` vs ``max_length``: trl's internal generate call
inherits the base model's default ``max_length=32768`` but our
``max_new_tokens=384`` correctly takes precedence, as documented
- right-padding detected: our tokenizer is configured with
``padding_side='left'`` (see ``_load_model_and_tokenizer``); trl
also re-fixes padding per batch
- ``AttentionMaskConverter`` / ``attention_mask_utils`` deprecation:
transformers v5.10 internal migration, unrelated to our code
"""
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
warnings.filterwarnings("ignore", message=r".*max_new_tokens.*max_length.*")
warnings.filterwarnings("ignore", message=r".*right-padding was detected.*")
warnings.filterwarnings("ignore", message=r".*AttentionMaskConverter.*")
warnings.filterwarnings("ignore", message=r".*attention_mask_utils.*")
warnings.filterwarnings("ignore", category=FutureWarning, module=r"transformers(\..*)?")
try:
from transformers.utils import logging as _hf_logging # type: ignore
_hf_logging.set_verbosity_error()
except Exception: # noqa: BLE001
pass
_silence_noisy_warnings()
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"--env-urls",
nargs="+",
required=True,
help="one or more openenv sysadmin server base urls. hosted hf spaces work directly",
)
parser.add_argument(
"--env-api-key",
default=os.environ.get("OPENENV_API_KEY", ""),
help="bearer token required by the sysadmin-env server (set OPENENV_API_KEY env var or pass directly)",
)
parser.add_argument(
"--model",
default=os.environ.get("HPC_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct"),
help=(
"hf hub id. defaults to Qwen/Qwen2.5-Coder-7B-Instruct (the kaggle a100 profile). "
"use Qwen/Qwen2.5-Coder-3B-Instruct for t4 colab"
),
)
parser.add_argument("--output-dir", default="./runs/hpc_openenv_gemma")
parser.add_argument("--group-size", type=int, default=8)
# bumped from 16: scenarios like hpc_pid_stale / hpc_nfs_stale routinely
# take 10+ turns to even surface a useful observation, and a small
# instruct model spends several turns getting the format right. with
# the old 16 ceiling most rollouts truncated before the health signal
# moved. keep --max-turns a cli override.
parser.add_argument("--max-turns", type=int, default=24)
parser.add_argument("--max-seq-length", type=int, default=4096)
parser.add_argument("--num-train-steps", type=int, default=200)
parser.add_argument("--learning-rate", type=float, default=2e-5)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--max-new-tokens", type=int, default=384)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument(
"--scenarios",
default="hpc_outage,hpc_munge,hpc_pid_stale,hpc_gpu_ecc,hpc_nfs_stale,hpc_ood_apache",
)
parser.add_argument("--logging-steps", type=int, default=5)
parser.add_argument("--save-steps", type=int, default=50)
parser.add_argument("--report-to", default="tensorboard")
parser.add_argument("--wandb-project", default=os.environ.get("WANDB_PROJECT"))
parser.add_argument("--hub-repo", default=os.environ.get("HF_HUB_REPO"))
parser.add_argument(
"--dry-run",
action="store_true",
help="skip heavy deps and run a single random-policy rollout through the remote servers",
)
parser.add_argument(
"--backend",
choices=["unsloth", "transformers"],
default="unsloth",
help="model loader. unsloth (default) for colab/single gpu, transformers for vertex/hf jobs",
)
parser.add_argument(
"--curriculum",
action="store_true",
help=(
"enable curriculum sampling. early grpo steps only sample the "
"easiest scenario bucket (hpc_pid_stale, hpc_gpu_ecc, "
"hpc_ood_apache) and new buckets are introduced as training "
"progresses. addresses the judge guide section on avoiding "
"zero-reward starts"
),
)
parser.add_argument(
"--save-adapter-only",
action="store_true",
help=(
"save only the lora adapter weights and skip the risky "
"upcast-then-merge path. matches the unsloth qlora save warning "
"from section 16 of the judge guide"
),
)
return parser.parse_args()
def _resolve_scenarios(raw: str) -> list[str]:
names = [part.strip() for part in raw.split(",") if part.strip()]
if not names:
raise ValueError("at least one scenario id must be provided")
return names
def _random_policy(rng: random.Random):
pool = [
"sinfo",
"squeue",
"ssh compute-01",
"cat /etc/sysconfig/network-scripts/route-eth0",
"printf 'default via 10.0.0.1 dev eth0\\n10.0.0.0/24 dev eth0 proto kernel scope link src 10.0.0.11\\n' > /etc/sysconfig/network-scripts/route-eth0",
"systemctl restart slurmd",
"chmod 0400 /etc/munge/munge.key",
"systemctl restart munge",
"rm /var/run/slurmd.pid",
"nvidia-smi",
"nvidia-smi -r -i 0",
"umount -l /mnt/shared",
"mount /mnt/shared",
"apachectl configtest",
"apachectl graceful",
"exit",
"curl -I http://localhost:8080/",
"curl -I http://localhost:8081/",
]
def generate(batches):
return [f"<bash>{rng.choice(pool)}</bash>" for _ in batches]
return generate
def _env_factory(env_urls: list[str], scenarios: list[str], api_key: str | None = None):
from training.remote_env import HttpEnterpriseHPCEnv
from training.remote_env import RemoteEndpointPool
pool = RemoteEndpointPool(env_urls, api_key=api_key or None)
active_scenarios = list(scenarios)
def make_env():
return HttpEnterpriseHPCEnv(
env_urls=env_urls, scenario_pool=active_scenarios, pool=pool
)
def set_scenarios(new_scenarios: list[str]) -> None:
active_scenarios[:] = new_scenarios
return make_env, pool, set_scenarios
# curriculum buckets ordered from lowest to highest expected difficulty. the
# guide section 6 ("keep the task simple at first") and section 14
# ("curriculum") both argue for this so the policy sees non-zero reward
# quickly.
CURRICULUM_BUCKETS: list[list[str]] = [
["hpc_pid_stale", "hpc_gpu_ecc", "hpc_ood_apache"],
["hpc_nfs_stale"],
["hpc_outage", "hpc_munge"],
]
def _curriculum_scenarios(step: int, total_steps: int, full_pool: list[str]) -> list[str]:
if total_steps <= 0:
return full_pool
progress = min(1.0, step / max(1, total_steps))
# split training into three thirds; each unlocks the next bucket
if progress < 0.34:
unlocked = CURRICULUM_BUCKETS[0]
elif progress < 0.67:
unlocked = CURRICULUM_BUCKETS[0] + CURRICULUM_BUCKETS[1]
else:
unlocked = [s for bucket in CURRICULUM_BUCKETS for s in bucket]
filtered = [s for s in unlocked if s in full_pool]
return filtered or full_pool
def _dry_run(args: argparse.Namespace) -> int:
from training.logger import RewardLogger
from training.rollout import run_interactive_group
from training.rollout import summarize_group
scenarios = _resolve_scenarios(args.scenarios)
rng = random.Random(args.seed)
make_env, pool, _set_scenarios = _env_factory(args.env_urls, scenarios, api_key=args.env_api_key or None)
logger = RewardLogger(args.output_dir, run_name="dry_run", hub_repo=args.hub_repo, wandb_project=args.wandb_project)
try:
records = run_interactive_group(
group_size=args.group_size,
generate_fn=_random_policy(rng),
env_factory=make_env,
max_turns=args.max_turns,
seed_start=args.seed,
)
logger.log(step=0, records=records)
print(f"dry_run summary {summarize_group(records)}")
finally:
logger.close()
pool.close()
return 0
def _load_model_and_tokenizer(args: argparse.Namespace):
if args.backend == "unsloth":
try:
from unsloth import FastLanguageModel # type: ignore
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
dtype=None,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=args.lora_alpha,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=args.seed,
)
FastLanguageModel.for_inference(model)
tokenizer.padding_side = "left"
return model, tokenizer, "unsloth"
except Exception as _ue: # noqa: BLE001
# Unsloth raises RuntimeError/AssertionError on CUDA/version mismatch, not just ImportError
print(f"unsloth unavailable ({_ue.__class__.__name__}: {_ue}) — falling back to transformers backend", file=sys.stderr)
import torch # type: ignore
from peft import LoraConfig # type: ignore
from peft import get_peft_model # type: ignore
from transformers import AutoModelForCausalLM # type: ignore
from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, padding_side="left")
try:
from transformers import AutoModelForMultimodalLM # type: ignore
model = AutoModelForMultimodalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
device_map="auto",
)
except Exception:
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
device_map="auto",
)
_lora_kwargs: dict = dict(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
)
# multimodal models (eg Gemma4) wrap vision-encoder linears in non-standard
# classes (Gemma4ClippableLinear) that older PEFT can't inject into. Qwen2.5-Coder
# is text-only so this branch is a no-op for it, but we keep the guard so the
# script still works when pointed at a vision model like gemma-4-e4b-it.
_vision_substrings = ("vision_tower", "multi_modal_projector", "image_newline", "patch_embedding")
_has_vision = any(
any(s in name for s in _vision_substrings) for name, _ in model.named_modules()
)
if _has_vision:
import inspect as _inspect # noqa: PLC0415
if "exclude_modules" in _inspect.signature(LoraConfig.__init__).parameters:
_lora_kwargs["exclude_modules"] = list(_vision_substrings)
else:
# Older PEFT: filter target_modules to only nn.Linear instances,
# which excludes wrapped Gemma4ClippableLinear in the vision tower.
import torch.nn as _nn # noqa: PLC0415
_suffixes = set(_lora_kwargs["target_modules"])
_safe_targets: set[str] = set()
for _name, _mod in model.named_modules():
if type(_mod) is _nn.Linear:
for _sfx in _suffixes:
if _name.endswith(f".{_sfx}"):
_safe_targets.add(_sfx)
_lora_kwargs["target_modules"] = sorted(_safe_targets) or list(_suffixes)
lora = LoraConfig(**_lora_kwargs)
model = get_peft_model(model, lora)
return model, tokenizer, "transformers"
def _train(args: argparse.Namespace) -> int:
try:
from datasets import Dataset # type: ignore
from trl import GRPOConfig # type: ignore
from trl import GRPOTrainer # type: ignore
except ImportError as exc:
print(f"trl or datasets missing install them first {exc}", file=sys.stderr)
return 2
import torch # type: ignore
from training.agent_prompt import SYSTEM_PROMPT
from training.agent_prompt import USER_PROMPT
from training.logger import RewardLogger
from training.rollout import run_interactive_group
from training.rollout import summarize_group
scenarios = _resolve_scenarios(args.scenarios)
make_env, pool, set_scenarios = _env_factory(args.env_urls, scenarios, api_key=args.env_api_key or None)
print(f"train load model {args.model} backend {args.backend}")
model, tokenizer, backend = _load_model_and_tokenizer(args)
prompt_text = tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
],
tokenize=False,
add_generation_prompt=True,
)
dataset = Dataset.from_dict({"prompt": [prompt_text] * max(args.num_train_steps, 32)})
def generate_fn(batch_messages):
texts = [
tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
for m in batch_messages
]
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=args.max_seq_length,
).to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
logger = RewardLogger(
args.output_dir,
run_name="hpc_openenv_gemma",
hub_repo=args.hub_repo,
wandb_project=args.wandb_project,
)
step_counter = {"n": 0}
from training.reward_functions import make_reward_functions
def _runner(group_size: int, _seed: int | None, completions: list[str] | None = None):
if args.curriculum:
set_scenarios(
_curriculum_scenarios(
step_counter["n"], args.num_train_steps, scenarios
)
)
return run_interactive_group(
group_size=group_size,
generate_fn=generate_fn,
env_factory=make_env,
max_turns=args.max_turns,
seed_start=random.randrange(1 << 30),
initial_completions=completions,
)
def _on_rollout(records, wall_seconds):
step_counter["n"] += 1
summary = summarize_group(records)
logger.log(step=step_counter["n"], records=records)
print(
f"grpo group summary {summary} rollout_seconds {wall_seconds:.2f}"
)
reward_funcs, _cache = make_reward_functions(
runner=_runner,
max_turns=args.max_turns,
on_rollout=_on_rollout,
)
training_args = GRPOConfig(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=args.group_size,
max_prompt_length=args.max_seq_length // 2,
max_completion_length=args.max_new_tokens,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
max_steps=args.num_train_steps,
bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
fp16=(not torch.cuda.is_bf16_supported()) if torch.cuda.is_available() else False,
report_to=args.report_to,
seed=args.seed,
temperature=args.temperature,
top_p=args.top_p,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset,
)
try:
print(f"train start backend {backend} steps {args.num_train_steps} group {args.group_size}")
started = time.time()
trainer.train()
print(f"train done elapsed {time.time() - started:.1f}s")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
_save_trained_model(trainer, tokenizer, args)
finally:
logger.close()
pool.close()
return 0
def _save_trained_model(trainer, tokenizer, args: argparse.Namespace) -> None:
"""save the trained model. by default we only persist the lora adapter,
following the judge guide section 16 warning about upcasting a 4-bit
model to 16-bit and merging the adapter naively."""
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
try:
model = trainer.model
if args.save_adapter_only and hasattr(model, "save_pretrained"):
adapter_dir = out / "lora_adapter"
model.save_pretrained(str(adapter_dir))
tokenizer.save_pretrained(str(adapter_dir))
print(f"save adapter only wrote {adapter_dir}")
return
trainer.save_model(str(out))
tokenizer.save_pretrained(str(out))
print(f"save full model wrote {out}")
except Exception as exc: # noqa: BLE001
print(f"save failed {type(exc).__name__} {exc}")
def main() -> int:
args = _parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
if args.dry_run:
return _dry_run(args)
return _train(args)
if __name__ == "__main__":
raise SystemExit(main())