|
""" |
|
Custom Ray Trainer that saves rollout data immediately after generation |
|
""" |
|
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer |
|
from verl.utils.debug import marked_timer |
|
import os |
|
import json |
|
|
|
|
|
class CustomCodeIORayPPOTrainer(CodeIORayPPOTrainer): |
|
""" |
|
Custom trainer that saves rollout data immediately after generation, |
|
not after actor update. |
|
""" |
|
|
|
def fit(self): |
|
""" |
|
Override fit method to save rollout data right after generation |
|
""" |
|
|
|
from absolute_zero_reasoner.trainer.ppo.reason_rl_ray_trainer import ( |
|
AdvantageEstimator, apply_kl_penalty, compute_advantage, reduce_metrics, |
|
compute_data_metrics, compute_timing_metrics, compute_response_mask, |
|
pad_dataproto_to_divisor, unpad_dataproto, DataProto, core_algos |
|
) |
|
from absolute_zero_reasoner.utils.tracking import ReasonRLTracking |
|
from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter as pp |
|
from omegaconf import OmegaConf |
|
from deepcopy import deepcopy |
|
import numpy as np |
|
import uuid |
|
import torch |
|
from datetime import datetime |
|
|
|
|
|
pp.section_header("Training Setup") |
|
logger = ReasonRLTracking( |
|
project_name=self.config.trainer.project_name, |
|
experiment_name=self.config.trainer.experiment_name, |
|
default_backend=self.config.trainer.logger, |
|
config=OmegaConf.to_container(self.config, resolve=True), |
|
tags=self.config.trainer.wandb_tags, |
|
resume="must" if self.config.trainer.resume_mode == 'auto' and \ |
|
self.config.trainer.wandb_run_id is not None else False, |
|
run_id=self.config.trainer.wandb_run_id \ |
|
if self.config.trainer.wandb_run_id is not None else None |
|
) |
|
|
|
pp.status("Config", f"Project: {self.config.trainer.project_name}, Experiment: {self.config.trainer.experiment_name}", "info") |
|
pp.status("Algorithm", f"Using {self.config.algorithm.adv_estimator} advantage estimator", "info") |
|
pp.status("Setup", f"Critic enabled: {self.use_critic}, Reference policy: {self.use_reference_policy}", "info") |
|
|
|
self.global_steps = 0 |
|
|
|
|
|
pp.status("Checkpoint", "Loading checkpoint if available...", "info") |
|
self._load_checkpoint() |
|
|
|
|
|
if self.config.actor_rollout_ref.model.pretrained_tokenizer: |
|
self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}" |
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0: |
|
pp.section_header("Initial Validation") |
|
pp.status("Validation", "Running initial validation...", "info") |
|
|
|
val_metrics = self._validate(do_sample=self.config.eval.do_sample) |
|
|
|
metrics_table = [] |
|
for k, v in val_metrics.items(): |
|
metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v]) |
|
pp.table(["Metric", "Value"], metrics_table, "Initial Validation Results") |
|
logger.log(data=val_metrics, step=self.global_steps) |
|
|
|
if self.config.trainer.get('val_only', False): |
|
pp.status("Training", "Validation only mode, exiting", "success") |
|
return |
|
|
|
|
|
self.global_steps += 1 |
|
last_val_metrics = None |
|
self.max_steps_duration = 0 |
|
|
|
pp.section_header("Starting Training") |
|
total_steps = len(self.train_dataloader) * self.config.trainer.total_epochs |
|
pp.status("Training", f"Starting training for {self.config.trainer.total_epochs} epochs ({total_steps} steps)", "info") |
|
|
|
for epoch in range(self.config.trainer.total_epochs): |
|
pp.status("Epoch", f"Starting epoch {epoch+1}/{self.config.trainer.total_epochs}", "info") |
|
for batch_idx, batch_dict in enumerate(self.train_dataloader): |
|
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False |
|
if do_profile: |
|
self.actor_rollout_wg.start_profile() |
|
if self.use_reference_policy: |
|
self.ref_policy_wg.start_profile() |
|
if self.use_critic: |
|
self.critic_wg.start_profile() |
|
if self.use_rm: |
|
self.rm_wg.start_profile() |
|
|
|
metrics = {} |
|
timing_raw = {} |
|
batch: DataProto = DataProto.from_single_dict(batch_dict) |
|
|
|
|
|
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] |
|
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] |
|
if "multi_modal_data" in batch.non_tensor_batch: |
|
non_tensor_batch_keys_to_pop.append("multi_modal_data") |
|
if "raw_prompt" in batch.non_tensor_batch: |
|
non_tensor_batch_keys_to_pop.append("raw_prompt") |
|
if "tools_kwargs" in batch.non_tensor_batch: |
|
non_tensor_batch_keys_to_pop.append("tools_kwargs") |
|
if "interaction_kwargs" in batch.non_tensor_batch: |
|
non_tensor_batch_keys_to_pop.append("interaction_kwargs") |
|
|
|
gen_batch = batch.pop( |
|
batch_keys=batch_keys_to_pop, |
|
non_tensor_batch_keys=non_tensor_batch_keys_to_pop, |
|
) |
|
is_last_step = self.global_steps >= self.total_training_steps |
|
|
|
with marked_timer("step", timing_raw): |
|
|
|
with marked_timer("gen", timing_raw, color="red"): |
|
pp.status("Step", f"Generating sequences for batch {batch_idx+1}", "info") |
|
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) |
|
|
|
|
|
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) |
|
if rollout_data_dir and gen_batch_output is not None: |
|
try: |
|
|
|
prompts = gen_batch.batch.get("prompts", []) |
|
responses = gen_batch_output.batch.get("responses", []) |
|
|
|
if len(prompts) > 0 and len(responses) > 0: |
|
|
|
input_texts = self.tokenizer.batch_decode(prompts, skip_special_tokens=True) |
|
output_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) |
|
|
|
|
|
scores = [0.0] * len(input_texts) |
|
|
|
|
|
os.makedirs(rollout_data_dir, exist_ok=True) |
|
filename = os.path.join(rollout_data_dir, f"{self.global_steps}_rollout.jsonl") |
|
|
|
with open(filename, "w") as f: |
|
for i in range(len(input_texts)): |
|
entry = { |
|
"step": self.global_steps, |
|
"input": input_texts[i], |
|
"output": output_texts[i], |
|
"score": scores[i], |
|
"saved_at": "after_rollout" |
|
} |
|
f.write(json.dumps(entry, ensure_ascii=False) + "\\n") |
|
|
|
print(f"β
Saved rollout data to {filename} (immediately after generation)") |
|
except Exception as e: |
|
print(f"β οΈ Failed to save rollout data: {e}") |
|
|
|
|
|
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: |
|
with marked_timer("gen_max", timing_raw, color="purple"): |
|
gen_baseline_batch = deepcopy(gen_batch) |
|
gen_baseline_batch.meta_info["do_sample"] = False |
|
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) |
|
batch = batch.union(gen_baseline_output) |
|
reward_baseline_tensor, _ = self.reward_fn(batch) |
|
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) |
|
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) |
|
batch.batch["reward_baselines"] = reward_baseline_tensor |
|
del gen_baseline_batch, gen_baseline_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.global_steps += 1 |
|
|
|
pp.status("Training", "Training completed successfully!", "success") |