|
""" |
|
IPO Group-aware Batch Sampler for TTRLVR |
|
|
|
๋์ผํ ipo_group_id๋ฅผ ๊ฐ์ง task๋ค์ ๊ฐ์ ๋ฐฐ์น์ ๋ฌถ๋ ์ปค์คํ
์ํ๋ฌ |
|
์ด๋ฅผ ํตํด ๋์ผํ IPO triple์์ ์์ฑ๋ induction/deduction/abduction task๋ค์ด |
|
ํจ๊ป ํ์ต๋๋๋ก ๋ณด์ฅํฉ๋๋ค. |
|
""" |
|
|
|
import torch |
|
from torch.utils.data import Sampler, BatchSampler |
|
from typing import Iterator, List, Optional |
|
import random |
|
from collections import defaultdict |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
class IPOGroupedBatchSampler(Sampler): |
|
"""๋์ผํ IPO์์ ์์ฑ๋ task๋ค์ ๊ฐ์ ๋ฐฐ์น์ ๋ฌถ๋ ์ํ๋ฌ""" |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size: int, |
|
shuffle: bool = True, |
|
drop_last: bool = False, |
|
seed: int = 42): |
|
""" |
|
Args: |
|
dataset: ipo_group_id๋ฅผ ๊ฐ์ง ๋ฐ์ดํฐ์
(TTRLVRDataset) |
|
batch_size: ๋ฐฐ์น ํฌ๊ธฐ |
|
shuffle: ๊ทธ๋ฃน ์์๋ฅผ ์์์ง ์ฌ๋ถ |
|
drop_last: ๋ง์ง๋ง ๋ถ์์ ํ ๋ฐฐ์น๋ฅผ ๋ฒ๋ฆด์ง ์ฌ๋ถ |
|
seed: ๋๋ค ์๋ |
|
""" |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.drop_last = drop_last |
|
self.generator = torch.Generator() |
|
self.generator.manual_seed(seed) |
|
|
|
|
|
self.groups = defaultdict(list) |
|
self._build_groups() |
|
|
|
|
|
self._create_batches() |
|
|
|
def _build_groups(self): |
|
"""๋ฐ์ดํฐ์
์์ ipo_group_id๋ณ๋ก ์ธ๋ฑ์ค๋ฅผ ๊ทธ๋ฃนํ""" |
|
|
|
for idx in range(len(self.dataset)): |
|
|
|
if hasattr(self.dataset, 'dataframe'): |
|
row = self.dataset.dataframe.iloc[idx] |
|
ipo_group_id = row.get('ipo_group_id', None) |
|
|
|
|
|
if not ipo_group_id or ipo_group_id == '': |
|
ipo_group_id = f'individual_{idx}' |
|
else: |
|
|
|
ipo_group_id = f'individual_{idx}' |
|
|
|
self.groups[ipo_group_id].append(idx) |
|
|
|
print(f"[IPOGroupedBatchSampler] Built {len(self.groups)} IPO groups from {len(self.dataset)} samples") |
|
|
|
|
|
group_sizes = [len(indices) for indices in self.groups.values()] |
|
if group_sizes: |
|
print(f" - Group sizes: min={min(group_sizes)}, max={max(group_sizes)}, avg={np.mean(group_sizes):.2f}") |
|
|
|
def _create_batches(self): |
|
"""๊ทธ๋ฃน๋ณ๋ก ๋ฐฐ์น ์์ฑ""" |
|
self.batches = [] |
|
|
|
|
|
all_indices = [] |
|
|
|
for group_id, indices in self.groups.items(): |
|
|
|
|
|
if len(indices) <= self.batch_size: |
|
|
|
all_indices.extend(indices) |
|
else: |
|
|
|
for i in range(0, len(indices), self.batch_size): |
|
chunk = indices[i:i + self.batch_size] |
|
all_indices.extend(chunk) |
|
|
|
|
|
current_batch = [] |
|
for idx in all_indices: |
|
current_batch.append(idx) |
|
|
|
if len(current_batch) == self.batch_size: |
|
self.batches.append(current_batch) |
|
current_batch = [] |
|
|
|
|
|
if current_batch and not self.drop_last: |
|
self.batches.append(current_batch) |
|
elif current_batch and self.drop_last: |
|
print(f"[IPOGroupedBatchSampler] Dropped last incomplete batch of size {len(current_batch)}") |
|
|
|
print(f"[IPOGroupedBatchSampler] Created {len(self.batches)} batches") |
|
|
|
def __iter__(self) -> Iterator[List[int]]: |
|
"""๋ฐฐ์น ๋ฐ๋ณต์""" |
|
|
|
if self.shuffle: |
|
indices = torch.randperm(len(self.batches), generator=self.generator).tolist() |
|
shuffled_batches = [self.batches[i] for i in indices] |
|
else: |
|
shuffled_batches = self.batches |
|
|
|
|
|
for batch in shuffled_batches: |
|
|
|
if self.shuffle: |
|
random.shuffle(batch) |
|
yield batch |
|
|
|
def __len__(self) -> int: |
|
"""์ ์ฒด ๋ฐฐ์น ์""" |
|
return len(self.batches) |
|
|
|
|
|
class IPOGroupPreservingBatchSampler(BatchSampler): |
|
""" |
|
IPO ๊ทธ๋ฃน์ ์ต๋ํ ๋ณด์กดํ๋ฉด์ ๋ฐฐ์น๋ฅผ ์์ฑํ๋ ์ํ๋ฌ |
|
|
|
์ด ์ํ๋ฌ๋ ๋ค์ ์ฐ์ ์์๋ก ์๋ํฉ๋๋ค: |
|
1. ๊ฐ์ ipo_group_id๋ฅผ ๊ฐ์ง ์ํ๋ค์ ์ฐ์ ์ ์ผ๋ก ๊ฐ์ ๋ฐฐ์น์ ๋ฐฐ์น |
|
2. ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ์ฑ์ฐ๊ธฐ ์ํด ํ์์ ๋ค๋ฅธ ๊ทธ๋ฃน์ ์ํ ์ถ๊ฐ |
|
3. ๋ชจ๋ ์ํ์ด ์ ํํ ํ ๋ฒ์ฉ ์ฌ์ฉ๋๋๋ก ๋ณด์ฅ |
|
""" |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size: int, |
|
shuffle: bool = True, |
|
drop_last: bool = False, |
|
seed: int = 42): |
|
""" |
|
Args: |
|
dataset: TTRLVRDataset ์ธ์คํด์ค |
|
batch_size: ๋ฐฐ์น ํฌ๊ธฐ |
|
shuffle: ๋ฐฐ์น ๋ฐ ๊ทธ๋ฃน ์์ ์๊ธฐ |
|
drop_last: ๋ง์ง๋ง ๋ถ์์ ํ ๋ฐฐ์น ๋ฒ๋ฆฌ๊ธฐ |
|
seed: ๋๋ค ์๋ |
|
""" |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.drop_last = drop_last |
|
self.seed = seed |
|
|
|
|
|
self.groups = self._build_groups() |
|
|
|
def _build_groups(self): |
|
"""ipo_group_id๋ณ๋ก ์ํ ์ธ๋ฑ์ค ๊ทธ๋ฃนํ""" |
|
groups = defaultdict(list) |
|
|
|
for idx in range(len(self.dataset)): |
|
if hasattr(self.dataset, 'dataframe'): |
|
row = self.dataset.dataframe.iloc[idx] |
|
ipo_group_id = row.get('ipo_group_id', '') |
|
|
|
|
|
if not ipo_group_id: |
|
ipo_group_id = f'single_{idx}' |
|
else: |
|
ipo_group_id = f'single_{idx}' |
|
|
|
groups[ipo_group_id].append(idx) |
|
|
|
return groups |
|
|
|
def __iter__(self): |
|
"""๋ฐฐ์น ์์ฑ ๋ฐ ๋ฐ๋ณต""" |
|
|
|
group_list = list(self.groups.items()) |
|
|
|
|
|
if self.shuffle: |
|
random.seed(self.seed) |
|
random.shuffle(group_list) |
|
|
|
|
|
current_batch = [] |
|
|
|
for group_id, indices in group_list: |
|
|
|
if self.shuffle: |
|
random.shuffle(indices) |
|
|
|
for idx in indices: |
|
current_batch.append(idx) |
|
|
|
|
|
if len(current_batch) == self.batch_size: |
|
yield current_batch |
|
current_batch = [] |
|
|
|
|
|
if current_batch and not self.drop_last: |
|
yield current_batch |
|
|
|
def __len__(self): |
|
"""์ ์ฒด ๋ฐฐ์น ์ ๊ณ์ฐ""" |
|
total_samples = len(self.dataset) |
|
|
|
if self.drop_last: |
|
return total_samples // self.batch_size |
|
else: |
|
return (total_samples + self.batch_size - 1) // self.batch_size |