|
|
|
""" |
|
TTRLVR + AZR ํตํฉ ํ์ต ๋ฉ์ธ ์คํฌ๋ฆฝํธ (Unified Version) |
|
|
|
UnifiedTTRLVRTrainer๋ฅผ ์ฌ์ฉํ์ฌ ํ๋์ VeRL ์ธ์
์์ ์ ์ฒด ํ์ต ์งํ: |
|
1. VeRL worker ํ ๋ฒ๋ง ์ด๊ธฐํ |
|
2. ๊ฐ ๋ผ์ด๋๋ง๋ค ๊ฐ์ vLLM์ผ๋ก Phase 1-4 ์คํ |
|
3. ๊ฐ์ vLLM์ผ๋ก Phase 5 PPO ํ์ต |
|
4. ๋๊ธฐํ ๋ฌธ์ ์์ ํด๊ฒฐ (dummy_dtensor ์ฌ์ฉ ๊ฐ๋ฅ) |
|
|
|
์ฌ์ฉ ์์: |
|
# ์ผ๋ฐ ํ์ต |
|
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 |
|
python train_ttrlvr_azr_unified.py --benchmark humaneval --problems 5 --rounds 10 |
|
python train_ttrlvr_azr_unified.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 |
|
|
|
# GPU ์ง์ |
|
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3 |
|
""" |
|
|
|
import os |
|
import sys |
|
import argparse |
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List |
|
import warnings |
|
import ray |
|
import torch |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*") |
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
sys.path.append(str(project_root)) |
|
|
|
|
|
parent_dir = project_root.parent |
|
for lib_name in ['verl', 'Absolute-Zero-Reasoner']: |
|
lib_path = parent_dir / lib_name |
|
if lib_path.exists(): |
|
sys.path.append(str(lib_path)) |
|
|
|
|
|
|
|
from verl import DataProto |
|
from omegaconf import OmegaConf |
|
import ray |
|
from verl.utils import hf_tokenizer |
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role |
|
from verl.single_controller.ray import RayWorkerGroup |
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker |
|
from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter |
|
|
|
|
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
|
|
|
|
|
|
_trainer_instance = None |
|
_logger_instance = None |
|
|
|
|
|
def cleanup_ray(): |
|
"""Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ํจ์""" |
|
global _trainer_instance, _logger_instance |
|
|
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info("๐ ๊ฐ์ ์ข
๋ฃ ๊ฐ์ง: Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ์ค...") |
|
except: |
|
print("๐ ๊ฐ์ ์ข
๋ฃ ๊ฐ์ง: Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ์ค...") |
|
|
|
try: |
|
|
|
if _trainer_instance: |
|
_trainer_instance.cleanup_ray() |
|
except Exception as e: |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_error(f"IterativeTrainer ์ ๋ฆฌ ์คํจ: {e}") |
|
except: |
|
print(f"IterativeTrainer ์ ๋ฆฌ ์คํจ: {e}") |
|
|
|
try: |
|
|
|
import ray |
|
if ray.is_initialized(): |
|
ray.shutdown() |
|
except Exception as e: |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_error(f"Ray ์ข
๋ฃ ์คํจ: {e}") |
|
except: |
|
print(f"Ray ์ข
๋ฃ ์คํจ: {e}") |
|
|
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info("โ
Ray ์ ๋ฆฌ ์๋ฃ") |
|
except: |
|
print("โ
Ray ์ ๋ฆฌ ์๋ฃ") |
|
|
|
|
|
def signal_handler(signum, frame): |
|
"""์๊ทธ๋ ํธ๋ค๋ฌ (Ctrl+C, ๊ฐ์ ์ข
๋ฃ ๋ฑ)""" |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info(f"๐ ์๊ทธ๋ {signum} ์์ : ํ๋ก๊ทธ๋จ ์ข
๋ฃ ์ค...") |
|
except: |
|
print(f"๐ ์๊ทธ๋ {signum} ์์ : ํ๋ก๊ทธ๋จ ์ข
๋ฃ ์ค...") |
|
|
|
cleanup_ray() |
|
sys.exit(1) |
|
|
|
|
|
def parse_arguments(): |
|
"""๋ช
๋ นํ ์ธ์ ํ์ฑ""" |
|
|
|
parser = argparse.ArgumentParser( |
|
description='TTRLVR + AZR ํตํฉ ๋ฐ๋ณต ํ์ต', |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=""" |
|
์์: |
|
# MBPP 10๋ฌธ์ ๋ก 30๋ผ์ด๋ ํ์ต |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 |
|
|
|
# HumanEval 5๋ฌธ์ ๋ก 10๋ผ์ด๋ ํ์ต |
|
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10 |
|
|
|
# 15๋ผ์ด๋๋ถํฐ ์ฌ๊ฐ |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --resume 15 |
|
|
|
# ํน์ GPU ์ฌ์ฉ |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --gpu 4 |
|
""" |
|
) |
|
|
|
parser.add_argument( |
|
'--benchmark', |
|
choices=['mbpp', 'humaneval'], |
|
default='mbpp', |
|
help='๋ฒค์น๋งํฌ ์ ํ (๊ธฐ๋ณธ๊ฐ: mbpp)' |
|
) |
|
|
|
parser.add_argument( |
|
'--problems', |
|
type=int, |
|
default=10, |
|
help='๋ฌธ์ ์ (๊ธฐ๋ณธ๊ฐ: 10)' |
|
) |
|
|
|
parser.add_argument( |
|
'--problem-id', |
|
type=str, |
|
help='ํน์ ๋ฌธ์ ID (์: HumanEval/1, Mbpp/10)' |
|
) |
|
|
|
parser.add_argument( |
|
'--rounds', |
|
type=int, |
|
default=30, |
|
help='์ด ๋ผ์ด๋ ์ (๊ธฐ๋ณธ๊ฐ: 30)' |
|
) |
|
|
|
parser.add_argument( |
|
'--resume', |
|
type=int, |
|
default=1, |
|
help='์ฌ๊ฐํ ๋ผ์ด๋ ๋ฒํธ (๊ธฐ๋ณธ๊ฐ: 1)' |
|
) |
|
|
|
parser.add_argument( |
|
'--gpu', |
|
type=str, |
|
default='5', |
|
help='์ฌ์ฉํ GPU ๋ฒํธ (๋จ์ผ: 5, ๋ค์ค: 1,2,3,5)' |
|
) |
|
|
|
parser.add_argument( |
|
'--output-dir', |
|
type=str, |
|
default='./results/ttrlvr_azr', |
|
help='๊ฒฐ๊ณผ ์ ์ฅ ๋๋ ํ ๋ฆฌ (๊ธฐ๋ณธ๊ฐ: ./results/ttrlvr_azr)' |
|
) |
|
|
|
parser.add_argument( |
|
'--config', |
|
type=str, |
|
help='์ค์ ํ์ผ ๊ฒฝ๋ก (์ ํ์ฌํญ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--model', |
|
type=str, |
|
default='Qwen/Qwen2.5-7B', |
|
help='์ฌ์ฉํ ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ: Qwen/Qwen2.5-7B)' |
|
) |
|
|
|
parser.add_argument( |
|
'--debug', |
|
action='store_true', |
|
help='๋๋ฒ๊ทธ ๋ชจ๋ ํ์ฑํ' |
|
) |
|
|
|
parser.add_argument( |
|
'--batch-size', |
|
type=int, |
|
default=24, |
|
help='ํ์ต ๋ฐฐ์น ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ: 24, OOM ์ ์ค์ด๊ธฐ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--batch-epochs', |
|
type=int, |
|
default=1, |
|
help='๋ฐฐ์น๋น ์ํญ ์ (๊ธฐ๋ณธ๊ฐ: 1, ๋ ๋ง์ ํ์ต์ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--num-programs', |
|
type=int, |
|
default=4, |
|
help='์์ฑํ ๋ค์ํ ํ๋ก๊ทธ๋จ ์ (๊ธฐ๋ณธ๊ฐ: 4, ๋ ๋ค์ํ ๋ฐ์ดํฐ๋ฅผ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--input-generation-rounds', |
|
type=int, |
|
default=3, |
|
help='๋ค์ํ ์
๋ ฅ ์์ฑ ๋ผ์ด๋ ์ (๊ธฐ๋ณธ๊ฐ: 3, ๋ผ์ด๋๋น 5๊ฐ์ฉ ์์ฑ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--parallel-batch-size', |
|
type=int, |
|
default=4, |
|
help='๋์ ์ฒ๋ฆฌํ ํ๋กฌํํธ ์ (๊ธฐ๋ณธ๊ฐ: 4, GPU ๋ฉ๋ชจ๋ฆฌ์ ๋ฐ๋ผ ์กฐ์ )' |
|
) |
|
|
|
parser.add_argument( |
|
'--eval-rounds', |
|
type=int, |
|
default=5, |
|
help='๋งค ๋ผ์ด๋ ์ ํ๋ ์ธก์ ํ์ (๊ธฐ๋ณธ๊ฐ: 5, ๋ ์ ํํ ํ๊ฐ๋ฅผ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--skip-task-eval', |
|
action='store_true', |
|
help='Task evaluation(4๋จ๊ณ) ์คํตํ์ฌ ๋น ๋ฅธ ํ
์คํธ (๋ฐ์ดํฐ ์์ฑ ํ ๋ฐ๋ก VeRL ํ์ต)' |
|
) |
|
|
|
parser.add_argument( |
|
'--save-every-round', |
|
action='store_true', |
|
help='๋งค ๋ผ์ด๋๋ง๋ค ์ฒดํฌํฌ์ธํธ ์ ์ฅ (๊ธฐ๋ณธ๊ฐ: False)' |
|
) |
|
|
|
parser.add_argument( |
|
'--save-round-interval', |
|
type=int, |
|
default=5, |
|
help='์ฒดํฌํฌ์ธํธ ์ ์ฅ ๊ฐ๊ฒฉ (์: 5 = 5๋ผ์ด๋๋ง๋ค ์ ์ฅ, ๊ธฐ๋ณธ๊ฐ: 5)' |
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def setup_environment(gpu_id: str, batch_size: int = None): |
|
"""ํ๊ฒฝ ๋ณ์ ์ค์ - run_ttrlvr_azr_training.sh์ ๋์ผํ๊ฒ""" |
|
|
|
|
|
if gpu_id: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id |
|
print(f"๐ฏ Using command line GPU setting: {gpu_id}") |
|
elif 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES']: |
|
print(f"๐ฏ Using existing CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") |
|
else: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '5' |
|
print(f"๐ฏ Using default GPU: 5") |
|
|
|
|
|
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' |
|
|
|
|
|
os.environ['RAY_memory_monitor_refresh_ms'] = '0' |
|
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG' |
|
|
|
|
|
os.environ['HYDRA_FULL_ERROR'] = '1' |
|
|
|
|
|
pythonpath = os.environ.get('PYTHONPATH', '') |
|
project_root = Path(__file__).parent.parent |
|
|
|
|
|
paths_to_add = [str(project_root)] |
|
parent_dir = project_root.parent |
|
|
|
|
|
if (parent_dir / 'verl').exists(): |
|
paths_to_add.append(str(parent_dir / 'verl')) |
|
if (parent_dir / 'Absolute-Zero-Reasoner').exists(): |
|
paths_to_add.append(str(parent_dir / 'Absolute-Zero-Reasoner')) |
|
|
|
|
|
for path in paths_to_add: |
|
if path not in pythonpath: |
|
pythonpath = f"{path}:{pythonpath}" if pythonpath else path |
|
|
|
os.environ['PYTHONPATH'] = pythonpath |
|
|
|
|
|
if batch_size is not None: |
|
os.environ['TRAIN_BATCH_SIZE'] = str(batch_size) |
|
|
|
|
|
|
|
os.environ.setdefault('HF_HOME', os.path.expanduser('~/.cache/huggingface')) |
|
os.environ.setdefault('TRANSFORMERS_CACHE', os.path.expanduser('~/.cache/huggingface')) |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
|
|
|
current_pythonpath = os.environ.get('PYTHONPATH', '') |
|
project_root = Path(__file__).parent.parent |
|
new_paths = [ |
|
str(project_root) |
|
|
|
] |
|
|
|
for path in new_paths: |
|
if path not in current_pythonpath: |
|
current_pythonpath = f"{path}:{current_pythonpath}" if current_pythonpath else path |
|
|
|
os.environ['PYTHONPATH'] = current_pythonpath |
|
|
|
|
|
def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]: |
|
"""๋ฒค์น๋งํฌ์์ ๋ฌธ์ ID ๋ชฉ๋ก ๋ก๋ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
problems = [] |
|
|
|
if benchmark_config.name == 'mbpp': |
|
|
|
try: |
|
from evalplus.data.mbpp import get_mbpp_plus |
|
mbpp_problems = get_mbpp_plus() |
|
problems = list(mbpp_problems.keys()) |
|
print(f"โ
MBPP+ ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ (EvalPlus ํ์ค ๋ฐฉ์)") |
|
except Exception as e: |
|
print(f"โ MBPP+ EvalPlus ๋ก๋ฉ ์คํจ, ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ: {e}") |
|
|
|
data_path = benchmark_config.data_path |
|
if os.path.exists(data_path): |
|
with open(data_path, 'r') as f: |
|
for line in f: |
|
problem = json.loads(line.strip()) |
|
problems.append(problem['task_id']) |
|
|
|
elif benchmark_config.name == 'humaneval': |
|
|
|
try: |
|
from evalplus.data.humaneval import get_human_eval_plus |
|
humaneval_problems = get_human_eval_plus() |
|
problems = list(humaneval_problems.keys()) |
|
print(f"โ
HumanEval+ ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ (EvalPlus ํ์ค ๋ฐฉ์)") |
|
except Exception as e: |
|
print(f"โ HumanEval+ EvalPlus ๋ก๋ฉ ์คํจ, ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ: {e}") |
|
|
|
data_path = benchmark_config.data_path |
|
if os.path.exists(data_path): |
|
with open(data_path, 'r') as f: |
|
for line in f: |
|
problem = json.loads(line.strip()) |
|
problems.append(problem['task_id']) |
|
|
|
return problems |
|
|
|
|
|
def create_problem_list(benchmark: str, num_problems: int, specific_problem_id: str = None) -> list: |
|
"""๋ฒค์น๋งํฌ๋ณ ๋ฌธ์ ID ๋ฆฌ์คํธ ์์ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
|
|
benchmark_config = create_benchmark_config(benchmark) |
|
|
|
|
|
all_problems = load_benchmark_problems(benchmark_config) |
|
|
|
if not all_problems: |
|
raise ValueError(f"No problems found for benchmark: {benchmark}") |
|
|
|
|
|
if specific_problem_id: |
|
if specific_problem_id in all_problems: |
|
return [specific_problem_id] |
|
else: |
|
raise ValueError(f"Problem ID '{specific_problem_id}' not found in {benchmark} benchmark") |
|
|
|
|
|
if num_problems <= 0 or num_problems > len(all_problems): |
|
return all_problems |
|
else: |
|
return all_problems[:num_problems] |
|
|
|
|
|
def create_config(args) -> TestTimeConfig: |
|
"""TestTimeConfig ์์ฑ""" |
|
|
|
config = TestTimeConfig() |
|
|
|
|
|
config.model_name = args.model |
|
config.max_new_tokens = 512 |
|
config.temperature = 0.05 |
|
config.baseline_evaluation_rounds = args.eval_rounds |
|
|
|
|
|
config.num_program_variations = args.num_programs |
|
config.input_generation_rounds = args.input_generation_rounds |
|
config.parallel_batch_size = args.parallel_batch_size |
|
|
|
|
|
config.skip_task_evaluation = args.skip_task_eval |
|
|
|
|
|
if args.debug: |
|
config.debug = True |
|
config.verbose = True |
|
|
|
return config |
|
|
|
|
|
def create_benchmark_config(benchmark: str) -> BenchmarkConfig: |
|
"""BenchmarkConfig ์์ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
|
|
|
|
base_dir = Path(__file__).parent.parent |
|
|
|
if benchmark == 'mbpp': |
|
benchmark_config = BenchmarkConfig.get_mbpp_config() |
|
benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/MbppPlus.jsonl') |
|
return benchmark_config |
|
elif benchmark == 'humaneval': |
|
benchmark_config = BenchmarkConfig.get_humaneval_config() |
|
benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/HumanEvalPlus.jsonl') |
|
return benchmark_config |
|
else: |
|
raise ValueError(f"Unknown benchmark: {benchmark}") |
|
|
|
|
|
|
|
|
|
def run_step5_only_mode(args): |
|
"""Step 5 ์ ์ฉ ๋ชจ๋ ์คํ""" |
|
from pathlib import Path |
|
|
|
print(f"๐ Running Step 5 (VeRL training) only mode") |
|
print(f"๐ Data path: {args.data_path}") |
|
|
|
|
|
data_path = Path(args.data_path) |
|
if not data_path.exists(): |
|
print(f"โ Error: Data path does not exist: {data_path}") |
|
return 1 |
|
|
|
|
|
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet'] |
|
missing_files = [] |
|
for file_name in required_files: |
|
if not (data_path / file_name).exists(): |
|
missing_files.append(file_name) |
|
|
|
if missing_files: |
|
print(f"โ Error: Missing required files: {missing_files}") |
|
return 1 |
|
|
|
print(f"โ
Found all required training data files in: {data_path}") |
|
|
|
|
|
for file_name in required_files: |
|
file_path = data_path / file_name |
|
file_size = file_path.stat().st_size |
|
print(f" ๐ {file_name}: {file_size:,} bytes") |
|
|
|
|
|
setup_environment(args.gpu, args.batch_size) |
|
|
|
|
|
config_path = args.config |
|
if not config_path: |
|
|
|
gpu_count = len(args.gpu.split(',')) if args.gpu else 1 |
|
if gpu_count >= 4: |
|
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_4gpu.yaml') |
|
else: |
|
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_1gpu.yaml') |
|
|
|
print(f"๐ Initializing trainer with config: {config_path}") |
|
|
|
|
|
config = create_config(args) |
|
|
|
|
|
logger = TestTimeLogger() |
|
|
|
|
|
global _trainer_instance |
|
_trainer_instance = IterativeTrainer( |
|
config=config, |
|
logger=logger, |
|
verl_config_path=config_path |
|
) |
|
|
|
|
|
try: |
|
result = _trainer_instance.run_verl_training_only( |
|
training_data_path=str(data_path), |
|
round_num=args.resume, |
|
experiment_name=f"step5_only_{args.benchmark}" |
|
) |
|
|
|
if result.get('success', False): |
|
print(f"โ
VeRL training completed successfully!") |
|
print(f"โฑ๏ธ Duration: {result.get('duration', 'N/A')} seconds") |
|
if 'model_path' in result: |
|
print(f"๐ค Updated model: {result['model_path']}") |
|
return 0 |
|
else: |
|
print(f"โ VeRL training failed: {result.get('error', 'Unknown error')}") |
|
return 1 |
|
|
|
except Exception as e: |
|
print(f"๐ฅ Training failed with exception: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
|
|
|
|
def main(): |
|
"""๋ฉ์ธ ์คํ ํจ์ - UnifiedTTRLVRTrainer ์ฌ์ฉ""" |
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
setup_environment(args.gpu) |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
output_dir = os.path.join( |
|
args.output_dir, |
|
f'ttrlvr_unified_{args.benchmark}_{args.rounds}rounds_{timestamp}' |
|
) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
PrettyPrinter.section_header("๐ TTRLVR Unified Training") |
|
PrettyPrinter.status("Config", f"Benchmark: {args.benchmark}", "info") |
|
PrettyPrinter.status("Config", f"Rounds: {args.rounds}", "info") |
|
PrettyPrinter.status("Config", f"Output: {output_dir}", "info") |
|
|
|
|
|
|
|
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id) |
|
PrettyPrinter.status("Problems", f"Selected {len(problem_ids)} problems", "info") |
|
|
|
|
|
ttrlvr_config = { |
|
'num_programs': args.num_programs, |
|
'input_generation_rounds': args.input_generation_rounds, |
|
'parallel_batch_size': args.parallel_batch_size, |
|
} |
|
|
|
|
|
if args.config: |
|
config_path = os.path.abspath(args.config) |
|
else: |
|
|
|
config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_unified_4gpu.yaml') |
|
|
|
PrettyPrinter.status("Config", f"Using VeRL config: {config_path}", "info") |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids) |
|
os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds) |
|
os.environ['TTRLVR_OUTPUT_DIR'] = output_dir |
|
os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
PrettyPrinter.section_header("๐ฏ Starting UnifiedTTRLVRTrainer (AZR-style initialization)") |
|
|
|
|
|
PrettyPrinter.status("Config", f"Loading {config_path}", "info") |
|
verl_config = OmegaConf.load(config_path) |
|
|
|
|
|
verl_config.trainer.project_name = f'ttrlvr_unified_{args.benchmark}' |
|
verl_config.trainer.experiment_name = f'round_{args.rounds}_{timestamp}' |
|
verl_config.trainer.total_epochs = args.rounds |
|
|
|
|
|
if not ray.is_initialized(): |
|
cuda_visible_devices = args.gpu or "0,1,2,3" |
|
PrettyPrinter.status("Ray", f"Initializing Ray cluster (GPUs: {cuda_visible_devices})", "info") |
|
ray.init( |
|
runtime_env={"env_vars": { |
|
"TOKENIZERS_PARALLELISM": "true", |
|
"NCCL_DEBUG": "WARN", |
|
"VLLM_LOGGING_LEVEL": "WARN", |
|
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", |
|
"CUDA_VISIBLE_DEVICES": cuda_visible_devices |
|
}}, |
|
num_cpus=verl_config.ray_init.num_cpus, |
|
|
|
) |
|
|
|
|
|
model_path = verl_config.actor_rollout_ref.model.path |
|
PrettyPrinter.status("Model", f"Loading tokenizer from {model_path}", "info") |
|
tokenizer = hf_tokenizer(model_path) |
|
|
|
|
|
role_worker_mapping = {} |
|
|
|
|
|
if verl_config.actor_rollout_ref.rollout.name == 'vllm': |
|
if verl_config.actor_rollout_ref.rollout.mode == 'async': |
|
actor_rollout_cls = AsyncActorRolloutRefWorker |
|
else: |
|
actor_rollout_cls = ActorRolloutRefWorker |
|
|
|
role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) |
|
PrettyPrinter.status("Workers", f"Using {actor_rollout_cls.__name__} for ActorRollout", "info") |
|
|
|
|
|
if verl_config.critic.include_critic: |
|
|
|
role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) |
|
PrettyPrinter.status("Workers", "Including Critic worker", "info") |
|
else: |
|
PrettyPrinter.status("Workers", "No Critic (using REINFORCE++)", "info") |
|
|
|
|
|
|
|
global_pool_id = "global_pool" |
|
n_gpus_per_node = verl_config.trainer.n_gpus_per_node |
|
nnodes = verl_config.trainer.nnodes |
|
resource_pool_spec = { |
|
global_pool_id: [n_gpus_per_node] * nnodes, |
|
} |
|
mapping = { |
|
Role.ActorRollout: global_pool_id, |
|
} |
|
if verl_config.critic.include_critic: |
|
mapping[Role.Critic] = global_pool_id |
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) |
|
PrettyPrinter.status("Resources", f"Created ResourcePoolManager with {len(resource_pool_spec)} pools", "info") |
|
|
|
|
|
from trainer.unified_ttrlvr_trainer import UnifiedTTRLVRTrainer |
|
|
|
PrettyPrinter.status("Trainer", "Creating UnifiedTTRLVRTrainer", "info") |
|
trainer = UnifiedTTRLVRTrainer( |
|
past_epoch_window=verl_config.azr.past_epoch_window, |
|
config=verl_config, |
|
tokenizer=tokenizer, |
|
processor=None, |
|
role_worker_mapping=role_worker_mapping, |
|
resource_pool_manager=resource_pool_manager, |
|
ray_worker_group_cls=RayWorkerGroup, |
|
reward_fn=None, |
|
val_reward_fn=None, |
|
|
|
ttrlvr_config=ttrlvr_config, |
|
problem_ids=problem_ids, |
|
total_rounds=args.rounds, |
|
output_dir=output_dir |
|
) |
|
|
|
|
|
PrettyPrinter.section_header("๐ Starting Training") |
|
PrettyPrinter.status("Training", f"Running {args.rounds} rounds with {len(problem_ids)} problems", "info") |
|
trainer.fit() |
|
|
|
PrettyPrinter.section_header("โ
Training Complete") |
|
return 0 |
|
|
|
except KeyboardInterrupt: |
|
PrettyPrinter.status("Interrupt", "Training interrupted by user", "warning") |
|
return 130 |
|
|
|
except Exception as e: |
|
PrettyPrinter.status("Error", f"Training failed: {e}", "error") |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
finally: |
|
|
|
if ray.is_initialized(): |
|
ray.shutdown() |
|
|
|
PrettyPrinter.status("Cleanup", "Resources cleaned up", "success") |
|
|
|
|
|
if __name__ == '__main__': |
|
exit_code = main() |
|
sys.exit(exit_code) |