# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from ..extras.constants import DATA_CONFIG from ..extras.misc import use_modelscope if TYPE_CHECKING: from ..hparams import DataArguments @dataclass class DatasetAttr: r""" Dataset attributes. """ """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str formatting: Literal["alpaca", "sharegpt"] = "alpaca" ranking: bool = False """ extra configs """ subset: Optional[str] = None folder: Optional[str] = None num_samples: Optional[int] = None """ common columns """ system: Optional[str] = None tools: Optional[str] = None images: Optional[str] = None """ rlhf columns """ chosen: Optional[str] = None rejected: Optional[str] = None kto_tag: Optional[str] = None """ alpaca columns """ prompt: Optional[str] = "instruction" query: Optional[str] = "input" response: Optional[str] = "output" history: Optional[str] = None """ sharegpt columns """ messages: Optional[str] = "conversations" """ sharegpt tags """ role_tag: Optional[str] = "from" content_tag: Optional[str] = "value" user_tag: Optional[str] = "human" assistant_tag: Optional[str] = "gpt" observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" def __repr__(self) -> str: return self.dataset_name def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: setattr(self, key, obj.get(key, default)) def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: if data_args.dataset is not None: dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] else: dataset_names = [] if data_args.dataset_dir == "ONLINE": dataset_info = None else: try: with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: dataset_info = json.load(f) except Exception as err: if len(dataset_names) != 0: raise ValueError( "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) ) dataset_info = None if data_args.interleave_probs is not None: data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] dataset_list: List[DatasetAttr] = [] for name in dataset_names: if dataset_info is None: load_from = "ms_hub" if use_modelscope() else "hf_hub" dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_list.append(dataset_attr) continue if name not in dataset_info: raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) has_hf_url = "hf_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name] if has_hf_url or has_ms_url: if (use_modelscope() and has_ms_url) or (not has_hf_url): dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) else: dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) else: column_names.extend(["messages"]) for column_name in column_names: dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: tag_names = ( "role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag", ) for tag in tag_names: dataset_attr.set_attr(tag, dataset_info[name]["tags"]) dataset_list.append(dataset_attr) return dataset_list