# coding=utf-8 # Copyright 2023 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for TFDS datasets.""" import functools import os from typing import Dict, List, Tuple from clu import deterministic_data from clu import preprocess_spec import jax import jax.numpy as jnp import ml_collections import sunds import tensorflow as tf import tensorflow_datasets as tfds from invariant_slot_attention.lib import preprocessing Array = jnp.ndarray PRNGKey = Array PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords" FEATURES_CLEVR_WITH_MASKS = { "image": tf.io.FixedLenFeature([240, 320, 3], tf.string), "mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string), "x": tf.io.FixedLenFeature([11], tf.float32), "y": tf.io.FixedLenFeature([11], tf.float32), "z": tf.io.FixedLenFeature([11], tf.float32), "pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32), "rotation": tf.io.FixedLenFeature([11], tf.float32), "size": tf.io.FixedLenFeature([11], tf.string), "material": tf.io.FixedLenFeature([11], tf.string), "shape": tf.io.FixedLenFeature([11], tf.string), "color": tf.io.FixedLenFeature([11], tf.string), "visibility": tf.io.FixedLenFeature([11], tf.float32), } PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords" FEATURES_TETROMINOES = { "image": tf.io.FixedLenFeature([35, 35, 3], tf.string), "mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string), "x": tf.io.FixedLenFeature([4], tf.float32), "y": tf.io.FixedLenFeature([4], tf.float32), "shape": tf.io.FixedLenFeature([4], tf.float32), "color": tf.io.FixedLenFeature([4, 3], tf.float32), "visibility": tf.io.FixedLenFeature([4], tf.float32), } PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords" FEATURES_OBJECTS_ROOM = { "image": tf.io.FixedLenFeature([64, 64, 3], tf.string), "mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string), } PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords" FEATURES_WAYMO_OPEN = { "image": tf.io.FixedLenFeature([128, 192, 3], tf.string), "segmentations": tf.io.FixedLenFeature([128, 192], tf.string), "depth": tf.io.FixedLenFeature([128, 192], tf.float32), "num_objects": tf.io.FixedLenFeature([1], tf.int64), "has_mask": tf.io.FixedLenFeature([1], tf.int64), "camera": tf.io.FixedLenFeature([1], tf.int64), } def _decode_tetrominoes(example_proto): single_example = tf.io.parse_single_example( example_proto, FEATURES_TETROMINOES) for k in ["mask", "image"]: single_example[k] = tf.squeeze( tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) return single_example def _decode_objects_room(example_proto): single_example = tf.io.parse_single_example( example_proto, FEATURES_OBJECTS_ROOM) for k in ["mask", "image"]: single_example[k] = tf.squeeze( tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) return single_example def _decode_clevr_with_masks(example_proto): single_example = tf.io.parse_single_example( example_proto, FEATURES_CLEVR_WITH_MASKS) for k in ["mask", "image", "color", "material", "shape", "size"]: single_example[k] = tf.squeeze( tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) return single_example def _decode_waymo_open(example_proto): """Unserializes a serialized tf.train.Example sample.""" single_example = tf.io.parse_single_example( example_proto, FEATURES_WAYMO_OPEN) for k in ["image", "segmentations"]: single_example[k] = tf.squeeze( tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) single_example["segmentations"] = tf.expand_dims( single_example["segmentations"], axis=-1) single_example["depth"] = tf.expand_dims( single_example["depth"], axis=-1) return single_example def _preprocess_minimal(example): return { "image": example["image"], "segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8), } def _sunds_create_task(): """Create a sunds task to return images and instance segmentation.""" return sunds.tasks.Nerf( yield_mode=sunds.tasks.YieldMode.IMAGE, additional_camera_specs={ "depth_image": False, # Not available in the dataset. "category_image": False, # Not available in the dataset. "instance_image": True, "extrinsics": True, }, additional_frame_specs={"pose": True}, add_name=True ) def preprocess_example(features, preprocess_strs): """Processes a single data example. Args: features: A dictionary containing the tensors of a single data example. preprocess_strs: List of strings, describing one preprocessing operation each, in clu.preprocess_spec format. Returns: Dictionary containing the preprocessed tensors of a single data example. """ all_ops = preprocessing.all_ops() preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops) return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types def get_batch_dims(global_batch_size): """Gets the first two axis sizes for data batches. Args: global_batch_size: Integer, the global batch size (across all devices). Returns: List of batch dimensions Raises: ValueError if the requested dimensions don't make sense with the number of devices. """ num_local_devices = jax.local_device_count() if global_batch_size % jax.host_count() != 0: raise ValueError(f"Global batch size {global_batch_size} not evenly " f"divisble with {jax.host_count()}.") per_host_batch_size = global_batch_size // jax.host_count() if per_host_batch_size % num_local_devices != 0: raise ValueError(f"Global batch size {global_batch_size} not evenly " f"divisible with {jax.host_count()} hosts with a per host " f"batch size of {per_host_batch_size} and " f"{num_local_devices} local devices. ") return [num_local_devices, per_host_batch_size // num_local_devices] def create_datasets( config, data_rng): """Create datasets for training and evaluation. For the same data_rng and config this will return the same datasets. The datasets only contain stateless operations. Args: config: Configuration to use. data_rng: JAX PRNGKey for dataset pipeline. Returns: A tuple with the training dataset and the evaluation dataset. """ if config.data.dataset_name == "tetrominoes": ds = tf.data.TFRecordDataset( PATH_TETROMINOES, compression_type="GZIP", buffer_size=2*(2**20)) ds = ds.map(_decode_tetrominoes, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.map(_preprocess_minimal, num_parallel_calls=tf.data.experimental.AUTOTUNE) class TetrominoesBuilder: """Builder for tentrominoes dataset.""" def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): """Simple function to conform to the builder api.""" if split == "train": # We use 512 training examples. ds = ds.skip(100) ds = ds.take(512) return tf.data.experimental.assert_cardinality(512)(ds) elif split == "validation": # 100 validation examples. ds = ds.take(100) return tf.data.experimental.assert_cardinality(100)(ds) else: raise ValueError("Invalid split.") dataset_builder = TetrominoesBuilder() elif config.data.dataset_name == "objects_room": ds = tf.data.TFRecordDataset( PATH_OBJECTS_ROOM, compression_type="GZIP", buffer_size=2*(2**20)) ds = ds.map(_decode_objects_room, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.map(_preprocess_minimal, num_parallel_calls=tf.data.experimental.AUTOTUNE) class ObjectsRoomBuilder: """Builder for objects room dataset.""" def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): """Simple function to conform to the builder api.""" if split == "train": # 1M - 100 training examples. ds = ds.skip(100) return tf.data.experimental.assert_cardinality(999900)(ds) elif split == "validation": # 100 validation examples. ds = ds.take(100) return tf.data.experimental.assert_cardinality(100)(ds) else: raise ValueError("Invalid split.") dataset_builder = ObjectsRoomBuilder() elif config.data.dataset_name == "clevr_with_masks": ds = tf.data.TFRecordDataset( PATH_CLEVR_WITH_MASKS, compression_type="GZIP", buffer_size=2*(2**20)) ds = ds.map(_decode_clevr_with_masks, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.map(_preprocess_minimal, num_parallel_calls=tf.data.experimental.AUTOTUNE) class CLEVRWithMasksBuilder: def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): if split == "train": ds = ds.skip(100) return tf.data.experimental.assert_cardinality(99900)(ds) elif split == "validation": ds = ds.take(100) return tf.data.experimental.assert_cardinality(100)(ds) else: raise ValueError("Invalid split.") dataset_builder = CLEVRWithMasksBuilder() elif config.data.dataset_name == "waymo_open": train_path = os.path.join( PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*") eval_path = os.path.join( PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*") train_files = tf.data.Dataset.list_files(train_path) eval_files = tf.data.Dataset.list_files(eval_path) train_data_reader = functools.partial( tf.data.TFRecordDataset, compression_type="ZLIB", buffer_size=2*(2**20)) eval_data_reader = functools.partial( tf.data.TFRecordDataset, compression_type="ZLIB", buffer_size=2*(2**20)) train_dataset = train_files.interleave( train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE) eval_dataset = eval_files.interleave( eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE) train_dataset = train_dataset.map( _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE) eval_dataset = eval_dataset.map( _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE) # We need to set the dataset cardinality. We assume we have # the full dataset. train_dataset = train_dataset.apply( tf.data.experimental.assert_cardinality(158081)) class WaymoOpenBuilder: def as_dataset(self, split, *unused_args, **unused_kwargs): if split == "train": return train_dataset elif split == "validation": return eval_dataset else: raise ValueError("Invalid split.") dataset_builder = WaymoOpenBuilder() elif config.data.dataset_name == "multishapenet_easy": dataset_builder = sunds.builder( name=config.get("tfds_name", "msn_easy"), data_dir=config.get( "data_dir", "gs://kubric-public/tfds"), try_gcs=True) dataset_builder.as_dataset = functools.partial( dataset_builder.as_dataset, task=_sunds_create_task()) elif config.data.dataset_name == "tfds": dataset_builder = tfds.builder( config.data.tfds_name, data_dir=config.data.data_dir) else: raise ValueError("Please specify a valid dataset name.") batch_dims = get_batch_dims(config.batch_size) train_preprocess_fn = functools.partial( preprocess_example, preprocess_strs=config.preproc_train) eval_preprocess_fn = functools.partial( preprocess_example, preprocess_strs=config.preproc_eval) train_split_name = config.get("train_split", "train") eval_split_name = config.get("validation_split", "validation") train_ds = deterministic_data.create_dataset( dataset_builder, split=train_split_name, rng=data_rng, preprocess_fn=train_preprocess_fn, cache=False, shuffle_buffer_size=config.data.shuffle_buffer_size, batch_dims=batch_dims, num_epochs=None, shuffle=True) if config.data.dataset_name == "waymo_open": # We filter Waymo Open for empty segmentation masks. def filter_fn(features): unique_instances = tf.unique( tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0] n_instances = tf.size(unique_instances, tf.int32) # n_instances == 1 means we only have the background. return 2 <= n_instances else: filter_fn = None eval_ds = deterministic_data.create_dataset( dataset_builder, split=eval_split_name, rng=None, preprocess_fn=eval_preprocess_fn, filter_fn=filter_fn, cache=False, batch_dims=batch_dims, num_epochs=1, shuffle=False, pad_up_to_batches=None) if config.data.dataset_name == "waymo_open": # We filter Waymo Open for empty segmentation masks after preprocessing. # For the full dataset, we know how many we will end up with. eval_batch_size = batch_dims[0] * batch_dims[1] # We don't pad the last batch => floor. eval_num_batches = int( jnp.floor(1872 / eval_batch_size / jax.host_count())) eval_ds = eval_ds.apply( tf.data.experimental.assert_cardinality( eval_num_batches)) return train_ds, eval_ds