aws_rl_env / scripts /grpo_train.py
Sizzing's picture
Upload folder using huggingface_hub
e56d042 verified
"""End-to-end GRPO training on AWS RL Env, driven from Google Colab.
This script is pedagogical β€” it shows how the moving pieces connect:
[Central Curriculum] --picks task_id--> [G parallel rollouts via GrpoPool]
|
v
[Per-rollout trajectory of (prompt, action, reward)]
|
v
[Group-normalized advantages: A_i = (R_i - mean) / std]
|
v
[PPO-style policy-gradient loss on logprobs]
|
v
[Optimizer step]
Why this is "GRPO" and not vanilla REINFORCE:
GRPO (Group Relative Policy Optimization, DeepSeek) replaces the value
baseline with the **group mean** of rewards. For each task we sample G
trajectories; the advantage for rollout i is A_i = (R_i - mean(R)) / std(R).
This is variance-reduced and critic-free β€” perfect for our env where G=8.
Requirements (install in the Colab cell before running):
!pip install unsloth trl torch transformers accelerate bitsandbytes httpx websockets
Prerequisites:
The RL env server must be running somewhere Colab can reach, with
AWS_RL_ENV_POOL_SIZE=8 set. Easiest:
docker run -p 8000:8000 -e AWS_RL_ENV_POOL_SIZE=8 aws-rl-env:latest
And expose port 8000 via `cloudflared tunnel` or `ngrok http 8000`.
Set BASE_URL below to the public URL of that tunnel.
Run:
python scripts/grpo_train.py
"""
from __future__ import annotations
import asyncio
import logging
import math
import os
from dataclasses import dataclass
from typing import List, Tuple
import torch
import torch.nn.functional as F
from client import AwsRlEnv
from models import AwsRlAction, Task
from scripts.grpo_pool import GrpoPool
from server.services.curriculum import Curriculum
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Config β€” tune for your setup
# ---------------------------------------------------------------------------
BASE_URL = os.getenv("AWS_RL_ENV_BASE_URL", "http://localhost:8000")
GROUP_SIZE = int(os.getenv("GRPO_GROUP_SIZE", "8")) # G in GRPO
NUM_GRPO_STEPS = int(os.getenv("GRPO_NUM_STEPS", "100")) # outer training steps
MAX_EPISODE_STEPS = int(
os.getenv("GRPO_MAX_STEPS", "15")
) # per-rollout step cap (matches MAX_STEPS in env)
LEARNING_RATE = float(os.getenv("GRPO_LR", "5e-6"))
KL_COEFF = float(os.getenv("GRPO_KL", "0.04")) # KL penalty vs reference model
CLIP_EPS = float(os.getenv("GRPO_CLIP", "0.2")) # PPO clip for stability
TEMPERATURE = float(os.getenv("GRPO_TEMP", "0.9"))
MAX_NEW_TOKENS = int(os.getenv("GRPO_MAX_NEW", "96")) # per model generation
MODEL_NAME = os.getenv("GRPO_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
# ---------------------------------------------------------------------------
# Model loading (Unsloth β€” 4-bit LoRA, Colab-friendly)
# ---------------------------------------------------------------------------
def load_model_and_tokenizer():
"""Load a 4-bit LoRA-wrapped model via Unsloth.
Unsloth is a drop-in replacement for transformers that ~2x speeds up
fine-tuning on a single GPU and fits a 1.5B model on a free Colab T4.
"""
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=4096,
load_in_4bit=True,
)
# LoRA wrapping β€” only these params receive gradients
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.0,
bias="none",
use_gradient_checkpointing="unsloth",
)
FastLanguageModel.for_training(model)
return model, tokenizer
# ---------------------------------------------------------------------------
# Prompt construction & action extraction
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are an expert AWS SRE agent. You operate a simulated AWS cloud by \
emitting one AWS CLI command at a time. You will see the task description and the most \
recent command output, then reply with EXACTLY ONE AWS CLI command on a single line \
starting with 'aws '. No explanation, no markdown, no quotes β€” just the command."""
def build_prompt(tokenizer, task: Task, history: List[Tuple[str, str]]) -> str:
"""Build a chat prompt from the task + command/output history."""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"TASK: {task.description}"},
]
for cmd, out in history[-4:]: # keep last 4 turns to fit context
messages.append({"role": "assistant", "content": cmd})
messages.append({"role": "user", "content": f"OUTPUT:\n{out[:400]}"})
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def extract_command(raw: str) -> str:
"""Pull the first `aws …` line out of the model's raw decoded output."""
for line in raw.splitlines():
line = line.strip().strip("`").strip()
if line.startswith("aws "):
return line
return "aws help" # safe fallback so env always accepts the command
# ---------------------------------------------------------------------------
# Rollout β€” one trajectory, one env, one task
# ---------------------------------------------------------------------------
@dataclass
class Step:
"""One step of a trajectory. `prompt_ids` + `action_ids` are what we backprop on."""
prompt_ids: torch.Tensor # shape [prompt_len]
action_ids: torch.Tensor # shape [action_len]
logprob_sum: (
torch.Tensor
) # scalar β€” sum of model logprobs over action_ids at sample time
reward: float
@dataclass
class Trajectory:
steps: List[Step]
total_reward: float
@torch.no_grad()
def generate_action(
model,
tokenizer,
prompt: str,
device: torch.device,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample one command from the model; return (text, prompt_ids, action_ids, logprob_sum)."""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
prompt_ids = inputs["input_ids"][0]
out = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
)
full_ids = out.sequences[0]
action_ids = full_ids[prompt_ids.size(0) :]
# Gather per-token logprobs of the sampled tokens (at generation time)
if out.scores:
logits = torch.stack(out.scores, dim=0) # [T, 1, V]
logprobs = torch.log_softmax(logits, dim=-1)[:, 0, :]
token_lp = logprobs.gather(1, action_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum = token_lp.sum()
else:
logprob_sum = torch.tensor(0.0, device=device)
text = tokenizer.decode(action_ids, skip_special_tokens=True)
return text, prompt_ids, action_ids, logprob_sum
async def run_single_rollout(
env: AwsRlEnv,
task: Task,
model,
tokenizer,
device: torch.device,
) -> Trajectory:
"""Drive one env through up to MAX_EPISODE_STEPS, recording every step."""
result = await env.reset(task=task)
history: List[Tuple[str, str]] = []
steps: List[Step] = []
total_reward = 0.0
for _ in range(MAX_EPISODE_STEPS):
prompt = build_prompt(tokenizer, task, history)
raw, prompt_ids, action_ids, logprob_sum = generate_action(
model, tokenizer, prompt, device
)
command = extract_command(raw)
result = await env.step(AwsRlAction(command=command))
reward = float(result.reward)
total_reward += reward
steps.append(
Step(
prompt_ids=prompt_ids.cpu(),
action_ids=action_ids.cpu(),
logprob_sum=logprob_sum.detach().cpu(),
reward=reward,
)
)
history.append((command, result.observation.command_output or ""))
if result.done:
break
return Trajectory(steps=steps, total_reward=total_reward)
# ---------------------------------------------------------------------------
# GRPO loss β€” group-normalized advantages, PPO-style clipped ratio
# ---------------------------------------------------------------------------
def compute_group_advantages(rewards: List[float]) -> List[float]:
"""Core GRPO step: subtract the group mean and divide by group std.
A_i = (R_i - mean(R_1..G)) / (std(R_1..G) + eps)
This makes the "baseline" the group's own performance β€” no value network
needed. If all G rollouts tied, advantages are zero (no signal, correct).
"""
mean = sum(rewards) / len(rewards)
var = sum((r - mean) ** 2 for r in rewards) / len(rewards)
std = math.sqrt(var) + 1e-8
return [(r - mean) / std for r in rewards]
def logprob_under_current_model(
model, tokenizer, step: Step, device: torch.device
) -> torch.Tensor:
"""Re-score the sampled action under the CURRENT policy (for gradient).
At rollout time we recorded the old policy's logprob_sum. To get a
differentiable ratio we have to recompute it now with the current weights.
"""
full = torch.cat([step.prompt_ids, step.action_ids]).unsqueeze(0).to(device)
attn = torch.ones_like(full)
outputs = model(input_ids=full, attention_mask=attn)
logits = outputs.logits[0, :-1, :] # predict next token
targets = full[0, 1:]
prompt_len = step.prompt_ids.size(0)
# Only the action tokens contribute to the loss
action_logits = logits[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)]
action_targets = targets[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)]
logp = F.log_softmax(action_logits, dim=-1)
token_logp = logp.gather(1, action_targets.unsqueeze(-1)).squeeze(-1)
return token_logp.sum()
def grpo_loss(
model,
tokenizer,
trajectories: List[Trajectory],
device: torch.device,
) -> torch.Tensor:
"""GRPO objective: maximize clipped advantage-weighted logprob ratio.
loss = -mean_i [ min(ratio_i * A_i, clip(ratio_i, 1-eps, 1+eps) * A_i) ]
"""
rewards = [t.total_reward for t in trajectories]
advantages = compute_group_advantages(rewards)
losses: List[torch.Tensor] = []
for traj, adv in zip(trajectories, advantages):
if not traj.steps:
continue
adv_t = torch.tensor(adv, device=device, dtype=torch.float32)
for step in traj.steps:
new_logp = logprob_under_current_model(model, tokenizer, step, device)
old_logp = step.logprob_sum.to(device)
ratio = torch.exp(new_logp - old_logp)
unclipped = ratio * adv_t
clipped = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * adv_t
losses.append(-torch.min(unclipped, clipped))
if not losses:
return torch.tensor(0.0, device=device, requires_grad=True)
return torch.stack(losses).mean()
# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------
async def train() -> None:
model, tokenizer = load_model_and_tokenizer()
device = next(model.parameters()).device
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=LEARNING_RATE,
)
curriculum = Curriculum()
async with GrpoPool(
base_url=BASE_URL, size=GROUP_SIZE, curriculum=curriculum
) as pool:
logger.info("Connected pool of %d envs against %s", GROUP_SIZE, BASE_URL)
for step_idx in range(NUM_GRPO_STEPS):
# 1) central curriculum picks ONE task for the whole group
task = curriculum.next_task()
logger.info(
"[step %d/%d] task_id=%d tier=%s",
step_idx + 1,
NUM_GRPO_STEPS,
task.task_id,
task.difficulty.value,
)
# 2) launch G parallel rollouts, all on the same task_id
rollout_coros = [
run_single_rollout(e, task, model, tokenizer, device) for e in pool.envs
]
trajectories = await asyncio.gather(*rollout_coros)
rewards = [t.total_reward for t in trajectories]
logger.info(
" rewards: min=%.3f mean=%.3f max=%.3f",
min(rewards),
sum(rewards) / len(rewards),
max(rewards),
)
# 3) GRPO loss + update
model.train()
loss = grpo_loss(model, tokenizer, trajectories, device)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
logger.info(" loss=%.4f", loss.item())
# 4) feed result back to curriculum (one record per group, not per rollout)
pool.record_group_result(task, rewards)
# Save LoRA adapter
output_dir = os.getenv("GRPO_OUTPUT_DIR", "./grpo_lora_out")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info("Saved LoRA adapter to %s", output_dir)
if __name__ == "__main__":
asyncio.run(train())