neural-mesh / test /train_ttrlvr_azr_unified.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
#!/usr/bin/env python3
"""
TTRLVR + AZR ํ†ตํ•ฉ ํ•™์Šต ๋ฉ”์ธ ์Šคํฌ๋ฆฝํŠธ (Unified Version)
UnifiedTTRLVRTrainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•˜๋‚˜์˜ VeRL ์„ธ์…˜์—์„œ ์ „์ฒด ํ•™์Šต ์ง„ํ–‰:
1. VeRL worker ํ•œ ๋ฒˆ๋งŒ ์ดˆ๊ธฐํ™”
2. ๊ฐ ๋ผ์šด๋“œ๋งˆ๋‹ค ๊ฐ™์€ vLLM์œผ๋กœ Phase 1-4 ์‹คํ–‰
3. ๊ฐ™์€ vLLM์œผ๋กœ Phase 5 PPO ํ•™์Šต
4. ๋™๊ธฐํ™” ๋ฌธ์ œ ์™„์ „ ํ•ด๊ฒฐ (dummy_dtensor ์‚ฌ์šฉ ๊ฐ€๋Šฅ)
์‚ฌ์šฉ ์˜ˆ์‹œ:
# ์ผ๋ฐ˜ ํ•™์Šต
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30
python train_ttrlvr_azr_unified.py --benchmark humaneval --problems 5 --rounds 10
python train_ttrlvr_azr_unified.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5
# GPU ์ง€์ •
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import List
import warnings
import ray
import torch
# Gradient checkpointing ๊ด€๋ จ ๊ฒฝ๊ณ  ํ•„ํ„ฐ๋ง
warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*")
# ๊ฒฝ๋กœ ์„ค์ • - ์ƒ๋Œ€ ๊ฒฝ๋กœ ์‚ฌ์šฉ
project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 ๋””๋ ‰ํ† ๋ฆฌ
sys.path.append(str(project_root))
# verl๊ณผ Absolute-Zero-Reasoner๋Š” ์ƒ์œ„ ๋””๋ ‰ํ† ๋ฆฌ์—์„œ ์ฐพ๊ธฐ
parent_dir = project_root.parent
for lib_name in ['verl', 'Absolute-Zero-Reasoner']:
lib_path = parent_dir / lib_name
if lib_path.exists():
sys.path.append(str(lib_path))
# pip๋กœ ์„ค์น˜๋œ ๊ฒฝ์šฐ๋Š” ์ž๋™์œผ๋กœ import ๋จ
# AZR/VeRL ๋ชจ๋“ˆ ์ž„ํฌํŠธ (main_azr_ppo.py์™€ ๋™์ผํ•œ ๊ตฌ์กฐ)
from verl import DataProto
from omegaconf import OmegaConf
import ray
from verl.utils import hf_tokenizer
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter
# TTRLVR ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig
from absolute_zero_reasoner.testtime.logger import TestTimeLogger
# Ray ์ •๋ฆฌ ๋ณ€์ˆ˜
_trainer_instance = None
_logger_instance = None
def cleanup_ray():
"""Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ํ•จ์ˆ˜"""
global _trainer_instance, _logger_instance
try:
if _logger_instance:
_logger_instance.log_info("๐Ÿ”„ ๊ฐ•์ œ ์ข…๋ฃŒ ๊ฐ์ง€: Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ์ค‘...")
except:
print("๐Ÿ”„ ๊ฐ•์ œ ์ข…๋ฃŒ ๊ฐ์ง€: Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ์ค‘...")
try:
# IterativeTrainer ์ •๋ฆฌ
if _trainer_instance:
_trainer_instance.cleanup_ray()
except Exception as e:
try:
if _logger_instance:
_logger_instance.log_error(f"IterativeTrainer ์ •๋ฆฌ ์‹คํŒจ: {e}")
except:
print(f"IterativeTrainer ์ •๋ฆฌ ์‹คํŒจ: {e}")
try:
# ํ˜„์žฌ ํ”„๋กœ๊ทธ๋žจ์˜ Ray๋งŒ ์ข…๋ฃŒ (์•ˆ์ „ํ•œ ๋ฐฉ๋ฒ•)
import ray
if ray.is_initialized():
ray.shutdown()
except Exception as e:
try:
if _logger_instance:
_logger_instance.log_error(f"Ray ์ข…๋ฃŒ ์‹คํŒจ: {e}")
except:
print(f"Ray ์ข…๋ฃŒ ์‹คํŒจ: {e}")
try:
if _logger_instance:
_logger_instance.log_info("โœ… Ray ์ •๋ฆฌ ์™„๋ฃŒ")
except:
print("โœ… Ray ์ •๋ฆฌ ์™„๋ฃŒ")
def signal_handler(signum, frame):
"""์‹œ๊ทธ๋„ ํ•ธ๋“ค๋Ÿฌ (Ctrl+C, ๊ฐ•์ œ ์ข…๋ฃŒ ๋“ฑ)"""
try:
if _logger_instance:
_logger_instance.log_info(f"๐Ÿ›‘ ์‹œ๊ทธ๋„ {signum} ์ˆ˜์‹ : ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ ์ค‘...")
except:
print(f"๐Ÿ›‘ ์‹œ๊ทธ๋„ {signum} ์ˆ˜์‹ : ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ ์ค‘...")
cleanup_ray()
sys.exit(1)
def parse_arguments():
"""๋ช…๋ นํ–‰ ์ธ์ž ํŒŒ์‹ฑ"""
parser = argparse.ArgumentParser(
description='TTRLVR + AZR ํ†ตํ•ฉ ๋ฐ˜๋ณต ํ•™์Šต',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
์˜ˆ์‹œ:
# MBPP 10๋ฌธ์ œ๋กœ 30๋ผ์šด๋“œ ํ•™์Šต
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30
# HumanEval 5๋ฌธ์ œ๋กœ 10๋ผ์šด๋“œ ํ•™์Šต
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10
# 15๋ผ์šด๋“œ๋ถ€ํ„ฐ ์žฌ๊ฐœ
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --resume 15
# ํŠน์ • GPU ์‚ฌ์šฉ
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --gpu 4
"""
)
parser.add_argument(
'--benchmark',
choices=['mbpp', 'humaneval'],
default='mbpp',
help='๋ฒค์น˜๋งˆํฌ ์„ ํƒ (๊ธฐ๋ณธ๊ฐ’: mbpp)'
)
parser.add_argument(
'--problems',
type=int,
default=10,
help='๋ฌธ์ œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 10)'
)
parser.add_argument(
'--problem-id',
type=str,
help='ํŠน์ • ๋ฌธ์ œ ID (์˜ˆ: HumanEval/1, Mbpp/10)'
)
parser.add_argument(
'--rounds',
type=int,
default=30,
help='์ด ๋ผ์šด๋“œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 30)'
)
parser.add_argument(
'--resume',
type=int,
default=1,
help='์žฌ๊ฐœํ•  ๋ผ์šด๋“œ ๋ฒˆํ˜ธ (๊ธฐ๋ณธ๊ฐ’: 1)'
)
parser.add_argument(
'--gpu',
type=str,
default='5',
help='์‚ฌ์šฉํ•  GPU ๋ฒˆํ˜ธ (๋‹จ์ผ: 5, ๋‹ค์ค‘: 1,2,3,5)'
)
parser.add_argument(
'--output-dir',
type=str,
default='./results/ttrlvr_azr',
help='๊ฒฐ๊ณผ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: ./results/ttrlvr_azr)'
)
parser.add_argument(
'--config',
type=str,
help='์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (์„ ํƒ์‚ฌํ•ญ)'
)
parser.add_argument(
'--model',
type=str,
default='Qwen/Qwen2.5-7B',
help='์‚ฌ์šฉํ•  ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ’: Qwen/Qwen2.5-7B)'
)
parser.add_argument(
'--debug',
action='store_true',
help='๋””๋ฒ„๊ทธ ๋ชจ๋“œ ํ™œ์„ฑํ™”'
)
parser.add_argument(
'--batch-size',
type=int,
default=24,
help='ํ•™์Šต ๋ฐฐ์น˜ ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ’: 24, OOM ์‹œ ์ค„์ด๊ธฐ)'
)
parser.add_argument(
'--batch-epochs',
type=int,
default=1,
help='๋ฐฐ์น˜๋‹น ์—ํญ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 1, ๋” ๋งŽ์€ ํ•™์Šต์„ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--num-programs',
type=int,
default=4,
help='์ƒ์„ฑํ•  ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 4, ๋” ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--input-generation-rounds',
type=int,
default=3,
help='๋‹ค์–‘ํ•œ ์ž…๋ ฅ ์ƒ์„ฑ ๋ผ์šด๋“œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 3, ๋ผ์šด๋“œ๋‹น 5๊ฐœ์”ฉ ์ƒ์„ฑ)'
)
parser.add_argument(
'--parallel-batch-size',
type=int,
default=4,
help='๋™์‹œ ์ฒ˜๋ฆฌํ•  ํ”„๋กฌํ”„ํŠธ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 4, GPU ๋ฉ”๋ชจ๋ฆฌ์— ๋”ฐ๋ผ ์กฐ์ •)'
)
parser.add_argument(
'--eval-rounds',
type=int,
default=5,
help='๋งค ๋ผ์šด๋“œ ์ •ํ™•๋„ ์ธก์ • ํšŸ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 5, ๋” ์ •ํ™•ํ•œ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--skip-task-eval',
action='store_true',
help='Task evaluation(4๋‹จ๊ณ„) ์Šคํ‚ตํ•˜์—ฌ ๋น ๋ฅธ ํ…Œ์ŠคํŠธ (๋ฐ์ดํ„ฐ ์ƒ์„ฑ ํ›„ ๋ฐ”๋กœ VeRL ํ•™์Šต)'
)
parser.add_argument(
'--save-every-round',
action='store_true',
help='๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (๊ธฐ๋ณธ๊ฐ’: False)'
)
parser.add_argument(
'--save-round-interval',
type=int,
default=5,
help='์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๊ฐ„๊ฒฉ (์˜ˆ: 5 = 5๋ผ์šด๋“œ๋งˆ๋‹ค ์ €์žฅ, ๊ธฐ๋ณธ๊ฐ’: 5)'
)
return parser.parse_args()
def setup_environment(gpu_id: str, batch_size: int = None):
"""ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • - run_ttrlvr_azr_training.sh์™€ ๋™์ผํ•˜๊ฒŒ"""
# GPU ์„ค์ • - ๋ช…๋ นํ–‰ ์ธ์ž๋ฅผ ์šฐ์„  ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ์กด ํ™˜๊ฒฝ๋ณ€์ˆ˜ ์‚ฌ์šฉ
if gpu_id:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
print(f"๐ŸŽฏ Using command line GPU setting: {gpu_id}")
elif 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES']:
print(f"๐ŸŽฏ Using existing CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '5' # ๊ธฐ๋ณธ๊ฐ’
print(f"๐ŸŽฏ Using default GPU: 5")
# VLLM ์„ค์ • (run_ttrlvr_azr_training.sh์™€ ๋™์ผ)
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN'
# Ray ์„ค์ • (run_ttrlvr_azr_training.sh์™€ ๋™์ผ)
os.environ['RAY_memory_monitor_refresh_ms'] = '0'
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG'
# Hydra ์„ค์ •
os.environ['HYDRA_FULL_ERROR'] = '1'
# Python ๊ฒฝ๋กœ ์„ค์ • (verl ๊ฒฝ๋กœ ์ถ”๊ฐ€) - ์ƒ๋Œ€ ๊ฒฝ๋กœ ์‚ฌ์šฉ
pythonpath = os.environ.get('PYTHONPATH', '')
project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 directory
# ํ”„๋กœ์ ํŠธ ๊ฒฝ๋กœ๋“ค ์„ค์ •
paths_to_add = [str(project_root)]
parent_dir = project_root.parent
# verl๊ณผ Absolute-Zero-Reasoner ๊ฒฝ๋กœ ์ถ”๊ฐ€ (์กด์žฌํ•˜๋Š” ๊ฒฝ์šฐ)
if (parent_dir / 'verl').exists():
paths_to_add.append(str(parent_dir / 'verl'))
if (parent_dir / 'Absolute-Zero-Reasoner').exists():
paths_to_add.append(str(parent_dir / 'Absolute-Zero-Reasoner'))
# PYTHONPATH ์—…๋ฐ์ดํŠธ
for path in paths_to_add:
if path not in pythonpath:
pythonpath = f"{path}:{pythonpath}" if pythonpath else path
os.environ['PYTHONPATH'] = pythonpath
# batch size ์„ค์ •
if batch_size is not None:
os.environ['TRAIN_BATCH_SIZE'] = str(batch_size)
# ์ถ”๊ฐ€ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • (์œ„์—์„œ ์„ค์ •ํ•˜์ง€ ์•Š์€ ๊ฒƒ๋“ค๋งŒ)
# ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ํ™ˆ ๋””๋ ‰ํ† ๋ฆฌ ์‚ฌ์šฉ, ํ™˜๊ฒฝ๋ณ€์ˆ˜๋กœ ์˜ค๋ฒ„๋ผ์ด๋“œ ๊ฐ€๋Šฅ
os.environ.setdefault('HF_HOME', os.path.expanduser('~/.cache/huggingface'))
os.environ.setdefault('TRANSFORMERS_CACHE', os.path.expanduser('~/.cache/huggingface'))
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# PYTHONPATH ์„ค์ • - ์ƒ๋Œ€ ๊ฒฝ๋กœ ์‚ฌ์šฉ
current_pythonpath = os.environ.get('PYTHONPATH', '')
project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 directory
new_paths = [
str(project_root)
# site-packages๋Š” ์ž๋™์œผ๋กœ ํฌํ•จ๋˜๋ฏ€๋กœ ์ œ๊ฑฐ
]
for path in new_paths:
if path not in current_pythonpath:
current_pythonpath = f"{path}:{current_pythonpath}" if current_pythonpath else path
os.environ['PYTHONPATH'] = current_pythonpath
def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]:
"""๋ฒค์น˜๋งˆํฌ์—์„œ ๋ฌธ์ œ ID ๋ชฉ๋ก ๋กœ๋“œ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
problems = []
if benchmark_config.name == 'mbpp':
# MBPP+ EvalPlus ํ‘œ์ค€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ
try:
from evalplus.data.mbpp import get_mbpp_plus
mbpp_problems = get_mbpp_plus() # ์ž๋™์œผ๋กœ mbpp_deserialize_inputs ์ ์šฉ๋จ
problems = list(mbpp_problems.keys())
print(f"โœ… MBPP+ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹)")
except Exception as e:
print(f"โŒ MBPP+ EvalPlus ๋กœ๋”ฉ ์‹คํŒจ, ๊ธฐ์กด ๋ฐฉ์‹ ์‚ฌ์šฉ: {e}")
# Fallback to original method
data_path = benchmark_config.data_path
if os.path.exists(data_path):
with open(data_path, 'r') as f:
for line in f:
problem = json.loads(line.strip())
problems.append(problem['task_id'])
elif benchmark_config.name == 'humaneval':
# HumanEval+ EvalPlus ํ‘œ์ค€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ
try:
from evalplus.data.humaneval import get_human_eval_plus
humaneval_problems = get_human_eval_plus()
problems = list(humaneval_problems.keys())
print(f"โœ… HumanEval+ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹)")
except Exception as e:
print(f"โŒ HumanEval+ EvalPlus ๋กœ๋”ฉ ์‹คํŒจ, ๊ธฐ์กด ๋ฐฉ์‹ ์‚ฌ์šฉ: {e}")
# Fallback to original method
data_path = benchmark_config.data_path
if os.path.exists(data_path):
with open(data_path, 'r') as f:
for line in f:
problem = json.loads(line.strip())
problems.append(problem['task_id'])
return problems
def create_problem_list(benchmark: str, num_problems: int, specific_problem_id: str = None) -> list:
"""๋ฒค์น˜๋งˆํฌ๋ณ„ ๋ฌธ์ œ ID ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
# BenchmarkConfig ์ƒ์„ฑ
benchmark_config = create_benchmark_config(benchmark)
# ์ „์ฒด ๋ฌธ์ œ ๋ชฉ๋ก ๋กœ๋“œ
all_problems = load_benchmark_problems(benchmark_config)
if not all_problems:
raise ValueError(f"No problems found for benchmark: {benchmark}")
# ํŠน์ • ๋ฌธ์ œ ID๊ฐ€ ์ง€์ •๋œ ๊ฒฝ์šฐ
if specific_problem_id:
if specific_problem_id in all_problems:
return [specific_problem_id]
else:
raise ValueError(f"Problem ID '{specific_problem_id}' not found in {benchmark} benchmark")
# ์š”์ฒญ๋œ ์ˆ˜๋งŒํผ ๋ฌธ์ œ ์„ ํƒ
if num_problems <= 0 or num_problems > len(all_problems):
return all_problems
else:
return all_problems[:num_problems]
def create_config(args) -> TestTimeConfig:
"""TestTimeConfig ์ƒ์„ฑ"""
config = TestTimeConfig()
# ๊ธฐ๋ณธ ์„ค์ •
config.model_name = args.model # ์ธ์ž๋กœ ๋ฐ›์€ ๋ชจ๋ธ ์‚ฌ์šฉ
config.max_new_tokens = 512
config.temperature = 0.05
config.baseline_evaluation_rounds = args.eval_rounds # ํ‰๊ฐ€ ํšŸ์ˆ˜
# ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ์„ค์ •
config.num_program_variations = args.num_programs # ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ๊ฐœ์ˆ˜
config.input_generation_rounds = args.input_generation_rounds # ์ž…๋ ฅ ์ƒ์„ฑ ๋ผ์šด๋“œ ์ˆ˜
config.parallel_batch_size = args.parallel_batch_size # ๋™์‹œ ์ฒ˜๋ฆฌ ํ”„๋กฌํ”„ํŠธ ์ˆ˜
# Task evaluation ์Šคํ‚ต ์„ค์ •
config.skip_task_evaluation = args.skip_task_eval # Task evaluation ์Šคํ‚ต ์—ฌ๋ถ€
# ๋””๋ฒ„๊ทธ ๋ชจ๋“œ
if args.debug:
config.debug = True
config.verbose = True
return config
def create_benchmark_config(benchmark: str) -> BenchmarkConfig:
"""BenchmarkConfig ์ƒ์„ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
# ๊ธฐ์กด TTRLVR ์‹œ์Šคํ…œ๊ณผ ๋™์ผํ•œ ๋ฐฉ์‹์œผ๋กœ BenchmarkConfig ์ƒ์„ฑ
# TestTime-RLVR-v2 ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ base๋กœ ์‚ฌ์šฉ
base_dir = Path(__file__).parent.parent # TestTime-RLVR-v2 directory
if benchmark == 'mbpp':
benchmark_config = BenchmarkConfig.get_mbpp_config()
benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/MbppPlus.jsonl')
return benchmark_config
elif benchmark == 'humaneval':
benchmark_config = BenchmarkConfig.get_humaneval_config()
benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/HumanEvalPlus.jsonl')
return benchmark_config
else:
raise ValueError(f"Unknown benchmark: {benchmark}")
def run_step5_only_mode(args):
"""Step 5 ์ „์šฉ ๋ชจ๋“œ ์‹คํ–‰"""
from pathlib import Path
print(f"๐ŸŽ“ Running Step 5 (VeRL training) only mode")
print(f"๐Ÿ“‚ Data path: {args.data_path}")
# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ๊ฒ€์ฆ
data_path = Path(args.data_path)
if not data_path.exists():
print(f"โŒ Error: Data path does not exist: {data_path}")
return 1
# ํ•„์ˆ˜ ํŒŒ์ผ๋“ค ํ™•์ธ
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
missing_files = []
for file_name in required_files:
if not (data_path / file_name).exists():
missing_files.append(file_name)
if missing_files:
print(f"โŒ Error: Missing required files: {missing_files}")
return 1
print(f"โœ… Found all required training data files in: {data_path}")
# ํŒŒ์ผ ํฌ๊ธฐ ์ •๋ณด ์ถœ๋ ฅ
for file_name in required_files:
file_path = data_path / file_name
file_size = file_path.stat().st_size
print(f" ๐Ÿ“„ {file_name}: {file_size:,} bytes")
# ํ™˜๊ฒฝ ์„ค์ •
setup_environment(args.gpu, args.batch_size)
# ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ ๊ฒฐ์ •
config_path = args.config
if not config_path:
# GPU ๊ฐœ์ˆ˜์— ๋”ฐ๋ผ ๊ธฐ๋ณธ ์„ค์ • ํŒŒ์ผ ์„ ํƒ
gpu_count = len(args.gpu.split(',')) if args.gpu else 1
if gpu_count >= 4:
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_4gpu.yaml')
else:
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_1gpu.yaml')
print(f"๐Ÿš€ Initializing trainer with config: {config_path}")
# TestTimeConfig ์ƒ์„ฑ (๊ธฐ์กด create_config ํ•จ์ˆ˜ ์‚ฌ์šฉ)
config = create_config(args)
# ๋กœ๊ฑฐ ์ดˆ๊ธฐํ™”
logger = TestTimeLogger()
# IterativeTrainer ์ดˆ๊ธฐํ™”
global _trainer_instance
_trainer_instance = IterativeTrainer(
config=config,
logger=logger,
verl_config_path=config_path
)
# Step 5 ์ „์šฉ VeRL ํ•™์Šต ์‹คํ–‰
try:
result = _trainer_instance.run_verl_training_only(
training_data_path=str(data_path),
round_num=args.resume, # resume์„ round number๋กœ ์‚ฌ์šฉ
experiment_name=f"step5_only_{args.benchmark}"
)
if result.get('success', False):
print(f"โœ… VeRL training completed successfully!")
print(f"โฑ๏ธ Duration: {result.get('duration', 'N/A')} seconds")
if 'model_path' in result:
print(f"๐Ÿค– Updated model: {result['model_path']}")
return 0
else:
print(f"โŒ VeRL training failed: {result.get('error', 'Unknown error')}")
return 1
except Exception as e:
print(f"๐Ÿ’ฅ Training failed with exception: {e}")
import traceback
traceback.print_exc()
return 1
def main():
"""๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜ - UnifiedTTRLVRTrainer ์‚ฌ์šฉ"""
# ์ธ์ž ํŒŒ์‹ฑ
args = parse_arguments()
# ํ™˜๊ฒฝ ์„ค์ •
setup_environment(args.gpu)
# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(
args.output_dir,
f'ttrlvr_unified_{args.benchmark}_{args.rounds}rounds_{timestamp}'
)
os.makedirs(output_dir, exist_ok=True)
PrettyPrinter.section_header("๐Ÿš€ TTRLVR Unified Training")
PrettyPrinter.status("Config", f"Benchmark: {args.benchmark}", "info")
PrettyPrinter.status("Config", f"Rounds: {args.rounds}", "info")
PrettyPrinter.status("Config", f"Output: {output_dir}", "info")
# ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
PrettyPrinter.status("Problems", f"Selected {len(problem_ids)} problems", "info")
# TTRLVR ์„ค์ •
ttrlvr_config = {
'num_programs': args.num_programs,
'input_generation_rounds': args.input_generation_rounds,
'parallel_batch_size': args.parallel_batch_size,
}
# VeRL config ํŒŒ์ผ ๊ฒฝ๋กœ
if args.config:
config_path = os.path.abspath(args.config)
else:
# ํ˜„์žฌ๋Š” 4GPU config๋งŒ ์‚ฌ์šฉ (์ถ”ํ›„ 1GPU config ์ถ”๊ฐ€ ์‹œ ์ˆ˜์ •)
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_unified_4gpu.yaml')
PrettyPrinter.status("Config", f"Using VeRL config: {config_path}", "info")
try:
# ============================================
# VeRL์„ ํ†ตํ•ด UnifiedTTRLVRTrainer ์‹คํ–‰
# ============================================
# VeRL ์‹คํ–‰์„ ์œ„ํ•œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids)
os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds)
os.environ['TTRLVR_OUTPUT_DIR'] = output_dir
os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config)
# ============================================
# AZR ํ˜•์‹์œผ๋กœ ์ดˆ๊ธฐํ™”, TTRLVR ๋ฐฉ์‹์œผ๋กœ ์‹คํ–‰
# (main_azr_ppo.py์˜ ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ฅด๋˜ UnifiedTTRLVRTrainer ์‚ฌ์šฉ)
# ============================================
PrettyPrinter.section_header("๐ŸŽฏ Starting UnifiedTTRLVRTrainer (AZR-style initialization)")
# 1. Config ๋กœ๋“œ (main_azr_ppo.py์™€ ๋™์ผ)
PrettyPrinter.status("Config", f"Loading {config_path}", "info")
verl_config = OmegaConf.load(config_path)
# Config ์—…๋ฐ์ดํŠธ
verl_config.trainer.project_name = f'ttrlvr_unified_{args.benchmark}'
verl_config.trainer.experiment_name = f'round_{args.rounds}_{timestamp}'
verl_config.trainer.total_epochs = args.rounds
# 2. Ray ์ดˆ๊ธฐํ™” (main_azr_ppo.py์™€ ๋™์ผ)
if not ray.is_initialized():
cuda_visible_devices = args.gpu or "0,1,2,3"
PrettyPrinter.status("Ray", f"Initializing Ray cluster (GPUs: {cuda_visible_devices})", "info")
ray.init(
runtime_env={"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "WARN",
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true",
"CUDA_VISIBLE_DEVICES": cuda_visible_devices
}},
num_cpus=verl_config.ray_init.num_cpus,
# num_gpus ์ง€์ •ํ•˜์ง€ ์•Š์Œ - Ray๊ฐ€ ์ž๋™์œผ๋กœ GPU ๊ฐ์ง€ (AZR ์›๋ณธ๊ณผ ๋™์ผ)
)
# 3. Tokenizer ๋กœ๋“œ (main_azr_ppo.py์™€ ๋™์ผ)
model_path = verl_config.actor_rollout_ref.model.path
PrettyPrinter.status("Model", f"Loading tokenizer from {model_path}", "info")
tokenizer = hf_tokenizer(model_path)
# 4. Worker ๋งคํ•‘ ์„ค์ • (main_azr_ppo.py์™€ ๋™์ผ)
role_worker_mapping = {}
# Actor/Rollout Worker ์„ ํƒ
if verl_config.actor_rollout_ref.rollout.name == 'vllm':
if verl_config.actor_rollout_ref.rollout.mode == 'async':
actor_rollout_cls = AsyncActorRolloutRefWorker
else:
actor_rollout_cls = ActorRolloutRefWorker
# AZR ์›๋ณธ๊ณผ ๋™์ผํ•˜๊ฒŒ ray.remote() ์‚ฌ์šฉ
role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
PrettyPrinter.status("Workers", f"Using {actor_rollout_cls.__name__} for ActorRollout", "info")
# Critic Worker (REINFORCE++๋Š” ์‚ฌ์šฉ ์•ˆํ•จ)
if verl_config.critic.include_critic:
# AZR ์›๋ณธ๊ณผ ๋™์ผํ•˜๊ฒŒ ray.remote() ์‚ฌ์šฉ
role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
PrettyPrinter.status("Workers", "Including Critic worker", "info")
else:
PrettyPrinter.status("Workers", "No Critic (using REINFORCE++)", "info")
# 5. ResourcePoolManager ์ƒ์„ฑ (main_azr_ppo.py์™€ ๋™์ผ)
# AZR ์Šคํƒ€์ผ๋กœ resource_pool_spec ์ง์ ‘ ์ƒ์„ฑ
global_pool_id = "global_pool"
n_gpus_per_node = verl_config.trainer.n_gpus_per_node
nnodes = verl_config.trainer.nnodes
resource_pool_spec = {
global_pool_id: [n_gpus_per_node] * nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
if verl_config.critic.include_critic:
mapping[Role.Critic] = global_pool_id
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
PrettyPrinter.status("Resources", f"Created ResourcePoolManager with {len(resource_pool_spec)} pools", "info")
# 6. UnifiedTTRLVRTrainer ์ƒ์„ฑ (CodeIORayPPOTrainer ๋Œ€์‹ )
from trainer.unified_ttrlvr_trainer import UnifiedTTRLVRTrainer
PrettyPrinter.status("Trainer", "Creating UnifiedTTRLVRTrainer", "info")
trainer = UnifiedTTRLVRTrainer(
past_epoch_window=verl_config.azr.past_epoch_window, # AZR ํ•„์ˆ˜ ํŒŒ๋ผ๋ฏธํ„ฐ (TTRLVR์€ ๋งค ๋ผ์šด๋“œ ์ƒˆ ๋ฐ์ดํ„ฐ)
config=verl_config,
tokenizer=tokenizer,
processor=None, # TTRLVR์€ ํ…์ŠคํŠธ ์ „์šฉ์ด๋ฏ€๋กœ ๋ถˆํ•„์š”
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=RayWorkerGroup,
reward_fn=None, # TTRLVR์€ ์ž์ฒด ๋ณด์ƒ ๊ณ„์‚ฐ ์‚ฌ์šฉ (use_ttrlvr_rewards=True)
val_reward_fn=None, # TTRLVR์€ ๊ฒ€์ฆ ์—†์Œ
# TTRLVR ํŠนํ™” ํŒŒ๋ผ๋ฏธํ„ฐ
ttrlvr_config=ttrlvr_config,
problem_ids=problem_ids,
total_rounds=args.rounds,
output_dir=output_dir
)
# 7. ํ•™์Šต ์‹คํ–‰ (main_azr_ppo.py์™€ ๋™์ผ)
PrettyPrinter.section_header("๐Ÿš€ Starting Training")
PrettyPrinter.status("Training", f"Running {args.rounds} rounds with {len(problem_ids)} problems", "info")
trainer.fit() # ๋‚ด๋ถ€์—์„œ TTRLVR Phase 1-5 ์‹คํ–‰
PrettyPrinter.section_header("โœ… Training Complete")
return 0
except KeyboardInterrupt:
PrettyPrinter.status("Interrupt", "Training interrupted by user", "warning")
return 130
except Exception as e:
PrettyPrinter.status("Error", f"Training failed: {e}", "error")
import traceback
traceback.print_exc()
return 1
finally:
# Ray cleanup
if ray.is_initialized():
ray.shutdown()
PrettyPrinter.status("Cleanup", "Resources cleaned up", "success")
if __name__ == '__main__':
exit_code = main()
sys.exit(exit_code)