# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
""" DatasetInfo and MetricInfo record information we know about a dataset and a metric.
This includes things that we know about the dataset statically, i.e.:
- description
- canonical location
- does it have validation and tests splits
- size
- etc.
This also includes the things that can and should be computed once we've
processed the dataset as well:
- number of examples (in each split)
- etc.
"""
import copy
import dataclasses
import json
import os
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union
from .features import Features, Value
from .splits import SplitDict
from .utils import Version
from .utils.logging import get_logger
logger = get_logger(__name__)
# Name of the file to output the DatasetInfo p rotobuf object.
DATASET_INFO_FILENAME = "dataset_info.json"
DATASET_INFOS_DICT_FILE_NAME = "dataset_infos.json"
LICENSE_FILENAME = "LICENSE"
METRIC_INFO_FILENAME = "metric_info.json"
@dataclass
class SupervisedKeysData:
input: str = ""
output: str = ""
@dataclass
class DownloadChecksumsEntryData:
key: str = ""
value: str = ""
class MissingCachedSizesConfigError(Exception):
"""The expected cached sizes of the download file are missing."""
class NonMatchingCachedSizesError(Exception):
"""The prepared split doesn't have expected sizes."""
@dataclass
class PostProcessedInfo:
features: Optional[Features] = None
resources_checksums: Optional[dict] = None
def __post_init__(self):
# Convert back to the correct classes when we reload from dict
if self.features is not None and not isinstance(self.features, Features):
self.features = Features.from_dict(self.features)
@classmethod
def from_dict(cls, post_processed_info_dict: dict) -> "PostProcessedInfo":
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in post_processed_info_dict.items() if k in field_names})
[docs]@dataclass
class DatasetInfo:
"""Information about a dataset.
`DatasetInfo` documents datasets, including its name, version, and features.
See the constructor arguments and properties for a full list.
Note: Not all fields are known on construction and may be updated later.
"""
# Set in the dataset scripts
description: str = field(default_factory=str)
citation: str = field(default_factory=str)
homepage: str = field(default_factory=str)
license: str = field(default_factory=str)
features: Optional[Features] = None
post_processed: Optional[PostProcessedInfo] = None
supervised_keys: Optional[SupervisedKeysData] = None
# Set later by the builder
builder_name: Optional[str] = None
config_name: Optional[str] = None
version: Optional[Union[str, Version]] = None
# Set later by `download_and_prepare`
splits: Optional[dict] = None
download_checksums: Optional[dict] = None
download_size: Optional[int] = None
post_processing_size: Optional[int] = None
dataset_size: Optional[int] = None
size_in_bytes: Optional[int] = None
def __post_init__(self):
# Convert back to the correct classes when we reload from dict
if self.features is not None and not isinstance(self.features, Features):
self.features = Features.from_dict(self.features)
if self.post_processed is not None and not isinstance(self.post_processed, PostProcessedInfo):
self.post_processed = PostProcessedInfo.from_dict(self.post_processed)
if self.version is not None and not isinstance(self.version, Version):
if isinstance(self.version, str):
self.version = Version(self.version)
else:
self.version = Version.from_dict(self.version)
if self.splits is not None and not isinstance(self.splits, SplitDict):
self.splits = SplitDict.from_split_dict(self.splits)
if self.supervised_keys is not None and not isinstance(self.supervised_keys, SupervisedKeysData):
if isinstance(self.supervised_keys, (tuple, list)):
self.supervised_keys = SupervisedKeysData(*self.supervised_keys)
else:
self.supervised_keys = SupervisedKeysData(**self.supervised_keys)
def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, LICENSE_FILENAME)
[docs] def write_to_directory(self, dataset_info_dir):
"""Write `DatasetInfo` as JSON to `dataset_info_dir`.
Also save the license separately in LICENCE.
"""
with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "wb") as f:
self._dump_info(f)
with open(os.path.join(dataset_info_dir, LICENSE_FILENAME), "wb") as f:
self._dump_license(f)
def _dump_info(self, file):
"""Dump info in `file` file-like object open in bytes mode (to support remote files)"""
file.write(json.dumps(asdict(self)).encode("utf-8"))
def _dump_license(self, file):
"""Dump license in `file` file-like object open in bytes mode (to support remote files)"""
file.write(self.license.encode("utf-8"))
@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None]
description = "\n\n".join([info.description for info in dataset_infos])
citation = "\n\n".join([info.citation for info in dataset_infos])
homepage = "\n\n".join([info.homepage for info in dataset_infos])
license = "\n\n".join([info.license for info in dataset_infos])
features = None
supervised_keys = None
return cls(
description=description,
citation=citation,
homepage=homepage,
license=license,
features=features,
supervised_keys=supervised_keys,
)
[docs] @classmethod
def from_directory(cls, dataset_info_dir: dict) -> "DatasetInfo":
"""Create DatasetInfo from the JSON file in `dataset_info_dir`.
This function updates all the dynamically generated fields (num_examples,
hash, time of creation,...) of the DatasetInfo.
This will overwrite all previous metadata.
Args:
dataset_info_dir: `str` The directory containing the metadata file. This
should be the root directory of a specific dataset version.
"""
logger.info("Loading Dataset info from %s", dataset_info_dir)
if not dataset_info_dir:
raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.")
with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "r", encoding="utf-8") as f:
dataset_info_dict = json.load(f)
return cls.from_dict(dataset_info_dict)
@classmethod
def from_dict(cls, dataset_info_dict: dict) -> "DatasetInfo":
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in dataset_info_dict.items() if k in field_names})
def update(self, other_dataset_info: "DatasetInfo", ignore_none=True):
self_dict = self.__dict__
self_dict.update(
**{
k: copy.deepcopy(v)
for k, v in other_dataset_info.__dict__.items()
if (v is not None or not ignore_none)
}
)
def copy(self) -> "DatasetInfo":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
class DatasetInfosDict(dict):
def write_to_directory(self, dataset_infos_dir, overwrite=False):
total_dataset_infos = {}
dataset_infos_path = os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME)
if os.path.exists(dataset_infos_path) and not overwrite:
logger.info("Dataset Infos already exists in {}. Completing it with new infos.".format(dataset_infos_dir))
total_dataset_infos = self.from_directory(dataset_infos_dir)
else:
logger.info("Writing new Dataset Infos in {}".format(dataset_infos_dir))
total_dataset_infos.update(self)
with open(dataset_infos_path, "w", encoding="utf-8") as f:
json.dump({config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()}, f)
@classmethod
def from_directory(cls, dataset_infos_dir):
logger.info("Loading Dataset Infos from {}".format(dataset_infos_dir))
with open(os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME), "r", encoding="utf-8") as f:
dataset_infos_dict = {
config_name: DatasetInfo.from_dict(dataset_info_dict)
for config_name, dataset_info_dict in json.load(f).items()
}
return cls(**dataset_infos_dict)
[docs]@dataclass
class MetricInfo:
"""Information about a metric.
`MetricInfo` documents a metric, including its name, version, and features.
See the constructor arguments and properties for a full list.
Note: Not all fields are known on construction and may be updated later.
"""
# Set in the dataset scripts
description: str
citation: str
features: Features
inputs_description: str = field(default_factory=str)
homepage: str = field(default_factory=str)
license: str = field(default_factory=str)
codebase_urls: List[str] = field(default_factory=list)
reference_urls: List[str] = field(default_factory=list)
streamable: bool = False
format: Optional[str] = None
# Set later by the builder
metric_name: Optional[str] = None
config_name: Optional[str] = None
experiment_id: Optional[str] = None
def __post_init__(self):
assert "predictions" in self.features, "Need to have at least a 'predictions' field in 'features'."
if self.format is not None:
for key, value in self.features.items():
if not isinstance(value, Value):
raise ValueError(
f"When using 'numpy' format, all features should be a `datasets.Value` feature. "
f"Here {key} is an instance of {value.__class__.__name__}"
)
[docs] def write_to_directory(self, metric_info_dir):
"""Write `MetricInfo` as JSON to `metric_info_dir`.
Also save the license separately in LICENCE.
"""
with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "w", encoding="utf-8") as f:
json.dump(asdict(self), f)
with open(os.path.join(metric_info_dir, LICENSE_FILENAME), "w", encoding="utf-8") as f:
f.write(self.license)
[docs] @classmethod
def from_directory(cls, metric_info_dir) -> "MetricInfo":
"""Create MetricInfo from the JSON file in `metric_info_dir`.
Args:
metric_info_dir: `str` The directory containing the metadata file. This
should be the root directory of a specific dataset version.
"""
logger.info("Loading Metric info from %s", metric_info_dir)
if not metric_info_dir:
raise ValueError("Calling MetricInfo.from_directory() with undefined metric_info_dir.")
with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "r", encoding="utf-8") as f:
metric_info_dict = json.load(f)
return cls.from_dict(metric_info_dict)
@classmethod
def from_dict(cls, metric_info_dict: dict) -> "MetricInfo":
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in metric_info_dict.items() if k in field_names})