File size: 3,276 Bytes
f50dc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
#!/usr/bin/env python3
"""
Step 5 μ μ© μ€ν μ€ν¬λ¦½νΈ
κΈ°μ‘΄ AZR νμ΅ λ°μ΄ν°λ‘ VeRL PPO νμ΅λ§ μ€ν
"""
import os
import sys
import argparse
from pathlib import Path
# κ²½λ‘ μ€μ
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2')
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2/test')
from test.utils.iterative_trainer import IterativeTrainer
def main():
parser = argparse.ArgumentParser(description='Run VeRL training (Step 5) only with existing data')
parser.add_argument('--data_path', type=str, required=True,
help='Path to existing azr_training_data directory')
parser.add_argument('--round', type=int, default=1,
help='Round number for logging (default: 1)')
parser.add_argument('--experiment_name', type=str, default=None,
help='Custom experiment name')
parser.add_argument('--config', type=str,
default='/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_4gpu.yaml',
help='VeRL config file path')
args = parser.parse_args()
# λ°μ΄ν° κ²½λ‘ κ²μ¦
data_path = Path(args.data_path)
if not data_path.exists():
print(f"β Error: Data path does not exist: {data_path}")
return 1
# νμ νμΌλ€ νμΈ
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
missing_files = []
for file_name in required_files:
if not (data_path / file_name).exists():
missing_files.append(file_name)
if missing_files:
print(f"β Error: Missing required files: {missing_files}")
return 1
print(f"β
Found all required training data files in: {data_path}")
# νμΌ ν¬κΈ° μ 보 μΆλ ₯
for file_name in required_files:
file_path = data_path / file_name
file_size = file_path.stat().st_size
print(f" π {file_name}: {file_size:,} bytes")
# IterativeTrainer μ΄κΈ°ν
print(f"π Initializing trainer with config: {args.config}")
trainer = IterativeTrainer(config_path=args.config)
# Step 5 μ μ© VeRL νμ΅ μ€ν
print(f"π Starting VeRL training (Step 5 only)")
print(f"π Data path: {data_path}")
print(f"π Round: {args.round}")
try:
result = trainer.run_verl_training_only(
training_data_path=str(data_path),
round_num=args.round,
experiment_name=args.experiment_name
)
if result.get('success', False):
print(f"β
VeRL training completed successfully!")
print(f"β±οΈ Duration: {result.get('duration', 'N/A')} seconds")
if 'model_path' in result:
print(f"π€ Updated model: {result['model_path']}")
else:
print(f"β VeRL training failed: {result.get('error', 'Unknown error')}")
return 1
except Exception as e:
print(f"π₯ Training failed with exception: {e}")
import traceback
traceback.print_exc()
return 1
print(f"π Step 5 training completed!")
return 0
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code) |