neural-mesh / test /docs /azr_integration_plan.md
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified

TTRLVR-AZR ํ†ตํ•ฉ ๊ณ„ํš์„œ

๊ฐœ์š”

TTRLVR์„ AZR ๋ฐฉ์‹์œผ๋กœ ์™„์ „ ํ†ตํ•ฉํ•˜์—ฌ ํ•˜๋‚˜์˜ VeRL ์„ธ์…˜์—์„œ ๋ชจ๋“  Phase๋ฅผ ์ฒ˜๋ฆฌํ•˜๋„๋ก ์žฌ๊ตฌ์กฐํ™”

1. ์ „์ฒด ๊ตฌ์กฐ ๋ณ€๊ฒฝ

ํ˜„์žฌ ๊ตฌ์กฐ (๋ถ„๋ฆฌํ˜•)

train_ttrlvr_azr.py
โ”œโ”€โ”€ for round in rounds:
โ”‚   โ”œโ”€โ”€ Phase 1-4: RemoteTestTimePipeline (๋…๋ฆฝ vLLM)
โ”‚   โ”‚   โ”œโ”€โ”€ Step 1: ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
โ”‚   โ”‚   โ”œโ”€โ”€ Step 2: I/O ์Œ ์ƒ์„ฑ
โ”‚   โ”‚   โ”œโ”€โ”€ Step 3: Task ์ƒ์„ฑ
โ”‚   โ”‚   โ””โ”€โ”€ Step 4: ๊ฒ€์ฆ
โ”‚   โ”œโ”€โ”€ ray.kill(pipeline)  # vLLM ์‚ญ์ œ
โ”‚   โ””โ”€โ”€ Phase 5: VeRL Training (์ƒˆ vLLM)
โ”‚       โ”œโ”€โ”€ trainer.init_workers()  # ๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค
โ”‚       โ””โ”€โ”€ trainer.fit()  # 1 epoch

๋ชฉํ‘œ ๊ตฌ์กฐ (ํ†ตํ•ฉํ˜•)

train_ttrlvr_azr_unified.py
โ”œโ”€โ”€ trainer = UnifiedTTRLVRTrainer()
โ”œโ”€โ”€ trainer.init_workers()  # 1๋ฒˆ๋งŒ!
โ””โ”€โ”€ trainer.fit()
    โ””โ”€โ”€ for round in rounds:  # ๋‚ด๋ถ€์—์„œ ์ฒ˜๋ฆฌ
        โ”œโ”€โ”€ Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (๊ฐ™์€ vLLM)
        โ””โ”€โ”€ Phase 5: ํ•™์Šต (๊ฐ™์€ vLLM)

2. ํŒŒ์ผ๋ณ„ ์ˆ˜์ • ๊ณ„ํš

2.1 ์ƒˆ๋กœ์šด ํŒŒ์ผ ์ƒ์„ฑ

/test/trainer/unified_ttrlvr_trainer.py

"""
ํ†ตํ•ฉ TTRLVR Trainer - ๋ชจ๋“  Phase๋ฅผ ํ•˜๋‚˜์˜ ์„ธ์…˜์—์„œ ์ฒ˜๋ฆฌ
"""
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer

