import os import datasets from .artifact import Artifact, UnitxtArtifactNotFoundError from .artifact import __file__ as _ from .artifact import fetch_artifact from .blocks import __file__ as _ from .card import __file__ as _ from .catalog import __file__ as _ from .collections import __file__ as _ from .common import __file__ as _ from .dataclass import __file__ as _ from .dict_utils import __file__ as _ from .file_utils import __file__ as _ from .formats import __file__ as _ from .fusion import __file__ as _ from .generator_utils import __file__ as _ from .hf_utils import __file__ as _ from .instructions import __file__ as _ from .load import __file__ as _ from .loaders import __file__ as _ from .metric import __file__ as _ from .metrics import __file__ as _ from .normalizers import __file__ as _ from .operator import __file__ as _ from .operators import __file__ as _ from .processors import __file__ as _ from .random_utils import __file__ as _ from .recipe import __file__ as _ from .register import __file__ as _ from .register import _reset_env_local_catalogs, register_all_artifacts from .renderers import __file__ as _ from .schema import __file__ as _ from .split_utils import __file__ as _ from .splitters import __file__ as _ from .standard import __file__ as _ from .stream import __file__ as _ from .task import __file__ as _ from .templates import __file__ as _ from .text_utils import __file__ as _ from .type_utils import __file__ as _ from .utils import __file__ as _ from .validate import __file__ as _ from .version import __file__ as _ from .version import version __default_recipe__ = "common_recipe" def fetch(artifact_name): try: artifact, _ = fetch_artifact(artifact_name) return artifact except UnitxtArtifactNotFoundError: return None def parse(query: str): """ Parses a query of the form 'key1=value1,key2=value2,...' into a dictionary. """ result = {} kvs = query.split(",") if len(kvs) == 0: raise ValueError( 'Illegal query: "{query}" should contain at least one assignment of the form: key1=value1,key2=value2' ) for kv in kvs: key_val = kv.split("=") if len(key_val) != 2 or len(key_val[0].strip()) == 0 or len(key_val[1].strip()) == 0: raise ValueError('Illegal query: "{query}" with wrong assignment "{kv}" should be of the form: key=value.') key, val = key_val if val.isdigit(): result[key] = int(val) elif val.replace(".", "", 1).isdigit(): result[key] = float(val) else: result[key] = val return result def get_dataset_artifact(dataset_str): _reset_env_local_catalogs() register_all_artifacts() recipe = fetch(dataset_str) if recipe is None: args = parse(dataset_str) if "type" not in args: args["type"] = os.environ.get("UNITXT_DEFAULT_RECIPE", __default_recipe__) recipe = Artifact.from_dict(args) return recipe class Dataset(datasets.GeneratorBasedBuilder): """TODO: Short description of my dataset.""" VERSION = datasets.Version(version) builder_configs = {} @property def generators(self): if not hasattr(self, "_generators") or self._generators is None: try: from unitxt.dataset import ( get_dataset_artifact as get_dataset_artifact_installed, ) unitxt_installed = True except ImportError: unitxt_installed = False if unitxt_installed: print("Loading with installed unitxt library...") dataset = get_dataset_artifact_installed(self.config.name) else: print("Loading with installed unitxt library...") dataset = get_dataset_artifact(self.config.name) self._generators = dataset() return self._generators def _info(self): return datasets.DatasetInfo() def _split_generators(self, _): return [datasets.SplitGenerator(name=name, gen_kwargs={"split_name": name}) for name in self.generators.keys()] def _generate_examples(self, split_name): generator = self.generators[split_name] for i, row in enumerate(generator): yield i, row def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_kwargs): result = super()._download_and_prepare(dl_manager, "no_checks", **prepare_splits_kwargs) return result