hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
IPO Group-aware Batch Sampler for TTRLVR
๋™์ผํ•œ ipo_group_id๋ฅผ ๊ฐ€์ง„ task๋“ค์„ ๊ฐ™์€ ๋ฐฐ์น˜์— ๋ฌถ๋Š” ์ปค์Šคํ…€ ์ƒ˜ํ”Œ๋Ÿฌ
์ด๋ฅผ ํ†ตํ•ด ๋™์ผํ•œ IPO triple์—์„œ ์ƒ์„ฑ๋œ induction/deduction/abduction task๋“ค์ด
ํ•จ๊ป˜ ํ•™์Šต๋˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
"""
import torch
from torch.utils.data import Sampler, BatchSampler
from typing import Iterator, List, Optional
import random
from collections import defaultdict
import pandas as pd
import numpy as np
class IPOGroupedBatchSampler(Sampler):
"""๋™์ผํ•œ IPO์—์„œ ์ƒ์„ฑ๋œ task๋“ค์„ ๊ฐ™์€ ๋ฐฐ์น˜์— ๋ฌถ๋Š” ์ƒ˜ํ”Œ๋Ÿฌ"""
def __init__(self,
dataset,
batch_size: int,
shuffle: bool = True,
drop_last: bool = False,
seed: int = 42):
"""
Args:
dataset: ipo_group_id๋ฅผ ๊ฐ€์ง„ ๋ฐ์ดํ„ฐ์…‹ (TTRLVRDataset)
batch_size: ๋ฐฐ์น˜ ํฌ๊ธฐ
shuffle: ๊ทธ๋ฃน ์ˆœ์„œ๋ฅผ ์„ž์„์ง€ ์—ฌ๋ถ€
drop_last: ๋งˆ์ง€๋ง‰ ๋ถˆ์™„์ „ํ•œ ๋ฐฐ์น˜๋ฅผ ๋ฒ„๋ฆด์ง€ ์—ฌ๋ถ€
seed: ๋žœ๋ค ์‹œ๋“œ
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.generator = torch.Generator()
self.generator.manual_seed(seed)
# ipo_group_id๋ณ„๋กœ ์ธ๋ฑ์Šค ๊ทธ๋ฃนํ•‘
self.groups = defaultdict(list)
self._build_groups()
# ๋ฐฐ์น˜ ์ƒ์„ฑ
self._create_batches()
def _build_groups(self):
"""๋ฐ์ดํ„ฐ์…‹์—์„œ ipo_group_id๋ณ„๋กœ ์ธ๋ฑ์Šค๋ฅผ ๊ทธ๋ฃนํ•‘"""
for idx in range(len(self.dataset)):
# TTRLVRDataset์˜ dataframe์—์„œ ์ง์ ‘ ์ ‘๊ทผ
if hasattr(self.dataset, 'dataframe'):
row = self.dataset.dataframe.iloc[idx]
ipo_group_id = row.get('ipo_group_id', None)
# ipo_group_id๊ฐ€ ์—†์œผ๋ฉด ๊ฐœ๋ณ„ ๊ทธ๋ฃน์œผ๋กœ ์ฒ˜๋ฆฌ
if not ipo_group_id or ipo_group_id == '':
ipo_group_id = f'individual_{idx}'
else:
# Fallback: ๊ฐœ๋ณ„ ๊ทธ๋ฃน
ipo_group_id = f'individual_{idx}'
self.groups[ipo_group_id].append(idx)
print(f"[IPOGroupedBatchSampler] Built {len(self.groups)} IPO groups from {len(self.dataset)} samples")
# ๊ทธ๋ฃน ํฌ๊ธฐ ํ†ต๊ณ„
group_sizes = [len(indices) for indices in self.groups.values()]
if group_sizes:
print(f" - Group sizes: min={min(group_sizes)}, max={max(group_sizes)}, avg={np.mean(group_sizes):.2f}")
def _create_batches(self):
"""๊ทธ๋ฃน๋ณ„๋กœ ๋ฐฐ์น˜ ์ƒ์„ฑ"""
self.batches = []
# ๋ชจ๋“  ์ธ๋ฑ์Šค๋ฅผ ์ˆ˜์ง‘ (๊ทธ๋ฃน ๋‹จ์œ„๋กœ)
all_indices = []
for group_id, indices in self.groups.items():
# ๊ฐ™์€ IPO ๊ทธ๋ฃน์˜ task๋“ค์„ ํ•จ๊ป˜ ์œ ์ง€
# ์ผ๋ฐ˜์ ์œผ๋กœ 3๊ฐœ (induction, deduction, abduction)
if len(indices) <= self.batch_size:
# ๊ทธ๋ฃน์ด ๋ฐฐ์น˜ ํฌ๊ธฐ๋ณด๋‹ค ์ž‘์œผ๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
all_indices.extend(indices)
else:
# ๊ทธ๋ฃน์ด ๋ฐฐ์น˜ ํฌ๊ธฐ๋ณด๋‹ค ํฌ๋ฉด ๋ถ„ํ•  (๋“œ๋ฌผ์ง€๋งŒ ๊ฐ€๋Šฅ)
for i in range(0, len(indices), self.batch_size):
chunk = indices[i:i + self.batch_size]
all_indices.extend(chunk)
# ๋ฐฐ์น˜ ์ƒ์„ฑ
current_batch = []
for idx in all_indices:
current_batch.append(idx)
if len(current_batch) == self.batch_size:
self.batches.append(current_batch)
current_batch = []
# ๋งˆ์ง€๋ง‰ ๋ถˆ์™„์ „ํ•œ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
if current_batch and not self.drop_last:
self.batches.append(current_batch)
elif current_batch and self.drop_last:
print(f"[IPOGroupedBatchSampler] Dropped last incomplete batch of size {len(current_batch)}")
print(f"[IPOGroupedBatchSampler] Created {len(self.batches)} batches")
def __iter__(self) -> Iterator[List[int]]:
"""๋ฐฐ์น˜ ๋ฐ˜๋ณต์ž"""
# ๋ฐฐ์น˜ ์ˆœ์„œ ์„ž๊ธฐ
if self.shuffle:
indices = torch.randperm(len(self.batches), generator=self.generator).tolist()
shuffled_batches = [self.batches[i] for i in indices]
else:
shuffled_batches = self.batches
# ๊ฐ ๋ฐฐ์น˜ yield
for batch in shuffled_batches:
# ๋ฐฐ์น˜ ๋‚ด๋ถ€๋„ ์„ž์„ ์ˆ˜ ์žˆ์Œ (์„ ํƒ์ )
if self.shuffle:
random.shuffle(batch)
yield batch
def __len__(self) -> int:
"""์ „์ฒด ๋ฐฐ์น˜ ์ˆ˜"""
return len(self.batches)
class IPOGroupPreservingBatchSampler(BatchSampler):
"""
IPO ๊ทธ๋ฃน์„ ์ตœ๋Œ€ํ•œ ๋ณด์กดํ•˜๋ฉด์„œ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์ƒ˜ํ”Œ๋Ÿฌ
์ด ์ƒ˜ํ”Œ๋Ÿฌ๋Š” ๋‹ค์Œ ์šฐ์„ ์ˆœ์œ„๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค:
1. ๊ฐ™์€ ipo_group_id๋ฅผ ๊ฐ€์ง„ ์ƒ˜ํ”Œ๋“ค์„ ์šฐ์„ ์ ์œผ๋กœ ๊ฐ™์€ ๋ฐฐ์น˜์— ๋ฐฐ์น˜
2. ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ์ฑ„์šฐ๊ธฐ ์œ„ํ•ด ํ•„์š”์‹œ ๋‹ค๋ฅธ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ ์ถ”๊ฐ€
3. ๋ชจ๋“  ์ƒ˜ํ”Œ์ด ์ •ํ™•ํžˆ ํ•œ ๋ฒˆ์”ฉ ์‚ฌ์šฉ๋˜๋„๋ก ๋ณด์žฅ
"""
def __init__(self,
dataset,
batch_size: int,
shuffle: bool = True,
drop_last: bool = False,
seed: int = 42):
"""
Args:
dataset: TTRLVRDataset ์ธ์Šคํ„ด์Šค
batch_size: ๋ฐฐ์น˜ ํฌ๊ธฐ
shuffle: ๋ฐฐ์น˜ ๋ฐ ๊ทธ๋ฃน ์ˆœ์„œ ์„ž๊ธฐ
drop_last: ๋งˆ์ง€๋ง‰ ๋ถˆ์™„์ „ํ•œ ๋ฐฐ์น˜ ๋ฒ„๋ฆฌ๊ธฐ
seed: ๋žœ๋ค ์‹œ๋“œ
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed
# ๊ทธ๋ฃน๋ณ„ ์ธ๋ฑ์Šค ๊ตฌ์ถ•
self.groups = self._build_groups()
def _build_groups(self):
"""ipo_group_id๋ณ„๋กœ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ทธ๋ฃนํ•‘"""
groups = defaultdict(list)
for idx in range(len(self.dataset)):
if hasattr(self.dataset, 'dataframe'):
row = self.dataset.dataframe.iloc[idx]
ipo_group_id = row.get('ipo_group_id', '')
# ๋นˆ ๊ฐ’์ด๋ฉด ๊ฐœ๋ณ„ ์ฒ˜๋ฆฌ
if not ipo_group_id:
ipo_group_id = f'single_{idx}'
else:
ipo_group_id = f'single_{idx}'
groups[ipo_group_id].append(idx)
return groups
def __iter__(self):
"""๋ฐฐ์น˜ ์ƒ์„ฑ ๋ฐ ๋ฐ˜๋ณต"""
# ๊ทธ๋ฃน๋“ค์„ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜
group_list = list(self.groups.items())
# ์…”ํ”Œ
if self.shuffle:
random.seed(self.seed)
random.shuffle(group_list)
# ๋ฐฐ์น˜ ์ƒ์„ฑ
current_batch = []
for group_id, indices in group_list:
# ๊ทธ๋ฃน ๋‚ด ์ธ๋ฑ์Šค๋„ ์…”ํ”Œ
if self.shuffle:
random.shuffle(indices)
for idx in indices:
current_batch.append(idx)
# ๋ฐฐ์น˜๊ฐ€ ๊ฐ€๋“ ์ฐจ๋ฉด yield
if len(current_batch) == self.batch_size:
yield current_batch
current_batch = []
# ๋งˆ์ง€๋ง‰ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
if current_batch and not self.drop_last:
yield current_batch
def __len__(self):
"""์ „์ฒด ๋ฐฐ์น˜ ์ˆ˜ ๊ณ„์‚ฐ"""
total_samples = len(self.dataset)
if self.drop_last:
return total_samples // self.batch_size
else:
return (total_samples + self.batch_size - 1) // self.batch_size