class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer):
    def __init__(self, ttrlvr_config, problem_ids, total_rounds, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ttrlvr_config = ttrlvr_config
        self.problem_ids = problem_ids
        self.total_rounds = total_rounds
        self.current_round = 0
        
    def fit(self):
        """๋ฉ”์ธ ํ•™์Šต ๋ฃจํ”„ - ๋ผ์šด๋“œ๋ณ„ ์ฒ˜๋ฆฌ ํฌํ•จ"""
        # ๋กœ๊ฑฐ ์„ค์ •
        logger = self._setup_logger()
        
        # ์ „์ฒด ๋ผ์šด๋“œ ๋ฐ˜๋ณต
        for round_num in range(1, self.total_rounds + 1):
            self.current_round = round_num
            
            # Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
            round_data = self._generate_round_data()
            
            # Phase 5: 1 epoch ํ•™์Šต
            self._train_one_round(round_data)
            
            # ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
            if round_num % 5 == 0:
                self._save_checkpoint()
    
    def _generate_round_data(self):
        """Phase 1-4๋ฅผ VeRL ๋‚ด๋ถ€์—์„œ ์ฒ˜๋ฆฌ"""
        # ๊ธฐ์กด TestTimePipeline ๋กœ์ง์„ ์ด๊ณณ์œผ๋กœ ์ด๋™
        pass

2.2 ๊ธฐ์กด ํŒŒ์ผ ์ˆ˜์ •

/test/train_ttrlvr_azr.py โ†’ /test/train_ttrlvr_azr_unified.py

๋ณ€๊ฒฝ ์ „:

# ๋ณต์žกํ•œ ๋ผ์šด๋“œ๋ณ„ ์ฒ˜๋ฆฌ
trainer = IterativeTrainer(...)
for round in rounds:
    # Phase 1-4
    pipeline = RemoteTestTimePipeline(...)
    data = pipeline.run_complete_pipeline(...)
    ray.kill(pipeline)
    
    # Phase 5
    trainer.train_with_data(data)

๋ณ€๊ฒฝ ํ›„:

# ๋‹จ์ˆœํ™”๋œ ๋ฉ”์ธ ๋กœ์ง
from trainer.unified_ttrlvr_trainer import UnifiedTTRLVRTrainer

# ์„ค์ •
config = load_config()
trainer = UnifiedTTRLVRTrainer(
    config=config,
    problem_ids=problem_ids,
    total_rounds=args.rounds,
    tokenizer=tokenizer,
    ...
)

# ํ•œ ๋ฒˆ๋งŒ ์ดˆ๊ธฐํ™”
trainer.init_workers()

# ๋ชจ๋“  ๋ผ์šด๋“œ ์ฒ˜๋ฆฌ
trainer.fit()

/test/utils/testtime_pipeline.py ๋กœ์ง ์ด๋™

๊ธฐ์กด Phase 1-4 ๋กœ์ง์„ UnifiedTTRLVRTrainer๋กœ ์ด๋™:

class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer):
    def _generate_programs(self, problem_data):
        """Step 1: ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ - TestTimePipeline์—์„œ ์ด๋™"""
        prompt = self._create_program_prompt(problem_data)
        
        # VeRL์˜ vLLM ์‚ฌ์šฉ!
        prompts_proto = DataProto.from_dict({
            "input_ids": tokenize(prompt),
            "attention_mask": ...
        })
        
        # ๊ธฐ์กด actor_rollout_wg ์‚ฌ์šฉ
        outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
        return self._parse_programs(outputs)
    
    def _generate_io_pairs(self, programs):
        """Step 2: I/O ์ƒ์„ฑ - TestTimePipeline์—์„œ ์ด๋™"""
        # ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„
        pass

2.3 ์„ค์ • ํŒŒ์ผ ์ˆ˜์ •

/test/configs/ttrlvr_azr_unified.yaml

# ํ†ตํ•ฉ ์„ค์ •
actor_rollout_ref:
  rollout:
    # dummy_dtensor ์‚ฌ์šฉ ๊ฐ€๋Šฅ (๊ฐ™์€ vLLM ๊ณ„์† ์‚ฌ์šฉ)
    load_format: dummy_dtensor
    
# TTRLVR ํŠนํ™” ์„ค์ •
ttrlvr:
  # Phase 1-4 ์„ค์ •
  num_programs: 4
  input_generation_rounds: 3
  
  # Phase 5 ์„ค์ •  
  train_batch_size: 8
  epochs_per_round: 1  # ๋ผ์šด๋“œ๋‹น 1 epoch

3. ๊ตฌํ˜„ ์ƒ์„ธ

