SimMart / train.py
Viani's picture
HF Space: 4-dept SimMart env + 1.5B SFT+GRPO training (hackathon submission)
5c35138
"""SimMart RL training — lightweight GRPO-style loop.
Trains a Qwen 2.5 Instruct CEO against the SimMart environment using
group-normalised REINFORCE (the "GRPO" special case from DeepSeek).
Design notes (tuned for a 2-day hackathon timeline + MI250X + Unsloth):
• Rollout batch B env instances with B distinct seeds → same policy
parameters at each step, different trajectories. Variance reduction
comes from the group baseline rather than multiple completions for the
same prompt (which would require state-checkpointing).
• Each training step:
1. 13-week rollout across B parallel envs
2. Per-week advantage = (reward - group_mean) / (group_std + eps)
3. Policy-gradient loss = -E[A_t * log π(a_t | s_t)]
4. KL penalty vs. frozen reference policy (optional; β from --kl)
5. Adam step on LoRA adapters only (Unsloth 4-bit base)
• Log reward mean/max and parse-error rate to stdout (+ W&B optional).
Usage (inside edaamd/unsloth-vllm container):
python train.py --model Qwen/Qwen2.5-1.5B-Instruct --steps 30 --batch 4 \\
--lr 1e-5 --max-new-tokens 768 --kl 0.02 \\
--out /mnt/dcgpuval/hkandala/simmart-runs/smoke-1p5b
Hero run (7B overnight):
python train.py --model Qwen/Qwen2.5-7B-Instruct --steps 120 --batch 6 \\
--lr 5e-6 --kl 0.02 \\
--out /mnt/dcgpuval/hkandala/simmart-runs/hero-7b
"""
from __future__ import annotations
import argparse
import json
import os
import random
import statistics
import sys
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
HERE = os.path.dirname(os.path.abspath(__file__))
if HERE not in sys.path:
sys.path.insert(0, HERE)
# Unsloth must be imported BEFORE transformers / peft for its patches
from unsloth import FastLanguageModel # noqa: E402
import torch # noqa: E402
import torch.distributed as dist # noqa: E402
import torch.nn.functional as F # noqa: E402
from models import ProposalDecision, SimMartAction # noqa: E402
from prompts import ( # noqa: E402
SYSTEM_PROMPT, build_chat, parse_response, render_observation,
build_action_chat, build_journal_chat, parse_journal_response,
)
from server.environment import SimMartEnvironment # noqa: E402
# ---------------------------------------------------------------------------
# Distributed setup
# ---------------------------------------------------------------------------
def init_distributed() -> Tuple[int, int, int]:
"""Initialise torch.distributed if LOCAL_RANK is set (accelerate launch).
Returns (rank, local_rank, world_size). Falls back to single-process.
"""
if "LOCAL_RANK" not in os.environ:
return 0, 0, 1
local_rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, local_rank, world_size
def all_reduce_mean_gradients(model, world_size: int) -> None:
"""Sum grads across ranks and divide by world_size (manual DDP step).
Cheaper than wrapping Unsloth's 4-bit model in DDP which has known quirks.
"""
if world_size == 1:
return
for p in model.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
p.grad.div_(world_size)
def all_reduce_scalar(value: float, world_size: int, device) -> float:
if world_size == 1:
return value
t = torch.tensor([value], device=device, dtype=torch.float32)
dist.all_reduce(t, op=dist.ReduceOp.SUM)
return (t.item() / world_size)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass
class TrainConfig:
model: str = "Qwen/Qwen2.5-1.5B-Instruct"
steps: int = 30
batch: int = 4 # parallel envs per rank per training step
max_seq_len: int = 4096
max_new_tokens: int = 768
lr: float = 1e-5
lr_min: float = 0.0 # if > 0 and < lr, cosine-anneal lr -> lr_min over `steps`
beta_kl: float = 0.02 # KL penalty vs. reference policy
entropy_coef: float = 0.01
clip_grad: float = 1.0
seed_offset: int = 0
out_dir: str = "./simmart-run"
log_every: int = 1
save_every: int = 10
wandb_project: Optional[str] = None
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.0
load_in_4bit: bool = True
dtype: str = "bfloat16" # "float16" or "bfloat16"
init_adapter: Optional[str] = None # load LoRA weights from an SFT checkpoint (action head)
mb_size: int = 8 # minibatch size for the PG pass
# ----- Dual-head (two-pass) config ----------------------------------
# If ``journal_adapter`` is set, the trainer loads it as a SECOND LoRA
# named "journal", freezes its params, and uses a two-pass rollout:
# action (trainable) + journal (frozen). The RL reward attribution is
# then cleanly scoped to action tokens only.
journal_adapter: Optional[str] = None
action_max_tokens: int = 300 # action-head output budget (JSON only)
journal_max_tokens: int = 400 # journal-head output budget (free text)
# ----- Rollout sampling + reward shaping ----------------------------
rollout_temperature: float = 0.9
rollout_top_p: float = 0.95
fmt_penalty: float = 0.0 # subtracted from env reward when parse fails
# Mixed-temperature rollouts: when both > 0 and rollout_temp_low !=
# rollout_temp_high, the action batch is split into halves and the two
# halves are sampled at the two temperatures respectively. Otherwise the
# full batch is sampled at ``rollout_temperature``. Sentinel <=0 = unset.
rollout_temp_low: float = -1.0
rollout_temp_high: float = -1.0
# DDP state (set at runtime, not from CLI)
rank: int = 0
local_rank: int = 0
world_size: int = 1
@dataclass
class StepLog:
step: int
mean_reward: float
max_reward: float
min_reward: float
reward_std: float
mean_episode_return: float
parse_error_rate: float
rogue_recall: float
loss: float
pg_loss: float
kl_loss: float
entropy: float
elapsed_s: float
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def log_kv(step: int, kvs: Dict[str, float]) -> None:
parts = [f"[step {step:03d}]"]
for k, v in kvs.items():
if isinstance(v, float):
parts.append(f"{k}={v:+.4f}")
else:
parts.append(f"{k}={v}")
print(" ".join(parts), flush=True)
# ---------------------------------------------------------------------------
# Model setup
# ---------------------------------------------------------------------------
def load_policy(cfg: TrainConfig):
dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16
device_map = {"": cfg.local_rank} if cfg.world_size > 1 else "auto"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=cfg.model,
max_seq_length=cfg.max_seq_len,
dtype=dtype,
load_in_4bit=cfg.load_in_4bit,
device_map=device_map,
)
model = FastLanguageModel.get_peft_model(
model,
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
bias="none",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
random_state=42,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Optionally warm-start the action head from an SFT checkpoint.
if cfg.init_adapter:
if cfg.rank == 0:
print(f"[init] loading SFT adapter weights (action) from {cfg.init_adapter}")
from peft import set_peft_model_state_dict, load_peft_weights
state_dict = load_peft_weights(cfg.init_adapter)
# set_peft_model_state_dict handles key-name mapping across peft versions
set_peft_model_state_dict(model, state_dict)
# Optionally load a FROZEN journal adapter (dual-head architecture).
if cfg.journal_adapter:
if cfg.rank == 0:
print(f"[init] loading FROZEN journal adapter from {cfg.journal_adapter}")
model.load_adapter(cfg.journal_adapter, adapter_name="journal")
# Freeze all journal-adapter parameters. PEFT names LoRA params like
# ``...lora_A.journal.weight`` / ``...lora_B.journal.weight``, so
# filtering on ``.journal.`` isolates them from the action/default
# adapter.
n_frozen = 0
for name, param in model.named_parameters():
if ".journal." in name:
param.requires_grad = False
n_frozen += 1
if cfg.rank == 0:
print(f"[init] froze {n_frozen} journal-adapter params")
# Make the action (default) adapter active for training.
model.set_adapter("default")
return model, tokenizer
# ---------------------------------------------------------------------------
# Rollout
# ---------------------------------------------------------------------------
def _apply_chat(tokenizer, chat) -> str:
"""Render the chat in the model's template with the assistant header open."""
return tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True,
)
def batched_generate(
model, tokenizer, prompts: List[str], max_new_tokens: int,
temperature: float = 0.9, top_p: float = 0.95,
) -> Tuple[List[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate B completions, return:
completions_text : List[str], len B
input_ids : (B, L_in_max) padded-left
completion_ids : (B, L_gen_max) padded-right
input_mask : (B, L_in_max) 1 where real prompt token
completion_mask : (B, L_gen_max) 1 where real generated token
"""
tokenizer.padding_side = "left"
enc = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True,
max_length=model.config.max_position_embeddings,
).to(model.device)
# Unsloth-accelerated inference path: keeps LoRA active but switches to
# faster attention/cache kernels. We flip back to train mode afterwards.
FastLanguageModel.for_inference(model)
try:
with torch.inference_mode():
out = model.generate(
**enc,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
finally:
FastLanguageModel.for_training(model)
input_len = enc.input_ids.size(1)
completion_ids = out[:, input_len:]
completion_mask = (completion_ids != tokenizer.pad_token_id).long()
completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return completions_text, enc.input_ids, completion_ids, enc.attention_mask, completion_mask
def compute_completion_logprobs(
model, input_ids: torch.Tensor, completion_ids: torch.Tensor,
input_mask: torch.Tensor, completion_mask: torch.Tensor,
) -> torch.Tensor:
"""Return per-token log-probs for the completion tokens (B, L_gen).
Computed with teacher-forcing on prompt+completion.
"""
full_ids = torch.cat([input_ids, completion_ids], dim=1)
full_mask = torch.cat([input_mask, completion_mask], dim=1)
out = model(input_ids=full_ids, attention_mask=full_mask)
logits = out.logits[:, :-1, :] # align with next-token
targets = full_ids[:, 1:]
logp = F.log_softmax(logits.float(), dim=-1)
tok_logp = logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
L_in = input_ids.size(1)
completion_logp = tok_logp[:, L_in - 1:] # logp for completion tokens
comp_mask = completion_mask.float()
completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask
return completion_logp, comp_mask
def compute_entropy_and_logp(
model, input_ids: torch.Tensor, completion_ids: torch.Tensor,
input_mask: torch.Tensor, completion_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return (completion_logp, entropy, comp_mask).
completion_logp: per-token log π(a_t) -- (B, L_gen)
entropy : per-token H(π) -- (B, L_gen)
"""
full_ids = torch.cat([input_ids, completion_ids], dim=1)
full_mask = torch.cat([input_mask, completion_mask], dim=1)
out = model(input_ids=full_ids, attention_mask=full_mask)
logits = out.logits[:, :-1, :]
targets = full_ids[:, 1:]
logp_full = F.log_softmax(logits.float(), dim=-1)
p_full = logp_full.exp()
ent_full = -(p_full * logp_full).sum(-1)
tok_logp = logp_full.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
L_in = input_ids.size(1)
completion_logp = tok_logp[:, L_in - 1:]
entropy = ent_full[:, L_in - 1:]
comp_mask = completion_mask.float()
completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask
entropy = entropy[:, :comp_mask.size(1)] * comp_mask
return completion_logp, entropy, comp_mask
@dataclass
class RolloutStep:
week: int
env_idx: int
prompt_text: str
completion_text: str
prompt_ids: torch.Tensor # (L_in,)
completion_ids: torch.Tensor # (L_gen,)
prompt_mask: torch.Tensor
completion_mask: torch.Tensor
reward: float
parse_ok: bool
was_rogue_week: bool = False
caught_rogue: bool = False
def rollout_batch(
model, tokenizer, cfg: TrainConfig, step_idx: int,
) -> List[RolloutStep]:
"""Run `cfg.batch` parallel 13-week episodes on this rank.
Dual-head mode (``cfg.journal_adapter`` set):
Week t ──> action pass (default adapter, trainable) → JSON
journal pass (journal adapter, frozen) → text
env.step(combined_action)
The RolloutStep stores action-pass tensors only so policy-gradient
flows exclusively through action tokens (clean credit assignment).
Legacy mode (``cfg.journal_adapter`` = None):
Single ``build_chat`` generation — identical to v5 behaviour.
Seeds spread across ranks so different GPUs explore different env
trajectories: global_seed = seed_offset + step * world_size * batch
+ rank * batch + i.
"""
dual_head = cfg.journal_adapter is not None
envs: List[SimMartEnvironment] = []
obss = []
for i in range(cfg.batch):
env = SimMartEnvironment()
seed = (cfg.seed_offset
+ step_idx * cfg.world_size * cfg.batch
+ cfg.rank * cfg.batch + i)
obs = env.reset(seed=seed, episode_id=f"train-r{cfg.rank}-{step_idx}-{i}")
envs.append(env)
obss.append(obs)
rollout: List[RolloutStep] = []
for week in range(1, SimMartEnvironment.MAX_WEEKS + 1):
# ---- Pass 1: ACTION (trainable "default" adapter) --------------------
if dual_head:
model.set_adapter("default")
prompts_text = [
_apply_chat(tokenizer, build_action_chat(obs)) for obs in obss
]
action_budget = cfg.action_max_tokens
else:
prompts_text = [_apply_chat(tokenizer, build_chat(obs)) for obs in obss]
action_budget = cfg.max_new_tokens
mixed_temp = (
cfg.rollout_temp_low > 0
and cfg.rollout_temp_high > 0
and cfg.rollout_temp_low != cfg.rollout_temp_high
and len(prompts_text) >= 2
)
if mixed_temp:
mid = len(prompts_text) // 2
txt_lo, in_lo, comp_lo, im_lo, cm_lo = batched_generate(
model, tokenizer, prompts_text[:mid], max_new_tokens=action_budget,
temperature=cfg.rollout_temp_low, top_p=cfg.rollout_top_p,
)
txt_hi, in_hi, comp_hi, im_hi, cm_hi = batched_generate(
model, tokenizer, prompts_text[mid:], max_new_tokens=action_budget,
temperature=cfg.rollout_temp_high, top_p=cfg.rollout_top_p,
)
completions_text = list(txt_lo) + list(txt_hi)
input_ids_list = (
[in_lo[i].detach().cpu() for i in range(in_lo.size(0))]
+ [in_hi[i].detach().cpu() for i in range(in_hi.size(0))]
)
completion_ids_list = (
[comp_lo[i].detach().cpu() for i in range(comp_lo.size(0))]
+ [comp_hi[i].detach().cpu() for i in range(comp_hi.size(0))]
)
input_mask_list = (
[im_lo[i].detach().cpu() for i in range(im_lo.size(0))]
+ [im_hi[i].detach().cpu() for i in range(im_hi.size(0))]
)
completion_mask_list = (
[cm_lo[i].detach().cpu() for i in range(cm_lo.size(0))]
+ [cm_hi[i].detach().cpu() for i in range(cm_hi.size(0))]
)
else:
completions_text_t, input_ids, completion_ids, input_mask, completion_mask = \
batched_generate(
model, tokenizer, prompts_text, max_new_tokens=action_budget,
temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p,
)
completions_text = list(completions_text_t)
input_ids_list = [input_ids[i].detach().cpu() for i in range(input_ids.size(0))]
completion_ids_list = [completion_ids[i].detach().cpu() for i in range(completion_ids.size(0))]
input_mask_list = [input_mask[i].detach().cpu() for i in range(input_mask.size(0))]
completion_mask_list = [completion_mask[i].detach().cpu() for i in range(completion_mask.size(0))]
# Parse actions — in dual-head mode these carry decisions + budget but
# an empty journal; in single-pass mode they carry everything.
parsed: List[Tuple[Any, Dict[str, Any]]] = []
for obs, comp_text in zip(obss, completions_text):
parsed.append(parse_response(comp_text, obs.inbox))
# ---- Pass 2: JOURNAL (frozen "journal" adapter) ----------------------
if dual_head:
model.set_adapter("journal")
journal_prompts = [
_apply_chat(
tokenizer,
build_journal_chat(
obs, action.decisions, action.budget_allocations,
),
)
for obs, (action, _) in zip(obss, parsed)
]
with torch.inference_mode():
jc_text, _, _, _, _ = batched_generate(
model, tokenizer, journal_prompts,
max_new_tokens=cfg.journal_max_tokens,
temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p,
)
for (action, _), jt in zip(parsed, jc_text):
action.journal_entry = parse_journal_response(jt)
# Switch back to the trainable adapter so the logprob recompute
# in train_step uses the same adapter configuration as rollout.
model.set_adapter("default")
# ---- Env step + RolloutStep record (action tokens only for grad) ----
for i, (env, obs, (action, tel), comp_text) in enumerate(
zip(envs, obss, parsed, completions_text),
):
rogue_ids_this_week: set = set()
for r in env.state.rogue_incidents:
if week in r.active_weeks:
rogue_ids_this_week.update(r.associated_proposal_ids)
was_rogue_week = len(rogue_ids_this_week) > 0
step_obs = env.step(action)
caught = any(
d.verdict == "flag_suspicious" and d.proposal_id in rogue_ids_this_week
for d in action.decisions
)
env_reward = float(step_obs.reward or 0.0)
parse_ok = tel["parse_ok"] or tel.get("parse_partial", False)
shaped_reward = env_reward - cfg.fmt_penalty * (0.0 if parse_ok else 1.0)
rollout.append(RolloutStep(
week=week,
env_idx=i,
prompt_text=prompts_text[i],
completion_text=comp_text,
prompt_ids=input_ids_list[i],
completion_ids=completion_ids_list[i],
prompt_mask=input_mask_list[i],
completion_mask=completion_mask_list[i],
reward=shaped_reward,
parse_ok=parse_ok,
was_rogue_week=was_rogue_week,
caught_rogue=caught,
))
obss[i] = step_obs
# Clean up large GPU tensors from the action pass
if mixed_temp:
del in_lo, in_hi, comp_lo, comp_hi, im_lo, im_hi, cm_lo, cm_hi
else:
del input_ids, completion_ids, input_mask, completion_mask
torch.cuda.empty_cache()
return rollout
# ---------------------------------------------------------------------------
# Advantage + loss
# ---------------------------------------------------------------------------
def compute_advantages(rollout: List[RolloutStep]) -> List[float]:
"""Group-normalise rewards: per-week (r - mean) / (std + eps).
Matches the DeepSeek GRPO trick: the 'group' is the B parallel rollouts
at the same week index, so advantage is well-conditioned even as absolute
reward drifts over training.
"""
by_week: Dict[int, List[float]] = {}
for r in rollout:
by_week.setdefault(r.week, []).append(r.reward)
week_stats = {
w: (statistics.mean(xs), statistics.stdev(xs) if len(xs) > 1 else 1e-6)
for w, xs in by_week.items()
}
eps = 1e-6
return [(r.reward - week_stats[r.week][0]) / (week_stats[r.week][1] + eps) for r in rollout]
# ---------------------------------------------------------------------------
# Training step
# ---------------------------------------------------------------------------
def train_step(
model, tokenizer, ref_model, optimizer, cfg: TrainConfig, step_idx: int,
) -> StepLog:
t0 = time.time()
rollout = rollout_batch(model, tokenizer, cfg, step_idx)
rewards = [r.reward for r in rollout]
parse_ok_count = sum(1 for r in rollout if r.parse_ok)
n_rogue_weeks = sum(1 for r in rollout if r.was_rogue_week)
n_caught = sum(1 for r in rollout if r.caught_rogue)
advantages = compute_advantages(rollout)
# -------------------------------------------------------------------
# Minibatch PG update (one pass over the rollout)
# -------------------------------------------------------------------
total_loss = 0.0
total_pg = 0.0
total_kl = 0.0
total_ent = 0.0
seen = 0
mb_size = cfg.mb_size
model.train()
for mb_start in range(0, len(rollout), mb_size):
mb = rollout[mb_start:mb_start + mb_size]
advs_mb = advantages[mb_start:mb_start + mb_size]
# Re-tokenise-and-pad the minibatch fresh (avoids mixed lengths from rollout)
input_ids = torch.nn.utils.rnn.pad_sequence(
[r.prompt_ids for r in mb],
batch_first=True, padding_value=tokenizer.pad_token_id,
).to(model.device)
completion_ids = torch.nn.utils.rnn.pad_sequence(
[r.completion_ids for r in mb],
batch_first=True, padding_value=tokenizer.pad_token_id,
).to(model.device)
input_mask = torch.nn.utils.rnn.pad_sequence(
[r.prompt_mask for r in mb],
batch_first=True, padding_value=0,
).to(model.device)
completion_mask = torch.nn.utils.rnn.pad_sequence(
[r.completion_mask for r in mb],
batch_first=True, padding_value=0,
).to(model.device)
# Current policy log-probs + entropy
completion_logp, entropy, mask = compute_entropy_and_logp(
model, input_ids, completion_ids, input_mask, completion_mask,
)
# Reference policy log-probs (LoRA-disabled base model) for KL.
# We share weights with the policy and toggle the adapter off so
# there's no second copy in GPU memory.
if ref_model is not None and cfg.beta_kl > 0:
with torch.inference_mode(), ref_model.disable_adapter():
ref_logp, _ = compute_completion_logprobs(
ref_model, input_ids, completion_ids, input_mask, completion_mask,
)
kl_per_tok = (completion_logp - ref_logp.detach()) * mask
else:
kl_per_tok = torch.zeros_like(completion_logp)
# Per-sample log-prob sum (masked avg)
denom = mask.sum(dim=1).clamp_min(1.0)
logp_per_sample = (completion_logp * mask).sum(dim=1) / denom
entropy_per_sample = (entropy * mask).sum(dim=1) / denom
kl_per_sample = kl_per_tok.sum(dim=1) / denom
adv_t = torch.tensor(advs_mb, device=model.device, dtype=logp_per_sample.dtype)
pg = -(adv_t * logp_per_sample).mean()
ent_term = -cfg.entropy_coef * entropy_per_sample.mean()
kl_term = cfg.beta_kl * kl_per_sample.mean()
loss = pg + kl_term + ent_term
optimizer.zero_grad()
loss.backward()
# Manual DDP: sum grads across ranks, average. Cheaper than wrapping
# Unsloth's 4-bit model in torch.nn.parallel.DistributedDataParallel.
all_reduce_mean_gradients(model, cfg.world_size)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad)
optimizer.step()
total_loss += float(loss.item()) * len(mb)
total_pg += float(pg.item()) * len(mb)
total_kl += float(kl_term.item()) * len(mb)
total_ent += float(entropy_per_sample.mean().item()) * len(mb)
seen += len(mb)
del input_ids, completion_ids, input_mask, completion_mask
del completion_logp, entropy, mask, kl_per_tok
torch.cuda.empty_cache()
# -------------------------------------------------------------------
# Aggregate metrics for logging
# -------------------------------------------------------------------
episode_returns: Dict[int, float] = {}
for r in rollout:
episode_returns[r.env_idx] = episode_returns.get(r.env_idx, 0.0) + r.reward
elapsed = time.time() - t0
return StepLog(
step=step_idx,
mean_reward=statistics.mean(rewards),
max_reward=max(rewards),
min_reward=min(rewards),
reward_std=statistics.stdev(rewards) if len(rewards) > 1 else 0.0,
mean_episode_return=statistics.mean(list(episode_returns.values())),
parse_error_rate=1.0 - (parse_ok_count / max(1, len(rollout))),
rogue_recall=(n_caught / n_rogue_weeks) if n_rogue_weeks > 0 else 0.0,
loss=total_loss / max(1, seen),
pg_loss=total_pg / max(1, seen),
kl_loss=total_kl / max(1, seen),
entropy=total_ent / max(1, seen),
elapsed_s=elapsed,
)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def _reduce_log(log: StepLog, world_size: int, device) -> StepLog:
"""All-reduce metric averages across ranks so rank-0 logs reflect the
full batch, not just one GPU's slice."""
if world_size == 1:
return log
fields = [
"mean_reward", "max_reward", "min_reward", "reward_std",
"mean_episode_return", "parse_error_rate", "rogue_recall",
"loss", "pg_loss", "kl_loss", "entropy",
]
values = torch.tensor(
[getattr(log, f) for f in fields], device=device, dtype=torch.float32,
)
dist.all_reduce(values, op=dist.ReduceOp.SUM)
values = values / world_size
# max and min should stay max/min, not averaged
vmax = torch.tensor([log.max_reward], device=device, dtype=torch.float32)
vmin = torch.tensor([log.min_reward], device=device, dtype=torch.float32)
dist.all_reduce(vmax, op=dist.ReduceOp.MAX)
dist.all_reduce(vmin, op=dist.ReduceOp.MIN)
kv = dict(zip(fields, values.tolist()))
kv["max_reward"] = vmax.item()
kv["min_reward"] = vmin.item()
kv["step"] = log.step
kv["elapsed_s"] = log.elapsed_s
return StepLog(**kv)
def main() -> int:
p = argparse.ArgumentParser()
p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct")
p.add_argument("--steps", type=int, default=30)
p.add_argument("--batch", type=int, default=4)
p.add_argument("--lr", type=float, default=1e-5)
p.add_argument("--lr-min", type=float, default=0.0,
help="If > 0, cosine-anneal lr -> lr_min over --steps (default: flat lr)")
p.add_argument("--max-seq-len", type=int, default=4096)
p.add_argument("--max-new-tokens", type=int, default=768)
p.add_argument("--kl", type=float, default=0.02, dest="beta_kl")
p.add_argument("--entropy", type=float, default=0.01, dest="entropy_coef")
p.add_argument("--seed-offset", type=int, default=0)
p.add_argument("--out", default="./simmart-run", dest="out_dir")
p.add_argument("--save-every", type=int, default=10)
p.add_argument("--wandb", type=str, default=None)
p.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16")
p.add_argument("--init-adapter", default=None,
help="Load LoRA weights from this SFT checkpoint at startup (action head)")
p.add_argument("--journal-adapter", default=None,
help="Optional: load a frozen journal LoRA as second adapter. "
"Triggers dual-head (two-pass) rollout mode.")
p.add_argument("--action-max-tokens", type=int, default=300,
help="Output budget for the action head JSON (dual-head mode)")
p.add_argument("--journal-max-tokens", type=int, default=400,
help="Output budget for the journal head text (dual-head mode)")
p.add_argument("--mb-size", type=int, default=8,
help="Minibatch size for the PG pass (was 2; bigger = fewer forward passes)")
p.add_argument("--rollout-temperature", type=float, default=0.9, dest="rollout_temperature",
help="Sampling temperature for training rollouts (default 0.9; lower shrinks sampled-vs-greedy gap)")
p.add_argument("--rollout-top-p", type=float, default=0.95, dest="rollout_top_p",
help="Nucleus sampling top_p for training rollouts (default 0.95)")
p.add_argument("--fmt-penalty", type=float, default=0.0, dest="fmt_penalty",
help="Magnitude (>=0) subtracted from env reward when parse fails (default 0)")
p.add_argument("--rollout-temp-low", type=float, default=-1.0, dest="rollout_temp_low",
help="Low temperature for mixed-temp rollouts (set both -low and -high to enable; default off)")
p.add_argument("--rollout-temp-high", type=float, default=-1.0, dest="rollout_temp_high",
help="High temperature for mixed-temp rollouts (default off; falls back to --rollout-temperature)")
args = p.parse_args()
rank, local_rank, world_size = init_distributed()
is_main = rank == 0
cfg = TrainConfig(
model=args.model, steps=args.steps, batch=args.batch,
max_seq_len=args.max_seq_len, max_new_tokens=args.max_new_tokens,
lr=args.lr, lr_min=args.lr_min, beta_kl=args.beta_kl, entropy_coef=args.entropy_coef,
seed_offset=args.seed_offset, out_dir=args.out_dir,
save_every=args.save_every, wandb_project=args.wandb, dtype=args.dtype,
init_adapter=args.init_adapter,
journal_adapter=args.journal_adapter,
action_max_tokens=args.action_max_tokens,
journal_max_tokens=args.journal_max_tokens,
mb_size=args.mb_size,
rollout_temperature=args.rollout_temperature,
rollout_top_p=args.rollout_top_p,
fmt_penalty=args.fmt_penalty,
rollout_temp_low=args.rollout_temp_low,
rollout_temp_high=args.rollout_temp_high,
rank=rank, local_rank=local_rank, world_size=world_size,
)
out_dir = Path(cfg.out_dir)
if is_main:
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "config.json", "w") as f:
json.dump(asdict(cfg), f, indent=2)
wandb = None
if cfg.wandb_project and is_main:
try:
import wandb as _wb
wandb = _wb
wandb.init(project=cfg.wandb_project, config=asdict(cfg))
except Exception as e:
print(f"[warn] wandb disabled: {e}")
if is_main:
print(f"Loading policy: {cfg.model} (4-bit LoRA r={cfg.lora_r})")
if cfg.init_adapter:
print(f"Init adapter: {cfg.init_adapter}")
print(f"DDP: rank={rank}/{world_size} local_rank={local_rank} "
f"per_rank_batch={cfg.batch} global_batch={cfg.batch * world_size}")
model, tokenizer = load_policy(cfg)
if is_main:
print(f"Device: {model.device} dtype: {cfg.dtype}")
ref_model = model
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.0,
)
scheduler = None
if 0 < cfg.lr_min < cfg.lr:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.steps, eta_min=cfg.lr_min,
)
if is_main:
print(f"[lr] cosine decay: {cfg.lr:.2e} -> {cfg.lr_min:.2e} over {cfg.steps} steps")
# Barrier so all ranks have loaded weights before step 1 starts.
if world_size > 1:
dist.barrier()
history: List[StepLog] = []
for step in range(1, cfg.steps + 1):
log = train_step(model, tokenizer, ref_model, optimizer, cfg, step)
reduced = _reduce_log(log, world_size, model.device)
history.append(reduced)
if is_main:
log_kv(step, {
"mean_r": reduced.mean_reward,
"ep_ret": reduced.mean_episode_return,
"r_std": reduced.reward_std,
"max_r": reduced.max_reward,
"parse_err": reduced.parse_error_rate,
"rogue_rec": reduced.rogue_recall,
"loss": reduced.loss,
"pg": reduced.pg_loss,
"kl": reduced.kl_loss,
"ent": reduced.entropy,
"sec/step": reduced.elapsed_s,
"lr": optimizer.param_groups[0]["lr"],
})
if wandb:
wandb.log({f"train/{k}": v for k, v in asdict(reduced).items()}, step=step)
with open(out_dir / "history.jsonl", "a") as f:
f.write(json.dumps(asdict(reduced)) + "\n")
if (step % cfg.save_every == 0 or step == cfg.steps):
if is_main:
ckpt = out_dir / f"adapter-step-{step:03d}"
model.save_pretrained(str(ckpt))
tokenizer.save_pretrained(str(ckpt))
print(f"[ckpt] saved {ckpt}")
if world_size > 1:
dist.barrier()
if scheduler is not None:
scheduler.step()
if is_main:
print("Training complete.")
if wandb:
wandb.finish()
if world_size > 1:
dist.destroy_process_group()
return 0
if __name__ == "__main__":
raise SystemExit(main())