hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
TTRLVR Dataset for AZR Integration
TTRLVR parquet ํŒŒ์ผ์„ ์ฝ์–ด AZR ํ•™์Šต์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๋ณ€ํ™˜
"""
import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from .rl_dataset import RLHFDataset
class TTRLVRDataset(RLHFDataset):
"""TTRLVR ๋ฐ์ดํ„ฐ๋ฅผ AZR ํ˜•์‹์œผ๋กœ ๋กœ๋“œํ•˜๋Š” Dataset"""
def __init__(self,
parquet_files: str,
tokenizer: AutoTokenizer,
task_type: Optional[str] = None,
**kwargs):
"""
Args:
parquet_files: TTRLVR parquet ํŒŒ์ผ ๊ฒฝ๋กœ (๋˜๋Š” ๋””๋ ‰ํ† ๋ฆฌ)
tokenizer: ํ† ํฌ๋‚˜์ด์ €
task_type: ํŠน์ • task ํƒ€์ž…๋งŒ ๋กœ๋“œ (induction/deduction/abduction)
"""
# parquet_files๊ฐ€ ListConfig์ธ ๊ฒฝ์šฐ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜
from omegaconf import ListConfig
if isinstance(parquet_files, ListConfig):
parquet_files = list(parquet_files)
# parquet_files๊ฐ€ ๋””๋ ‰ํ† ๋ฆฌ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if isinstance(parquet_files, str) and os.path.isdir(parquet_files):
if task_type:
parquet_files = os.path.join(parquet_files, f"{task_type}.parquet")
else:
# ๋ชจ๋“  task ํƒ€์ž… ํŒŒ์ผ ์ˆ˜์ง‘
files = []
for t in ['induction', 'deduction', 'abduction']:
f = os.path.join(parquet_files, f"{t}.parquet")
if os.path.exists(f):
files.append(f)
parquet_files = files
super().__init__(
parquet_files=parquet_files,
tokenizer=tokenizer,
prompt_key='prompt', # TTRLVR์€ 'prompt' ํ‚ค ์‚ฌ์šฉ
**kwargs
)
self.task_type = task_type
def __getitem__(self, idx):
"""๋‹จ์ผ ์ƒ˜ํ”Œ ๋ฐ˜ํ™˜"""
# TTRLVR ํŠน๋ณ„ ์ฒ˜๋ฆฌ - ์›๋ณธ ๋ฐ์ดํ„ฐ์— ๋จผ์ € ์ ‘๊ทผ
# RLHFDataset.__getitem__์ด prompt๋ฅผ popํ•˜๊ธฐ ์ „์— ์›๋ณธ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
if hasattr(self, 'dataframe') and hasattr(self.dataframe, 'iloc'):
# pandas DataFrame์ธ ๊ฒฝ์šฐ
original_row = self.dataframe.iloc[idx].to_dict()
# prompt ํ•„๋“œ ๋ฐฑ์—…
original_prompt = original_row.get('prompt', None)
else:
original_row = {}
original_prompt = None
# ๊ธฐ๋ณธ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ (RLHFDataset.__getitem__ ํ˜ธ์ถœ)
data = super().__getitem__(idx)
# ๋ฐฑ์—…ํ•œ prompt ์‚ฌ์šฉ
row = original_row
# prompt ์ฒ˜๋ฆฌ - numpy array, list, dict ๋“ฑ ๋‹ค์–‘ํ•œ ํ˜•ํƒœ ์ฒ˜๋ฆฌ
prompt_text = None
if original_prompt is not None:
if isinstance(original_prompt, np.ndarray):
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict):
prompt_text = original_prompt[0].get('content', '')
else:
prompt_text = str(original_prompt)
elif isinstance(original_prompt, list):
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict):
prompt_text = original_prompt[0].get('content', '')
else:
prompt_text = str(original_prompt)
elif isinstance(original_prompt, str):
prompt_text = original_prompt
else:
prompt_text = str(original_prompt)
# prompt๋ฅผ data์— ์ถ”๊ฐ€ (๋ฌธ์ž์—ด๋กœ)
if prompt_text is not None:
data['prompt'] = prompt_text
# TTRLVR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ถ”๊ฐ€
ttrlvr_metadata = {
'task_type': self._extract_task_type(row),
'expected_solution': row.get('ground_truth', ''),
'problem': row.get('problem', {}),
'ipo_group_id': row.get('ipo_group_id', ''),
'uid': row.get('uid', ''),
'evaluation_data': self._prepare_evaluation_data(row)
}
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋ฅผ data์— ์ถ”๊ฐ€
data['ttrlvr_metadata'] = ttrlvr_metadata
return data
def _extract_task_type(self, row: pd.Series) -> str:
"""ํ–‰์—์„œ task ํƒ€์ž… ์ถ”์ถœ"""
uid = row.get('uid', '')
if 'induction' in uid:
return 'induction'
elif 'deduction' in uid:
return 'deduction'
elif 'abduction' in uid:
return 'abduction'
return 'unknown'
def _prepare_evaluation_data(self, row: pd.Series) -> Dict[str, Any]:
"""Task ํƒ€์ž…๋ณ„ evaluation data ์ค€๋น„"""
# 1. ๋จผ์ € ์ €์žฅ๋œ evaluation_data๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ (Phase 1-4์—์„œ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ)
if 'evaluation_data' in row and row['evaluation_data']:
eval_data = row['evaluation_data']
# pandas/numpy ๊ฐ์ฒด๋ฅผ Python ๋„ค์ดํ‹ฐ๋ธŒ ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if hasattr(eval_data, 'item'):
eval_data = eval_data.item()
return eval_data if isinstance(eval_data, dict) else {}
# 2. Fallback: problem ํ•„๋“œ์—์„œ ๊ตฌ์„ฑ (ํ˜ธํ™˜์„ฑ ์œ ์ง€)
task_type = self._extract_task_type(row)
problem = row.get('problem', {})
if task_type == 'induction':
# IPO์—์„œ input/output ์Œ ์ถ”์ถœ
return {
'input_output_pairs': [
(problem.get('input', ''),
problem.get('output', ''))
]
}
elif task_type == 'deduction':
return {
'function_code': problem.get('snippet', ''),
'input': problem.get('input', '')
}
elif task_type == 'abduction':
return {
'function_code': problem.get('snippet', ''),
'expected_output': problem.get('output', '')
}
return {}