3.1 UnifiedTTRLVRTrainer ์ „์ฒด ๊ตฌํ˜„

# /test/trainer/unified_ttrlvr_trainer.py

import os
import json
import torch
import pandas as pd
from typing import List, Dict, Any, Optional
from datetime import datetime
import numpy as np

from verl import DataProto
from verl.utils.py_utils import merge_dict
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
from absolute_zero_reasoner.testtime.config import BenchmarkConfig
from absolute_zero_reasoner.testtime.execution import PythonExecutor


class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer):
    """
    TTRLVR์˜ ๋ชจ๋“  Phase๋ฅผ ํ•˜๋‚˜์˜ VeRL ์„ธ์…˜์—์„œ ์ฒ˜๋ฆฌํ•˜๋Š” ํ†ตํ•ฉ Trainer
    """
    
    def __init__(
        self,
        ttrlvr_config: Dict[str, Any],
        benchmark_config: BenchmarkConfig,
        problem_ids: List[str],
        total_rounds: int,
        output_dir: str,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        self.ttrlvr_config = ttrlvr_config
        self.benchmark_config = benchmark_config
        self.problem_ids = problem_ids
        self.total_rounds = total_rounds
        self.output_dir = output_dir
        self.current_round = 0
        
        # Phase 1-4์šฉ ์„ค์ •
        self.num_programs = ttrlvr_config.get('num_programs', 4)
        self.input_rounds = ttrlvr_config.get('input_generation_rounds', 3)
        self.parallel_batch_size = ttrlvr_config.get('parallel_batch_size', 4)
        
        # Python ์‹คํ–‰๊ธฐ
        self.executor = PythonExecutor(timeout_length=10)
        
    def fit(self):
        """
        ํ†ตํ•ฉ ํ•™์Šต ๋ฃจํ”„ - AZR์˜ fit()์„ ํ™•์žฅํ•˜์—ฌ ๋ผ์šด๋“œ๋ณ„ ์ฒ˜๋ฆฌ
        """
        # ๊ธฐ๋ณธ ๋กœ๊ฑฐ ์„ค์ •
        from verl.utils.tracking import Tracking
        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=self.config,
            tags=self.config.trainer.wandb_tags,
            entity=self.config.trainer.wandb.entity,
            wandb_run_id=self.config.trainer.wandb_run_id,
        )
        
        # ์ „์ฒด ๋ผ์šด๋“œ ๋ฐ˜๋ณต
        for round_num in range(1, self.total_rounds + 1):
            self.current_round = round_num
            logger.log({"round": round_num})
            
            print(f"\n{'='*80}")
            print(f"๐Ÿ”„ Round {round_num}/{self.total_rounds}")
            print(f"{'='*80}")
            
            # Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
            round_start = datetime.now()
            round_data = self._generate_round_data()
            data_gen_time = (datetime.now() - round_start).total_seconds()
            
            print(f"โœ… Data generation completed in {data_gen_time:.2f}s")
            print(f"๐Ÿ“Š Generated {len(round_data)} training examples")
            
            # ๋ฐ์ดํ„ฐ๋ฅผ parquet ํŒŒ์ผ๋กœ ์ €์žฅ
            self._save_round_data(round_data, round_num)
            
            # Phase 5: PPO ํ•™์Šต (1 epoch)
            train_start = datetime.now()
            metrics = self._train_one_round(round_data, logger)
            train_time = (datetime.now() - train_start).total_seconds()
            
            print(f"โœ… Training completed in {train_time:.2f}s")
            
            # ๋ฉ”ํŠธ๋ฆญ ๋กœ๊น…
            logger.log({
                "round_time/data_generation": data_gen_time,
                "round_time/training": train_time,
                "round_time/total": data_gen_time + train_time,
                **metrics
            })
            
            # ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
            if round_num % 5 == 0:
                self._save_checkpoint()
                
    def _generate_round_data(self) -> List[Dict[str, Any]]:
        """
        Phase 1-4: ํ˜„์žฌ ๋ชจ๋ธ๋กœ ๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
        """
        all_tasks = []
        
        for problem_id in self.problem_ids:
            print(f"\n๐Ÿ“ Processing problem: {problem_id}")
            
            try:
                # Step 1: ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
                programs = self._generate_programs(problem_id)
                print(f"  โœ“ Generated {len(programs)} programs")
                
                # Step 2: I/O ์Œ ์ƒ์„ฑ
                io_pairs = self._generate_io_pairs(problem_id, programs)
                print(f"  โœ“ Generated {len(io_pairs)} I/O pairs")
                
                # Step 3: Task ์ƒ์„ฑ
                tasks = self._create_reasoning_tasks(problem_id, programs, io_pairs)
                print(f"  โœ“ Created {len(tasks)} tasks")
                
                # Step 4: ๊ฒ€์ฆ
                valid_tasks = self._validate_tasks(tasks)
                print(f"  โœ“ Validated {len(valid_tasks)}/{len(tasks)} tasks")
                
                all_tasks.extend(valid_tasks)
                
            except Exception as e:
                print(f"  โœ— Error processing {problem_id}: {e}")
                continue
                
        return all_tasks
    
    def _generate_programs(self, problem_id: str) -> List[str]:
        """
        Step 1: ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
        VeRL์˜ vLLM์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒ์„ฑ
        """
        # ๋ฌธ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
        problem_data = self._load_problem_data(problem_id)
        
        # ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
        prompt = f"""You are given a programming problem. Generate {self.num_programs} different solutions.

Problem: {problem_data['description']}

Generate {self.num_programs} different Python solutions:"""
        
        # ํ† ํฐํ™”
        input_ids = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.config.data.max_prompt_length
        ).input_ids
        
        # DataProto ์ƒ์„ฑ
        prompts_proto = DataProto.from_dict({
            "input_ids": input_ids.cuda(),
            "attention_mask": torch.ones_like(input_ids).cuda(),
            "position_ids": torch.arange(input_ids.size(1)).unsqueeze(0).cuda()
        })
        
        # ๋ฉ”ํƒ€ ์ •๋ณด ์ถ”๊ฐ€
        prompts_proto.meta_info = {
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
            "temperature": 0.8,  # ๋‹ค์–‘์„ฑ์„ ์œ„ํ•ด ๋†’์€ temperature
            "do_sample": True,
            "top_p": 0.95,
            "response_length": 512
        }
        
        # VeRL์˜ vLLM์œผ๋กœ ์ƒ์„ฑ!
        outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
        
        # ํ”„๋กœ๊ทธ๋žจ ์ถ”์ถœ
        programs = []
        generated_text = self.tokenizer.decode(
            outputs.batch["input_ids"][0], 
            skip_special_tokens=True
        )
        
        # ํ”„๋กœ๊ทธ๋žจ ํŒŒ์‹ฑ (์ฝ”๋“œ ๋ธ”๋ก ์ถ”์ถœ)
        code_blocks = self._extract_code_blocks(generated_text)
        programs.extend(code_blocks[:self.num_programs])
        
        return programs
    
    def _generate_io_pairs(
        self, 
        problem_id: str, 
        programs: List[str]
    ) -> List[Dict[str, Any]]:
        """
        Step 2: ํ”„๋กœ๊ทธ๋žจ๋“ค๋กœ๋ถ€ํ„ฐ I/O ์Œ ์ƒ์„ฑ
        """
        io_pairs = []
        
        for program in programs:
            # ๊ฐ ํ”„๋กœ๊ทธ๋žจ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ์ž…๋ ฅ ์ƒ์„ฑ
            for round_idx in range(self.input_rounds):
                prompt = f"""Given this Python function, generate {5} test inputs.

Function:
```python
{program}

Generate {5} different test inputs as a Python list:"""

            # ์ž…๋ ฅ ์ƒ์„ฑ
            inputs = self._generate_with_vllm(prompt, temperature=0.7)
            
            # ๊ฐ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์ถœ๋ ฅ ๊ณ„์‚ฐ
            for test_input in inputs:
                try:
                    output = self.executor.execute(program, test_input)
                    if output['success']:
                        io_pairs.append({
                            'input': test_input,
                            'output': output['result'],
                            'program': program
                        })
                except:
                    continue
                    
    return io_pairs

