File size: 10,469 Bytes
f50dc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
Custom Ray Trainer that saves rollout data immediately after generation
"""
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
from verl.utils.debug import marked_timer
import os
import json
class CustomCodeIORayPPOTrainer(CodeIORayPPOTrainer):
"""
Custom trainer that saves rollout data immediately after generation,
not after actor update.
"""
def fit(self):
"""
Override fit method to save rollout data right after generation
"""
# Import here to avoid circular dependencies
from absolute_zero_reasoner.trainer.ppo.reason_rl_ray_trainer import (
AdvantageEstimator, apply_kl_penalty, compute_advantage, reduce_metrics,
compute_data_metrics, compute_timing_metrics, compute_response_mask,
pad_dataproto_to_divisor, unpad_dataproto, DataProto, core_algos
)
from absolute_zero_reasoner.utils.tracking import ReasonRLTracking
from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter as pp
from omegaconf import OmegaConf
from deepcopy import deepcopy
import numpy as np
import uuid
import torch
from datetime import datetime
# ๊ธฐ๋ณธ ์ด๊ธฐํ๋ ๋ถ๋ชจ ํด๋์ค์ fit ๋ฉ์๋ ์์ ๋ถ๋ถ์ ๋ฐ๋ฆ
pp.section_header("Training Setup")
logger = ReasonRLTracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
tags=self.config.trainer.wandb_tags,
resume="must" if self.config.trainer.resume_mode == 'auto' and \
self.config.trainer.wandb_run_id is not None else False,
run_id=self.config.trainer.wandb_run_id \
if self.config.trainer.wandb_run_id is not None else None
)
pp.status("Config", f"Project: {self.config.trainer.project_name}, Experiment: {self.config.trainer.experiment_name}", "info")
pp.status("Algorithm", f"Using {self.config.algorithm.adv_estimator} advantage estimator", "info")
pp.status("Setup", f"Critic enabled: {self.use_critic}, Reference policy: {self.use_reference_policy}", "info")
self.global_steps = 0
# load checkpoint before doing anything
pp.status("Checkpoint", "Loading checkpoint if available...", "info")
self._load_checkpoint()
# base model chat template
if self.config.actor_rollout_ref.model.pretrained_tokenizer:
self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
# perform validation before training
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0:
pp.section_header("Initial Validation")
pp.status("Validation", "Running initial validation...", "info")
val_metrics = self._validate(do_sample=self.config.eval.do_sample)
# Convert metrics to table format
metrics_table = []
for k, v in val_metrics.items():
metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
pp.table(["Metric", "Value"], metrics_table, "Initial Validation Results")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
pp.status("Training", "Validation only mode, exiting", "success")
return
# we start from step 1
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
pp.section_header("Starting Training")
total_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
pp.status("Training", f"Starting training for {self.config.trainer.total_epochs} epochs ({total_steps} steps)", "info")
for epoch in range(self.config.trainer.total_epochs):
pp.status("Epoch", f"Starting epoch {epoch+1}/{self.config.trainer.total_epochs}", "info")
for batch_idx, batch_dict in enumerate(self.train_dataloader):
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
if do_profile:
self.actor_rollout_wg.start_profile()
if self.use_reference_policy:
self.ref_policy_wg.start_profile()
if self.use_critic:
self.critic_wg.start_profile()
if self.use_rm:
self.rm_wg.start_profile()
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
pp.status("Step", f"Generating sequences for batch {batch_idx+1}", "info")
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
# ๐ฅ ํต์ฌ ์์ : Rollout ์งํ ๋ฐ๋ก ์ ์ฅ
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir and gen_batch_output is not None:
try:
# gen_batch_output์์ ํ์ํ ๋ฐ์ดํฐ ์ถ์ถ
prompts = gen_batch.batch.get("prompts", [])
responses = gen_batch_output.batch.get("responses", [])
if len(prompts) > 0 and len(responses) > 0:
# ํ
์คํธ๋ก ๋์ฝ๋
input_texts = self.tokenizer.batch_decode(prompts, skip_special_tokens=True)
output_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
# ์ ์๋ ๋์ค์ ๊ณ์ฐ๋๋ฏ๋ก ์ผ๋จ 0์ผ๋ก ์ด๊ธฐํ
scores = [0.0] * len(input_texts)
# ์ ์ฅ
os.makedirs(rollout_data_dir, exist_ok=True)
filename = os.path.join(rollout_data_dir, f"{self.global_steps}_rollout.jsonl")
with open(filename, "w") as f:
for i in range(len(input_texts)):
entry = {
"step": self.global_steps,
"input": input_texts[i],
"output": output_texts[i],
"score": scores[i],
"saved_at": "after_rollout"
}
f.write(json.dumps(entry, ensure_ascii=False) + "\\n")
print(f"โ
Saved rollout data to {filename} (immediately after generation)")
except Exception as e:
print(f"โ ๏ธ Failed to save rollout data: {e}")
# ์ดํ๋ ์๋ ์ฝ๋ ๊ทธ๋๋ก...
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
# ๋๋จธ์ง๋ ๋ถ๋ชจ ํด๋์ค์ fit ๋ฉ์๋๋ฅผ ํธ์ถ
# ์ค์ ๋ก๋ ์ฌ๊ธฐ์ ๋ถ๋ชจ ํด๋์ค์ ๋๋จธ์ง ๋ก์ง์ ๋ชจ๋ ๊ตฌํํด์ผ ํ์ง๋ง,
# ๊ฐ๋จํ ํ๊ธฐ ์ํด ํต์ฌ ๋ถ๋ถ๋ง ๊ตฌํ
# Continue with normal training flow...
# (์ด ๋ถ๋ถ์ ์ค์ ๊ตฌํ ์ ๋ถ๋ชจ ํด๋์ค์ ์ ์ฒด ๋ก์ง์ ํฌํจํด์ผ ํจ)
self.global_steps += 1
pp.status("Training", "Training completed successfully!", "success") |