Spaces:
Runtime error
Runtime error
from typing import Union | |
from torch.utils.data import Dataset | |
from torch import LongTensor, cuda | |
from transformers import PreTrainedTokenizerFast | |
from fastparquet import ParquetFile | |
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
import datasets | |
import pyarrow.parquet as pq | |
from numpy import array, int64 | |
from numpy.random import shuffle | |
# import sys | |
# sys.path.extend(['.', '..']) | |
from config import PROJECT_ROOT | |
class MyDataset(Dataset): | |
def __init__(self, | |
parquet_file: str, | |
tokenizer_dir: str, | |
keep_in_memory: bool=False, | |
max_seq_len: int=512, | |
buffer_size: int=40960, | |
) -> None: | |
''' | |
keep_in_memory: 是否将parquet文件转换为pandas.DataFrame格式存放到内存, | |
False将使用迭代生成器(迭代生成器不支持打乱数据),减少大数据集内存占用 | |
''' | |
super().__init__() | |
if cuda.device_count() >= 2 and not keep_in_memory: | |
raise ValueError(f'多GPU时使用MyDataset,参数keep_in_memory必须=True,否则无法进行分布式训练. 当前keep_in_memory={keep_in_memory}') | |
self.keep_in_memory = keep_in_memory | |
self.max_seq_len = max_seq_len | |
# 使用pyarrow.parquet读取,to_pandas、for遍历速度更快 | |
parquet_table = pq.read_table(parquet_file) | |
# 获取数据集长度 | |
self.length = parquet_table.num_rows | |
# 缓冲区大小不能超过数据长度 | |
self.buffer_size = self.length if buffer_size > self.length else buffer_size | |
if keep_in_memory: | |
# 转化为pandas放到内存中 | |
self.data = parquet_table.to_pandas() | |
else: | |
self.data = parquet_table | |
# 初始化tokenizer | |
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir) | |
# 在这里初始化generator | |
self.sample_generator = self.item_generator() | |
def item_generator(self,) -> tuple: | |
''' | |
一条数据的生成器,防止大数据集OOM | |
''' | |
parquet_table = self.data | |
# 生成器是死循环,不用退出,训练结束(epoch结束)会停止调用next() | |
buffer_list = [] | |
while True: | |
for prompt, response in zip(parquet_table['prompt'], parquet_table['response']): | |
# 缓存数据不够,添加数据 | |
if len(buffer_list) < self.buffer_size: | |
buffer_list.append( (prompt.as_py(), response.as_py()) ) | |
continue | |
# 执行到这里,缓存区够了,打乱数据 | |
shuffle(buffer_list) | |
for p, r in buffer_list: | |
# 在这里迭代 | |
yield p, r | |
# 迭代完成,清空缓存区 | |
buffer_list = [] | |
def __getitem__(self, index): | |
''' | |
返回一条样本 | |
''' | |
if self.keep_in_memory: | |
data = self.data | |
prompt, response = data.iloc[index].prompt, data.iloc[index].response | |
else: | |
prompt, response = next(self.sample_generator) | |
max_seq_len = self.max_seq_len - 5 # len('[EOS]') = 5 | |
# add an eos token note that end of resopnse, using in generate. | |
return f"{prompt[0: max_seq_len]}[EOS]", f"{response[0: max_seq_len]}[EOS]" | |
def collate_fn(self, data: list[list]) -> dict: | |
''' | |
合并一个批次数据返回 | |
''' | |
tokenizer = self.tokenizer | |
prompt = tokenizer([item[0] for item in data], padding=True, return_token_type_ids=False) | |
response = tokenizer([item[1] for item in data], padding=True, return_token_type_ids=False) | |
input_ids = array(prompt.input_ids, dtype=int64) | |
input_mask = array(prompt.attention_mask, dtype=int64) | |
target_ids = array(response.input_ids, dtype=int64) | |
ret = { | |
'input_ids': LongTensor(input_ids), | |
'input_mask': LongTensor(input_mask), | |
'target_ids': LongTensor(target_ids), | |
} | |
return ret | |
def __len__(self) -> int: | |
return self.length | |
class ParquetDataset: | |
def __init__(self, | |
parquet_file: Union[str, dict], | |
tokenizer_dir: str, | |
keep_in_memory: bool=False, | |
cache_dir: str='./.cache', | |
buffer_size: int=10240, | |
max_len: int=512, | |
seed: int=23333 | |
) -> None: | |
''' | |
使用huggingface的loaddataset方法加载, | |
parquet_file: 单个文件,此时只能使用dataset['train'], | |
多个文件请用:parquet_file={'train': 'train.parquet', 'test': 'test.parquet', 'validation': 'validation.parquet'}) | |
其他用法见:https://huggingface.co/docs/datasets/loading | |
keep_in_memory: 是否将parquet文件转换为pandas.DataFrame格式存放到内存 | |
''' | |
self.keep_in_memory = keep_in_memory | |
self.len_dict = self.__get_all_parquet_file_size(parquet_file=parquet_file) | |
self.max_len = max_len | |
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir) | |
self.tokenizer = self.tokenizer | |
streaming = False if keep_in_memory else True | |
# streaming=True,否则大数据集OOM | |
dataset = load_dataset('parquet', data_files=parquet_file, cache_dir=cache_dir, streaming=streaming) | |
# 这里的batch_size不是训练的batch_size,是传递给precess_batch_func批处理的batch_size | |
dataset = dataset.map(self.precess_batch_func, batched=True, batch_size=buffer_size, \ | |
remove_columns=['prompt', 'response'], fn_kwargs={'max_len': max_len}) | |
dataset = dataset.with_format(type="torch") | |
if keep_in_memory: | |
dataset = dataset.shuffle(seed=seed, keep_in_memory=keep_in_memory) | |
else: | |
# 只能打乱缓冲区内的数据,不能打乱整个数据集,因此可以将缓存区设置稍微大一些 | |
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) | |
self.dataset = dataset | |
def precess_batch_func(item: dict, max_len: int=512) -> dict: | |
''' | |
添加EOS | |
''' | |
max_len -= 5 # len('[EOS]') = 5 | |
for i in range(len(item['prompt'])): | |
item['prompt'][i] = f"{item['prompt'][i][0: max_len]}[EOS]" | |
for i in range(len(item['response'])): | |
item['response'][i] = f"{item['response'][i][0: max_len]}[EOS]" | |
return { | |
'prompt': item['prompt'], | |
'response': item['response'], | |
} | |
def collate_fn(self, data: list[list]) -> dict: | |
''' | |
合并一个批次数据返回 | |
''' | |
tokenizer = self.tokenizer | |
prompt = [item['prompt'] for item in data ] | |
response = [item['response'] for item in data ] | |
# 按批次pad | |
prompt_encoded = tokenizer(prompt, padding=True, return_token_type_ids=False) | |
response_encoded = tokenizer(response, padding=True, return_token_type_ids=False) | |
input_ids = array(prompt_encoded.input_ids, dtype=int64) | |
input_mask = array(prompt_encoded.attention_mask, dtype=int64) | |
target_ids = array(response_encoded.input_ids, dtype=int64) | |
ret = { | |
'input_ids': LongTensor(input_ids), | |
'input_mask': LongTensor(input_mask), | |
'target_ids': LongTensor(target_ids), | |
} | |
return ret | |
def __getitem__(self, index: str) -> datasets.Dataset: | |
''' | |
魔术方法,实现下标访问,如:dataset['train']、dataset['validation']、dataset['test'] | |
''' | |
return self.dataset[index] | |
def __get_all_parquet_file_size(self, parquet_file: Union[str, dict]) -> dict: | |
''' | |
获取所有parquet file的长度 | |
''' | |
len_dict = dict() | |
if type(parquet_file) is str: | |
train_len = self.__get_size_of_praquet(parquet_file) | |
len_dict['train'] = train_len | |
if type(parquet_file) is dict: | |
for split_type, file in parquet_file.items(): | |
len_dict[split_type] = self.__get_size_of_praquet(file) | |
return len_dict | |
def __get_size_of_praquet(self, file_name: str) -> int: | |
''' | |
获取一个parquet文件的行数 | |
''' | |
parquet_data = pq.read_table(file_name) | |
return parquet_data.num_rows | |
def __len__(self) -> int: | |
''' | |
魔术方法,如果只有一个数据集,返回默认数据集大小 | |
''' | |
if len(self.len_dict) == 1: | |
return self.len_dict['train'] | |
else: | |
raise Exception("this dataset contains many splited datasets, use `get_dataset_size(split_name)` function to get length, e.g: get_dataset_size('train')") | |
def get_dataset_size(self, split_name: str) -> int: | |
''' | |
获取每个切分数据集的长度 | |
split_name可取:train、validation、test | |
''' | |
return self.len_dict[split_name] | |
def get_tokenizer(self, ) -> PreTrainedTokenizerFast: | |
return self.tokenizer | |
if __name__ == '__main__': | |
parquet_file = PROJECT_ROOT + '/data/my_valid_dataset.parquet' | |
tokenizer_dir = PROJECT_ROOT + '/model_save/tokenizer' | |
# example 1: | |
dataset = MyDataset(parquet_file, tokenizer_dir, keep_in_memory=False, max_seq_len=128) | |
print('\nexample 1, dataset size: ', len(dataset)) | |
dataloader = DataLoader(dataset, batch_size=32, collate_fn=dataset.collate_fn) | |
for epoch in range(2): | |
print('epoch: {}'.format(epoch)) | |
for step, batch in enumerate(dataloader): | |
x, x_mask, y = batch['input_ids'], batch['input_mask'], batch['target_ids'] | |
print('step:{}'.format(step), x.shape, x_mask.shape, y.shape) | |
if step == 5: | |
break | |
# exit(0) | |
# example 2: | |
dataset = ParquetDataset(parquet_file, tokenizer_dir, keep_in_memory=True, max_len=32) | |
dataloader = DataLoader(dataset['train'], batch_size=32, collate_fn=dataset.collate_fn) | |
print('\nexample 2, dataset size: ', dataset.get_dataset_size('train')) | |
for epoch in range(2): | |
print('epoch: {}'.format(epoch)) | |
for step, batch in enumerate(dataloader): | |
x, x_mask, y = batch['input_ids'], batch['input_mask'], batch['target_ids'] | |
print('step:{}'.format(step), x.shape, x_mask.shape, y.shape) | |
if step == 5: | |
break | |