# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
""" Metrics base class."""
import os
import types
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pyarrow as pa
from . import config
from .arrow_dataset import Dataset
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter
from .features import Features
from .info import DatasetInfo, MetricInfo
from .naming import camelcase_to_snakecase
from .utils import copyfunc, temp_seed
from .utils.download_manager import DownloadManager
from .utils.file_utils import DownloadConfig
from .utils.filelock import BaseFileLock, FileLock, Timeout
from .utils.logging import get_logger
logger = get_logger(__name__)
class FileFreeLock(BaseFileLock):
"""Thread lock until a file **cannot** be locked"""
def __init__(self, lock_file, *args, **kwargs):
self.filelock = FileLock(lock_file)
super().__init__(lock_file, *args, **kwargs)
def _acquire(self):
try:
self.filelock.acquire(timeout=0.01, poll_intervall=0.02) # Try to lock once
except Timeout:
# We couldn't acquire the lock, the file is locked!
self._lock_file_fd = self.filelock.lock_file
else:
# We were able to acquire the lock, the file is not yet locked!
self.filelock.release()
self._lock_file_fd = None
def _release(self):
self._lock_file_fd = None
class MetricInfoMixin:
"""This base class exposes some attributes of MetricInfo
at the base level of the Metric for easy access.
"""
def __init__(self, info: MetricInfo):
self._metric_info = info
@property
def info(self):
""":class:`datasets.MetricInfo` object containing all the metadata in the metric."""
return self._metric_info
@property
def name(self) -> str:
return self._metric_info.metric_name
@property
def experiment_id(self) -> Optional[str]:
return self._metric_info.experiment_id
@property
def description(self) -> str:
return self._metric_info.description
@property
def citation(self) -> str:
return self._metric_info.citation
@property
def features(self) -> Features:
return self._metric_info.features
@property
def inputs_description(self) -> str:
return self._metric_info.inputs_description
@property
def homepage(self) -> Optional[str]:
return self._metric_info.homepage
@property
def license(self) -> str:
return self._metric_info.license
@property
def codebase_urls(self) -> Optional[List[str]]:
return self._metric_info.codebase_urls
@property
def reference_urls(self) -> Optional[List[str]]:
return self._metric_info.reference_urls
@property
def streamable(self) -> bool:
return self._metric_info.streamable
@property
def format(self) -> Optional[str]:
return self._metric_info.format
[docs]class Metric(MetricInfoMixin):
"""A Metric is the base class and common API for all metrics.
Args:
config_name (``str``): This is used to define a hash specific to a metrics computation script and prevents the metric's data
to be overridden when the metric loading script is modified.
keep_in_memory (``bool``): keep all predictions and references in memory. Not possible in distributed settings.
cache_dir (``str``): Path to a directory in which temporary prediction/references data will be stored.
The data directory should be located on a shared file-system in distributed setups.
num_process (``int``): specify the total number of nodes in a distributed settings.
This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
process_id (``int``): specify the id of the current process in a distributed setup (between 0 and num_process-1)
This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
seed (Optional ``int``): If specified, this will temporarily set numpy's random seed when :func:`datasets.Metric.compute` is run.
experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system.
This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
max_concurrent_cache_files (``int``): Max number of concurrent metrics cache files (default 10000).
timeout (``Union[int, float]``): Timeout in second for distributed setting synchronization.
"""
def __init__(
self,
config_name: Optional[str] = None,
keep_in_memory: bool = False,
cache_dir: Optional[str] = None,
num_process: int = 1,
process_id: int = 0,
seed: Optional[int] = None,
experiment_id: Optional[str] = None,
max_concurrent_cache_files: int = 10000,
timeout: Union[int, float] = 100,
**kwargs,
):
# prepare info
self.config_name = config_name or "default"
info = self._info()
info.metric_name = camelcase_to_snakecase(self.__class__.__name__)
info.config_name = self.config_name
info.experiment_id = experiment_id or "default_experiment"
MetricInfoMixin.__init__(self, info) # For easy access on low level
# Safety checks on num_process and process_id
assert isinstance(process_id, int) and process_id >= 0, "'process_id' should be a number greater than 0"
assert (
isinstance(num_process, int) and num_process > process_id
), "'num_process' should be a number greater than process_id"
assert (
num_process == 1 or not keep_in_memory
), "Using 'keep_in_memory' is not possible in distributed setting (num_process > 1)."
self.num_process = num_process
self.process_id = process_id
self.max_concurrent_cache_files = max_concurrent_cache_files
self.keep_in_memory = keep_in_memory
self._data_dir_root = os.path.expanduser(cache_dir or config.HF_METRICS_CACHE)
self.data_dir = self._build_data_dir()
self.seed: int = seed or np.random.get_state()[1][0]
self.timeout: Union[int, float] = timeout
# Update 'compute' and 'add' docstring
# methods need to be copied otherwise it changes the docstrings of every instance
self.compute = types.MethodType(copyfunc(self.compute), self)
self.add_batch = types.MethodType(copyfunc(self.add_batch), self)
self.add = types.MethodType(copyfunc(self.add), self)
self.compute.__func__.__doc__ += self.info.inputs_description
self.add_batch.__func__.__doc__ += self.info.inputs_description
self.add.__func__.__doc__ += self.info.inputs_description
# self.arrow_schema = pa.schema(field for field in self.info.features.type)
self.buf_writer = None
self.writer = None
self.writer_batch_size = None
self.data = None
# This is the cache file we store our predictions/references in
# Keep it None for now so we can (cloud)pickle the object
self.cache_file_name = None
self.filelock = None
self.rendez_vous_lock = None
# This is all the cache files on which we have a lock when we are in a distributed setting
self.file_paths = None
self.filelocks = None
def __len__(self):
"""Return the number of examples (predictions or predictions/references pair)
currently stored in the metric's cache.
"""
return 0 if self.writer is None else len(self.writer)
def __repr__(self):
return (
f'Metric(name: "{self.name}", features: {self.features}, '
f'usage: """{self.inputs_description}""", '
f"stored examples: {len(self)})"
)
def _build_data_dir(self):
"""Path of this metric in cache_dir:
Will be:
self._data_dir_root/self.name/self.config_name/self.hash (if not none)/
If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
"""
builder_data_dir = self._data_dir_root
builder_data_dir = os.path.join(builder_data_dir, self.name, self.config_name)
os.makedirs(builder_data_dir, exist_ok=True)
return builder_data_dir
def _create_cache_file(self, timeout=1) -> Tuple[str, FileLock]:
"""Create a new cache file. If the default cache file is used, we generated a new hash."""
file_path = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{self.process_id}.arrow")
filelock = None
for i in range(self.max_concurrent_cache_files):
filelock = FileLock(file_path + ".lock")
try:
filelock.acquire(timeout=timeout)
except Timeout:
# If we have reached the max number of attempts or we are not allow to find a free name (distributed setup)
# We raise an error
if self.num_process != 1:
raise ValueError(
f"Error in _create_cache_file: another metric instance is already using the local cache file at {file_path}. "
f"Please specify an experiment_id (currently: {self.experiment_id}) to avoid collision "
f"between distributed metric instances."
) from None
if i == self.max_concurrent_cache_files - 1:
raise ValueError(
f"Cannot acquire lock, too many metric instance are operating concurrently on this file system."
f"You should set a larger value of max_concurrent_cache_files when creating the metric "
f"(current value is {self.max_concurrent_cache_files})."
) from None
# In other cases (allow to find new file name + not yet at max num of attempts) we can try to sample a new hashing name.
file_uuid = str(uuid.uuid4())
file_path = os.path.join(
self.data_dir, f"{self.experiment_id}-{file_uuid}-{self.num_process}-{self.process_id}.arrow"
)
else:
break
return file_path, filelock
def _get_all_cache_files(self) -> Tuple[List[str], List[FileLock]]:
"""Get a lock on all the cache files in a distributed setup.
We wait for timeout second to let all the distributed node finish their tasks (default is 100 seconds).
"""
if self.num_process == 1:
if self.cache_file_name is None:
raise ValueError(
"Metric cache file doesn't exist. Please make sure that you call `add` or `add_batch` "
"at least once before calling `compute`."
)
file_paths = [self.cache_file_name]
else:
file_paths = [
os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{process_id}.arrow")
for process_id in range(self.num_process)
]
# Let's acquire a lock on each process files to be sure they are finished writing
filelocks = []
for process_id, file_path in enumerate(file_paths):
if process_id == 0: # process 0 already has its lock file
filelocks.append(self.filelock)
else:
filelock = FileLock(file_path + ".lock")
try:
filelock.acquire(timeout=self.timeout)
except Timeout:
raise ValueError(
f"Cannot acquire lock on cached file {file_path} for process {process_id}."
) from None
else:
filelocks.append(filelock)
return file_paths, filelocks
def _check_all_processes_locks(self):
expected_lock_file_names = [
os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{process_id}.arrow.lock")
for process_id in range(self.num_process)
]
for expected_lock_file_name in expected_lock_file_names:
nofilelock = FileFreeLock(expected_lock_file_name)
try:
nofilelock.acquire(timeout=self.timeout)
except Timeout:
raise ValueError(
f"Expected to find locked file {expected_lock_file_name} from process {self.process_id} but it doesn't exist."
) from None
else:
nofilelock.release()
def _check_rendez_vous(self):
expected_lock_file_name = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-0.arrow.lock")
nofilelock = FileFreeLock(expected_lock_file_name)
try:
nofilelock.acquire(timeout=self.timeout)
except Timeout:
raise ValueError(
f"Expected to find locked file {expected_lock_file_name} from process {self.process_id} but it doesn't exist."
) from None
else:
nofilelock.release()
lock_file_name = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-rdv.lock")
rendez_vous_lock = FileLock(lock_file_name)
try:
rendez_vous_lock.acquire(timeout=self.timeout)
except Timeout:
raise ValueError(f"Couldn't acquire lock on {lock_file_name} from process {self.process_id}.") from None
else:
rendez_vous_lock.release()
def _finalize(self):
"""Close all the writing process and load/gather the data
from all the nodes if main node or all_process is True.
"""
if self.writer is not None:
self.writer.finalize()
self.writer = None
# release the locks of the processes > 0 so that process 0 can lock them to read + delete the data
if self.filelock is not None and self.process_id > 0:
self.filelock.release()
if self.keep_in_memory:
# Read the predictions and references
reader = ArrowReader(path=self.data_dir, info=DatasetInfo(features=self.features))
self.data = Dataset.from_buffer(self.buf_writer.getvalue())
elif self.process_id == 0:
# Let's acquire a lock on each node files to be sure they are finished writing
file_paths, filelocks = self._get_all_cache_files()
# Read the predictions and references
try:
reader = ArrowReader(path="", info=DatasetInfo(features=self.features))
self.data = Dataset(**reader.read_files([{"filename": f} for f in file_paths]))
except FileNotFoundError:
raise ValueError(
"Error in finalize: another metric instance is already using the local cache file. "
"Please specify an experiment_id to avoid collision between distributed metric instances."
) from None
# Store file paths and locks and we will release/delete them after the computation.
self.file_paths = file_paths
self.filelocks = filelocks
[docs] def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[dict]:
"""Compute the metrics.
Usage of positional arguments is not allowed to prevent mistakes.
Args:
predictions (list/array/tensor, optional): Predictions.
references (list/array/tensor, optional): References.
**kwargs (optional): Keyword arguments that will be forwarded to the metrics :meth:`_compute`
method (see details in the docstring).
Return:
dict or None
- Dictionary with the metrics if this metric is run on the main process (``process_id == 0``).
- None if the metric is not run on the main process (``process_id != 0``).
"""
if predictions is not None:
self.add_batch(predictions=predictions, references=references)
self._finalize()
self.cache_file_name = None
self.filelock = None
if self.process_id == 0:
self.data.set_format(type=self.info.format)
predictions = self.data["predictions"]
references = self.data["references"]
with temp_seed(self.seed):
output = self._compute(predictions=predictions, references=references, **kwargs)
if self.buf_writer is not None:
self.buf_writer = None
del self.data
self.data = None
else:
# Release locks and delete all the cache files. Process 0 is released last.
for filelock, file_path in reversed(list(zip(self.filelocks, self.file_paths))):
logger.info(f"Removing {file_path}")
del self.data
self.data = None
del self.writer
self.writer = None
os.remove(file_path)
filelock.release()
return output
else:
return None
[docs] def add_batch(self, *, predictions=None, references=None):
"""Add a batch of predictions and references for the metric's stack.
Args:
predictions (list/array/tensor, optional): Predictions.
references (list/array/tensor, optional): References.
"""
batch = {"predictions": predictions, "references": references}
batch = self.info.features.encode_batch(batch)
if self.writer is None:
self._init_writer()
try:
self.writer.write_batch(batch)
except pa.ArrowInvalid:
raise ValueError(
f"Predictions and/or references don't match the expected format.\n"
f"Expected format: {self.features},\n"
f"Input predictions: {predictions},\n"
f"Input references: {references}"
) from None
[docs] def add(self, *, prediction=None, reference=None):
"""Add one prediction and reference for the metric's stack.
Args:
prediction (list/array/tensor, optional): Predictions.
reference (list/array/tensor, optional): References.
"""
example = {"predictions": prediction, "references": reference}
example = self.info.features.encode_example(example)
if self.writer is None:
self._init_writer()
try:
self.writer.write(example)
except pa.ArrowInvalid:
raise ValueError(
f"Prediction and/or reference don't match the expected format.\n"
f"Expected format: {self.features},\n"
f"Input predictions: {prediction},\n"
f"Input references: {reference}"
) from None
def _init_writer(self, timeout=1):
if self.num_process > 1:
if self.process_id == 0:
file_path = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-rdv.lock")
self.rendez_vous_lock = FileLock(file_path)
try:
self.rendez_vous_lock.acquire(timeout=timeout)
except TimeoutError:
raise ValueError(
f"Error in _init_writer: another metric instance is already using the local cache file at {file_path}. "
f"Please specify an experiment_id (currently: {self.experiment_id}) to avoid collision "
f"between distributed metric instances."
) from None
if self.keep_in_memory:
self.buf_writer = pa.BufferOutputStream()
self.writer = ArrowWriter(
features=self.info.features, stream=self.buf_writer, writer_batch_size=self.writer_batch_size
)
else:
self.buf_writer = None
# Get cache file name and lock it
if self.cache_file_name is None or self.filelock is None:
cache_file_name, filelock = self._create_cache_file() # get ready
self.cache_file_name = cache_file_name
self.filelock = filelock
self.writer = ArrowWriter(
features=self.info.features, path=self.cache_file_name, writer_batch_size=self.writer_batch_size
)
# Setup rendez-vous here if
if self.num_process > 1:
if self.process_id == 0:
self._check_all_processes_locks() # wait for everyone to be ready
self.rendez_vous_lock.release() # let everyone go
else:
self._check_rendez_vous() # wait for master to be ready and to let everyone go
def _info(self) -> MetricInfo:
"""Construct the MetricInfo object. See `MetricInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
Returns:
info: (MetricInfo) The metrics information
"""
raise NotImplementedError
[docs] def download_and_prepare(
self,
download_config: Optional[DownloadConfig] = None,
dl_manager: Optional[DownloadManager] = None,
):
"""Downloads and prepares dataset for reading.
Args:
download_config (:class:`DownloadConfig`, optional): Specific download configuration parameters.
dl_manager (:class:`DownloadManager`, optional): Specific download manager to use.
"""
if dl_manager is None:
if download_config is None:
download_config = DownloadConfig()
download_config.cache_dir = os.path.join(self.data_dir, "downloads")
download_config.force_download = False
dl_manager = DownloadManager(
dataset_name=self.name, download_config=download_config, data_dir=self.data_dir
)
self._download_and_prepare(dl_manager)
def _download_and_prepare(self, dl_manager):
"""Downloads and prepares resources for the metric.
This is the internal implementation to overwrite called when user calls
`download_and_prepare`. It should download all required resources for the metric.
Args:
dl_manager (:class:`DownloadManager`): `DownloadManager` used to download and cache data.
"""
return None
def _compute(self, *, predictions=None, references=None, **kwargs) -> Dict[str, Any]:
"""This method defines the common API for all the metrics in the library"""
raise NotImplementedError
def __del__(self):
if hasattr(self, "filelock") and self.filelock is not None:
self.filelock.release()
if hasattr(self, "rendez_vous_lock") and self.rendez_vous_lock is not None:
self.rendez_vous_lock.release()
if hasattr(self, "writer"): # in case it was already deleted
del self.writer
if hasattr(self, "data"): # in case it was already deleted
del self.data