isa / invariant_slot_attention /lib /input_pipeline.py
ondrejbiza's picture
Working on isa demo.
a560c26
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input pipeline for TFDS datasets."""
import functools
import os
from typing import Dict, List, Tuple
from clu import deterministic_data
from clu import preprocess_spec
import jax
import jax.numpy as jnp
import ml_collections
import sunds
import tensorflow as tf
import tensorflow_datasets as tfds
from invariant_slot_attention.lib import preprocessing
Array = jnp.ndarray
PRNGKey = Array
PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords"
FEATURES_CLEVR_WITH_MASKS = {
"image": tf.io.FixedLenFeature([240, 320, 3], tf.string),
"mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string),
"x": tf.io.FixedLenFeature([11], tf.float32),
"y": tf.io.FixedLenFeature([11], tf.float32),
"z": tf.io.FixedLenFeature([11], tf.float32),
"pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32),
"rotation": tf.io.FixedLenFeature([11], tf.float32),
"size": tf.io.FixedLenFeature([11], tf.string),
"material": tf.io.FixedLenFeature([11], tf.string),
"shape": tf.io.FixedLenFeature([11], tf.string),
"color": tf.io.FixedLenFeature([11], tf.string),
"visibility": tf.io.FixedLenFeature([11], tf.float32),
}
PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords"
FEATURES_TETROMINOES = {
"image": tf.io.FixedLenFeature([35, 35, 3], tf.string),
"mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string),
"x": tf.io.FixedLenFeature([4], tf.float32),
"y": tf.io.FixedLenFeature([4], tf.float32),
"shape": tf.io.FixedLenFeature([4], tf.float32),
"color": tf.io.FixedLenFeature([4, 3], tf.float32),
"visibility": tf.io.FixedLenFeature([4], tf.float32),
}
PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords"
FEATURES_OBJECTS_ROOM = {
"image": tf.io.FixedLenFeature([64, 64, 3], tf.string),
"mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string),
}
PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords"
FEATURES_WAYMO_OPEN = {
"image": tf.io.FixedLenFeature([128, 192, 3], tf.string),
"segmentations": tf.io.FixedLenFeature([128, 192], tf.string),
"depth": tf.io.FixedLenFeature([128, 192], tf.float32),
"num_objects": tf.io.FixedLenFeature([1], tf.int64),
"has_mask": tf.io.FixedLenFeature([1], tf.int64),
"camera": tf.io.FixedLenFeature([1], tf.int64),
}
def _decode_tetrominoes(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_TETROMINOES)
for k in ["mask", "image"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_objects_room(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_OBJECTS_ROOM)
for k in ["mask", "image"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_clevr_with_masks(example_proto):
single_example = tf.io.parse_single_example(
example_proto, FEATURES_CLEVR_WITH_MASKS)
for k in ["mask", "image", "color", "material", "shape", "size"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
return single_example
def _decode_waymo_open(example_proto):
"""Unserializes a serialized tf.train.Example sample."""
single_example = tf.io.parse_single_example(
example_proto, FEATURES_WAYMO_OPEN)
for k in ["image", "segmentations"]:
single_example[k] = tf.squeeze(
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
single_example["segmentations"] = tf.expand_dims(
single_example["segmentations"], axis=-1)
single_example["depth"] = tf.expand_dims(
single_example["depth"], axis=-1)
return single_example
def _preprocess_minimal(example):
return {
"image": example["image"],
"segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8),
}
def _sunds_create_task():
"""Create a sunds task to return images and instance segmentation."""
return sunds.tasks.Nerf(
yield_mode=sunds.tasks.YieldMode.IMAGE,
additional_camera_specs={
"depth_image": False, # Not available in the dataset.
"category_image": False, # Not available in the dataset.
"instance_image": True,
"extrinsics": True,
},
additional_frame_specs={"pose": True},
add_name=True
)
def preprocess_example(features,
preprocess_strs):
"""Processes a single data example.
Args:
features: A dictionary containing the tensors of a single data example.
preprocess_strs: List of strings, describing one preprocessing operation
each, in clu.preprocess_spec format.
Returns:
Dictionary containing the preprocessed tensors of a single data example.
"""
all_ops = preprocessing.all_ops()
preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops)
return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types
def get_batch_dims(global_batch_size):
"""Gets the first two axis sizes for data batches.
Args:
global_batch_size: Integer, the global batch size (across all devices).
Returns:
List of batch dimensions
Raises:
ValueError if the requested dimensions don't make sense with the
number of devices.
"""
num_local_devices = jax.local_device_count()
if global_batch_size % jax.host_count() != 0:
raise ValueError(f"Global batch size {global_batch_size} not evenly "
f"divisble with {jax.host_count()}.")
per_host_batch_size = global_batch_size // jax.host_count()
if per_host_batch_size % num_local_devices != 0:
raise ValueError(f"Global batch size {global_batch_size} not evenly "
f"divisible with {jax.host_count()} hosts with a per host "
f"batch size of {per_host_batch_size} and "
f"{num_local_devices} local devices. ")
return [num_local_devices, per_host_batch_size // num_local_devices]
def create_datasets(
config,
data_rng):
"""Create datasets for training and evaluation.
For the same data_rng and config this will return the same datasets. The
datasets only contain stateless operations.
Args:
config: Configuration to use.
data_rng: JAX PRNGKey for dataset pipeline.
Returns:
A tuple with the training dataset and the evaluation dataset.
"""
if config.data.dataset_name == "tetrominoes":
ds = tf.data.TFRecordDataset(
PATH_TETROMINOES,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_tetrominoes,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TetrominoesBuilder:
"""Builder for tentrominoes dataset."""
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
"""Simple function to conform to the builder api."""
if split == "train":
# We use 512 training examples.
ds = ds.skip(100)
ds = ds.take(512)
return tf.data.experimental.assert_cardinality(512)(ds)
elif split == "validation":
# 100 validation examples.
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = TetrominoesBuilder()
elif config.data.dataset_name == "objects_room":
ds = tf.data.TFRecordDataset(
PATH_OBJECTS_ROOM,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_objects_room,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class ObjectsRoomBuilder:
"""Builder for objects room dataset."""
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
"""Simple function to conform to the builder api."""
if split == "train":
# 1M - 100 training examples.
ds = ds.skip(100)
return tf.data.experimental.assert_cardinality(999900)(ds)
elif split == "validation":
# 100 validation examples.
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = ObjectsRoomBuilder()
elif config.data.dataset_name == "clevr_with_masks":
ds = tf.data.TFRecordDataset(
PATH_CLEVR_WITH_MASKS,
compression_type="GZIP", buffer_size=2*(2**20))
ds = ds.map(_decode_clevr_with_masks,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_preprocess_minimal,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
class CLEVRWithMasksBuilder:
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
if split == "train":
ds = ds.skip(100)
return tf.data.experimental.assert_cardinality(99900)(ds)
elif split == "validation":
ds = ds.take(100)
return tf.data.experimental.assert_cardinality(100)(ds)
else:
raise ValueError("Invalid split.")
dataset_builder = CLEVRWithMasksBuilder()
elif config.data.dataset_name == "waymo_open":
train_path = os.path.join(
PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*")
eval_path = os.path.join(
PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*")
train_files = tf.data.Dataset.list_files(train_path)
eval_files = tf.data.Dataset.list_files(eval_path)
train_data_reader = functools.partial(
tf.data.TFRecordDataset,
compression_type="ZLIB", buffer_size=2*(2**20))
eval_data_reader = functools.partial(
tf.data.TFRecordDataset,
compression_type="ZLIB", buffer_size=2*(2**20))
train_dataset = train_files.interleave(
train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_dataset = eval_files.interleave(
eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.map(
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_dataset = eval_dataset.map(
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# We need to set the dataset cardinality. We assume we have
# the full dataset.
train_dataset = train_dataset.apply(
tf.data.experimental.assert_cardinality(158081))
class WaymoOpenBuilder:
def as_dataset(self, split, *unused_args, **unused_kwargs):
if split == "train":
return train_dataset
elif split == "validation":
return eval_dataset
else:
raise ValueError("Invalid split.")
dataset_builder = WaymoOpenBuilder()
elif config.data.dataset_name == "multishapenet_easy":
dataset_builder = sunds.builder(
name=config.get("tfds_name", "msn_easy"),
data_dir=config.get(
"data_dir", "gs://kubric-public/tfds"),
try_gcs=True)
dataset_builder.as_dataset = functools.partial(
dataset_builder.as_dataset, task=_sunds_create_task())
elif config.data.dataset_name == "tfds":
dataset_builder = tfds.builder(
config.data.tfds_name, data_dir=config.data.data_dir)
else:
raise ValueError("Please specify a valid dataset name.")
batch_dims = get_batch_dims(config.batch_size)
train_preprocess_fn = functools.partial(
preprocess_example, preprocess_strs=config.preproc_train)
eval_preprocess_fn = functools.partial(
preprocess_example, preprocess_strs=config.preproc_eval)
train_split_name = config.get("train_split", "train")
eval_split_name = config.get("validation_split", "validation")
train_ds = deterministic_data.create_dataset(
dataset_builder,
split=train_split_name,
rng=data_rng,
preprocess_fn=train_preprocess_fn,
cache=False,
shuffle_buffer_size=config.data.shuffle_buffer_size,
batch_dims=batch_dims,
num_epochs=None,
shuffle=True)
if config.data.dataset_name == "waymo_open":
# We filter Waymo Open for empty segmentation masks.
def filter_fn(features):
unique_instances = tf.unique(
tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0]
n_instances = tf.size(unique_instances, tf.int32)
# n_instances == 1 means we only have the background.
return 2 <= n_instances
else:
filter_fn = None
eval_ds = deterministic_data.create_dataset(
dataset_builder,
split=eval_split_name,
rng=None,
preprocess_fn=eval_preprocess_fn,
filter_fn=filter_fn,
cache=False,
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=None)
if config.data.dataset_name == "waymo_open":
# We filter Waymo Open for empty segmentation masks after preprocessing.
# For the full dataset, we know how many we will end up with.
eval_batch_size = batch_dims[0] * batch_dims[1]
# We don't pad the last batch => floor.
eval_num_batches = int(
jnp.floor(1872 / eval_batch_size / jax.host_count()))
eval_ds = eval_ds.apply(
tf.data.experimental.assert_cardinality(
eval_num_batches))
return train_ds, eval_ds