|
from verl.utils.dataset.rl_dataset import RLHFDataset |
|
from verl.utils.model import compute_position_id_with_mask |
|
import verl.utils.torch_functional as verl_F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from omegaconf import ListConfig |
|
import os |
|
from typing import List, Union |
|
import copy |
|
import pandas as pd |
|
|
|
import torch |
|
import numpy as np |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import AutoTokenizer, PreTrainedTokenizer |
|
from verl.utils.fs import copy_local_path_from_hdfs |
|
|
|
from verl.utils.model import compute_position_id_with_mask |
|
import verl.utils.torch_functional as verl_F |
|
|
|
|
|
def collate_fn(data_list: list[dict]) -> dict: |
|
tensors = {} |
|
non_tensors = {} |
|
|
|
for data in data_list: |
|
for key, val in data.items(): |
|
if isinstance(val, torch.Tensor): |
|
if key not in tensors: |
|
tensors[key] = [] |
|
tensors[key].append(val) |
|
else: |
|
if key not in non_tensors: |
|
non_tensors[key] = [] |
|
non_tensors[key].append(val) |
|
|
|
for key, val in tensors.items(): |
|
tensors[key] = torch.stack(val, dim=0) |
|
|
|
for key, val in non_tensors.items(): |
|
non_tensors[key] = np.array(val, dtype=object) |
|
|
|
output = {} |
|
output.update(tensors) |
|
output.update(non_tensors) |
|
return output |
|
|
|
|
|
class RLHFDataset(Dataset): |
|
""" |
|
We assume the dataset contains a column that contains prompts and other information |
|
""" |
|
|
|
def __init__(self, |
|
parquet_files: Union[str, List[str]], |
|
tokenizer: PreTrainedTokenizer, |
|
prompt_key='prompt', |
|
max_prompt_length=1024, |
|
filter_prompts=True, |
|
cache_dir='~/.cache/verl/rlhf', |
|
chat_template_func=None, |
|
return_raw_chat=False, |
|
truncation='error', |
|
extra_source_key=None, |
|
): |
|
if not isinstance(parquet_files, (List, ListConfig)): |
|
parquet_files = [parquet_files] |
|
|
|
|
|
filtered_files = [] |
|
for file_path in parquet_files: |
|
if os.path.isdir(file_path): |
|
|
|
parquet_files_in_dir = [ |
|
os.path.join(file_path, f) |
|
for f in os.listdir(file_path) |
|
if f.endswith('.parquet') |
|
] |
|
filtered_files.extend(sorted(parquet_files_in_dir)) |
|
else: |
|
|
|
filtered_files.append(file_path) |
|
|
|
self.parquet_files = copy.deepcopy(filtered_files) |
|
self.original_parquet_files = copy.deepcopy(filtered_files) |
|
self.cache_dir = os.path.expanduser(cache_dir) |
|
self.tokenizer = tokenizer |
|
self.extra_source_key = extra_source_key |
|
|
|
self.prompt_key = prompt_key |
|
self.max_prompt_length = max_prompt_length |
|
self.filter_prompts = filter_prompts |
|
|
|
self.return_raw_chat = return_raw_chat |
|
self.chat_template_func = chat_template_func |
|
self.truncation = truncation |
|
|
|
|
|
|
|
self.serialize_dataset = False |
|
self._download() |
|
self._read_files_and_tokenize() |
|
|
|
def _download(self, use_origin_parquet=False): |
|
from verl.utils.fs import copy_local_path_from_hdfs |
|
parquet_files = self.parquet_files if not use_origin_parquet else self.original_parquet_files |
|
for i, parquet_file in enumerate(parquet_files): |
|
self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir) |
|
|
|
def _read_files_and_tokenize(self): |
|
dataframes = [] |
|
for parquet_file in self.parquet_files: |
|
|
|
dataframe = pd.read_parquet(parquet_file) |
|
dataframes.append(dataframe) |
|
self.dataframe = pd.concat(dataframes) |
|
|
|
print(f'original dataset len: {len(self.dataframe)}{". Source: " + self.extra_source_key if self.extra_source_key else ""}') |
|
|
|
|
|
tokenizer = self.tokenizer |
|
prompt_key = self.prompt_key |
|
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len( |
|
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length, |
|
axis=1)] |
|
|
|
print(f'filter dataset len: {len(self.dataframe)}{". Source: " + self.extra_source_key if self.extra_source_key else ""}') |
|
|
|
def resume_dataset_state(self): |
|
self.serialize_dataset = False if hasattr(self, 'original_parquet_files') else True |
|
|
|
if not self.serialize_dataset: |
|
self._download(use_origin_parquet=True) |
|
self._read_files_and_tokenize() |
|
else: |
|
print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance') |
|
|
|
def __len__(self): |
|
return len(self.dataframe) |
|
|
|
def __getitem__(self, item): |
|
""" |
|
Note that we also return the raw_input_ids so that it can be combined with other chat template |
|
""" |
|
row_dict = self.dataframe.iloc[item].to_dict() |
|
|
|
chat = row_dict.pop(self.prompt_key) |
|
|
|
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) |
|
|
|
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template, |
|
tokenizer=self.tokenizer, |
|
max_length=self.max_prompt_length, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
left_pad=True, |
|
truncation=self.truncation) |
|
|
|
position_ids = compute_position_id_with_mask(attention_mask) |
|
|
|
row_dict['input_ids'] = input_ids[0] |
|
row_dict['attention_mask'] = attention_mask[0] |
|
row_dict['position_ids'] = position_ids[0] |
|
|
|
|
|
if self.return_raw_chat: |
|
row_dict['raw_prompt'] = chat.tolist() |
|
|
|
|
|
index = row_dict.get("extra_info", {}).get("index", 0) |
|
row_dict["index"] = index |
|
|
|
return row_dict |
|
|
|
def __getstate__(self): |
|
if not self.serialize_dataset: |
|
state = self.__dict__.copy() |
|
|
|
if 'dataframe' in state: |
|
del state['dataframe'] |
|
return state |
|
return self.__dict__.copy() |
|
|