neural-mesh / test /train_ttrlvr_azr.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
#!/usr/bin/env python3
"""
TTRLVR + AZR ํ†ตํ•ฉ ํ•™์Šต ๋ฉ”์ธ ์Šคํฌ๋ฆฝํŠธ
30๋ผ์šด๋“œ ๋ฐ˜๋ณต ํ•™์Šต์„ ํ†ตํ•ด TTRLVR ํŒŒ์ดํ”„๋ผ์ธ๊ณผ AZR ํ•™์Šต์„ ํ†ตํ•ฉ:
1. ๊ฐ ๋ผ์šด๋“œ๋งˆ๋‹ค ํ˜„์žฌ ๋ชจ๋ธ๋กœ (i,p,o) โ†’ induction/deduction/abduction tasks ์ƒ์„ฑ
2. ํ•ด๋‹น ๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ๋กœ๋งŒ AZR ํ•™์Šต (๊ฐ ๋ผ์šด๋“œ์˜ task๋งŒ ์‚ฌ์šฉ)
3. ๊ฐœ์„ ๋œ ๋ชจ๋ธ๋กœ ๋‹ค์Œ ๋ผ์šด๋“œ ์ง„ํ–‰
4. 5๋ผ์šด๋“œ๋งˆ๋‹ค ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
์‚ฌ์šฉ ์˜ˆ์‹œ:
# ์ผ๋ฐ˜ ํ•™์Šต
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10 --resume 15
python train_ttrlvr_azr.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 --num-programs 8 --eval-rounds 10
python train_ttrlvr_azr.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 --skip-task-eval
# Step 5 ์ „์šฉ ๋ชจ๋“œ (๊ธฐ์กด ๋ฐ์ดํ„ฐ๋กœ VeRL ํ•™์Šต๋งŒ)
python train_ttrlvr_azr.py --step5-only --data-path /path/to/azr_training_data --gpu 1,2,3,0 --config configs/ttrlvr_azr_ppo_4gpu.yaml
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import List
import warnings
import signal
import atexit
# Gradient checkpointing ๊ด€๋ จ ๊ฒฝ๊ณ  ํ•„ํ„ฐ๋ง
warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*")
# ๊ฒฝ๋กœ ์„ค์ •
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2')
# EvalPlus ๊ฒฝ๋กœ ์ถ”๊ฐ€ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹)
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2/evaluation/code_eval/coding')
# TTRLVR ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig
from absolute_zero_reasoner.testtime.logger import TestTimeLogger
from utils.iterative_trainer import IterativeTrainer
# Ray ์ •๋ฆฌ ๋ณ€์ˆ˜
_trainer_instance = None
_logger_instance = None
def cleanup_ray():
"""Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ํ•จ์ˆ˜"""
global _trainer_instance, _logger_instance
try:
if _logger_instance:
_logger_instance.log_info("๐Ÿ”„ ๊ฐ•์ œ ์ข…๋ฃŒ ๊ฐ์ง€: Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ์ค‘...")
except:
print("๐Ÿ”„ ๊ฐ•์ œ ์ข…๋ฃŒ ๊ฐ์ง€: Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ ์ค‘...")
try:
# IterativeTrainer ์ •๋ฆฌ
if _trainer_instance:
_trainer_instance.cleanup_ray()
except Exception as e:
try:
if _logger_instance:
_logger_instance.log_error(f"IterativeTrainer ์ •๋ฆฌ ์‹คํŒจ: {e}")
except:
print(f"IterativeTrainer ์ •๋ฆฌ ์‹คํŒจ: {e}")
try:
# ํ˜„์žฌ ํ”„๋กœ๊ทธ๋žจ์˜ Ray๋งŒ ์ข…๋ฃŒ (์•ˆ์ „ํ•œ ๋ฐฉ๋ฒ•)
import ray
if ray.is_initialized():
ray.shutdown()
except Exception as e:
try:
if _logger_instance:
_logger_instance.log_error(f"Ray ์ข…๋ฃŒ ์‹คํŒจ: {e}")
except:
print(f"Ray ์ข…๋ฃŒ ์‹คํŒจ: {e}")
try:
if _logger_instance:
_logger_instance.log_info("โœ… Ray ์ •๋ฆฌ ์™„๋ฃŒ")
except:
print("โœ… Ray ์ •๋ฆฌ ์™„๋ฃŒ")
def signal_handler(signum, frame):
"""์‹œ๊ทธ๋„ ํ•ธ๋“ค๋Ÿฌ (Ctrl+C, ๊ฐ•์ œ ์ข…๋ฃŒ ๋“ฑ)"""
try:
if _logger_instance:
_logger_instance.log_info(f"๐Ÿ›‘ ์‹œ๊ทธ๋„ {signum} ์ˆ˜์‹ : ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ ์ค‘...")
except:
print(f"๐Ÿ›‘ ์‹œ๊ทธ๋„ {signum} ์ˆ˜์‹ : ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ ์ค‘...")
cleanup_ray()
sys.exit(1)
def parse_arguments():
"""๋ช…๋ นํ–‰ ์ธ์ž ํŒŒ์‹ฑ"""
parser = argparse.ArgumentParser(
description='TTRLVR + AZR ํ†ตํ•ฉ ๋ฐ˜๋ณต ํ•™์Šต',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
์˜ˆ์‹œ:
# MBPP 10๋ฌธ์ œ๋กœ 30๋ผ์šด๋“œ ํ•™์Šต
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30
# HumanEval 5๋ฌธ์ œ๋กœ 10๋ผ์šด๋“œ ํ•™์Šต
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10
# 15๋ผ์šด๋“œ๋ถ€ํ„ฐ ์žฌ๊ฐœ
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --resume 15
# ํŠน์ • GPU ์‚ฌ์šฉ
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --gpu 4
"""
)
parser.add_argument(
'--benchmark',
choices=['mbpp', 'humaneval'],
default='mbpp',
help='๋ฒค์น˜๋งˆํฌ ์„ ํƒ (๊ธฐ๋ณธ๊ฐ’: mbpp)'
)
parser.add_argument(
'--problems',
type=int,
default=10,
help='๋ฌธ์ œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 10)'
)
parser.add_argument(
'--problem-id',
type=str,
help='ํŠน์ • ๋ฌธ์ œ ID (์˜ˆ: HumanEval/1, Mbpp/10)'
)
parser.add_argument(
'--rounds',
type=int,
default=30,
help='์ด ๋ผ์šด๋“œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 30)'
)
parser.add_argument(
'--resume',
type=int,
default=1,
help='์žฌ๊ฐœํ•  ๋ผ์šด๋“œ ๋ฒˆํ˜ธ (๊ธฐ๋ณธ๊ฐ’: 1)'
)
parser.add_argument(
'--gpu',
type=str,
default='5',
help='์‚ฌ์šฉํ•  GPU ๋ฒˆํ˜ธ (๋‹จ์ผ: 5, ๋‹ค์ค‘: 1,2,3,5)'
)
parser.add_argument(
'--output-dir',
type=str,
default='./results/ttrlvr_azr',
help='๊ฒฐ๊ณผ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: ./results/ttrlvr_azr)'
)
parser.add_argument(
'--config',
type=str,
help='์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (์„ ํƒ์‚ฌํ•ญ)'
)
parser.add_argument(
'--model',
type=str,
default='Qwen/Qwen2.5-7B',
help='์‚ฌ์šฉํ•  ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ’: Qwen/Qwen2.5-7B)'
)
parser.add_argument(
'--debug',
action='store_true',
help='๋””๋ฒ„๊ทธ ๋ชจ๋“œ ํ™œ์„ฑํ™”'
)
parser.add_argument(
'--batch-size',
type=int,
default=24,
help='ํ•™์Šต ๋ฐฐ์น˜ ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ’: 24, OOM ์‹œ ์ค„์ด๊ธฐ)'
)
parser.add_argument(
'--batch-epochs',
type=int,
default=1,
help='๋ฐฐ์น˜๋‹น ์—ํญ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 1, ๋” ๋งŽ์€ ํ•™์Šต์„ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--num-programs',
type=int,
default=4,
help='์ƒ์„ฑํ•  ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 4, ๋” ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--input-generation-rounds',
type=int,
default=3,
help='๋‹ค์–‘ํ•œ ์ž…๋ ฅ ์ƒ์„ฑ ๋ผ์šด๋“œ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 3, ๋ผ์šด๋“œ๋‹น 5๊ฐœ์”ฉ ์ƒ์„ฑ)'
)
parser.add_argument(
'--parallel-batch-size',
type=int,
default=4,
help='๋™์‹œ ์ฒ˜๋ฆฌํ•  ํ”„๋กฌํ”„ํŠธ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 4, GPU ๋ฉ”๋ชจ๋ฆฌ์— ๋”ฐ๋ผ ์กฐ์ •)'
)
parser.add_argument(
'--eval-rounds',
type=int,
default=5,
help='๋งค ๋ผ์šด๋“œ ์ •ํ™•๋„ ์ธก์ • ํšŸ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: 5, ๋” ์ •ํ™•ํ•œ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ์ฆ๊ฐ€ ๊ฐ€๋Šฅ)'
)
parser.add_argument(
'--skip-task-eval',
action='store_true',
help='Task evaluation(4๋‹จ๊ณ„) ์Šคํ‚ตํ•˜์—ฌ ๋น ๋ฅธ ํ…Œ์ŠคํŠธ (๋ฐ์ดํ„ฐ ์ƒ์„ฑ ํ›„ ๋ฐ”๋กœ VeRL ํ•™์Šต)'
)
parser.add_argument(
'--save-every-round',
action='store_true',
help='๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (๊ธฐ๋ณธ๊ฐ’: False)'
)
parser.add_argument(
'--save-round-interval',
type=int,
default=5,
help='์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๊ฐ„๊ฒฉ (์˜ˆ: 5 = 5๋ผ์šด๋“œ๋งˆ๋‹ค ์ €์žฅ, ๊ธฐ๋ณธ๊ฐ’: 5)'
)
parser.add_argument(
'--step5-only',
action='store_true',
help='๊ธฐ์กด ๋ฐ์ดํ„ฐ๋กœ Step 5 (VeRL ํ•™์Šต)๋งŒ ์‹คํ–‰'
)
parser.add_argument(
'--data-path',
type=str,
help='Step5 ์ „์šฉ ๋ชจ๋“œ์—์„œ ์‚ฌ์šฉํ•  ๊ธฐ์กด azr_training_data ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ'
)
return parser.parse_args()
def setup_environment(gpu_id: str, batch_size: int = None):
"""ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • - run_ttrlvr_azr_training.sh์™€ ๋™์ผํ•˜๊ฒŒ"""
# GPU ์„ค์ • - ๋ช…๋ นํ–‰ ์ธ์ž๋ฅผ ์šฐ์„  ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ์กด ํ™˜๊ฒฝ๋ณ€์ˆ˜ ์‚ฌ์šฉ
if gpu_id:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
print(f"๐ŸŽฏ Using command line GPU setting: {gpu_id}")
elif 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES']:
print(f"๐ŸŽฏ Using existing CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '5' # ๊ธฐ๋ณธ๊ฐ’
print(f"๐ŸŽฏ Using default GPU: 5")
# VLLM ์„ค์ • (run_ttrlvr_azr_training.sh์™€ ๋™์ผ)
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN'
# Ray ์„ค์ • (run_ttrlvr_azr_training.sh์™€ ๋™์ผ)
os.environ['RAY_memory_monitor_refresh_ms'] = '0'
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG'
# Hydra ์„ค์ •
os.environ['HYDRA_FULL_ERROR'] = '1'
# Python ๊ฒฝ๋กœ ์„ค์ • (verl ๊ฒฝ๋กœ ์ถ”๊ฐ€)
pythonpath = os.environ.get('PYTHONPATH', '')
if '/home/ubuntu/RLVR/verl' not in pythonpath:
os.environ['PYTHONPATH'] = f"{pythonpath}:/home/ubuntu/RLVR/verl:/home/ubuntu/RLVR/TestTime-RLVR-v2"
# batch size ์„ค์ •
if batch_size is not None:
os.environ['TRAIN_BATCH_SIZE'] = str(batch_size)
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN'
os.environ['RAY_memory_monitor_refresh_ms'] = '0'
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG'
os.environ['HYDRA_FULL_ERROR'] = '1'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/huggingface'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# PYTHONPATH ์„ค์ •
current_pythonpath = os.environ.get('PYTHONPATH', '')
new_paths = [
'/home/ubuntu/RLVR/TestTime-RLVR-v2',
'/data/miniforge3/envs/azr/lib/python3.10/site-packages'
]
for path in new_paths:
if path not in current_pythonpath:
current_pythonpath = f"{path}:{current_pythonpath}" if current_pythonpath else path
os.environ['PYTHONPATH'] = current_pythonpath
def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]:
"""๋ฒค์น˜๋งˆํฌ์—์„œ ๋ฌธ์ œ ID ๋ชฉ๋ก ๋กœ๋“œ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
problems = []
if benchmark_config.name == 'mbpp':
# MBPP+ EvalPlus ํ‘œ์ค€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ
try:
from evalplus.data.mbpp import get_mbpp_plus
mbpp_problems = get_mbpp_plus() # ์ž๋™์œผ๋กœ mbpp_deserialize_inputs ์ ์šฉ๋จ
problems = list(mbpp_problems.keys())
print(f"โœ… MBPP+ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹)")
except Exception as e:
print(f"โŒ MBPP+ EvalPlus ๋กœ๋”ฉ ์‹คํŒจ, ๊ธฐ์กด ๋ฐฉ์‹ ์‚ฌ์šฉ: {e}")
# Fallback to original method
data_path = benchmark_config.data_path
if os.path.exists(data_path):
with open(data_path, 'r') as f:
for line in f:
problem = json.loads(line.strip())
problems.append(problem['task_id'])
elif benchmark_config.name == 'humaneval':
# HumanEval+ EvalPlus ํ‘œ์ค€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ
try:
from evalplus.data.humaneval import get_human_eval_plus
humaneval_problems = get_human_eval_plus()
problems = list(humaneval_problems.keys())
print(f"โœ… HumanEval+ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹)")
except Exception as e:
print(f"โŒ HumanEval+ EvalPlus ๋กœ๋”ฉ ์‹คํŒจ, ๊ธฐ์กด ๋ฐฉ์‹ ์‚ฌ์šฉ: {e}")
# Fallback to original method
data_path = benchmark_config.data_path
if os.path.exists(data_path):
with open(data_path, 'r') as f:
for line in f:
problem = json.loads(line.strip())
problems.append(problem['task_id'])
return problems
def create_problem_list(benchmark: str, num_problems: int, specific_problem_id: str = None) -> list:
"""๋ฒค์น˜๋งˆํฌ๋ณ„ ๋ฌธ์ œ ID ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
# BenchmarkConfig ์ƒ์„ฑ
benchmark_config = create_benchmark_config(benchmark)
# ์ „์ฒด ๋ฌธ์ œ ๋ชฉ๋ก ๋กœ๋“œ
all_problems = load_benchmark_problems(benchmark_config)
if not all_problems:
raise ValueError(f"No problems found for benchmark: {benchmark}")
# ํŠน์ • ๋ฌธ์ œ ID๊ฐ€ ์ง€์ •๋œ ๊ฒฝ์šฐ
if specific_problem_id:
if specific_problem_id in all_problems:
return [specific_problem_id]
else:
raise ValueError(f"Problem ID '{specific_problem_id}' not found in {benchmark} benchmark")
# ์š”์ฒญ๋œ ์ˆ˜๋งŒํผ ๋ฌธ์ œ ์„ ํƒ
if num_problems <= 0 or num_problems > len(all_problems):
return all_problems
else:
return all_problems[:num_problems]
def create_config(args) -> TestTimeConfig:
"""TestTimeConfig ์ƒ์„ฑ"""
config = TestTimeConfig()
# ๊ธฐ๋ณธ ์„ค์ •
config.model_name = args.model # ์ธ์ž๋กœ ๋ฐ›์€ ๋ชจ๋ธ ์‚ฌ์šฉ
config.max_new_tokens = 512
config.temperature = 0.05
config.baseline_evaluation_rounds = args.eval_rounds # ํ‰๊ฐ€ ํšŸ์ˆ˜
# ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ์„ค์ •
config.num_program_variations = args.num_programs # ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ๊ฐœ์ˆ˜
config.input_generation_rounds = args.input_generation_rounds # ์ž…๋ ฅ ์ƒ์„ฑ ๋ผ์šด๋“œ ์ˆ˜
config.parallel_batch_size = args.parallel_batch_size # ๋™์‹œ ์ฒ˜๋ฆฌ ํ”„๋กฌํ”„ํŠธ ์ˆ˜
# Task evaluation ์Šคํ‚ต ์„ค์ •
config.skip_task_evaluation = args.skip_task_eval # Task evaluation ์Šคํ‚ต ์—ฌ๋ถ€
# ๋””๋ฒ„๊ทธ ๋ชจ๋“œ
if args.debug:
config.debug = True
config.verbose = True
return config
def create_benchmark_config(benchmark: str) -> BenchmarkConfig:
"""BenchmarkConfig ์ƒ์„ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
# ๊ธฐ์กด TTRLVR ์‹œ์Šคํ…œ๊ณผ ๋™์ผํ•œ ๋ฐฉ์‹์œผ๋กœ BenchmarkConfig ์ƒ์„ฑ
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if benchmark == 'mbpp':
benchmark_config = BenchmarkConfig.get_mbpp_config()
benchmark_config.data_path = os.path.join(base_dir, 'evaluation/code_eval/data/MbppPlus.jsonl')
return benchmark_config
elif benchmark == 'humaneval':
benchmark_config = BenchmarkConfig.get_humaneval_config()
benchmark_config.data_path = os.path.join(base_dir, 'evaluation/code_eval/data/HumanEvalPlus.jsonl')
return benchmark_config
else:
raise ValueError(f"Unknown benchmark: {benchmark}")
def save_run_config(args, problem_ids: list, output_dir: str):
"""์‹คํ–‰ ์„ค์ • ์ €์žฅ"""
config_data = {
'timestamp': datetime.now().isoformat(),
'benchmark': args.benchmark,
'num_problems': args.problems,
'total_rounds': args.rounds,
'resume_from': args.resume,
'gpu': args.gpu,
'problem_ids': problem_ids,
'output_dir': output_dir,
'command_line': ' '.join(sys.argv)
}
config_file = os.path.join(output_dir, 'run_config.json')
with open(config_file, 'w') as f:
json.dump(config_data, f, indent=2)
return config_file
def run_step5_only_mode(args):
"""Step 5 ์ „์šฉ ๋ชจ๋“œ ์‹คํ–‰"""
from pathlib import Path
print(f"๐ŸŽ“ Running Step 5 (VeRL training) only mode")
print(f"๐Ÿ“‚ Data path: {args.data_path}")
# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ๊ฒ€์ฆ
data_path = Path(args.data_path)
if not data_path.exists():
print(f"โŒ Error: Data path does not exist: {data_path}")
return 1
# ํ•„์ˆ˜ ํŒŒ์ผ๋“ค ํ™•์ธ
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
missing_files = []
for file_name in required_files:
if not (data_path / file_name).exists():
missing_files.append(file_name)
if missing_files:
print(f"โŒ Error: Missing required files: {missing_files}")
return 1
print(f"โœ… Found all required training data files in: {data_path}")
# ํŒŒ์ผ ํฌ๊ธฐ ์ •๋ณด ์ถœ๋ ฅ
for file_name in required_files:
file_path = data_path / file_name
file_size = file_path.stat().st_size
print(f" ๐Ÿ“„ {file_name}: {file_size:,} bytes")
# ํ™˜๊ฒฝ ์„ค์ •
setup_environment(args.gpu, args.batch_size)
# ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ ๊ฒฐ์ •
config_path = args.config
if not config_path:
# GPU ๊ฐœ์ˆ˜์— ๋”ฐ๋ผ ๊ธฐ๋ณธ ์„ค์ • ํŒŒ์ผ ์„ ํƒ
gpu_count = len(args.gpu.split(',')) if args.gpu else 1
if gpu_count >= 4:
config_path = '/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_4gpu.yaml'
else:
config_path = '/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_1gpu.yaml'
print(f"๐Ÿš€ Initializing trainer with config: {config_path}")
# TestTimeConfig ์ƒ์„ฑ (๊ธฐ์กด create_config ํ•จ์ˆ˜ ์‚ฌ์šฉ)
config = create_config(args)
# ๋กœ๊ฑฐ ์ดˆ๊ธฐํ™”
logger = TestTimeLogger()
# IterativeTrainer ์ดˆ๊ธฐํ™”
global _trainer_instance
_trainer_instance = IterativeTrainer(
config=config,
logger=logger,
verl_config_path=config_path
)
# Step 5 ์ „์šฉ VeRL ํ•™์Šต ์‹คํ–‰
try:
result = _trainer_instance.run_verl_training_only(
training_data_path=str(data_path),
round_num=args.resume, # resume์„ round number๋กœ ์‚ฌ์šฉ
experiment_name=f"step5_only_{args.benchmark}"
)
if result.get('success', False):
print(f"โœ… VeRL training completed successfully!")
print(f"โฑ๏ธ Duration: {result.get('duration', 'N/A')} seconds")
if 'model_path' in result:
print(f"๐Ÿค– Updated model: {result['model_path']}")
return 0
else:
print(f"โŒ VeRL training failed: {result.get('error', 'Unknown error')}")
return 1
except Exception as e:
print(f"๐Ÿ’ฅ Training failed with exception: {e}")
import traceback
traceback.print_exc()
return 1
def main():
"""๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜"""
global _trainer_instance, _logger_instance
# ์‹œ๊ทธ๋„ ํ•ธ๋“ค๋Ÿฌ ๋“ฑ๋ก
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # terminate
atexit.register(cleanup_ray) # ์ •์ƒ ์ข…๋ฃŒ ์‹œ์—๋„ ์ •๋ฆฌ
# ์ธ์ž ํŒŒ์‹ฑ
args = parse_arguments()
# Step 5 ์ „์šฉ ๋ชจ๋“œ ์ฒ˜๋ฆฌ
if args.step5_only:
if not args.data_path:
print("โŒ Error: --data-path is required when using --step5-only")
return 1
return run_step5_only_mode(args)
# ์ž„์‹œ ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (๋””๋ ‰ํ† ๋ฆฌ๋ช…์šฉ)
temp_problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
actual_problem_count = len(temp_problem_ids) if temp_problem_ids else args.problems
# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
if args.problem_id:
problem_desc = f"problem_{args.problem_id.replace('/', '_')}"
else:
problem_desc = f"{actual_problem_count}problems"
output_dir = os.path.join(
args.output_dir,
f'{args.benchmark}_{problem_desc}_{args.rounds}rounds_{timestamp}'
)
os.makedirs(output_dir, exist_ok=True)
# ํ™˜๊ฒฝ ์„ค์ •
setup_environment(args.gpu, args.batch_size)
# ๋กœ๊ฑฐ ์ดˆ๊ธฐํ™” (๋กœ๊ทธ ๋””๋ ‰ํ† ๋ฆฌ ์ง€์ •)
log_dir = "/home/ubuntu/RLVR/TestTime-RLVR-v2/logs"
os.makedirs(log_dir, exist_ok=True)
logger = TestTimeLogger(log_dir=log_dir)
_logger_instance = logger # ๊ธ€๋กœ๋ฒŒ ๋ณ€์ˆ˜์— ํ• ๋‹น
# ์‹ค์ œ ๋กœ๊ทธ ํŒŒ์ผ๋ช… ํ™•์ธ (Logger๊ฐ€ ์ƒ์„ฑํ•œ ํŒŒ์ผ๋ช…)
log_handlers = [h for h in logger.logger.handlers if hasattr(h, 'baseFilename')]
if log_handlers:
actual_log_file = log_handlers[0].baseFilename
print(f"๐Ÿ“ ๋กœ๊ทธ ํŒŒ์ผ: {actual_log_file}")
else:
print(f"๐Ÿ“ ๋กœ๊ทธ ๋””๋ ‰ํ† ๋ฆฌ: {log_dir}")
# ์„ค์ • ์ƒ์„ฑ
config = create_config(args)
benchmark_config = create_benchmark_config(args.benchmark)
# ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
# ์‹ค์ œ ์‚ฌ์šฉ๋˜๋Š” ๊ฐ’๋“ค๋กœ ๋กœ๊น…
actual_problem_count = len(problem_ids) if problem_ids else args.problems
actual_gpu = os.environ.get('CUDA_VISIBLE_DEVICES', args.gpu)
logger.log_info("๐Ÿš€ TTRLVR + AZR ํ†ตํ•ฉ ํ•™์Šต ์‹œ์ž‘")
logger.log_info("=" * 80)
logger.log_info(f"๐Ÿ“Š ์„ค์ •:")
logger.log_info(f" - ๋ฒค์น˜๋งˆํฌ: {args.benchmark}")
if args.problem_id:
logger.log_info(f" - ํŠน์ • ๋ฌธ์ œ ID: {args.problem_id}")
logger.log_info(f" - ๋ฌธ์ œ ์ˆ˜: {actual_problem_count} (ํŠน์ • ๋ฌธ์ œ)")
else:
logger.log_info(f" - ๋ฌธ์ œ ์ˆ˜: {actual_problem_count}")
logger.log_info(f" - ์ด ๋ผ์šด๋“œ: {args.rounds}")
logger.log_info(f" - GPU: {actual_gpu}")
logger.log_info(f" - ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ: {output_dir}")
if args.resume > 1:
logger.log_info(f" - ์žฌ๊ฐœ ๋ผ์šด๋“œ: {args.resume}")
logger.log_info("=" * 80)
logger.log_info(f"๐ŸŽฏ ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ: {problem_ids}")
try:
# ์‹คํ–‰ ์„ค์ • ์ €์žฅ
config_file = save_run_config(args, problem_ids, output_dir)
logger.log_info(f"๐Ÿ“„ ์‹คํ–‰ ์„ค์ • ์ €์žฅ: {config_file}")
# IterativeTrainer ์ดˆ๊ธฐํ™” (VeRL config ํŒŒ์ผ ๊ฒฝ๋กœ ์ „๋‹ฌ)
verl_config_path = None
if args.config:
verl_config_path = os.path.abspath(args.config)
trainer = IterativeTrainer(
config,
logger,
batch_epochs=args.batch_epochs,
verl_config_path=verl_config_path,
save_every_round=args.save_every_round,
save_round_interval=args.save_round_interval
)
_trainer_instance = trainer # ๊ธ€๋กœ๋ฒŒ ๋ณ€์ˆ˜์— ํ• ๋‹น
# ๋ฐ˜๋ณต ํ•™์Šต ์‹คํ–‰
logger.log_info("๐ŸŽ“ ๋ฐ˜๋ณต ํ•™์Šต ์‹œ์ž‘")
training_results = trainer.run_iterative_training(
benchmark_config=benchmark_config,
problem_ids=problem_ids,
total_rounds=args.rounds,
resume_from_round=args.resume
)
# ์ตœ์ข… ๊ฒฐ๊ณผ ์ €์žฅ (JSON ์ง๋ ฌํ™” ๊ฐ€๋Šฅํ•œ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜)
results_file = os.path.join(output_dir, 'training_results.json')
try:
# BenchmarkConfig ๊ฐ์ฒด๋ฅผ dict๋กœ ๋ณ€ํ™˜
serializable_results = json.loads(json.dumps(training_results, default=str))
with open(results_file, 'w') as f:
json.dump(serializable_results, f, indent=2)
except Exception as e:
logger.log_warning(f"๊ฒฐ๊ณผ ์ €์žฅ ์ค‘ ์˜ค๋ฅ˜ (๋ฌด์‹œ๋จ): {e}")
# ๊ธฐ๋ณธ ์ •๋ณด๋งŒ ์ €์žฅ
basic_results = {
'success': training_results.get('success', False),
'benchmark': args.benchmark,
'total_rounds': args.rounds,
'completed_rounds': len(training_results.get('rounds', {})),
'timestamp': training_results.get('end_time', 'unknown')
}
with open(results_file, 'w') as f:
json.dump(basic_results, f, indent=2)
logger.log_info(f"๐Ÿ’พ ์ตœ์ข… ๊ฒฐ๊ณผ ์ €์žฅ: {results_file}")
# Ray ํด๋Ÿฌ์Šคํ„ฐ ์ •๋ฆฌ
trainer.cleanup()
if training_results['success']:
logger.log_info("๐ŸŽ‰ TTRLVR + AZR ํ†ตํ•ฉ ํ•™์Šต ์„ฑ๊ณต์ ์œผ๋กœ ์™„๋ฃŒ!")
return 0
else:
logger.log_error(f"โŒ ํ•™์Šต ์‹คํŒจ: {training_results.get('error', 'Unknown error')}")
return 1
except KeyboardInterrupt:
logger.log_info("โš ๏ธ ์‚ฌ์šฉ์ž์— ์˜ํ•ด ์ค‘๋‹จ๋จ")
# ์ค‘๋‹จ ์‹œ์—๋„ cleanup
if 'trainer' in locals():
trainer.cleanup()
return 130
except Exception as e:
logger.log_error(f"๐Ÿ’ฅ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜: {e}")
# ์˜ค๋ฅ˜ ์‹œ์—๋„ cleanup
if 'trainer' in locals():
trainer.cleanup()
return 1
finally:
# finally ๋ธ”๋ก์—์„œ๋„ cleanup ๋ณด์žฅ
if 'trainer' in locals() and hasattr(trainer, 'cleanup'):
trainer.cleanup()
# ์ถ”๊ฐ€ Ray ํ”„๋กœ์„ธ์Šค ์ •๋ฆฌ
import subprocess
try:
# Ray stop ๋ช…๋ น์–ด ์‹คํ–‰
subprocess.run(['ray', 'stop', '--force'], capture_output=True, timeout=10)
except:
pass
if __name__ == '__main__':
exit_code = main()
sys.exit(exit_code)