Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
from functools import partial | |
import numpy as np | |
from datasets import DatasetDict | |
from mmengine.config import Config | |
from xtuner.dataset.utils import Packer, encode_fn | |
from xtuner.registry import BUILDER | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Verify the correctness of the config file for the ' | |
'custom dataset.') | |
parser.add_argument('config', help='config file name or path.') | |
args = parser.parse_args() | |
return args | |
def is_standard_format(dataset): | |
example = next(iter(dataset)) | |
if 'conversation' not in example: | |
return False | |
conversation = example['conversation'] | |
if not isinstance(conversation, list): | |
return False | |
for item in conversation: | |
if (not isinstance(item, dict)) or ('input' | |
not in item) or ('output' | |
not in item): | |
return False | |
input, output = item['input'], item['output'] | |
if (not isinstance(input, str)) or (not isinstance(output, str)): | |
return False | |
return True | |
def main(): | |
args = parse_args() | |
cfg = Config.fromfile(args.config) | |
tokenizer = BUILDER.build(cfg.tokenizer) | |
if cfg.get('framework', 'mmengine').lower() == 'huggingface': | |
train_dataset = cfg.train_dataset | |
else: | |
train_dataset = cfg.train_dataloader.dataset | |
dataset = train_dataset.dataset | |
max_length = train_dataset.max_length | |
dataset_map_fn = train_dataset.get('dataset_map_fn', None) | |
template_map_fn = train_dataset.get('template_map_fn', None) | |
max_dataset_length = train_dataset.get('max_dataset_length', 10) | |
split = train_dataset.get('split', 'train') | |
remove_unused_columns = train_dataset.get('remove_unused_columns', False) | |
rename_maps = train_dataset.get('rename_maps', []) | |
shuffle_before_pack = train_dataset.get('shuffle_before_pack', True) | |
pack_to_max_length = train_dataset.get('pack_to_max_length', True) | |
input_ids_with_output = train_dataset.get('input_ids_with_output', True) | |
if dataset.get('path', '') != 'json': | |
raise ValueError( | |
'You are using custom datasets for SFT. ' | |
'The custom datasets should be in json format. To load your JSON ' | |
'file, you can use the following code snippet: \n' | |
'"""\nfrom datasets import load_dataset \n' | |
'dataset = dict(type=load_dataset, path=\'json\', ' | |
'data_files=\'your_json_file.json\')\n"""\n' | |
'For more details, please refer to Step 5 in the ' | |
'`Using Custom Datasets` section of the documentation found at' | |
' docs/zh_cn/user_guides/single_turn_conversation.md.') | |
try: | |
dataset = BUILDER.build(dataset) | |
except RuntimeError: | |
raise RuntimeError( | |
'Unable to load the custom JSON file using ' | |
'`datasets.load_dataset`. Your data-related config is ' | |
f'{train_dataset}. Please refer to the official documentation on' | |
' `load_dataset` (https://huggingface.co/docs/datasets/loading) ' | |
'for more details.') | |
if isinstance(dataset, DatasetDict): | |
dataset = dataset[split] | |
if not is_standard_format(dataset) and dataset_map_fn is None: | |
raise ValueError( | |
'If the custom dataset is not in the XTuner-defined ' | |
'format, please utilize `dataset_map_fn` to map the original data' | |
' to the standard format. For more details, please refer to ' | |
'Step 1 and Step 5 in the `Using Custom Datasets` section of the ' | |
'documentation found at ' | |
'`docs/zh_cn/user_guides/single_turn_conversation.md`.') | |
if is_standard_format(dataset) and dataset_map_fn is not None: | |
raise ValueError( | |
'If the custom dataset is already in the XTuner-defined format, ' | |
'please set `dataset_map_fn` to None.' | |
'For more details, please refer to Step 1 and Step 5 in the ' | |
'`Using Custom Datasets` section of the documentation found at' | |
' docs/zh_cn/user_guides/single_turn_conversation.md.') | |
max_dataset_length = min(max_dataset_length, len(dataset)) | |
indices = np.random.choice(len(dataset), max_dataset_length, replace=False) | |
dataset = dataset.select(indices) | |
if dataset_map_fn is not None: | |
dataset = dataset.map(dataset_map_fn) | |
print('#' * 20 + ' dataset after `dataset_map_fn` ' + '#' * 20) | |
print(dataset[0]['conversation']) | |
if template_map_fn is not None: | |
template_map_fn = BUILDER.build(template_map_fn) | |
dataset = dataset.map(template_map_fn) | |
print('#' * 20 + ' dataset after adding templates ' + '#' * 20) | |
print(dataset[0]['conversation']) | |
for old, new in rename_maps: | |
dataset = dataset.rename_column(old, new) | |
if pack_to_max_length and (not remove_unused_columns): | |
raise ValueError('We have to remove unused columns if ' | |
'`pack_to_max_length` is set to True.') | |
dataset = dataset.map( | |
partial( | |
encode_fn, | |
tokenizer=tokenizer, | |
max_length=max_length, | |
input_ids_with_output=input_ids_with_output), | |
remove_columns=list(dataset.column_names) | |
if remove_unused_columns else None) | |
print('#' * 20 + ' encoded input_ids ' + '#' * 20) | |
print(dataset[0]['input_ids']) | |
print('#' * 20 + ' encoded labels ' + '#' * 20) | |
print(dataset[0]['labels']) | |
if pack_to_max_length and split == 'train': | |
if shuffle_before_pack: | |
dataset = dataset.shuffle() | |
dataset = dataset.flatten_indices() | |
dataset = dataset.map(Packer(max_length), batched=True) | |
print('#' * 20 + ' input_ids after packed to max_length ' + | |
'#' * 20) | |
print(dataset[0]['input_ids']) | |
print('#' * 20 + ' labels after packed to max_length ' + '#' * 20) | |
print(dataset[0]['labels']) | |
if __name__ == '__main__': | |
main() | |