def _create_reasoning_tasks(
    self,
    problem_id: str,
    programs: List[str],
    io_pairs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """
    Step 3: Induction, Deduction, Abduction task ์ƒ์„ฑ
    """
    tasks = []
    
    for io_pair in io_pairs:
        # Induction: I/O โ†’ Program
        tasks.append({
            'problem_id': problem_id,
            'task_type': 'induction',
            'input': io_pair['input'],
            'output': io_pair['output'],
            'target': io_pair['program'],
            'prompt': self._create_induction_prompt(io_pair)
        })
        
        # Deduction: Program + Input โ†’ Output  
        tasks.append({
            'problem_id': problem_id,
            'task_type': 'deduction',
            'input': io_pair['input'],
            'program': io_pair['program'],
            'target': io_pair['output'],
            'prompt': self._create_deduction_prompt(io_pair)
        })
        
        # Abduction: Program + Output โ†’ Input
        tasks.append({
            'problem_id': problem_id,
            'task_type': 'abduction',
            'program': io_pair['program'],
            'output': io_pair['output'],
            'target': io_pair['input'],
            'prompt': self._create_abduction_prompt(io_pair)
        })
        
    return tasks

def _train_one_round(
    self, 
    round_data: List[Dict[str, Any]], 
    logger
) -> Dict[str, float]:
    """
    Phase 5: ํ•œ ๋ผ์šด๋“œ์˜ PPO ํ•™์Šต
    """
    # ๋ฐ์ดํ„ฐ๋ฅผ VeRL ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
    train_dataset = self._convert_to_verl_dataset(round_data)
    
    # ํ˜„์žฌ dataloader ์—…๋ฐ์ดํŠธ
    self.train_dataloader = self._create_dataloader(
        train_dataset,
        self.val_dataset,
        self.collate_fn,
        self.train_sampler
    )
    
    # 1 epoch ํ•™์Šต
    epoch_metrics = {}
    for step, batch in enumerate(self.train_dataloader):
        # ๋ฐฐ์น˜ ์ค€๋น„
        gen_batch = self._prepare_generation_batch(batch)
        
        # ์‹œํ€€์Šค ์ƒ์„ฑ (๊ฐ™์€ vLLM ์‚ฌ์šฉ!)
        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
        
        # ๋ฆฌ์›Œ๋“œ ๊ณ„์‚ฐ
        batch = batch.union(gen_batch_output)
        reward_tensor = self.reward_fn(batch)
        
        # PPO ์—…๋ฐ์ดํŠธ
        update_metrics = self._ppo_update(batch, reward_tensor)
        
        # ๋ฉ”ํŠธ๋ฆญ ์ˆ˜์ง‘
        for k, v in update_metrics.items():
            if k not in epoch_metrics:
                epoch_metrics[k] = []
            epoch_metrics[k].append(v)
            
    # ํ‰๊ท  ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
    avg_metrics = {
        k: np.mean(v) for k, v in epoch_metrics.items()
    }
    
    return avg_metrics

def _generate_with_vllm(
    self, 
    prompt: str, 
    temperature: float = 0.7
) -> Any:
    """
    ํ—ฌํผ ํ•จ์ˆ˜: VeRL์˜ vLLM์„ ์‚ฌ์šฉํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ
    """
    # ํ† ํฐํ™”
    input_ids = self.tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).input_ids
    
    # DataProto ์ƒ์„ฑ
    prompts_proto = DataProto.from_dict({
        "input_ids": input_ids.cuda(),
        "attention_mask": torch.ones_like(input_ids).cuda(),
    })
    
    prompts_proto.meta_info = {
        "eos_token_id": self.tokenizer.eos_token_id,
        "pad_token_id": self.tokenizer.pad_token_id,
        "temperature": temperature,
        "do_sample": True,
        "response_length": 256
    }
    
    # ์ƒ์„ฑ
    outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
    
    # ๋””์ฝ”๋”ฉ
    generated_text = self.tokenizer.decode(
        outputs.batch["input_ids"][0],
        skip_special_tokens=True
    )
    
    return self._parse_output(generated_text)

