|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Core data functions, dispatch calls to the requested dataset.""" |
|
import importlib |
|
|
|
|
|
|
|
|
|
class DataSource: |
|
"""The API that any data source should implement.""" |
|
|
|
def get_tfdata(self, ordered, *, process_split=True, allow_cache=True): |
|
"""Creates this data object as a tf.data.Dataset. |
|
|
|
This will be called separately in each process, and it is up to the dataset |
|
implementation to shard it accordingly if desired! |
|
|
|
Args: |
|
ordered: if True, the dataset should use deterministic ordering, if False |
|
it may have undefined ordering. Think of True == val, False == train. |
|
process_split: if False then every process receives the entire dataset |
|
(e.g. for evaluators running in a single process). |
|
allow_cache: whether to allow caching the opened data or not. |
|
|
|
Returns: |
|
A tf.data.Dataset object. |
|
|
|
Raises: |
|
RuntimeError: if not implemented by the dataset, but called. |
|
""" |
|
raise RuntimeError("not implemented for {self.__class__.__name__}") |
|
|
|
@property |
|
def total_examples(self): |
|
"""Returns number of examples in the dataset, regardless of sharding.""" |
|
raise RuntimeError("not implemented for {self.__class__.__name__}") |
|
|
|
def num_examples_per_process(self): |
|
"""Returns a list of the numer of examples for each process. |
|
|
|
This is only needed for datasets that should go through make_for_inference. |
|
|
|
Returns: |
|
Returns a list of the numer of examples for each process. |
|
|
|
Ideally, this would always be `[total() / nprocess] * nprocess`, but in |
|
reality we can almost never perfectly shard a dataset across arbitrary |
|
number of processes. |
|
|
|
One alternative option that can work in some cases is to not even shard |
|
the dataset and thus return `[num_examples()] * nprocess. |
|
|
|
Raises: |
|
RuntimeError: if not implemented by the dataset, but called. |
|
""" |
|
raise RuntimeError("not implemented for {self.__class__.__name__}") |
|
|
|
|
|
def get(name, **kw): |
|
if name.startswith("bv:"): |
|
mod = importlib.import_module(f"big_vision.datasets.{name[3:]}") |
|
return mod.DataSource(**kw) |
|
else: |
|
mod = importlib.import_module("big_vision.datasets.tfds") |
|
return mod.DataSource(name, **kw) |
|
|