File size: 6,089 Bytes
f50dc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
TTRLVR Dataset for AZR Integration
TTRLVR parquet ํ์ผ์ ์ฝ์ด AZR ํ์ต์ ์ฌ์ฉํ ์ ์๋๋ก ๋ณํ
"""
import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from .rl_dataset import RLHFDataset
class TTRLVRDataset(RLHFDataset):
"""TTRLVR ๋ฐ์ดํฐ๋ฅผ AZR ํ์์ผ๋ก ๋ก๋ํ๋ Dataset"""
def __init__(self,
parquet_files: str,
tokenizer: AutoTokenizer,
task_type: Optional[str] = None,
**kwargs):
"""
Args:
parquet_files: TTRLVR parquet ํ์ผ ๊ฒฝ๋ก (๋๋ ๋๋ ํ ๋ฆฌ)
tokenizer: ํ ํฌ๋์ด์
task_type: ํน์ task ํ์
๋ง ๋ก๋ (induction/deduction/abduction)
"""
# parquet_files๊ฐ ListConfig์ธ ๊ฒฝ์ฐ ๋ฆฌ์คํธ๋ก ๋ณํ
from omegaconf import ListConfig
if isinstance(parquet_files, ListConfig):
parquet_files = list(parquet_files)
# parquet_files๊ฐ ๋๋ ํ ๋ฆฌ์ธ ๊ฒฝ์ฐ ์ฒ๋ฆฌ
if isinstance(parquet_files, str) and os.path.isdir(parquet_files):
if task_type:
parquet_files = os.path.join(parquet_files, f"{task_type}.parquet")
else:
# ๋ชจ๋ task ํ์
ํ์ผ ์์ง
files = []
for t in ['induction', 'deduction', 'abduction']:
f = os.path.join(parquet_files, f"{t}.parquet")
if os.path.exists(f):
files.append(f)
parquet_files = files
super().__init__(
parquet_files=parquet_files,
tokenizer=tokenizer,
prompt_key='prompt', # TTRLVR์ 'prompt' ํค ์ฌ์ฉ
**kwargs
)
self.task_type = task_type
def __getitem__(self, idx):
"""๋จ์ผ ์ํ ๋ฐํ"""
# TTRLVR ํน๋ณ ์ฒ๋ฆฌ - ์๋ณธ ๋ฐ์ดํฐ์ ๋จผ์ ์ ๊ทผ
# RLHFDataset.__getitem__์ด prompt๋ฅผ popํ๊ธฐ ์ ์ ์๋ณธ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
if hasattr(self, 'dataframe') and hasattr(self.dataframe, 'iloc'):
# pandas DataFrame์ธ ๊ฒฝ์ฐ
original_row = self.dataframe.iloc[idx].to_dict()
# prompt ํ๋ ๋ฐฑ์
original_prompt = original_row.get('prompt', None)
else:
original_row = {}
original_prompt = None
# ๊ธฐ๋ณธ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ (RLHFDataset.__getitem__ ํธ์ถ)
data = super().__getitem__(idx)
# ๋ฐฑ์
ํ prompt ์ฌ์ฉ
row = original_row
# prompt ์ฒ๋ฆฌ - numpy array, list, dict ๋ฑ ๋ค์ํ ํํ ์ฒ๋ฆฌ
prompt_text = None
if original_prompt is not None:
if isinstance(original_prompt, np.ndarray):
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict):
prompt_text = original_prompt[0].get('content', '')
else:
prompt_text = str(original_prompt)
elif isinstance(original_prompt, list):
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict):
prompt_text = original_prompt[0].get('content', '')
else:
prompt_text = str(original_prompt)
elif isinstance(original_prompt, str):
prompt_text = original_prompt
else:
prompt_text = str(original_prompt)
# prompt๋ฅผ data์ ์ถ๊ฐ (๋ฌธ์์ด๋ก)
if prompt_text is not None:
data['prompt'] = prompt_text
# TTRLVR ๋ฉํ๋ฐ์ดํฐ ์ถ๊ฐ
ttrlvr_metadata = {
'task_type': self._extract_task_type(row),
'expected_solution': row.get('ground_truth', ''),
'problem': row.get('problem', {}),
'ipo_group_id': row.get('ipo_group_id', ''),
'uid': row.get('uid', ''),
'evaluation_data': self._prepare_evaluation_data(row)
}
# ๋ฉํ๋ฐ์ดํฐ๋ฅผ data์ ์ถ๊ฐ
data['ttrlvr_metadata'] = ttrlvr_metadata
return data
def _extract_task_type(self, row: pd.Series) -> str:
"""ํ์์ task ํ์
์ถ์ถ"""
uid = row.get('uid', '')
if 'induction' in uid:
return 'induction'
elif 'deduction' in uid:
return 'deduction'
elif 'abduction' in uid:
return 'abduction'
return 'unknown'
def _prepare_evaluation_data(self, row: pd.Series) -> Dict[str, Any]:
"""Task ํ์
๋ณ evaluation data ์ค๋น"""
# 1. ๋จผ์ ์ ์ฅ๋ evaluation_data๊ฐ ์๋์ง ํ์ธ (Phase 1-4์์ ์์ฑ๋ ๋ฐ์ดํฐ)
if 'evaluation_data' in row and row['evaluation_data']:
eval_data = row['evaluation_data']
# pandas/numpy ๊ฐ์ฒด๋ฅผ Python ๋ค์ดํฐ๋ธ ํ์
์ผ๋ก ๋ณํ
if hasattr(eval_data, 'item'):
eval_data = eval_data.item()
return eval_data if isinstance(eval_data, dict) else {}
# 2. Fallback: problem ํ๋์์ ๊ตฌ์ฑ (ํธํ์ฑ ์ ์ง)
task_type = self._extract_task_type(row)
problem = row.get('problem', {})
if task_type == 'induction':
# IPO์์ input/output ์ ์ถ์ถ
return {
'input_output_pairs': [
(problem.get('input', ''),
problem.get('output', ''))
]
}
elif task_type == 'deduction':
return {
'function_code': problem.get('snippet', ''),
'input': problem.get('input', '')
}
elif task_type == 'abduction':
return {
'function_code': problem.get('snippet', ''),
'expected_output': problem.get('output', '')
}
return {} |