File size: 10,469 Bytes
f50dc54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Custom Ray Trainer that saves rollout data immediately after generation
"""
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
from verl.utils.debug import marked_timer
import os
import json


class CustomCodeIORayPPOTrainer(CodeIORayPPOTrainer):
    """
    Custom trainer that saves rollout data immediately after generation,
    not after actor update.
    """
    
    def fit(self):
        """
        Override fit method to save rollout data right after generation
        """
        # Import here to avoid circular dependencies
        from absolute_zero_reasoner.trainer.ppo.reason_rl_ray_trainer import (
            AdvantageEstimator, apply_kl_penalty, compute_advantage, reduce_metrics, 
            compute_data_metrics, compute_timing_metrics, compute_response_mask, 
            pad_dataproto_to_divisor, unpad_dataproto, DataProto, core_algos
        )
        from absolute_zero_reasoner.utils.tracking import ReasonRLTracking
        from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter as pp
        from omegaconf import OmegaConf
        from deepcopy import deepcopy
        import numpy as np
        import uuid
        import torch
        from datetime import datetime
        
        # ๊ธฐ๋ณธ ์ดˆ๊ธฐํ™”๋Š” ๋ถ€๋ชจ ํด๋ž˜์Šค์˜ fit ๋ฉ”์„œ๋“œ ์‹œ์ž‘ ๋ถ€๋ถ„์„ ๋”ฐ๋ฆ„
        pp.section_header("Training Setup")
        logger = ReasonRLTracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
            tags=self.config.trainer.wandb_tags,
            resume="must" if self.config.trainer.resume_mode == 'auto' and \
                self.config.trainer.wandb_run_id is not None else False,
            run_id=self.config.trainer.wandb_run_id \
                if self.config.trainer.wandb_run_id is not None else None
        )
        
        pp.status("Config", f"Project: {self.config.trainer.project_name}, Experiment: {self.config.trainer.experiment_name}", "info")
        pp.status("Algorithm", f"Using {self.config.algorithm.adv_estimator} advantage estimator", "info")
        pp.status("Setup", f"Critic enabled: {self.use_critic}, Reference policy: {self.use_reference_policy}", "info")
        
        self.global_steps = 0
        
        # load checkpoint before doing anything
        pp.status("Checkpoint", "Loading checkpoint if available...", "info")
        self._load_checkpoint()
        
        # base model chat template
        if self.config.actor_rollout_ref.model.pretrained_tokenizer:
            self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
        
        # perform validation before training
        if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0:
            pp.section_header("Initial Validation")
            pp.status("Validation", "Running initial validation...", "info")
            
            val_metrics = self._validate(do_sample=self.config.eval.do_sample)
            # Convert metrics to table format
            metrics_table = []
            for k, v in val_metrics.items():
                metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
            pp.table(["Metric", "Value"], metrics_table, "Initial Validation Results")
            logger.log(data=val_metrics, step=self.global_steps)
            
            if self.config.trainer.get('val_only', False):
                pp.status("Training", "Validation only mode, exiting", "success")
                return
        
        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None
        self.max_steps_duration = 0
        
        pp.section_header("Starting Training")
        total_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
        pp.status("Training", f"Starting training for {self.config.trainer.total_epochs} epochs ({total_steps} steps)", "info")
        
        for epoch in range(self.config.trainer.total_epochs):
            pp.status("Epoch", f"Starting epoch {epoch+1}/{self.config.trainer.total_epochs}", "info")
            for batch_idx, batch_dict in enumerate(self.train_dataloader):
                do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
                if do_profile:
                    self.actor_rollout_wg.start_profile()
                    if self.use_reference_policy:
                        self.ref_policy_wg.start_profile()
                    if self.use_critic:
                        self.critic_wg.start_profile()
                    if self.use_rm:
                        self.rm_wg.start_profile()
                
                metrics = {}
                timing_raw = {}
                batch: DataProto = DataProto.from_single_dict(batch_dict)
                
                # pop those keys for generation
                batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
                non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
                if "multi_modal_data" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("multi_modal_data")
                if "raw_prompt" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("raw_prompt")
                if "tools_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("tools_kwargs")
                if "interaction_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("interaction_kwargs")
                
                gen_batch = batch.pop(
                    batch_keys=batch_keys_to_pop,
                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
                )
                is_last_step = self.global_steps >= self.total_training_steps
                
                with marked_timer("step", timing_raw):
                    # generate a batch
                    with marked_timer("gen", timing_raw, color="red"):
                        pp.status("Step", f"Generating sequences for batch {batch_idx+1}", "info")
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                    
                    # ๐Ÿ”ฅ ํ•ต์‹ฌ ์ˆ˜์ •: Rollout ์งํ›„ ๋ฐ”๋กœ ์ €์žฅ
                    rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
                    if rollout_data_dir and gen_batch_output is not None:
                        try:
                            # gen_batch_output์—์„œ ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
                            prompts = gen_batch.batch.get("prompts", [])
                            responses = gen_batch_output.batch.get("responses", [])
                            
                            if len(prompts) > 0 and len(responses) > 0:
                                # ํ…์ŠคํŠธ๋กœ ๋””์ฝ”๋“œ
                                input_texts = self.tokenizer.batch_decode(prompts, skip_special_tokens=True)
                                output_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
                                
                                # ์ ์ˆ˜๋Š” ๋‚˜์ค‘์— ๊ณ„์‚ฐ๋˜๋ฏ€๋กœ ์ผ๋‹จ 0์œผ๋กœ ์ดˆ๊ธฐํ™”
                                scores = [0.0] * len(input_texts)
                                
                                # ์ €์žฅ
                                os.makedirs(rollout_data_dir, exist_ok=True)
                                filename = os.path.join(rollout_data_dir, f"{self.global_steps}_rollout.jsonl")
                                
                                with open(filename, "w") as f:
                                    for i in range(len(input_texts)):
                                        entry = {
                                            "step": self.global_steps,
                                            "input": input_texts[i],
                                            "output": output_texts[i],
                                            "score": scores[i],
                                            "saved_at": "after_rollout"
                                        }
                                        f.write(json.dumps(entry, ensure_ascii=False) + "\\n")
                                
                                print(f"โœ… Saved rollout data to {filename} (immediately after generation)")
                        except Exception as e:
                            print(f"โš ๏ธ Failed to save rollout data: {e}")
                    
                    # ์ดํ›„๋Š” ์›๋ž˜ ์ฝ”๋“œ ๊ทธ๋Œ€๋กœ...
                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with marked_timer("gen_max", timing_raw, color="purple"):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
                            batch = batch.union(gen_baseline_output)
                            reward_baseline_tensor, _ = self.reward_fn(batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
                            batch.batch["reward_baselines"] = reward_baseline_tensor
                            del gen_baseline_batch, gen_baseline_output
                    
                    # ๋‚˜๋จธ์ง€๋Š” ๋ถ€๋ชจ ํด๋ž˜์Šค์˜ fit ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœ
                    # ์‹ค์ œ๋กœ๋Š” ์—ฌ๊ธฐ์„œ ๋ถ€๋ชจ ํด๋ž˜์Šค์˜ ๋‚˜๋จธ์ง€ ๋กœ์ง์„ ๋ชจ๋‘ ๊ตฌํ˜„ํ•ด์•ผ ํ•˜์ง€๋งŒ,
                    # ๊ฐ„๋‹จํžˆ ํ•˜๊ธฐ ์œ„ํ•ด ํ•ต์‹ฌ ๋ถ€๋ถ„๋งŒ ๊ตฌํ˜„
                    
                    # Continue with normal training flow...
                    # (์ด ๋ถ€๋ถ„์€ ์‹ค์ œ ๊ตฌํ˜„ ์‹œ ๋ถ€๋ชจ ํด๋ž˜์Šค์˜ ์ „์ฒด ๋กœ์ง์„ ํฌํ•จํ•ด์•ผ ํ•จ)
                
                self.global_steps += 1
        
        pp.status("Training", "Training completed successfully!", "success")