neural-mesh / test /utils /custom_ray_trainer.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
Custom Ray Trainer that saves rollout data immediately after generation
"""
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
from verl.utils.debug import marked_timer
import os
import json
class CustomCodeIORayPPOTrainer(CodeIORayPPOTrainer):
"""
Custom trainer that saves rollout data immediately after generation,
not after actor update.
"""
def fit(self):
"""
Override fit method to save rollout data right after generation
"""
# Import here to avoid circular dependencies
from absolute_zero_reasoner.trainer.ppo.reason_rl_ray_trainer import (
AdvantageEstimator, apply_kl_penalty, compute_advantage, reduce_metrics,
compute_data_metrics, compute_timing_metrics, compute_response_mask,
pad_dataproto_to_divisor, unpad_dataproto, DataProto, core_algos
)
from absolute_zero_reasoner.utils.tracking import ReasonRLTracking
from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter as pp
from omegaconf import OmegaConf
from deepcopy import deepcopy
import numpy as np
import uuid
import torch
from datetime import datetime
# κΈ°λ³Έ μ΄ˆκΈ°ν™”λŠ” λΆ€λͺ¨ 클래슀의 fit λ©”μ„œλ“œ μ‹œμž‘ 뢀뢄을 따름
pp.section_header("Training Setup")
logger = ReasonRLTracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
tags=self.config.trainer.wandb_tags,
resume="must" if self.config.trainer.resume_mode == 'auto' and \
self.config.trainer.wandb_run_id is not None else False,
run_id=self.config.trainer.wandb_run_id \
if self.config.trainer.wandb_run_id is not None else None
)
pp.status("Config", f"Project: {self.config.trainer.project_name}, Experiment: {self.config.trainer.experiment_name}", "info")
pp.status("Algorithm", f"Using {self.config.algorithm.adv_estimator} advantage estimator", "info")
pp.status("Setup", f"Critic enabled: {self.use_critic}, Reference policy: {self.use_reference_policy}", "info")
self.global_steps = 0
# load checkpoint before doing anything
pp.status("Checkpoint", "Loading checkpoint if available...", "info")
self._load_checkpoint()
# base model chat template
if self.config.actor_rollout_ref.model.pretrained_tokenizer:
self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
# perform validation before training
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0:
pp.section_header("Initial Validation")
pp.status("Validation", "Running initial validation...", "info")
val_metrics = self._validate(do_sample=self.config.eval.do_sample)
# Convert metrics to table format
metrics_table = []
for k, v in val_metrics.items():
metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
pp.table(["Metric", "Value"], metrics_table, "Initial Validation Results")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
pp.status("Training", "Validation only mode, exiting", "success")
return
# we start from step 1
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
pp.section_header("Starting Training")
total_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
pp.status("Training", f"Starting training for {self.config.trainer.total_epochs} epochs ({total_steps} steps)", "info")
for epoch in range(self.config.trainer.total_epochs):
pp.status("Epoch", f"Starting epoch {epoch+1}/{self.config.trainer.total_epochs}", "info")
for batch_idx, batch_dict in enumerate(self.train_dataloader):
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
if do_profile:
self.actor_rollout_wg.start_profile()
if self.use_reference_policy:
self.ref_policy_wg.start_profile()
if self.use_critic:
self.critic_wg.start_profile()
if self.use_rm:
self.rm_wg.start_profile()
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
pp.status("Step", f"Generating sequences for batch {batch_idx+1}", "info")
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
# πŸ”₯ 핡심 μˆ˜μ •: Rollout 직후 λ°”λ‘œ μ €μž₯
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir and gen_batch_output is not None:
try:
# gen_batch_outputμ—μ„œ ν•„μš”ν•œ 데이터 μΆ”μΆœ
prompts = gen_batch.batch.get("prompts", [])
responses = gen_batch_output.batch.get("responses", [])
if len(prompts) > 0 and len(responses) > 0:
# ν…μŠ€νŠΈλ‘œ λ””μ½”λ“œ
input_texts = self.tokenizer.batch_decode(prompts, skip_special_tokens=True)
output_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
# μ μˆ˜λŠ” λ‚˜μ€‘μ— κ³„μ‚°λ˜λ―€λ‘œ 일단 0으둜 μ΄ˆκΈ°ν™”
scores = [0.0] * len(input_texts)
# μ €μž₯
os.makedirs(rollout_data_dir, exist_ok=True)
filename = os.path.join(rollout_data_dir, f"{self.global_steps}_rollout.jsonl")
with open(filename, "w") as f:
for i in range(len(input_texts)):
entry = {
"step": self.global_steps,
"input": input_texts[i],
"output": output_texts[i],
"score": scores[i],
"saved_at": "after_rollout"
}
f.write(json.dumps(entry, ensure_ascii=False) + "\\n")
print(f"βœ… Saved rollout data to {filename} (immediately after generation)")
except Exception as e:
print(f"⚠️ Failed to save rollout data: {e}")
# μ΄ν›„λŠ” μ›λž˜ μ½”λ“œ κ·ΈλŒ€λ‘œ...
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
# λ‚˜λ¨Έμ§€λŠ” λΆ€λͺ¨ 클래슀의 fit λ©”μ„œλ“œλ₯Ό 호좜
# μ‹€μ œλ‘œλŠ” μ—¬κΈ°μ„œ λΆ€λͺ¨ 클래슀의 λ‚˜λ¨Έμ§€ λ‘œμ§μ„ λͺ¨λ‘ κ΅¬ν˜„ν•΄μ•Ό ν•˜μ§€λ§Œ,
# κ°„λ‹¨νžˆ ν•˜κΈ° μœ„ν•΄ 핡심 λΆ€λΆ„λ§Œ κ΅¬ν˜„
# Continue with normal training flow...
# (이 뢀뢄은 μ‹€μ œ κ΅¬ν˜„ μ‹œ λΆ€λͺ¨ 클래슀의 전체 λ‘œμ§μ„ 포함해야 함)
self.global_steps += 1
pp.status("Training", "Training completed successfully!", "success")