|
|
|
""" |
|
TTRLVR + AZR ํตํฉ ํ์ต ๋ฉ์ธ ์คํฌ๋ฆฝํธ |
|
|
|
30๋ผ์ด๋ ๋ฐ๋ณต ํ์ต์ ํตํด TTRLVR ํ์ดํ๋ผ์ธ๊ณผ AZR ํ์ต์ ํตํฉ: |
|
1. ๊ฐ ๋ผ์ด๋๋ง๋ค ํ์ฌ ๋ชจ๋ธ๋ก (i,p,o) โ induction/deduction/abduction tasks ์์ฑ |
|
2. ํด๋น ๋ผ์ด๋ ๋ฐ์ดํฐ๋ก๋ง AZR ํ์ต (๊ฐ ๋ผ์ด๋์ task๋ง ์ฌ์ฉ) |
|
3. ๊ฐ์ ๋ ๋ชจ๋ธ๋ก ๋ค์ ๋ผ์ด๋ ์งํ |
|
4. 5๋ผ์ด๋๋ง๋ค ์ฒดํฌํฌ์ธํธ ์ ์ฅ |
|
|
|
์ฌ์ฉ ์์: |
|
# ์ผ๋ฐ ํ์ต |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 |
|
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10 --resume 15 |
|
python train_ttrlvr_azr.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 --num-programs 8 --eval-rounds 10 |
|
python train_ttrlvr_azr.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 --skip-task-eval |
|
|
|
# Step 5 ์ ์ฉ ๋ชจ๋ (๊ธฐ์กด ๋ฐ์ดํฐ๋ก VeRL ํ์ต๋ง) |
|
python train_ttrlvr_azr.py --step5-only --data-path /path/to/azr_training_data --gpu 1,2,3,0 --config configs/ttrlvr_azr_ppo_4gpu.yaml |
|
""" |
|
|
|
import os |
|
import sys |
|
import argparse |
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List |
|
import warnings |
|
import signal |
|
import atexit |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*") |
|
|
|
|
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') |
|
|
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2/evaluation/code_eval/coding') |
|
|
|
|
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
from utils.iterative_trainer import IterativeTrainer |
|
|
|
|
|
|
|
_trainer_instance = None |
|
_logger_instance = None |
|
|
|
|
|
def cleanup_ray(): |
|
"""Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ํจ์""" |
|
global _trainer_instance, _logger_instance |
|
|
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info("๐ ๊ฐ์ ์ข
๋ฃ ๊ฐ์ง: Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ์ค...") |
|
except: |
|
print("๐ ๊ฐ์ ์ข
๋ฃ ๊ฐ์ง: Ray ํด๋ฌ์คํฐ ์ ๋ฆฌ ์ค...") |
|
|
|
try: |
|
|
|
if _trainer_instance: |
|
_trainer_instance.cleanup_ray() |
|
except Exception as e: |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_error(f"IterativeTrainer ์ ๋ฆฌ ์คํจ: {e}") |
|
except: |
|
print(f"IterativeTrainer ์ ๋ฆฌ ์คํจ: {e}") |
|
|
|
try: |
|
|
|
import ray |
|
if ray.is_initialized(): |
|
ray.shutdown() |
|
except Exception as e: |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_error(f"Ray ์ข
๋ฃ ์คํจ: {e}") |
|
except: |
|
print(f"Ray ์ข
๋ฃ ์คํจ: {e}") |
|
|
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info("โ
Ray ์ ๋ฆฌ ์๋ฃ") |
|
except: |
|
print("โ
Ray ์ ๋ฆฌ ์๋ฃ") |
|
|
|
|
|
def signal_handler(signum, frame): |
|
"""์๊ทธ๋ ํธ๋ค๋ฌ (Ctrl+C, ๊ฐ์ ์ข
๋ฃ ๋ฑ)""" |
|
try: |
|
if _logger_instance: |
|
_logger_instance.log_info(f"๐ ์๊ทธ๋ {signum} ์์ : ํ๋ก๊ทธ๋จ ์ข
๋ฃ ์ค...") |
|
except: |
|
print(f"๐ ์๊ทธ๋ {signum} ์์ : ํ๋ก๊ทธ๋จ ์ข
๋ฃ ์ค...") |
|
|
|
cleanup_ray() |
|
sys.exit(1) |
|
|
|
|
|
def parse_arguments(): |
|
"""๋ช
๋ นํ ์ธ์ ํ์ฑ""" |
|
|
|
parser = argparse.ArgumentParser( |
|
description='TTRLVR + AZR ํตํฉ ๋ฐ๋ณต ํ์ต', |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=""" |
|
์์: |
|
# MBPP 10๋ฌธ์ ๋ก 30๋ผ์ด๋ ํ์ต |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 |
|
|
|
# HumanEval 5๋ฌธ์ ๋ก 10๋ผ์ด๋ ํ์ต |
|
python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10 |
|
|
|
# 15๋ผ์ด๋๋ถํฐ ์ฌ๊ฐ |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --resume 15 |
|
|
|
# ํน์ GPU ์ฌ์ฉ |
|
python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --gpu 4 |
|
""" |
|
) |
|
|
|
parser.add_argument( |
|
'--benchmark', |
|
choices=['mbpp', 'humaneval'], |
|
default='mbpp', |
|
help='๋ฒค์น๋งํฌ ์ ํ (๊ธฐ๋ณธ๊ฐ: mbpp)' |
|
) |
|
|
|
parser.add_argument( |
|
'--problems', |
|
type=int, |
|
default=10, |
|
help='๋ฌธ์ ์ (๊ธฐ๋ณธ๊ฐ: 10)' |
|
) |
|
|
|
parser.add_argument( |
|
'--problem-id', |
|
type=str, |
|
help='ํน์ ๋ฌธ์ ID (์: HumanEval/1, Mbpp/10)' |
|
) |
|
|
|
parser.add_argument( |
|
'--rounds', |
|
type=int, |
|
default=30, |
|
help='์ด ๋ผ์ด๋ ์ (๊ธฐ๋ณธ๊ฐ: 30)' |
|
) |
|
|
|
parser.add_argument( |
|
'--resume', |
|
type=int, |
|
default=1, |
|
help='์ฌ๊ฐํ ๋ผ์ด๋ ๋ฒํธ (๊ธฐ๋ณธ๊ฐ: 1)' |
|
) |
|
|
|
parser.add_argument( |
|
'--gpu', |
|
type=str, |
|
default='5', |
|
help='์ฌ์ฉํ GPU ๋ฒํธ (๋จ์ผ: 5, ๋ค์ค: 1,2,3,5)' |
|
) |
|
|
|
parser.add_argument( |
|
'--output-dir', |
|
type=str, |
|
default='./results/ttrlvr_azr', |
|
help='๊ฒฐ๊ณผ ์ ์ฅ ๋๋ ํ ๋ฆฌ (๊ธฐ๋ณธ๊ฐ: ./results/ttrlvr_azr)' |
|
) |
|
|
|
parser.add_argument( |
|
'--config', |
|
type=str, |
|
help='์ค์ ํ์ผ ๊ฒฝ๋ก (์ ํ์ฌํญ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--model', |
|
type=str, |
|
default='Qwen/Qwen2.5-7B', |
|
help='์ฌ์ฉํ ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ: Qwen/Qwen2.5-7B)' |
|
) |
|
|
|
parser.add_argument( |
|
'--debug', |
|
action='store_true', |
|
help='๋๋ฒ๊ทธ ๋ชจ๋ ํ์ฑํ' |
|
) |
|
|
|
parser.add_argument( |
|
'--batch-size', |
|
type=int, |
|
default=24, |
|
help='ํ์ต ๋ฐฐ์น ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ: 24, OOM ์ ์ค์ด๊ธฐ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--batch-epochs', |
|
type=int, |
|
default=1, |
|
help='๋ฐฐ์น๋น ์ํญ ์ (๊ธฐ๋ณธ๊ฐ: 1, ๋ ๋ง์ ํ์ต์ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--num-programs', |
|
type=int, |
|
default=4, |
|
help='์์ฑํ ๋ค์ํ ํ๋ก๊ทธ๋จ ์ (๊ธฐ๋ณธ๊ฐ: 4, ๋ ๋ค์ํ ๋ฐ์ดํฐ๋ฅผ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--input-generation-rounds', |
|
type=int, |
|
default=3, |
|
help='๋ค์ํ ์
๋ ฅ ์์ฑ ๋ผ์ด๋ ์ (๊ธฐ๋ณธ๊ฐ: 3, ๋ผ์ด๋๋น 5๊ฐ์ฉ ์์ฑ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--parallel-batch-size', |
|
type=int, |
|
default=4, |
|
help='๋์ ์ฒ๋ฆฌํ ํ๋กฌํํธ ์ (๊ธฐ๋ณธ๊ฐ: 4, GPU ๋ฉ๋ชจ๋ฆฌ์ ๋ฐ๋ผ ์กฐ์ )' |
|
) |
|
|
|
parser.add_argument( |
|
'--eval-rounds', |
|
type=int, |
|
default=5, |
|
help='๋งค ๋ผ์ด๋ ์ ํ๋ ์ธก์ ํ์ (๊ธฐ๋ณธ๊ฐ: 5, ๋ ์ ํํ ํ๊ฐ๋ฅผ ์ํด ์ฆ๊ฐ ๊ฐ๋ฅ)' |
|
) |
|
|
|
parser.add_argument( |
|
'--skip-task-eval', |
|
action='store_true', |
|
help='Task evaluation(4๋จ๊ณ) ์คํตํ์ฌ ๋น ๋ฅธ ํ
์คํธ (๋ฐ์ดํฐ ์์ฑ ํ ๋ฐ๋ก VeRL ํ์ต)' |
|
) |
|
|
|
parser.add_argument( |
|
'--save-every-round', |
|
action='store_true', |
|
help='๋งค ๋ผ์ด๋๋ง๋ค ์ฒดํฌํฌ์ธํธ ์ ์ฅ (๊ธฐ๋ณธ๊ฐ: False)' |
|
) |
|
|
|
parser.add_argument( |
|
'--save-round-interval', |
|
type=int, |
|
default=5, |
|
help='์ฒดํฌํฌ์ธํธ ์ ์ฅ ๊ฐ๊ฒฉ (์: 5 = 5๋ผ์ด๋๋ง๋ค ์ ์ฅ, ๊ธฐ๋ณธ๊ฐ: 5)' |
|
) |
|
|
|
parser.add_argument( |
|
'--step5-only', |
|
action='store_true', |
|
help='๊ธฐ์กด ๋ฐ์ดํฐ๋ก Step 5 (VeRL ํ์ต)๋ง ์คํ' |
|
) |
|
|
|
parser.add_argument( |
|
'--data-path', |
|
type=str, |
|
help='Step5 ์ ์ฉ ๋ชจ๋์์ ์ฌ์ฉํ ๊ธฐ์กด azr_training_data ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก' |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def setup_environment(gpu_id: str, batch_size: int = None): |
|
"""ํ๊ฒฝ ๋ณ์ ์ค์ - run_ttrlvr_azr_training.sh์ ๋์ผํ๊ฒ""" |
|
|
|
|
|
if gpu_id: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id |
|
print(f"๐ฏ Using command line GPU setting: {gpu_id}") |
|
elif 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES']: |
|
print(f"๐ฏ Using existing CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") |
|
else: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '5' |
|
print(f"๐ฏ Using default GPU: 5") |
|
|
|
|
|
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' |
|
|
|
|
|
os.environ['RAY_memory_monitor_refresh_ms'] = '0' |
|
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG' |
|
|
|
|
|
os.environ['HYDRA_FULL_ERROR'] = '1' |
|
|
|
|
|
pythonpath = os.environ.get('PYTHONPATH', '') |
|
if '/home/ubuntu/RLVR/verl' not in pythonpath: |
|
os.environ['PYTHONPATH'] = f"{pythonpath}:/home/ubuntu/RLVR/verl:/home/ubuntu/RLVR/TestTime-RLVR-v2" |
|
|
|
|
|
if batch_size is not None: |
|
os.environ['TRAIN_BATCH_SIZE'] = str(batch_size) |
|
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' |
|
os.environ['RAY_memory_monitor_refresh_ms'] = '0' |
|
os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG' |
|
os.environ['HYDRA_FULL_ERROR'] = '1' |
|
os.environ['HF_HOME'] = '/data/.cache/huggingface' |
|
os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/huggingface' |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
|
|
|
current_pythonpath = os.environ.get('PYTHONPATH', '') |
|
new_paths = [ |
|
'/home/ubuntu/RLVR/TestTime-RLVR-v2', |
|
'/data/miniforge3/envs/azr/lib/python3.10/site-packages' |
|
] |
|
|
|
for path in new_paths: |
|
if path not in current_pythonpath: |
|
current_pythonpath = f"{path}:{current_pythonpath}" if current_pythonpath else path |
|
|
|
os.environ['PYTHONPATH'] = current_pythonpath |
|
|
|
|
|
def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]: |
|
"""๋ฒค์น๋งํฌ์์ ๋ฌธ์ ID ๋ชฉ๋ก ๋ก๋ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
problems = [] |
|
|
|
if benchmark_config.name == 'mbpp': |
|
|
|
try: |
|
from evalplus.data.mbpp import get_mbpp_plus |
|
mbpp_problems = get_mbpp_plus() |
|
problems = list(mbpp_problems.keys()) |
|
print(f"โ
MBPP+ ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ (EvalPlus ํ์ค ๋ฐฉ์)") |
|
except Exception as e: |
|
print(f"โ MBPP+ EvalPlus ๋ก๋ฉ ์คํจ, ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ: {e}") |
|
|
|
data_path = benchmark_config.data_path |
|
if os.path.exists(data_path): |
|
with open(data_path, 'r') as f: |
|
for line in f: |
|
problem = json.loads(line.strip()) |
|
problems.append(problem['task_id']) |
|
|
|
elif benchmark_config.name == 'humaneval': |
|
|
|
try: |
|
from evalplus.data.humaneval import get_human_eval_plus |
|
humaneval_problems = get_human_eval_plus() |
|
problems = list(humaneval_problems.keys()) |
|
print(f"โ
HumanEval+ ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ (EvalPlus ํ์ค ๋ฐฉ์)") |
|
except Exception as e: |
|
print(f"โ HumanEval+ EvalPlus ๋ก๋ฉ ์คํจ, ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ: {e}") |
|
|
|
data_path = benchmark_config.data_path |
|
if os.path.exists(data_path): |
|
with open(data_path, 'r') as f: |
|
for line in f: |
|
problem = json.loads(line.strip()) |
|
problems.append(problem['task_id']) |
|
|
|
return problems |
|
|
|
|
|
def create_problem_list(benchmark: str, num_problems: int, specific_problem_id: str = None) -> list: |
|
"""๋ฒค์น๋งํฌ๋ณ ๋ฌธ์ ID ๋ฆฌ์คํธ ์์ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
|
|
benchmark_config = create_benchmark_config(benchmark) |
|
|
|
|
|
all_problems = load_benchmark_problems(benchmark_config) |
|
|
|
if not all_problems: |
|
raise ValueError(f"No problems found for benchmark: {benchmark}") |
|
|
|
|
|
if specific_problem_id: |
|
if specific_problem_id in all_problems: |
|
return [specific_problem_id] |
|
else: |
|
raise ValueError(f"Problem ID '{specific_problem_id}' not found in {benchmark} benchmark") |
|
|
|
|
|
if num_problems <= 0 or num_problems > len(all_problems): |
|
return all_problems |
|
else: |
|
return all_problems[:num_problems] |
|
|
|
|
|
def create_config(args) -> TestTimeConfig: |
|
"""TestTimeConfig ์์ฑ""" |
|
|
|
config = TestTimeConfig() |
|
|
|
|
|
config.model_name = args.model |
|
config.max_new_tokens = 512 |
|
config.temperature = 0.05 |
|
config.baseline_evaluation_rounds = args.eval_rounds |
|
|
|
|
|
config.num_program_variations = args.num_programs |
|
config.input_generation_rounds = args.input_generation_rounds |
|
config.parallel_batch_size = args.parallel_batch_size |
|
|
|
|
|
config.skip_task_evaluation = args.skip_task_eval |
|
|
|
|
|
if args.debug: |
|
config.debug = True |
|
config.verbose = True |
|
|
|
return config |
|
|
|
|
|
def create_benchmark_config(benchmark: str) -> BenchmarkConfig: |
|
"""BenchmarkConfig ์์ฑ (๊ธฐ์กด TTRLVR ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
|
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
if benchmark == 'mbpp': |
|
benchmark_config = BenchmarkConfig.get_mbpp_config() |
|
benchmark_config.data_path = os.path.join(base_dir, 'evaluation/code_eval/data/MbppPlus.jsonl') |
|
return benchmark_config |
|
elif benchmark == 'humaneval': |
|
benchmark_config = BenchmarkConfig.get_humaneval_config() |
|
benchmark_config.data_path = os.path.join(base_dir, 'evaluation/code_eval/data/HumanEvalPlus.jsonl') |
|
return benchmark_config |
|
else: |
|
raise ValueError(f"Unknown benchmark: {benchmark}") |
|
|
|
|
|
def save_run_config(args, problem_ids: list, output_dir: str): |
|
"""์คํ ์ค์ ์ ์ฅ""" |
|
|
|
config_data = { |
|
'timestamp': datetime.now().isoformat(), |
|
'benchmark': args.benchmark, |
|
'num_problems': args.problems, |
|
'total_rounds': args.rounds, |
|
'resume_from': args.resume, |
|
'gpu': args.gpu, |
|
'problem_ids': problem_ids, |
|
'output_dir': output_dir, |
|
'command_line': ' '.join(sys.argv) |
|
} |
|
|
|
config_file = os.path.join(output_dir, 'run_config.json') |
|
with open(config_file, 'w') as f: |
|
json.dump(config_data, f, indent=2) |
|
|
|
return config_file |
|
|
|
|
|
def run_step5_only_mode(args): |
|
"""Step 5 ์ ์ฉ ๋ชจ๋ ์คํ""" |
|
from pathlib import Path |
|
|
|
print(f"๐ Running Step 5 (VeRL training) only mode") |
|
print(f"๐ Data path: {args.data_path}") |
|
|
|
|
|
data_path = Path(args.data_path) |
|
if not data_path.exists(): |
|
print(f"โ Error: Data path does not exist: {data_path}") |
|
return 1 |
|
|
|
|
|
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet'] |
|
missing_files = [] |
|
for file_name in required_files: |
|
if not (data_path / file_name).exists(): |
|
missing_files.append(file_name) |
|
|
|
if missing_files: |
|
print(f"โ Error: Missing required files: {missing_files}") |
|
return 1 |
|
|
|
print(f"โ
Found all required training data files in: {data_path}") |
|
|
|
|
|
for file_name in required_files: |
|
file_path = data_path / file_name |
|
file_size = file_path.stat().st_size |
|
print(f" ๐ {file_name}: {file_size:,} bytes") |
|
|
|
|
|
setup_environment(args.gpu, args.batch_size) |
|
|
|
|
|
config_path = args.config |
|
if not config_path: |
|
|
|
gpu_count = len(args.gpu.split(',')) if args.gpu else 1 |
|
if gpu_count >= 4: |
|
config_path = '/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_4gpu.yaml' |
|
else: |
|
config_path = '/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_1gpu.yaml' |
|
|
|
print(f"๐ Initializing trainer with config: {config_path}") |
|
|
|
|
|
config = create_config(args) |
|
|
|
|
|
logger = TestTimeLogger() |
|
|
|
|
|
global _trainer_instance |
|
_trainer_instance = IterativeTrainer( |
|
config=config, |
|
logger=logger, |
|
verl_config_path=config_path |
|
) |
|
|
|
|
|
try: |
|
result = _trainer_instance.run_verl_training_only( |
|
training_data_path=str(data_path), |
|
round_num=args.resume, |
|
experiment_name=f"step5_only_{args.benchmark}" |
|
) |
|
|
|
if result.get('success', False): |
|
print(f"โ
VeRL training completed successfully!") |
|
print(f"โฑ๏ธ Duration: {result.get('duration', 'N/A')} seconds") |
|
if 'model_path' in result: |
|
print(f"๐ค Updated model: {result['model_path']}") |
|
return 0 |
|
else: |
|
print(f"โ VeRL training failed: {result.get('error', 'Unknown error')}") |
|
return 1 |
|
|
|
except Exception as e: |
|
print(f"๐ฅ Training failed with exception: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
|
|
|
|
def main(): |
|
"""๋ฉ์ธ ์คํ ํจ์""" |
|
global _trainer_instance, _logger_instance |
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
signal.signal(signal.SIGTERM, signal_handler) |
|
atexit.register(cleanup_ray) |
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
if args.step5_only: |
|
if not args.data_path: |
|
print("โ Error: --data-path is required when using --step5-only") |
|
return 1 |
|
|
|
return run_step5_only_mode(args) |
|
|
|
|
|
temp_problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id) |
|
actual_problem_count = len(temp_problem_ids) if temp_problem_ids else args.problems |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
if args.problem_id: |
|
problem_desc = f"problem_{args.problem_id.replace('/', '_')}" |
|
else: |
|
problem_desc = f"{actual_problem_count}problems" |
|
|
|
output_dir = os.path.join( |
|
args.output_dir, |
|
f'{args.benchmark}_{problem_desc}_{args.rounds}rounds_{timestamp}' |
|
) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
setup_environment(args.gpu, args.batch_size) |
|
|
|
|
|
log_dir = "/home/ubuntu/RLVR/TestTime-RLVR-v2/logs" |
|
os.makedirs(log_dir, exist_ok=True) |
|
logger = TestTimeLogger(log_dir=log_dir) |
|
_logger_instance = logger |
|
|
|
|
|
log_handlers = [h for h in logger.logger.handlers if hasattr(h, 'baseFilename')] |
|
if log_handlers: |
|
actual_log_file = log_handlers[0].baseFilename |
|
print(f"๐ ๋ก๊ทธ ํ์ผ: {actual_log_file}") |
|
else: |
|
print(f"๐ ๋ก๊ทธ ๋๋ ํ ๋ฆฌ: {log_dir}") |
|
|
|
|
|
config = create_config(args) |
|
benchmark_config = create_benchmark_config(args.benchmark) |
|
|
|
|
|
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id) |
|
|
|
|
|
actual_problem_count = len(problem_ids) if problem_ids else args.problems |
|
actual_gpu = os.environ.get('CUDA_VISIBLE_DEVICES', args.gpu) |
|
|
|
logger.log_info("๐ TTRLVR + AZR ํตํฉ ํ์ต ์์") |
|
logger.log_info("=" * 80) |
|
logger.log_info(f"๐ ์ค์ :") |
|
logger.log_info(f" - ๋ฒค์น๋งํฌ: {args.benchmark}") |
|
if args.problem_id: |
|
logger.log_info(f" - ํน์ ๋ฌธ์ ID: {args.problem_id}") |
|
logger.log_info(f" - ๋ฌธ์ ์: {actual_problem_count} (ํน์ ๋ฌธ์ )") |
|
else: |
|
logger.log_info(f" - ๋ฌธ์ ์: {actual_problem_count}") |
|
logger.log_info(f" - ์ด ๋ผ์ด๋: {args.rounds}") |
|
logger.log_info(f" - GPU: {actual_gpu}") |
|
logger.log_info(f" - ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ: {output_dir}") |
|
if args.resume > 1: |
|
logger.log_info(f" - ์ฌ๊ฐ ๋ผ์ด๋: {args.resume}") |
|
logger.log_info("=" * 80) |
|
logger.log_info(f"๐ฏ ๋ฌธ์ ๋ฆฌ์คํธ: {problem_ids}") |
|
|
|
try: |
|
|
|
|
|
config_file = save_run_config(args, problem_ids, output_dir) |
|
logger.log_info(f"๐ ์คํ ์ค์ ์ ์ฅ: {config_file}") |
|
|
|
|
|
verl_config_path = None |
|
if args.config: |
|
verl_config_path = os.path.abspath(args.config) |
|
trainer = IterativeTrainer( |
|
config, |
|
logger, |
|
batch_epochs=args.batch_epochs, |
|
verl_config_path=verl_config_path, |
|
save_every_round=args.save_every_round, |
|
save_round_interval=args.save_round_interval |
|
) |
|
_trainer_instance = trainer |
|
|
|
|
|
logger.log_info("๐ ๋ฐ๋ณต ํ์ต ์์") |
|
training_results = trainer.run_iterative_training( |
|
benchmark_config=benchmark_config, |
|
problem_ids=problem_ids, |
|
total_rounds=args.rounds, |
|
resume_from_round=args.resume |
|
) |
|
|
|
|
|
results_file = os.path.join(output_dir, 'training_results.json') |
|
try: |
|
|
|
serializable_results = json.loads(json.dumps(training_results, default=str)) |
|
with open(results_file, 'w') as f: |
|
json.dump(serializable_results, f, indent=2) |
|
except Exception as e: |
|
logger.log_warning(f"๊ฒฐ๊ณผ ์ ์ฅ ์ค ์ค๋ฅ (๋ฌด์๋จ): {e}") |
|
|
|
basic_results = { |
|
'success': training_results.get('success', False), |
|
'benchmark': args.benchmark, |
|
'total_rounds': args.rounds, |
|
'completed_rounds': len(training_results.get('rounds', {})), |
|
'timestamp': training_results.get('end_time', 'unknown') |
|
} |
|
with open(results_file, 'w') as f: |
|
json.dump(basic_results, f, indent=2) |
|
|
|
logger.log_info(f"๐พ ์ต์ข
๊ฒฐ๊ณผ ์ ์ฅ: {results_file}") |
|
|
|
|
|
trainer.cleanup() |
|
|
|
if training_results['success']: |
|
logger.log_info("๐ TTRLVR + AZR ํตํฉ ํ์ต ์ฑ๊ณต์ ์ผ๋ก ์๋ฃ!") |
|
return 0 |
|
else: |
|
logger.log_error(f"โ ํ์ต ์คํจ: {training_results.get('error', 'Unknown error')}") |
|
return 1 |
|
|
|
except KeyboardInterrupt: |
|
logger.log_info("โ ๏ธ ์ฌ์ฉ์์ ์ํด ์ค๋จ๋จ") |
|
|
|
if 'trainer' in locals(): |
|
trainer.cleanup() |
|
return 130 |
|
|
|
except Exception as e: |
|
logger.log_error(f"๐ฅ ์์์น ๋ชปํ ์ค๋ฅ: {e}") |
|
|
|
if 'trainer' in locals(): |
|
trainer.cleanup() |
|
return 1 |
|
finally: |
|
|
|
if 'trainer' in locals() and hasattr(trainer, 'cleanup'): |
|
trainer.cleanup() |
|
|
|
|
|
import subprocess |
|
try: |
|
|
|
subprocess.run(['ray', 'stop', '--force'], capture_output=True, timeout=10) |
|
except: |
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
exit_code = main() |
|
sys.exit(exit_code) |