File size: 3,276 Bytes
f50dc54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
"""
Step 5 μ „μš© μ‹€ν–‰ 슀크립트
κΈ°μ‘΄ AZR ν•™μŠ΅ λ°μ΄ν„°λ‘œ VeRL PPO ν•™μŠ΅λ§Œ μ‹€ν–‰
"""

import os
import sys
import argparse
from pathlib import Path

# 경둜 μ„€μ •
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2')
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2/test')

from test.utils.iterative_trainer import IterativeTrainer

def main():
    parser = argparse.ArgumentParser(description='Run VeRL training (Step 5) only with existing data')
    parser.add_argument('--data_path', type=str, required=True,
                        help='Path to existing azr_training_data directory')
    parser.add_argument('--round', type=int, default=1,
                        help='Round number for logging (default: 1)')
    parser.add_argument('--experiment_name', type=str, default=None,
                        help='Custom experiment name')
    parser.add_argument('--config', type=str, 
                        default='/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_4gpu.yaml',
                        help='VeRL config file path')
    
    args = parser.parse_args()
    
    # 데이터 경둜 검증
    data_path = Path(args.data_path)
    if not data_path.exists():
        print(f"❌ Error: Data path does not exist: {data_path}")
        return 1
    
    # ν•„μˆ˜ νŒŒμΌλ“€ 확인
    required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
    missing_files = []
    for file_name in required_files:
        if not (data_path / file_name).exists():
            missing_files.append(file_name)
    
    if missing_files:
        print(f"❌ Error: Missing required files: {missing_files}")
        return 1
    
    print(f"βœ… Found all required training data files in: {data_path}")
    
    # 파일 크기 정보 좜λ ₯
    for file_name in required_files:
        file_path = data_path / file_name
        file_size = file_path.stat().st_size
        print(f"  πŸ“„ {file_name}: {file_size:,} bytes")
    
    # IterativeTrainer μ΄ˆκΈ°ν™”
    print(f"πŸš€ Initializing trainer with config: {args.config}")
    trainer = IterativeTrainer(config_path=args.config)
    
    # Step 5 μ „μš© VeRL ν•™μŠ΅ μ‹€ν–‰
    print(f"πŸŽ“ Starting VeRL training (Step 5 only)")
    print(f"πŸ“‚ Data path: {data_path}")
    print(f"πŸ”„ Round: {args.round}")
    
    try:
        result = trainer.run_verl_training_only(
            training_data_path=str(data_path),
            round_num=args.round,
            experiment_name=args.experiment_name
        )
        
        if result.get('success', False):
            print(f"βœ… VeRL training completed successfully!")
            print(f"⏱️  Duration: {result.get('duration', 'N/A')} seconds")
            if 'model_path' in result:
                print(f"πŸ€– Updated model: {result['model_path']}")
        else:
            print(f"❌ VeRL training failed: {result.get('error', 'Unknown error')}")
            return 1
            
    except Exception as e:
        print(f"πŸ’₯ Training failed with exception: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    print(f"πŸŽ‰ Step 5 training completed!")
    return 0

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)