|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A common dataset reader.""" |
|
|
|
from typing import Any, Callable, List, Optional |
|
|
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
from official.modeling.hyperparams import config_definitions as cfg |
|
|
|
|
|
class InputReader: |
|
"""Input reader that returns a tf.data.Dataset instance.""" |
|
|
|
def __init__(self, |
|
params: cfg.DataConfig, |
|
shards: Optional[List[str]] = None, |
|
dataset_fn=tf.data.TFRecordDataset, |
|
decoder_fn: Optional[Callable[..., Any]] = None, |
|
parser_fn: Optional[Callable[..., Any]] = None, |
|
dataset_transform_fn: Optional[Callable[[tf.data.Dataset], |
|
tf.data.Dataset]] = None, |
|
postprocess_fn: Optional[Callable[..., Any]] = None): |
|
"""Initializes an InputReader instance. |
|
|
|
Args: |
|
params: A config_definitions.DataConfig object. |
|
shards: A list of files to be read. If given, read from these files. |
|
Otherwise, read from params.input_path. |
|
dataset_fn: A `tf.data.Dataset` that consumes the input files. For |
|
example, it can be `tf.data.TFRecordDataset`. |
|
decoder_fn: An optional `callable` that takes the serialized data string |
|
and decodes them into the raw tensor dictionary. |
|
parser_fn: An optional `callable` that takes the decoded raw tensors dict |
|
and parse them into a dictionary of tensors that can be consumed by the |
|
model. It will be executed after decoder_fn. |
|
dataset_transform_fn: An optional `callable` that takes a |
|
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be |
|
executed after parser_fn. |
|
postprocess_fn: A optional `callable` that processes batched tensors. It |
|
will be executed after batching. |
|
""" |
|
if params.input_path and params.tfds_name: |
|
raise ValueError('At most one of `input_path` and `tfds_name` can be ' |
|
'specified, but got %s and %s.' % ( |
|
params.input_path, params.tfds_name)) |
|
self._shards = shards |
|
self._tfds_builder = None |
|
if self._shards: |
|
self._num_files = len(self._shards) |
|
elif not params.tfds_name: |
|
self._input_patterns = params.input_path.strip().split(',') |
|
self._num_files = 0 |
|
for input_pattern in self._input_patterns: |
|
input_pattern = input_pattern.strip() |
|
if not input_pattern: |
|
continue |
|
matched_files = tf.io.gfile.glob(input_pattern) |
|
if not matched_files: |
|
raise ValueError('%s does not match any files.' % input_pattern) |
|
else: |
|
self._num_files += len(matched_files) |
|
if self._num_files == 0: |
|
raise ValueError('%s does not match any files.' % params.input_path) |
|
else: |
|
if not params.tfds_split: |
|
raise ValueError( |
|
'`tfds_name` is %s, but `tfds_split` is not specified.' % |
|
params.tfds_name) |
|
self._tfds_builder = tfds.builder( |
|
params.tfds_name, data_dir=params.tfds_data_dir) |
|
|
|
self._global_batch_size = params.global_batch_size |
|
self._is_training = params.is_training |
|
self._drop_remainder = params.drop_remainder |
|
self._shuffle_buffer_size = params.shuffle_buffer_size |
|
self._cache = params.cache |
|
self._cycle_length = params.cycle_length |
|
self._block_length = params.block_length |
|
self._sharding = params.sharding |
|
self._examples_consume = params.examples_consume |
|
self._tfds_split = params.tfds_split |
|
self._tfds_download = params.tfds_download |
|
self._tfds_as_supervised = params.tfds_as_supervised |
|
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature |
|
|
|
self._dataset_fn = dataset_fn |
|
self._decoder_fn = decoder_fn |
|
self._parser_fn = parser_fn |
|
self._dataset_transform_fn = dataset_transform_fn |
|
self._postprocess_fn = postprocess_fn |
|
|
|
def _read_sharded_files( |
|
self, |
|
input_context: Optional[tf.distribute.InputContext] = None): |
|
"""Reads a dataset from sharded files.""" |
|
|
|
if self._shards: |
|
dataset = tf.data.Dataset.from_tensor_slices(self._shards) |
|
else: |
|
dataset = tf.data.Dataset.list_files( |
|
self._input_patterns, shuffle=self._is_training) |
|
if self._sharding and input_context and ( |
|
input_context.num_input_pipelines > 1): |
|
dataset = dataset.shard(input_context.num_input_pipelines, |
|
input_context.input_pipeline_id) |
|
if self._is_training: |
|
dataset = dataset.repeat() |
|
|
|
dataset = dataset.interleave( |
|
map_func=self._dataset_fn, |
|
cycle_length=self._cycle_length, |
|
block_length=self._block_length, |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
return dataset |
|
|
|
def _read_single_file( |
|
self, |
|
input_context: Optional[tf.distribute.InputContext] = None): |
|
"""Reads a dataset from a single file.""" |
|
|
|
dataset = self._dataset_fn(self._shards or self._input_patterns) |
|
|
|
|
|
|
|
options = tf.data.Options() |
|
options.experimental_distribute.auto_shard_policy = ( |
|
tf.data.experimental.AutoShardPolicy.OFF) |
|
dataset = dataset.with_options(options) |
|
if self._sharding and input_context and ( |
|
input_context.num_input_pipelines > 1): |
|
dataset = dataset.shard(input_context.num_input_pipelines, |
|
input_context.input_pipeline_id) |
|
if self._is_training: |
|
dataset = dataset.repeat() |
|
return dataset |
|
|
|
def _read_tfds( |
|
self, |
|
input_context: Optional[tf.distribute.InputContext] = None |
|
) -> tf.data.Dataset: |
|
"""Reads a dataset from tfds.""" |
|
if self._tfds_download: |
|
self._tfds_builder.download_and_prepare() |
|
|
|
read_config = tfds.ReadConfig( |
|
interleave_cycle_length=self._cycle_length, |
|
interleave_block_length=self._block_length, |
|
input_context=input_context) |
|
decoders = {} |
|
if self._tfds_skip_decoding_feature: |
|
for skip_feature in self._tfds_skip_decoding_feature.split(','): |
|
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() |
|
dataset = self._tfds_builder.as_dataset( |
|
split=self._tfds_split, |
|
shuffle_files=self._is_training, |
|
as_supervised=self._tfds_as_supervised, |
|
decoders=decoders, |
|
read_config=read_config) |
|
return dataset |
|
|
|
@property |
|
def tfds_info(self) -> tfds.core.DatasetInfo: |
|
"""Returns TFDS dataset info, if available.""" |
|
if self._tfds_builder: |
|
return self._tfds_builder.info |
|
else: |
|
raise ValueError('tfds_info is not available, because the dataset ' |
|
'is not loaded from tfds.') |
|
|
|
def read( |
|
self, |
|
input_context: Optional[tf.distribute.InputContext] = None |
|
) -> tf.data.Dataset: |
|
"""Generates a tf.data.Dataset object.""" |
|
if self._tfds_builder: |
|
dataset = self._read_tfds(input_context) |
|
elif self._num_files > 1: |
|
dataset = self._read_sharded_files(input_context) |
|
else: |
|
assert self._num_files == 1 |
|
dataset = self._read_single_file(input_context) |
|
|
|
if self._cache: |
|
dataset = dataset.cache() |
|
|
|
if self._is_training: |
|
dataset = dataset.shuffle(self._shuffle_buffer_size) |
|
|
|
if self._examples_consume > 0: |
|
dataset = dataset.take(self._examples_consume) |
|
|
|
def maybe_map_fn(dataset, fn): |
|
return dataset if fn is None else dataset.map( |
|
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
|
|
dataset = maybe_map_fn(dataset, self._decoder_fn) |
|
dataset = maybe_map_fn(dataset, self._parser_fn) |
|
|
|
if self._dataset_transform_fn is not None: |
|
dataset = self._dataset_transform_fn(dataset) |
|
|
|
per_replica_batch_size = input_context.get_per_replica_batch_size( |
|
self._global_batch_size) if input_context else self._global_batch_size |
|
|
|
dataset = dataset.batch( |
|
per_replica_batch_size, drop_remainder=self._drop_remainder) |
|
dataset = maybe_map_fn(dataset, self._postprocess_fn) |
|
return dataset.prefetch(tf.data.experimental.AUTOTUNE) |
|
|