|
|
|
""" |
|
IPO Grouped Sampling ํ
์คํธ |
|
|
|
IPO ๊ทธ๋ฃน ์ํ๋ง์ด ์ ๋๋ก ์๋ํ๋์ง ํ์ธ |
|
""" |
|
|
|
import sys |
|
import os |
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') |
|
sys.path.append('/home/ubuntu/RLVR/verl') |
|
|
|
import pandas as pd |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
from absolute_zero_reasoner.utils.dataset.ttrlvr_dataset import TTRLVRDataset |
|
from absolute_zero_reasoner.utils.dataset.ipo_grouped_sampler import IPOGroupedBatchSampler |
|
|
|
|
|
def create_test_data(): |
|
"""ํ
์คํธ์ฉ parquet ํ์ผ ์์ฑ""" |
|
|
|
|
|
data = [] |
|
|
|
for ipo_idx in range(3): |
|
ipo_group_id = f"Mbpp_2_program_var_0_ipo_{ipo_idx}" |
|
|
|
|
|
for task_type in ['induction', 'deduction', 'abduction']: |
|
record = { |
|
'prompt': f"Test prompt for {task_type} task from IPO {ipo_idx}", |
|
'ground_truth': f"Expected solution for {task_type}", |
|
'uid': f"Mbpp_2_round_1_{task_type}_{ipo_idx}", |
|
'ipo_group_id': ipo_group_id, |
|
'problem': { |
|
'input': f"test_input_{ipo_idx}", |
|
'output': f"test_output_{ipo_idx}", |
|
'snippet': f"def test_func_{ipo_idx}(): pass" |
|
}, |
|
'basic_accuracy': 0.0 |
|
} |
|
data.append(record) |
|
|
|
|
|
df = pd.DataFrame(data) |
|
test_file = '/tmp/test_ipo_grouped.parquet' |
|
df.to_parquet(test_file) |
|
|
|
print(f"โ
Created test data with {len(data)} samples in {len(df['ipo_group_id'].unique())} IPO groups") |
|
print(f" Saved to: {test_file}") |
|
|
|
return test_file |
|
|
|
|
|
def test_ipo_grouped_sampler(): |
|
"""IPO ๊ทธ๋ฃน ์ํ๋ฌ ํ
์คํธ""" |
|
|
|
print("\n๐ง Testing IPO Grouped Sampler") |
|
print("=" * 60) |
|
|
|
|
|
test_file = create_test_data() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
|
|
|
|
dataset = TTRLVRDataset( |
|
parquet_files=test_file, |
|
tokenizer=tokenizer |
|
) |
|
|
|
print(f"\n๐ Dataset loaded: {len(dataset)} samples") |
|
|
|
|
|
batch_size = 3 |
|
sampler = IPOGroupedBatchSampler( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
drop_last=False |
|
) |
|
|
|
print(f"\n๐ฏ Sampler created with batch_size={batch_size}") |
|
print(f" Total batches: {len(sampler)}") |
|
|
|
|
|
print("\n๐ฆ Checking batch composition:") |
|
for batch_idx, batch_indices in enumerate(sampler): |
|
print(f"\n Batch {batch_idx + 1}: {len(batch_indices)} samples") |
|
|
|
|
|
ipo_groups = [] |
|
for idx in batch_indices: |
|
row = dataset.dataframe.iloc[idx] |
|
ipo_group = row['ipo_group_id'] |
|
uid = row['uid'] |
|
ipo_groups.append(ipo_group) |
|
print(f" - idx={idx}: {uid} (IPO: {ipo_group})") |
|
|
|
|
|
unique_groups = set(ipo_groups) |
|
if len(unique_groups) == 1: |
|
print(f" โ
All samples from same IPO group!") |
|
else: |
|
print(f" โ ๏ธ Mixed IPO groups: {unique_groups}") |
|
|
|
|
|
print("\n\n๐ Testing with shuffle=True:") |
|
sampler_shuffled = IPOGroupedBatchSampler( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
seed=42 |
|
) |
|
|
|
batch_order = [] |
|
for batch_idx, batch_indices in enumerate(sampler_shuffled): |
|
first_idx = batch_indices[0] |
|
row = dataset.dataframe.iloc[first_idx] |
|
ipo_group = row['ipo_group_id'] |
|
batch_order.append(ipo_group) |
|
print(f" Batch {batch_idx + 1}: IPO group = {ipo_group}") |
|
|
|
print("\nโ
IPO Grouped Sampler test completed!") |
|
|
|
return True |
|
|
|
|
|
def test_verl_integration(): |
|
"""VeRL์ create_rl_sampler์ ํตํฉ ํ
์คํธ""" |
|
|
|
print("\n\n๐ง Testing VeRL Integration") |
|
print("=" * 60) |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
|
data_config = OmegaConf.create({ |
|
'train_batch_size': 3, |
|
'shuffle': True, |
|
'use_ipo_grouping': True, |
|
'drop_last': False, |
|
'seed': 42 |
|
}) |
|
|
|
|
|
test_file = '/tmp/test_ipo_grouped.parquet' |
|
if not os.path.exists(test_file): |
|
test_file = create_test_data() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
dataset = TTRLVRDataset( |
|
parquet_files=test_file, |
|
tokenizer=tokenizer |
|
) |
|
|
|
|
|
from verl.trainer.main_ppo import create_rl_sampler |
|
|
|
sampler = create_rl_sampler(data_config, dataset) |
|
|
|
|
|
print(f"Sampler type: {type(sampler).__name__}") |
|
|
|
if hasattr(sampler, '__len__'): |
|
print(f"Number of batches: {len(sampler)}") |
|
|
|
|
|
if hasattr(sampler, '__iter__'): |
|
print("\nFirst 3 batches:") |
|
for i, batch in enumerate(sampler): |
|
if i >= 3: |
|
break |
|
if isinstance(batch, list): |
|
print(f" Batch {i+1}: {len(batch)} samples - indices: {batch}") |
|
else: |
|
print(f" Batch {i+1}: {batch}") |
|
|
|
print("\nโ
VeRL integration test completed!") |
|
|
|
return True |
|
|
|
|
|
if __name__ == "__main__": |
|
print("๐ Starting IPO Grouped Sampling Tests") |
|
print("=" * 80) |
|
|
|
|
|
test_ipo_grouped_sampler() |
|
|
|
|
|
test_verl_integration() |
|
|
|
print("\n" + "=" * 80) |
|
print("๐ All tests completed successfully!") |