Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
raw
history blame
6.26 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import json
from swift.utils import get_logger, use_hf_hub
from .preprocessor import DATASET_TYPE, AutoPreprocessor, MessagesPreprocessor
PreprocessFunc = Callable[..., DATASET_TYPE]
LoadFunction = Callable[..., DATASET_TYPE]
logger = get_logger()
@dataclass
class SubsetDataset:
# `Name` is used for matching subsets of the dataset, and `subset` refers to the subset_name on the hub.
name: Optional[str] = None
# If set to None, then subset is set to subset_name.
subset: str = 'default'
# Higher priority. If set to None, the attributes of the DatasetMeta will be used.
split: Optional[List[str]] = None
preprocess_func: Optional[PreprocessFunc] = None
# If the dataset specifies "all," weak subsets will be skipped.
is_weak_subset: bool = False
def __post_init__(self):
if self.name is None:
self.name = self.subset
def set_default(self, dataset_meta: 'DatasetMeta') -> 'SubsetDataset':
subset_dataset = deepcopy(self)
for k in ['split', 'preprocess_func']:
v = getattr(subset_dataset, k)
if v is None:
setattr(subset_dataset, k, deepcopy(getattr(dataset_meta, k)))
return subset_dataset
@dataclass
class DatasetMeta:
ms_dataset_id: Optional[str] = None
hf_dataset_id: Optional[str] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
ms_revision: Optional[str] = None
hf_revision: Optional[str] = None
subsets: List[Union[SubsetDataset, str]] = field(default_factory=lambda: ['default'])
# Applicable to all subsets.
split: List[str] = field(default_factory=lambda: ['train'])
# First perform column mapping, then proceed with the preprocess_func.
preprocess_func: PreprocessFunc = field(default_factory=lambda: AutoPreprocessor())
load_function: Optional[LoadFunction] = None
tags: List[str] = field(default_factory=list)
help: Optional[str] = None
huge_dataset: bool = False
def __post_init__(self):
from .loader import DatasetLoader
if self.load_function is None:
self.load_function = DatasetLoader.load
for i, subset in enumerate(self.subsets):
if isinstance(subset, str):
self.subsets[i] = SubsetDataset(subset=subset)
DATASET_MAPPING: Dict[Tuple[str, str, str], DatasetMeta] = {}
def get_dataset_list():
datasets = []
for key in DATASET_MAPPING:
if use_hf_hub():
if key[1]:
datasets.append(key[1])
else:
if key[0]:
datasets.append(key[0])
return datasets
def register_dataset(dataset_meta: DatasetMeta, *, exist_ok: bool = False) -> None:
"""Register dataset
Args:
dataset_meta: The `DatasetMeta` info of the dataset.
exist_ok: If the dataset id exists, raise error or update it.
"""
if dataset_meta.dataset_name:
dataset_name = dataset_meta.dataset_name
else:
dataset_name = dataset_meta.ms_dataset_id, dataset_meta.hf_dataset_id, dataset_meta.dataset_path
if not exist_ok and dataset_name in DATASET_MAPPING:
raise ValueError(f'The `{dataset_name}` has already been registered in the DATASET_MAPPING.')
DATASET_MAPPING[dataset_name] = dataset_meta
def _preprocess_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> Dict[str, Any]:
d_info = deepcopy(d_info)
columns = None
if 'columns' in d_info:
columns = d_info.pop('columns')
if 'messages' in d_info:
d_info['preprocess_func'] = MessagesPreprocessor(**d_info.pop('messages'), columns=columns)
else:
d_info['preprocess_func'] = AutoPreprocessor(columns=columns)
if 'dataset_path' in d_info:
dataset_path = d_info.pop('dataset_path')
if base_dir is not None and not os.path.isabs(dataset_path):
dataset_path = os.path.join(base_dir, dataset_path)
dataset_path = os.path.abspath(os.path.expanduser(dataset_path))
d_info['dataset_path'] = dataset_path
if 'subsets' in d_info:
subsets = d_info.pop('subsets')
for i, subset in enumerate(subsets):
if isinstance(subset, dict):
subsets[i] = SubsetDataset(**_preprocess_d_info(subset))
d_info['subsets'] = subsets
return d_info
def _register_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> DatasetMeta:
"""Register a single dataset to dataset mapping
Args:
d_info: The dataset info
"""
d_info = _preprocess_d_info(d_info, base_dir=base_dir)
dataset_meta = DatasetMeta(**d_info)
register_dataset(dataset_meta)
return dataset_meta
def register_dataset_info(dataset_info: Union[str, List[str], None] = None) -> List[DatasetMeta]:
"""Register dataset from the `dataset_info.json` or a custom dataset info file
This is used to deal with the datasets defined in the json info file.
Args:
dataset_info: The dataset info path
"""
# dataset_info_path: path, json or None
if dataset_info is None:
dataset_info = os.path.join(os.path.dirname(__file__), 'data', 'dataset_info.json')
assert isinstance(dataset_info, (str, list))
base_dir = None
log_msg = None
if isinstance(dataset_info, str):
dataset_path = os.path.abspath(os.path.expanduser(dataset_info))
if os.path.isfile(dataset_path):
log_msg = dataset_path
base_dir = os.path.dirname(dataset_path)
with open(dataset_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
else:
dataset_info = json.loads(dataset_info) # json
if len(dataset_info) == 0:
return []
res = []
for d_info in dataset_info:
res.append(_register_d_info(d_info, base_dir=base_dir))
if log_msg is None:
log_msg = dataset_info if len(dataset_info) < 5 else list(dataset_info.keys())
logger.info(f'Successfully registered `{log_msg}`.')
return res