pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core data functions, dispatch calls to the requested dataset."""
import importlib
# Note: intentionally not using ABC to avoid forcing implementation of every
# method, since one can imagine train-only datasets for example.
class DataSource:
"""The API that any data source should implement."""
def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
"""Creates this data object as a tf.data.Dataset.
This will be called separately in each process, and it is up to the dataset
implementation to shard it accordingly if desired!
Args:
ordered: if True, the dataset should use deterministic ordering, if False
it may have undefined ordering. Think of True == val, False == train.
process_split: if False then every process receives the entire dataset
(e.g. for evaluators running in a single process).
allow_cache: whether to allow caching the opened data or not.
Returns:
A tf.data.Dataset object.
Raises:
RuntimeError: if not implemented by the dataset, but called.
"""
raise RuntimeError("not implemented for {self.__class__.__name__}")
@property
def total_examples(self):
"""Returns number of examples in the dataset, regardless of sharding."""
raise RuntimeError("not implemented for {self.__class__.__name__}")
def num_examples_per_process(self):
"""Returns a list of the numer of examples for each process.
This is only needed for datasets that should go through make_for_inference.
Returns:
Returns a list of the numer of examples for each process.
Ideally, this would always be `[total() / nprocess] * nprocess`, but in
reality we can almost never perfectly shard a dataset across arbitrary
number of processes.
One alternative option that can work in some cases is to not even shard
the dataset and thus return `[num_examples()] * nprocess.
Raises:
RuntimeError: if not implemented by the dataset, but called.
"""
raise RuntimeError("not implemented for {self.__class__.__name__}")
def get(name, **kw):
if name.startswith("bv:"):
mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
return mod.DataSource(**kw)
else:
mod = importlib.import_module("big_vision.datasets.tfds")
return mod.DataSource(name, **kw)