|
|
|
|
|
import os |
|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from swift.utils import get_logger, use_hf_hub |
|
|
from .preprocessor import DATASET_TYPE, AutoPreprocessor, MessagesPreprocessor |
|
|
|
|
|
PreprocessFunc = Callable[..., DATASET_TYPE] |
|
|
LoadFunction = Callable[..., DATASET_TYPE] |
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SubsetDataset: |
|
|
|
|
|
name: Optional[str] = None |
|
|
|
|
|
subset: str = 'default' |
|
|
|
|
|
|
|
|
split: Optional[List[str]] = None |
|
|
preprocess_func: Optional[PreprocessFunc] = None |
|
|
|
|
|
|
|
|
is_weak_subset: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.name is None: |
|
|
self.name = self.subset |
|
|
|
|
|
def set_default(self, dataset_meta: 'DatasetMeta') -> 'SubsetDataset': |
|
|
subset_dataset = deepcopy(self) |
|
|
for k in ['split', 'preprocess_func']: |
|
|
v = getattr(subset_dataset, k) |
|
|
if v is None: |
|
|
setattr(subset_dataset, k, deepcopy(getattr(dataset_meta, k))) |
|
|
return subset_dataset |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DatasetMeta: |
|
|
ms_dataset_id: Optional[str] = None |
|
|
hf_dataset_id: Optional[str] = None |
|
|
dataset_path: Optional[str] = None |
|
|
dataset_name: Optional[str] = None |
|
|
ms_revision: Optional[str] = None |
|
|
hf_revision: Optional[str] = None |
|
|
|
|
|
subsets: List[Union[SubsetDataset, str]] = field(default_factory=lambda: ['default']) |
|
|
|
|
|
split: List[str] = field(default_factory=lambda: ['train']) |
|
|
|
|
|
preprocess_func: PreprocessFunc = field(default_factory=lambda: AutoPreprocessor()) |
|
|
load_function: Optional[LoadFunction] = None |
|
|
|
|
|
tags: List[str] = field(default_factory=list) |
|
|
help: Optional[str] = None |
|
|
huge_dataset: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
from .loader import DatasetLoader |
|
|
if self.load_function is None: |
|
|
self.load_function = DatasetLoader.load |
|
|
for i, subset in enumerate(self.subsets): |
|
|
if isinstance(subset, str): |
|
|
self.subsets[i] = SubsetDataset(subset=subset) |
|
|
|
|
|
|
|
|
DATASET_MAPPING: Dict[Tuple[str, str, str], DatasetMeta] = {} |
|
|
|
|
|
|
|
|
def get_dataset_list(): |
|
|
datasets = [] |
|
|
for key in DATASET_MAPPING: |
|
|
if use_hf_hub(): |
|
|
if key[1]: |
|
|
datasets.append(key[1]) |
|
|
else: |
|
|
if key[0]: |
|
|
datasets.append(key[0]) |
|
|
return datasets |
|
|
|
|
|
|
|
|
def register_dataset(dataset_meta: DatasetMeta, *, exist_ok: bool = False) -> None: |
|
|
"""Register dataset |
|
|
|
|
|
Args: |
|
|
dataset_meta: The `DatasetMeta` info of the dataset. |
|
|
exist_ok: If the dataset id exists, raise error or update it. |
|
|
""" |
|
|
if dataset_meta.dataset_name: |
|
|
dataset_name = dataset_meta.dataset_name |
|
|
else: |
|
|
dataset_name = dataset_meta.ms_dataset_id, dataset_meta.hf_dataset_id, dataset_meta.dataset_path |
|
|
if not exist_ok and dataset_name in DATASET_MAPPING: |
|
|
raise ValueError(f'The `{dataset_name}` has already been registered in the DATASET_MAPPING.') |
|
|
|
|
|
DATASET_MAPPING[dataset_name] = dataset_meta |
|
|
|
|
|
|
|
|
def _preprocess_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> Dict[str, Any]: |
|
|
d_info = deepcopy(d_info) |
|
|
|
|
|
columns = None |
|
|
if 'columns' in d_info: |
|
|
columns = d_info.pop('columns') |
|
|
|
|
|
if 'messages' in d_info: |
|
|
d_info['preprocess_func'] = MessagesPreprocessor(**d_info.pop('messages'), columns=columns) |
|
|
else: |
|
|
d_info['preprocess_func'] = AutoPreprocessor(columns=columns) |
|
|
|
|
|
if 'dataset_path' in d_info: |
|
|
dataset_path = d_info.pop('dataset_path') |
|
|
if base_dir is not None and not os.path.isabs(dataset_path): |
|
|
dataset_path = os.path.join(base_dir, dataset_path) |
|
|
dataset_path = os.path.abspath(os.path.expanduser(dataset_path)) |
|
|
|
|
|
d_info['dataset_path'] = dataset_path |
|
|
|
|
|
if 'subsets' in d_info: |
|
|
subsets = d_info.pop('subsets') |
|
|
for i, subset in enumerate(subsets): |
|
|
if isinstance(subset, dict): |
|
|
subsets[i] = SubsetDataset(**_preprocess_d_info(subset)) |
|
|
d_info['subsets'] = subsets |
|
|
return d_info |
|
|
|
|
|
|
|
|
def _register_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> DatasetMeta: |
|
|
"""Register a single dataset to dataset mapping |
|
|
|
|
|
Args: |
|
|
d_info: The dataset info |
|
|
""" |
|
|
d_info = _preprocess_d_info(d_info, base_dir=base_dir) |
|
|
dataset_meta = DatasetMeta(**d_info) |
|
|
register_dataset(dataset_meta) |
|
|
return dataset_meta |
|
|
|
|
|
|
|
|
def register_dataset_info(dataset_info: Union[str, List[str], None] = None) -> List[DatasetMeta]: |
|
|
"""Register dataset from the `dataset_info.json` or a custom dataset info file |
|
|
This is used to deal with the datasets defined in the json info file. |
|
|
|
|
|
Args: |
|
|
dataset_info: The dataset info path |
|
|
""" |
|
|
|
|
|
if dataset_info is None: |
|
|
dataset_info = os.path.join(os.path.dirname(__file__), 'data', 'dataset_info.json') |
|
|
assert isinstance(dataset_info, (str, list)) |
|
|
base_dir = None |
|
|
log_msg = None |
|
|
if isinstance(dataset_info, str): |
|
|
dataset_path = os.path.abspath(os.path.expanduser(dataset_info)) |
|
|
if os.path.isfile(dataset_path): |
|
|
log_msg = dataset_path |
|
|
base_dir = os.path.dirname(dataset_path) |
|
|
with open(dataset_path, 'r', encoding='utf-8') as f: |
|
|
dataset_info = json.load(f) |
|
|
else: |
|
|
dataset_info = json.loads(dataset_info) |
|
|
if len(dataset_info) == 0: |
|
|
return [] |
|
|
res = [] |
|
|
for d_info in dataset_info: |
|
|
res.append(_register_d_info(d_info, base_dir=base_dir)) |
|
|
|
|
|
if log_msg is None: |
|
|
log_msg = dataset_info if len(dataset_info) < 5 else list(dataset_info.keys()) |
|
|
logger.info(f'Successfully registered `{log_msg}`.') |
|
|
return res |
|
|
|