|
""" |
|
TTRLVR Dataset for AZR Integration |
|
|
|
TTRLVR parquet ํ์ผ์ ์ฝ์ด AZR ํ์ต์ ์ฌ์ฉํ ์ ์๋๋ก ๋ณํ |
|
""" |
|
|
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
from typing import Dict, List, Any, Optional |
|
from torch.utils.data import Dataset |
|
from transformers import AutoTokenizer |
|
|
|
from .rl_dataset import RLHFDataset |
|
|
|
|
|
class TTRLVRDataset(RLHFDataset): |
|
"""TTRLVR ๋ฐ์ดํฐ๋ฅผ AZR ํ์์ผ๋ก ๋ก๋ํ๋ Dataset""" |
|
|
|
def __init__(self, |
|
parquet_files: str, |
|
tokenizer: AutoTokenizer, |
|
task_type: Optional[str] = None, |
|
**kwargs): |
|
""" |
|
Args: |
|
parquet_files: TTRLVR parquet ํ์ผ ๊ฒฝ๋ก (๋๋ ๋๋ ํ ๋ฆฌ) |
|
tokenizer: ํ ํฌ๋์ด์ |
|
task_type: ํน์ task ํ์
๋ง ๋ก๋ (induction/deduction/abduction) |
|
""" |
|
|
|
from omegaconf import ListConfig |
|
if isinstance(parquet_files, ListConfig): |
|
parquet_files = list(parquet_files) |
|
|
|
|
|
if isinstance(parquet_files, str) and os.path.isdir(parquet_files): |
|
if task_type: |
|
parquet_files = os.path.join(parquet_files, f"{task_type}.parquet") |
|
else: |
|
|
|
files = [] |
|
for t in ['induction', 'deduction', 'abduction']: |
|
f = os.path.join(parquet_files, f"{t}.parquet") |
|
if os.path.exists(f): |
|
files.append(f) |
|
parquet_files = files |
|
|
|
super().__init__( |
|
parquet_files=parquet_files, |
|
tokenizer=tokenizer, |
|
prompt_key='prompt', |
|
**kwargs |
|
) |
|
|
|
self.task_type = task_type |
|
|
|
def __getitem__(self, idx): |
|
"""๋จ์ผ ์ํ ๋ฐํ""" |
|
|
|
|
|
if hasattr(self, 'dataframe') and hasattr(self.dataframe, 'iloc'): |
|
|
|
original_row = self.dataframe.iloc[idx].to_dict() |
|
|
|
original_prompt = original_row.get('prompt', None) |
|
else: |
|
original_row = {} |
|
original_prompt = None |
|
|
|
|
|
data = super().__getitem__(idx) |
|
|
|
|
|
row = original_row |
|
|
|
|
|
prompt_text = None |
|
if original_prompt is not None: |
|
if isinstance(original_prompt, np.ndarray): |
|
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict): |
|
prompt_text = original_prompt[0].get('content', '') |
|
else: |
|
prompt_text = str(original_prompt) |
|
elif isinstance(original_prompt, list): |
|
if len(original_prompt) > 0 and isinstance(original_prompt[0], dict): |
|
prompt_text = original_prompt[0].get('content', '') |
|
else: |
|
prompt_text = str(original_prompt) |
|
elif isinstance(original_prompt, str): |
|
prompt_text = original_prompt |
|
else: |
|
prompt_text = str(original_prompt) |
|
|
|
|
|
if prompt_text is not None: |
|
data['prompt'] = prompt_text |
|
|
|
|
|
ttrlvr_metadata = { |
|
'task_type': self._extract_task_type(row), |
|
'expected_solution': row.get('ground_truth', ''), |
|
'problem': row.get('problem', {}), |
|
'ipo_group_id': row.get('ipo_group_id', ''), |
|
'uid': row.get('uid', ''), |
|
'evaluation_data': self._prepare_evaluation_data(row) |
|
} |
|
|
|
|
|
data['ttrlvr_metadata'] = ttrlvr_metadata |
|
|
|
return data |
|
|
|
def _extract_task_type(self, row: pd.Series) -> str: |
|
"""ํ์์ task ํ์
์ถ์ถ""" |
|
uid = row.get('uid', '') |
|
if 'induction' in uid: |
|
return 'induction' |
|
elif 'deduction' in uid: |
|
return 'deduction' |
|
elif 'abduction' in uid: |
|
return 'abduction' |
|
return 'unknown' |
|
|
|
def _prepare_evaluation_data(self, row: pd.Series) -> Dict[str, Any]: |
|
"""Task ํ์
๋ณ evaluation data ์ค๋น""" |
|
|
|
|
|
if 'evaluation_data' in row and row['evaluation_data']: |
|
eval_data = row['evaluation_data'] |
|
|
|
if hasattr(eval_data, 'item'): |
|
eval_data = eval_data.item() |
|
return eval_data if isinstance(eval_data, dict) else {} |
|
|
|
|
|
task_type = self._extract_task_type(row) |
|
problem = row.get('problem', {}) |
|
|
|
if task_type == 'induction': |
|
|
|
return { |
|
'input_output_pairs': [ |
|
(problem.get('input', ''), |
|
problem.get('output', '')) |
|
] |
|
} |
|
elif task_type == 'deduction': |
|
return { |
|
'function_code': problem.get('snippet', ''), |
|
'input': problem.get('input', '') |
|
} |
|
elif task_type == 'abduction': |
|
return { |
|
'function_code': problem.get('snippet', ''), |
|
'expected_output': problem.get('output', '') |
|
} |
|
|
|
return {} |