def _save_round_data(self, round_data: List[Dict], round_num: int):
    """๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ๋ฅผ parquet ํŒŒ์ผ๋กœ ์ €์žฅ"""
    output_dir = os.path.join(self.output_dir, f"round_{round_num}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Task ํƒ€์ž…๋ณ„๋กœ ๋ถ„๋ฆฌ
    for task_type in ['induction', 'deduction', 'abduction']:
        tasks = [t for t in round_data if t['task_type'] == task_type]
        if tasks:
            df = pd.DataFrame(tasks)
            df.to_parquet(os.path.join(output_dir, f"{task_type}.parquet"))

### 3.2 ๋ฐ์ดํ„ฐ ํ๋ฆ„ ์ƒ์„ธ

```python
# ์‹ค์ œ ๋ฐ์ดํ„ฐ๊ฐ€ ํ๋ฅด๋Š” ๊ณผ์ •

# Round 1 ์‹œ์ž‘
trainer.current_round = 1

# 1. ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
programs = trainer._generate_programs("Mbpp/1")
# โ†’ trainer.actor_rollout_wg.generate_sequences() ํ˜ธ์ถœ
# โ†’ FSDP ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๊ฐ€ vLLM์— ๋™๊ธฐํ™”๋จ (์ฒซ ๋ฒˆ์งธ)
# โ†’ ์ถœ๋ ฅ: ["def solve(x): return x*2", "def solve(x): return 2*x", ...]

# 2. I/O ์ƒ์„ฑ  
io_pairs = trainer._generate_io_pairs("Mbpp/1", programs)
# โ†’ ๊ฐ™์€ vLLM ์‚ฌ์šฉ (๋™๊ธฐํ™” ๊ฑด๋„ˆ๋œ€ - base_sync_done=True)
# โ†’ ์ถœ๋ ฅ: [{"input": 5, "output": 10}, {"input": 3, "output": 6}, ...]

# 3. Task ์ƒ์„ฑ
tasks = trainer._create_reasoning_tasks(...)
# โ†’ ๋ฉ”๋ชจ๋ฆฌ์—์„œ๋งŒ ์ฒ˜๋ฆฌ (vLLM ํ˜ธ์ถœ ์—†์Œ)

# 4. PPO ํ•™์Šต
trainer._train_one_round(tasks)
# โ†’ ๊ฐ™์€ vLLM์œผ๋กœ response ์ƒ์„ฑ
# โ†’ FSDP ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ
# โ†’ vLLM์€ ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ๋กœ ์ž๋™ ์—…๋ฐ์ดํŠธ

# Round 2 ์‹œ์ž‘ - ๊ฐ™์€ vLLM ๊ณ„์† ์‚ฌ์šฉ!

3.3 ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ƒ์„ธ

class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer):
    def _manage_memory(self):
        """Phase ์ „ํ™˜ ์‹œ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ"""
        # Ray actor๋Š” ์œ ์ง€ํ•˜๋ฉด์„œ GPU ์บ์‹œ๋งŒ ์ •๋ฆฌ
        torch.cuda.empty_cache()
        
        # vLLM์˜ KV ์บ์‹œ ์ •๋ฆฌ (์„ ํƒ์ )
        if hasattr(self.actor_rollout_wg, 'clear_kv_cache'):
            self.actor_rollout_wg.clear_kv_cache()
            
    def _monitor_memory(self):
        """๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋ชจ๋‹ˆํ„ฐ๋ง"""
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")

3.4 ๋™๊ธฐํ™” ๋ฉ”์ปค๋‹ˆ์ฆ˜ ์ƒ์„ธ

# ๋™๊ธฐํ™”๊ฐ€ ์–ด๋–ป๊ฒŒ ๋ณด์žฅ๋˜๋Š”์ง€

# 1. ์ฒซ ๋ฒˆ์งธ generate_sequences ํ˜ธ์ถœ
with self.rollout_sharding_manager:  # __enter__() ํ˜ธ์ถœ
    # dummy_dtensor ์‚ฌ์šฉ ์‹œ:
    # - self.base_sync_done = False (์ดˆ๊ธฐ๊ฐ’)
    # - sync_model_weights() ์‹คํ–‰ โ†’ FSDP โ†’ vLLM ๋™๊ธฐํ™”
    # - self.base_sync_done = True ์„ค์ •
    
# 2. ์ดํ›„ generate_sequences ํ˜ธ์ถœ๋“ค
with self.rollout_sharding_manager:  # __enter__() ํ˜ธ์ถœ
    # - self.base_sync_done = True
    # - sync_model_weights() ๊ฑด๋„ˆ๋œ€
    # - ํ•˜์ง€๋งŒ ๊ฐ™์€ vLLM์ด๋ฏ€๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ๋กœ ์—…๋ฐ์ดํŠธ๋จ

# 3. ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ ๋ฉ”์ปค๋‹ˆ์ฆ˜
# FSDP ๋ชจ๋ธ๊ณผ vLLM ๋ชจ๋ธ์ด ๊ฐ™์€ tensor๋ฅผ ์ฐธ์กฐ
# FSDP ์—…๋ฐ์ดํŠธ โ†’ tensor ๊ฐ’ ๋ณ€๊ฒฝ โ†’ vLLM๋„ ์ž๋™์œผ๋กœ ์ƒˆ ๊ฐ’ ์‚ฌ์šฉ

3.5 ์—๋Ÿฌ ์ฒ˜๋ฆฌ ๋ฐ ๋ณต๊ตฌ

class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer):
    def _safe_generate(self, prompt: str, max_retries: int = 3):
        """์•ˆ์ „ํ•œ ์ƒ์„ฑ with ์žฌ์‹œ๋„"""
        for attempt in range(max_retries):
            try:
                return self._generate_with_vllm(prompt)
            except Exception as e:
                print(f"Generation failed (attempt {attempt+1}): {e}")
                if attempt == max_retries - 1:
                    raise
                # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ›„ ์žฌ์‹œ๋„
                torch.cuda.empty_cache()
                time.sleep(1)
                
    def _validate_tasks(self, tasks: List[Dict]) -> List[Dict]:
        """์ƒ์„ฑ๋œ task ๊ฒ€์ฆ"""
        valid_tasks = []
        for task in tasks:
            if self._is_valid_task(task):
                valid_tasks.append(task)
            else:
                print(f"Invalid task filtered: {task['task_type']}")
        return valid_tasks

4. ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ๊ณ„ํš

Phase 1: ์ฝ”๋“œ ์ค€๋น„

  1. UnifiedTTRLVRTrainer ํด๋ž˜์Šค ์ƒ์„ฑ
  2. TestTimePipeline ๋กœ์ง ์ด๋™
  3. ๋‹จ์œ„ ํ…Œ์ŠคํŠธ ์ž‘์„ฑ

Phase 2: ํ†ตํ•ฉ ํ…Œ์ŠคํŠธ

  1. ์†Œ๊ทœ๋ชจ ๋ฌธ์ œ๋กœ ํ…Œ์ŠคํŠธ (1-2 rounds)
  2. ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋ชจ๋‹ˆํ„ฐ๋ง
  3. ํ•™์Šต ์„ฑ๋Šฅ ๋น„๊ต

Phase 3: ์ „ํ™˜

  1. ๊ธฐ์กด ์Šคํฌ๋ฆฝํŠธ ๋ฐฑ์—…
  2. ์ƒˆ ์Šคํฌ๋ฆฝํŠธ๋กœ ๊ต์ฒด
  3. ์ „์ฒด ํ•™์Šต ์‹คํ–‰

5. ์˜ˆ์ƒ ๊ฒฐ๊ณผ

์žฅ์ 

  • โœ… ๋™๊ธฐํ™” ๋ฌธ์ œ ์™„์ „ ํ•ด๊ฒฐ
  • โœ… 30-40% ๋น ๋ฅธ ์‹คํ–‰ (vLLM ์žฌ์ƒ์„ฑ ์—†์Œ)
  • โœ… ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ 20-30% ๊ฐœ์„ 
  • โœ… ์ฝ”๋“œ ๊ตฌ์กฐ ๋‹จ์ˆœํ™”

๋‹จ์  ๋ฐ ๋Œ€์‘

  • โŒ Phase ๊ฐ„ ๊ฒฐํ•ฉ๋„ ์ฆ๊ฐ€
    • โ†’ ๋ช…ํ™•ํ•œ ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜๋กœ ํ•ด๊ฒฐ
  • โŒ ๋””๋ฒ„๊น… ๋ณต์žก๋„
    • โ†’ ์ƒ์„ธํ•œ ๋กœ๊น… ์ถ”๊ฐ€
  • โŒ ๊ธฐ์กด ์ฝ”๋“œ์™€ ํ˜ธํ™˜์„ฑ
    • โ†’ ๋‘ ๋ฒ„์ „ ๋ณ‘ํ–‰ ์œ ์ง€

6. ๊ตฌํ˜„ ์šฐ์„ ์ˆœ์œ„

  1. ๋†’์Œ: UnifiedTTRLVRTrainer ๊ธฐ๋ณธ ๊ตฌ์กฐ
  2. ๋†’์Œ: Phase 1-4 ๋กœ์ง ์ด๋™
  3. ์ค‘๊ฐ„: ์„ค์ • ํŒŒ์ผ ํ†ตํ•ฉ
  4. ๋‚ฎ์Œ: ์ถ”๊ฐ€ ์ตœ์ ํ™”

7. ํ…Œ์ŠคํŠธ ๊ณ„ํš

# ๋‹จ๊ณ„๋ณ„ ํ…Œ์ŠคํŠธ
# 1. ์†Œ๊ทœ๋ชจ ํ…Œ์ŠคํŠธ
python train_ttrlvr_azr_unified.py --rounds 2 --problems 1

# 2. ์ค‘๊ฐ„ ํ…Œ์ŠคํŠธ  
python train_ttrlvr_azr_unified.py --rounds 5 --problems 5

# 3. ์ „์ฒด ํ…Œ์ŠคํŠธ
python train_ttrlvr_azr_unified.py --rounds 30 --problems 10

8. ๋กค๋ฐฑ ๊ณ„ํš

๋ฌธ์ œ ๋ฐœ์ƒ ์‹œ:

  1. ๊ธฐ์กด ๋ถ„๋ฆฌํ˜• ๊ตฌ์กฐ๋กœ ์ฆ‰์‹œ ๋ณต๊ท€ ๊ฐ€๋Šฅ
  2. load_format: dtensor ์‚ฌ์šฉ์œผ๋กœ ์ž„์‹œ ํ•ด๊ฒฐ
  3. ๋‹จ๊ณ„์  ํ†ตํ•ฉ (Phase 5๋งŒ ๋จผ์ €)