#!/usr/bin/env python3 """ IPO Grouped Sampling 테스트 IPO 그룹 샘플링이 제대로 작동하는지 확인 """ import sys import os sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') sys.path.append('/home/ubuntu/RLVR/verl') import pandas as pd import numpy as np from transformers import AutoTokenizer from absolute_zero_reasoner.utils.dataset.ttrlvr_dataset import TTRLVRDataset from absolute_zero_reasoner.utils.dataset.ipo_grouped_sampler import IPOGroupedBatchSampler def create_test_data(): """테스트용 parquet 파일 생성""" # 3개의 IPO 그룹, 각각 3개의 task (induction, deduction, abduction) data = [] for ipo_idx in range(3): ipo_group_id = f"Mbpp_2_program_var_0_ipo_{ipo_idx}" # 각 IPO 그룹에 3개의 task for task_type in ['induction', 'deduction', 'abduction']: record = { 'prompt': f"Test prompt for {task_type} task from IPO {ipo_idx}", 'ground_truth': f"Expected solution for {task_type}", 'uid': f"Mbpp_2_round_1_{task_type}_{ipo_idx}", 'ipo_group_id': ipo_group_id, 'problem': { 'input': f"test_input_{ipo_idx}", 'output': f"test_output_{ipo_idx}", 'snippet': f"def test_func_{ipo_idx}(): pass" }, 'basic_accuracy': 0.0 } data.append(record) # DataFrame 생성 및 저장 df = pd.DataFrame(data) test_file = '/tmp/test_ipo_grouped.parquet' df.to_parquet(test_file) print(f"✅ Created test data with {len(data)} samples in {len(df['ipo_group_id'].unique())} IPO groups") print(f" Saved to: {test_file}") return test_file def test_ipo_grouped_sampler(): """IPO 그룹 샘플러 테스트""" print("\n🔧 Testing IPO Grouped Sampler") print("=" * 60) # 1. 테스트 데이터 생성 test_file = create_test_data() # 2. 토크나이저 로드 tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # 3. 데이터셋 생성 dataset = TTRLVRDataset( parquet_files=test_file, tokenizer=tokenizer ) print(f"\n📊 Dataset loaded: {len(dataset)} samples") # 4. IPO 그룹 샘플러 생성 batch_size = 3 # 한 IPO 그룹의 3개 task가 한 배치에 들어가도록 sampler = IPOGroupedBatchSampler( dataset=dataset, batch_size=batch_size, shuffle=False, # 테스트를 위해 셔플 비활성화 drop_last=False ) print(f"\n🎯 Sampler created with batch_size={batch_size}") print(f" Total batches: {len(sampler)}") # 5. 배치 확인 print("\n📦 Checking batch composition:") for batch_idx, batch_indices in enumerate(sampler): print(f"\n Batch {batch_idx + 1}: {len(batch_indices)} samples") # 배치 내 IPO 그룹 확인 ipo_groups = [] for idx in batch_indices: row = dataset.dataframe.iloc[idx] ipo_group = row['ipo_group_id'] uid = row['uid'] ipo_groups.append(ipo_group) print(f" - idx={idx}: {uid} (IPO: {ipo_group})") # 같은 IPO 그룹인지 확인 unique_groups = set(ipo_groups) if len(unique_groups) == 1: print(f" ✅ All samples from same IPO group!") else: print(f" ⚠️ Mixed IPO groups: {unique_groups}") # 6. 셔플 테스트 print("\n\n🔀 Testing with shuffle=True:") sampler_shuffled = IPOGroupedBatchSampler( dataset=dataset, batch_size=batch_size, shuffle=True, seed=42 ) batch_order = [] for batch_idx, batch_indices in enumerate(sampler_shuffled): first_idx = batch_indices[0] row = dataset.dataframe.iloc[first_idx] ipo_group = row['ipo_group_id'] batch_order.append(ipo_group) print(f" Batch {batch_idx + 1}: IPO group = {ipo_group}") print("\n✅ IPO Grouped Sampler test completed!") return True def test_verl_integration(): """VeRL의 create_rl_sampler와 통합 테스트""" print("\n\n🔧 Testing VeRL Integration") print("=" * 60) # 테스트 설정 from omegaconf import OmegaConf data_config = OmegaConf.create({ 'train_batch_size': 3, 'shuffle': True, 'use_ipo_grouping': True, # IPO 그룹 샘플링 활성화 'drop_last': False, 'seed': 42 }) # 테스트 데이터 test_file = '/tmp/test_ipo_grouped.parquet' if not os.path.exists(test_file): test_file = create_test_data() tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") dataset = TTRLVRDataset( parquet_files=test_file, tokenizer=tokenizer ) # create_rl_sampler 호출 from verl.trainer.main_ppo import create_rl_sampler sampler = create_rl_sampler(data_config, dataset) # 샘플러 타입 확인 print(f"Sampler type: {type(sampler).__name__}") if hasattr(sampler, '__len__'): print(f"Number of batches: {len(sampler)}") # 배치 확인 (BatchSampler인 경우) if hasattr(sampler, '__iter__'): print("\nFirst 3 batches:") for i, batch in enumerate(sampler): if i >= 3: break if isinstance(batch, list): print(f" Batch {i+1}: {len(batch)} samples - indices: {batch}") else: print(f" Batch {i+1}: {batch}") print("\n✅ VeRL integration test completed!") return True if __name__ == "__main__": print("🚀 Starting IPO Grouped Sampling Tests") print("=" * 80) # 기본 샘플러 테스트 test_ipo_grouped_sampler() # VeRL 통합 테스트 test_verl_integration() print("\n" + "=" * 80) print("🎉 All tests completed successfully!")