diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d816213244015ba9afbe6a25f1cf5c5c89500aa6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/venv +/flagged +/clevr_isa_ts +*.pyc \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78bf1a8d3d8c4665a7135205a3561cf10e097031 --- /dev/null +++ b/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2702ea3c08264ef4ff90458fa33e001e8269934d --- /dev/null +++ b/app.py @@ -0,0 +1,66 @@ +import functools +import os + +from absl import flags +import gradio as gr +import jax +import jax.numpy as jnp + +from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config +from invariant_slot_attention.lib import input_pipeline +from invariant_slot_attention.lib import utils + + +def load_model(config): + rng, data_rng = jax.random.split(rng) + + # Initialize model + model = utils.build_model_from_config(config.model) + + def init_model(rng): + rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4) + + init_conditioning = None + init_inputs = jnp.ones( + [1] + list(train_ds.element_spec["video"].shape)[2:], + jnp.float32) + initial_vars = model.init( + {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng}, + video=init_inputs, conditioning=init_conditioning, + padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32)) + + # Split into state variables (e.g. for batchnorm stats) and model params. + # Note that `pop()` on a FrozenDict performs a deep copy. + state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error + + # Filter out intermediates (we don't want to store these in the TrainState). + state_vars = utils.filter_key_from_frozen_dict( + state_vars, key="intermediates") + return state_vars, initial_params + + state_vars, initial_params = init_model(rng) + + learning_rate_fn = lr_schedules.get_learning_rate_fn(config) + tx = optimizers.get_optimizer( + config.optimizer_configs, learning_rate_fn, params=initial_params) + + opt_state = tx.init(initial_params) + + state = utils.TrainState( + step=1, opt_state=opt_state, params=initial_params, rng=rng, + variables=state_vars) + + loss_fn = functools.partial( + losses.compute_full_loss, loss_config=config.losses) + + checkpoint_dir = os.path.join(workdir, "checkpoints") + ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) + state = ckpt.restore_or_initialize(state) + + +def greet(name): + return "Hello " + name + "!" + + +demo = gr.Interface(fn=greet, inputs="text", outputs="text") +demo.launch() diff --git a/invariant_slot_attention/configs/__init__.py b/invariant_slot_attention/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78bf1a8d3d8c4665a7135205a3561cf10e097031 --- /dev/null +++ b/invariant_slot_attention/configs/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/invariant_slot_attention/configs/clevr_with_masks/baseline.py b/invariant_slot_attention/configs/clevr_with_masks/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..1266d3e9922986b7169442c6b37aedb185f0ef03 --- /dev/null +++ b/invariant_slot_attention/configs/clevr_with_masks/baseline.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVR.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "clevr_with_masks", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets + ], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..419c9f81177741c849cab8525078b6c9b3f2d7a8 --- /dev/null +++ b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVR.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "clevr_with_masks", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets + ], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..0b57031427a8de952f684150b1e3f46cc8832b77 --- /dev/null +++ b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVR.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "clevr_with_masks", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b34bb470710ca59cd13a7a8c5c18e58426c9c138 --- /dev/null +++ b/invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVR.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "clevr_with_masks", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "top_left_crop(top=29, left=64, height=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/resnet/baseline.py b/invariant_slot_attention/configs/clevrtex/resnet/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..fab9b2f23244de5ca43185481da8eabb4ee130e3 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/resnet/baseline.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..46920638650ed9d9fbf42a77d3a3434943ebe37f --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce6f5158816741970d700135a537cc389912b39 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..7919d29c22ab144ab87cc86589eadcfd25ef7134 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py b/invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..8da9c66d0e09049220534ac3ebf75aa7e6cd4615 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..7149ca57cc69ba64e289876216c990a1fa507290 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..27eeab8df435ecc2c8bd7923302633e8b63f87e0 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..4025bd0ffc76bbdd862d8b56c822a85de76d8122 --- /dev/null +++ b/invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on CLEVRTex.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tfds", + # The TFDS dataset will be created in the directory below + # if you follow the README in datasets/clevrtex. + "data_dir": "~/tensorflow_datasets", + "tfds_name": "clevr_tex", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "central_crop(height=192,width=192)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/multishapenet_easy/baseline.py b/invariant_slot_attention/configs/multishapenet_easy/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..66ae09df2b0296f00ec6e578fd7e6f916e083632 --- /dev/null +++ b/invariant_slot_attention/configs/multishapenet_easy/baseline.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on MultiShapeNet-Easy.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "multishapenet_easy", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..f5658b121974dc0c68cfb00b09b658eea066f17d --- /dev/null +++ b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on MultiShapeNet-Easy.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "multishapenet_easy", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..e149b03948ffed8e00319b74b235d92ebd4dfab7 --- /dev/null +++ b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on MultiShapeNet-Easy.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "multishapenet_easy", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..18cf7c72a79c8685bee8a6074198468ebfc02dfc --- /dev/null +++ b/invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on MultiShapeNet-Easy.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "multishapenet_easy", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 128) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.preproc_eval = [ + "sunds_to_tfds_video", + "video_from_tfds", + "subtract_one_from_segmentations", + "central_crop(height=240, width=240)", + "resize_small({size})".format(size=min(*config.data.resolution)) + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/objects_room/baseline.py b/invariant_slot_attention/configs/objects_room/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..620fd517e7c698e237faf489267bfc2de7e9d390 --- /dev/null +++ b/invariant_slot_attention/configs/objects_room/baseline.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on objects_room.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "objects_room", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (64, 64) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, False, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/objects_room/equiv_transl.py b/invariant_slot_attention/configs/objects_room/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..43300d39d46dc20e409e22d1ade3ed0f94a52715 --- /dev/null +++ b/invariant_slot_attention/configs/objects_room/equiv_transl.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on objects_room.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "objects_room", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (64, 64) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, False, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..eec4057424de0d62672fd16bae64d0179412c387 --- /dev/null +++ b/invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on objects_room.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "objects_room", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (64, 64) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, False, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/objects_room/equiv_transl_scale.py b/invariant_slot_attention/configs/objects_room/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..da1a093a0f8cb9bfeec4df54cefba1e3d5e9081d --- /dev/null +++ b/invariant_slot_attention/configs/objects_room/equiv_transl_scale.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on objects_room.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "objects_room", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (64, 64) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=10)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 16), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, False, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/tetrominoes/baseline.py b/invariant_slot_attention/configs/tetrominoes/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..73da5ef606d108ec7deaead86edc20a89fc77448 --- /dev/null +++ b/invariant_slot_attention/configs/tetrominoes/baseline.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Tetrominoes with 512 train samples.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 20000 + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 5000 + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 1000 + config.checkpoint_every_steps = 1000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tetrominoes", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (35, 35) + }) + + config.max_instances = 4 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=3)" + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=3)" + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(1, 1), (1, 1), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (4, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (35, 35), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 256, + "output_size": 256, + "num_hidden_layers": 5, + "activate_output": True + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 35, + } + + return config + + diff --git a/invariant_slot_attention/configs/tetrominoes/equiv_transl.py b/invariant_slot_attention/configs/tetrominoes/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..4dbda3bab3383218e562c6446a5a484e016f5723 --- /dev/null +++ b/invariant_slot_attention/configs/tetrominoes/equiv_transl.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Tetrominoes with 512 train samples.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 20000 + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 5000 + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + # TODO(obvis): Implement masked evaluation. + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 1000 + config.checkpoint_every_steps = 1000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "tetrominoes", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (35, 35) + }) + + config.max_instances = 4 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=3)" + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "sparse_to_dense_annotation(max_instances=3)" + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SimpleCNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(1, 1), (1, 1), (1, 1), (1, 1)] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (4, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (35, 35), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 256, + "output_size": 256, + "num_hidden_layers": 5, + "activate_output": True + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 35, + } + + return config + + diff --git a/invariant_slot_attention/configs/waymo_open/baseline.py b/invariant_slot_attention/configs/waymo_open/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..bedc948423c472a8bd1c91dc711cd6436792dedc --- /dev/null +++ b/invariant_slot_attention/configs/waymo_open/baseline.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Waymo Open.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "waymo_open", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 192) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "delete_small_masks(threshold=0.01, max_instances_after=11)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttention", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInit", + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 24), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/waymo_open/equiv_transl.py b/invariant_slot_attention/configs/waymo_open/equiv_transl.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1f59114b0e642d562ade9066b968f7c70ac42d --- /dev/null +++ b/invariant_slot_attention/configs/waymo_open/equiv_transl.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Waymo Open.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "waymo_open", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 192) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "delete_small_masks(threshold=0.01, max_instances_after=11)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "concat" + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv", + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.ParamStateInitRandomPositions", + "shape": + (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 24), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py b/invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..aaaa9fdcf5a0c815dc528308b730b729cfdce7cf --- /dev/null +++ b/invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Waymo Open.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "waymo_open", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 192) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "delete_small_masks(threshold=0.01, max_instances_after=11)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 24), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_rotations_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py b/invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..0680cec146181076afff3c4ed278c5af6d5b1348 --- /dev/null +++ b/invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for unsupervised training on Waymo Open.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 # from the original Slot Attention + config.init_checkpoint = ml_collections.ConfigDict() + config.init_checkpoint.xid = 0 # Disabled by default. + config.init_checkpoint.wid = 1 + + config.optimizer_configs = ml_collections.ConfigDict() + config.optimizer_configs.optimizer = "adam" + + config.optimizer_configs.grad_clip = ml_collections.ConfigDict() + config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm" + config.optimizer_configs.grad_clip.clip_value = 0.05 + + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = "compound" + config.lr_configs.factors = "constant * cosine_decay * linear_warmup" + config.lr_configs.warmup_steps = 10000 # from the original Slot Attention + config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps") + # from the original Slot Attention + config.lr_configs.base_learning_rate = 4e-4 + + config.eval_pad_last_batch = False # True + config.log_loss_every_steps = 50 + config.eval_every_steps = 5000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "dataset_name": "waymo_open", + "shuffle_buffer_size": config.batch_size * 8, + "resolution": (128, 192) + }) + + config.max_instances = 11 + config.num_slots = config.max_instances # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + ] + + config.preproc_eval = [ + "tfds_image_to_tfds_video", + "video_from_tfds", + "delete_small_masks(threshold=0.01, max_instances_after=11)", + ] + + config.eval_slice_size = 1 + config.eval_slice_keys = ["video", "segmentations_video"] + + # Dictionary of targets and corresponding channels. Losses need to match. + targets = {"video": 3} + config.losses = {"recon": {"targets": list(targets)}} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in targets}) + + config.model = ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long + "num_iterations": 3, + "qkv_size": 64, + "mlp_size": 128, + "grid_encoder": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.MLP", + "hidden_size": 128, + "layernorm": "pre" + }), + "add_rel_pos_to_values": True, # V3 + "zero_position_init": False, # Random positions. + "init_with_fixed_scale": None, # Random scales. + "scales_factor": 5.0, + }), + + # Predictor. + # Removed since we are running a single frame. + "predictor": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Identity" + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long + "shape": (11, 64), # (num_slots, slot_size) + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder", + "resolution": (16, 24), # Update if data resolution or strides change + "backbone": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.CNN", + "features": [64, 64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)], + "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "layer_transpose": [True, True, True, False, False] + }), + "target_readout": ml_collections.ConfigDict({ + "module": "invariant_slot_attention.modules.Readout", + "keys": list(targets), + "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension + "module": "invariant_slot_attention.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, + "output_size": targets[k]}) for k in targets], + }), + "relative_positions_and_scales": True, + "pos_emb": ml_collections.ConfigDict({ + "module": + "invariant_slot_attention.modules.RelativePositionEmbedding", + "embedding_type": + "linear", + "update_type": + "project_add", + "scales_factor": + 5.0, + }), + }), + "decode_corrected": True, + "decode_predicted": False, + }) + + # Which video-shaped variables to visualize. + config.debug_var_video_paths = { + "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long + } + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/invariant_slot_attention/lib/__init__.py b/invariant_slot_attention/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78bf1a8d3d8c4665a7135205a3561cf10e097031 --- /dev/null +++ b/invariant_slot_attention/lib/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/invariant_slot_attention/lib/evaluator.py b/invariant_slot_attention/lib/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..4b25f4915e47297635dc087edea9be21c6bdee33 --- /dev/null +++ b/invariant_slot_attention/lib/evaluator.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model evaluation.""" + +import functools +from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union + +from absl import logging +from clu import metrics +import flax +from flax import linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + +from invariant_slot_attention.lib import losses +from invariant_slot_attention.lib import utils + + +Array = jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +PRNGKey = Array + + +def get_eval_metrics( + preds, + batch, + loss_fn, + eval_metrics_cls, + predicted_max_num_instances, + ground_truth_max_num_instances, + ): + """Compute the metrics for the model predictions in inference mode. + + The metrics are averaged across *all* devices (of all hosts). + + Args: + preds: Model predictions. + batch: Inputs that should be evaluated. + loss_fn: Loss function that takes model predictions and a batch of data. + eval_metrics_cls: Evaluation metrics collection. + predicted_max_num_instances: Maximum number of instances in prediction. + ground_truth_max_num_instances: Maximum number of instances in ground truth, + including background (which counts as a separate instance). + + Returns: + The evaluation metrics. + """ + loss, loss_aux = loss_fn(preds, batch) + metrics_update = eval_metrics_cls.gather_from_model_output( + loss=loss, + **loss_aux, + predicted_segmentations=utils.remove_singleton_dim( + preds["outputs"].get("segmentations")), # pytype: disable=attribute-error + ground_truth_segmentations=batch.get("segmentations"), + predicted_max_num_instances=predicted_max_num_instances, + ground_truth_max_num_instances=ground_truth_max_num_instances, + padding_mask=batch.get("padding_mask"), + mask=batch.get("mask")) + return metrics_update + + +def eval_first_step( + model, + state_variables, + params, + batch, + rng, + conditioning_key = None +): + """Get the model predictions with a freshly initialized recurrent state. + + The model is applied to the inputs using all devices on the host. + + Args: + model: Model used in eval step. + state_variables: State variables for the model. + params: Params for the model. + batch: Inputs that should be evaluated. + rng: PRNGKey for model forward pass. + conditioning_key: Optional string. If provided, defines the batch key to be + used as conditioning signal for the model. Otherwise this is inferred from + the available keys in the batch. + Returns: + The model's predictions. + """ + logging.info("eval_first_step(batch=%s)", batch) + + conditioning = None + if conditioning_key: + conditioning = batch[conditioning_key] + preds, mutable_vars = model.apply( + {"params": params, **state_variables}, video=batch["video"], + conditioning=conditioning, mutable="intermediates", + rngs={"state_init": rng}, train=False, + padding_mask=batch.get("padding_mask")) + + if "intermediates" in mutable_vars: + preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"]) + + return preds + + +def eval_continued_step( + model, + state_variables, + params, + batch, + rng, + recurrent_states + ): + """Get the model predictions, continuing from a provided recurrent state. + + The model is applied to the inputs using all devices on the host. + + Args: + model: Model used in eval step. + state_variables: State variables for the model. + params: The model parameters. + batch: Inputs that should be evaluated. + rng: PRNGKey for model forward pass. + recurrent_states: Recurrent internal model state from which to continue. + Returns: + The model's predictions. + """ + logging.info("eval_continued_step(batch=%s, recurrent_states=%s)", batch, + recurrent_states) + + preds, mutable_vars = model.apply( + {"params": params, **state_variables}, video=batch["video"], + conditioning=recurrent_states, continue_from_previous_state=True, + mutable="intermediates", rngs={"state_init": rng}, train=False, + padding_mask=batch.get("padding_mask")) + + if "intermediates" in mutable_vars: + preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"]) + + return preds + + +def eval_step( + model, + state, + batch, + rng, + p_eval_first_step, + p_eval_continued_step, + slice_size = None, + slice_keys = None, + conditioning_key = None, + remove_from_predictions = None +): + """Compute the metrics for the given model in inference mode. + + The model is applied to the inputs using all devices on the host. Afterwards + metrics are averaged across *all* devices (of all hosts). + + Args: + model: Model used in eval step. + state: Replicated model state. + batch: Inputs that should be evaluated. + rng: PRNGKey for model forward pass. + p_eval_first_step: A parallel version of the function eval_first_step. + p_eval_continued_step: A parallel version of the function + eval_continued_step. + slice_size: Optional integer, if provided, evaluate the model on temporal + slices of this size instead of on the full sequence length at once. + slice_keys: Optional list of strings, the keys of the tensors which will be + sliced if slice_size is provided. + conditioning_key: Optional string. If provided, defines the batch key to be + used as conditioning signal for the model. Otherwise this is inferred from + the available keys in the batch. + remove_from_predictions: Remove the provided keys. The default None removes + "states" and "states_pred" from model output to save memory. Disable this + if either of these are required in the loss function or for visualization. + Returns: + Model predictions. + """ + if remove_from_predictions is None: + remove_from_predictions = ["states", "states_pred"] + + seq_len = batch["video"].shape[2] + # Sliced evaluation (i.e. on smaller temporal slices of the video). + if slice_size is not None and slice_size < seq_len: + num_slices = int(np.ceil(seq_len / slice_size)) + + assert slice_keys is not None, ( + "Slice keys need to be provided for sliced evaluation.") + + preds_per_slice = [] + # Get predictions for first slice (with fresh recurrent state). + batch_slice = utils.get_slices_along_axis( + batch, slice_keys=slice_keys, start_idx=0, end_idx=slice_size) + preds_slice = p_eval_first_step(model, state.variables, + state.params, batch_slice, rng, + conditioning_key) + preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU. + preds_per_slice.append(preds_slice) + + # Iterate over remaining slices (re-using the previous recurrent state). + for slice_idx in range(1, num_slices): + recurrent_states = preds_per_slice[-1]["states_pred"] + batch_slice = utils.get_slices_along_axis( + batch, slice_keys=slice_keys, start_idx=slice_idx * slice_size, + end_idx=(slice_idx + 1) * slice_size) + preds_slice = p_eval_continued_step( + model, state.variables, state.params, + batch_slice, rng, recurrent_states) + preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU. + preds_per_slice.append(preds_slice) + + # Remove states from predictions before concat to save memory. + for k in remove_from_predictions: + for i in range(num_slices): + _ = preds_per_slice[i].pop(k, None) + + # Join predictions along sequence dimension. + concat_fn = lambda _, *x: functools.partial(np.concatenate, axis=2)([*x]) + preds = jax.tree_map(concat_fn, preds_per_slice[0], *preds_per_slice) + + # Truncate to original sequence length. + # NOTE: This op assumes that all predictions have a (complete) time axis. + preds = jax.tree_map(lambda x: x[:, :, :seq_len], preds) + + # Evaluate on full sequence if no (or too large) slice size is provided. + else: + preds = p_eval_first_step(model, state.variables, + state.params, batch, rng, + conditioning_key) + for k in remove_from_predictions: + _ = preds.pop(k, None) + + return preds + + +def evaluate( + model, + state, + eval_ds, + loss_fn, + eval_metrics_cls, + predicted_max_num_instances, + ground_truth_max_num_instances, + slice_size = None, + slice_keys = None, + conditioning_key = None, + remove_from_predictions = None, + metrics_on_cpu = False, + ): + """Evaluate the model on the given dataset.""" + eval_metrics = None + batch = None + preds = None + rng = state.rng[0] # Get training state PRNGKey from first replica. + + if metrics_on_cpu and jax.process_count() > 1: + raise NotImplementedError( + "metrics_on_cpu feature cannot be used in a multi-host setup." + " This experiment is using {} hosts.".format(jax.process_count())) + metric_devices = jax.devices("cpu") if metrics_on_cpu else jax.devices() + + p_eval_first_step = jax.pmap( + eval_first_step, + axis_name="batch", + static_broadcasted_argnums=(0, 5), + devices=jax.devices()) + p_eval_continued_step = jax.pmap( + eval_continued_step, + axis_name="batch", + static_broadcasted_argnums=(0), + devices=jax.devices()) + p_get_eval_metrics = jax.pmap( + get_eval_metrics, + axis_name="batch", + static_broadcasted_argnums=(2, 3, 4, 5), + devices=metric_devices, + backend="cpu" if metrics_on_cpu else None) + + def reshape_fn(x): + """Function to reshape preds and batch before calling p_get_eval_metrics.""" + return np.reshape(x, [len(metric_devices), -1] + list(x.shape[2:])) + + for batch in eval_ds: + rng, eval_rng = jax.random.split(rng) + eval_rng = jax.random.fold_in(eval_rng, jax.host_id()) # Bind to host. + eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) + batch = jax.tree_map(np.asarray, batch) + preds = eval_step( + model=model, + state=state, + batch=batch, + rng=eval_rngs, + p_eval_first_step=p_eval_first_step, + p_eval_continued_step=p_eval_continued_step, + slice_size=slice_size, + slice_keys=slice_keys, + conditioning_key=conditioning_key, + remove_from_predictions=remove_from_predictions) + + if metrics_on_cpu: + # Reshape replica dim and batch-dims to work with metric_devices. + preds = jax.tree_map(reshape_fn, preds) + batch = jax.tree_map(reshape_fn, batch) + # Get metric updates. + update = p_get_eval_metrics(preds, batch, loss_fn, eval_metrics_cls, + predicted_max_num_instances, + ground_truth_max_num_instances) + update = flax.jax_utils.unreplicate(update) + eval_metrics = ( + update if eval_metrics is None else eval_metrics.merge(update)) + assert eval_metrics is not None + return eval_metrics, batch, preds diff --git a/invariant_slot_attention/lib/input_pipeline.py b/invariant_slot_attention/lib/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0ab91367ac5c3b9cbcb3b5a0396ff6a223d1f4 --- /dev/null +++ b/invariant_slot_attention/lib/input_pipeline.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Input pipeline for TFDS datasets.""" + +import functools +import os +from typing import Dict, List, Tuple + +from clu import deterministic_data +from clu import preprocess_spec + +import jax +import jax.numpy as jnp +import ml_collections + +import sunds +import tensorflow as tf +import tensorflow_datasets as tfds + +from invariant_slot_attention.lib import preprocessing + +Array = jnp.ndarray +PRNGKey = Array + + +PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords" +FEATURES_CLEVR_WITH_MASKS = { + "image": tf.io.FixedLenFeature([240, 320, 3], tf.string), + "mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string), + "x": tf.io.FixedLenFeature([11], tf.float32), + "y": tf.io.FixedLenFeature([11], tf.float32), + "z": tf.io.FixedLenFeature([11], tf.float32), + "pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32), + "rotation": tf.io.FixedLenFeature([11], tf.float32), + "size": tf.io.FixedLenFeature([11], tf.string), + "material": tf.io.FixedLenFeature([11], tf.string), + "shape": tf.io.FixedLenFeature([11], tf.string), + "color": tf.io.FixedLenFeature([11], tf.string), + "visibility": tf.io.FixedLenFeature([11], tf.float32), +} + +PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords" +FEATURES_TETROMINOES = { + "image": tf.io.FixedLenFeature([35, 35, 3], tf.string), + "mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string), + "x": tf.io.FixedLenFeature([4], tf.float32), + "y": tf.io.FixedLenFeature([4], tf.float32), + "shape": tf.io.FixedLenFeature([4], tf.float32), + "color": tf.io.FixedLenFeature([4, 3], tf.float32), + "visibility": tf.io.FixedLenFeature([4], tf.float32), +} + +PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords" +FEATURES_OBJECTS_ROOM = { + "image": tf.io.FixedLenFeature([64, 64, 3], tf.string), + "mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string), +} + +PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords" + +FEATURES_WAYMO_OPEN = { + "image": tf.io.FixedLenFeature([128, 192, 3], tf.string), + "segmentations": tf.io.FixedLenFeature([128, 192], tf.string), + "depth": tf.io.FixedLenFeature([128, 192], tf.float32), + "num_objects": tf.io.FixedLenFeature([1], tf.int64), + "has_mask": tf.io.FixedLenFeature([1], tf.int64), + "camera": tf.io.FixedLenFeature([1], tf.int64), +} + + +def _decode_tetrominoes(example_proto): + single_example = tf.io.parse_single_example( + example_proto, FEATURES_TETROMINOES) + for k in ["mask", "image"]: + single_example[k] = tf.squeeze( + tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) + return single_example + + +def _decode_objects_room(example_proto): + single_example = tf.io.parse_single_example( + example_proto, FEATURES_OBJECTS_ROOM) + for k in ["mask", "image"]: + single_example[k] = tf.squeeze( + tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) + return single_example + + +def _decode_clevr_with_masks(example_proto): + single_example = tf.io.parse_single_example( + example_proto, FEATURES_CLEVR_WITH_MASKS) + for k in ["mask", "image", "color", "material", "shape", "size"]: + single_example[k] = tf.squeeze( + tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) + return single_example + + +def _decode_waymo_open(example_proto): + """Unserializes a serialized tf.train.Example sample.""" + single_example = tf.io.parse_single_example( + example_proto, FEATURES_WAYMO_OPEN) + for k in ["image", "segmentations"]: + single_example[k] = tf.squeeze( + tf.io.decode_raw(single_example[k], tf.uint8), axis=-1) + single_example["segmentations"] = tf.expand_dims( + single_example["segmentations"], axis=-1) + single_example["depth"] = tf.expand_dims( + single_example["depth"], axis=-1) + return single_example + + +def _preprocess_minimal(example): + return { + "image": example["image"], + "segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8), + } + + +def _sunds_create_task(): + """Create a sunds task to return images and instance segmentation.""" + return sunds.tasks.Nerf( + yield_mode=sunds.tasks.YieldMode.IMAGE, + additional_camera_specs={ + "depth_image": False, # Not available in the dataset. + "category_image": False, # Not available in the dataset. + "instance_image": True, + "extrinsics": True, + }, + additional_frame_specs={"pose": True}, + add_name=True + ) + + +def preprocess_example(features, + preprocess_strs): + """Processes a single data example. + + Args: + features: A dictionary containing the tensors of a single data example. + preprocess_strs: List of strings, describing one preprocessing operation + each, in clu.preprocess_spec format. + + Returns: + Dictionary containing the preprocessed tensors of a single data example. + """ + all_ops = preprocessing.all_ops() + preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops) + return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types + + +def get_batch_dims(global_batch_size): + """Gets the first two axis sizes for data batches. + + Args: + global_batch_size: Integer, the global batch size (across all devices). + + Returns: + List of batch dimensions + + Raises: + ValueError if the requested dimensions don't make sense with the + number of devices. + """ + num_local_devices = jax.local_device_count() + if global_batch_size % jax.host_count() != 0: + raise ValueError(f"Global batch size {global_batch_size} not evenly " + f"divisble with {jax.host_count()}.") + per_host_batch_size = global_batch_size // jax.host_count() + if per_host_batch_size % num_local_devices != 0: + raise ValueError(f"Global batch size {global_batch_size} not evenly " + f"divisible with {jax.host_count()} hosts with a per host " + f"batch size of {per_host_batch_size} and " + f"{num_local_devices} local devices. ") + return [num_local_devices, per_host_batch_size // num_local_devices] + + +def create_datasets( + config, + data_rng): + """Create datasets for training and evaluation. + + For the same data_rng and config this will return the same datasets. The + datasets only contain stateless operations. + + Args: + config: Configuration to use. + data_rng: JAX PRNGKey for dataset pipeline. + + Returns: + A tuple with the training dataset and the evaluation dataset. + """ + + if config.data.dataset_name == "tetrominoes": + ds = tf.data.TFRecordDataset( + PATH_TETROMINOES, + compression_type="GZIP", buffer_size=2*(2**20)) + ds = ds.map(_decode_tetrominoes, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.map(_preprocess_minimal, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + class TetrominoesBuilder: + """Builder for tentrominoes dataset.""" + + def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): + """Simple function to conform to the builder api.""" + if split == "train": + # We use 512 training examples. + ds = ds.skip(100) + ds = ds.take(512) + return tf.data.experimental.assert_cardinality(512)(ds) + elif split == "validation": + # 100 validation examples. + ds = ds.take(100) + return tf.data.experimental.assert_cardinality(100)(ds) + else: + raise ValueError("Invalid split.") + + dataset_builder = TetrominoesBuilder() + elif config.data.dataset_name == "objects_room": + ds = tf.data.TFRecordDataset( + PATH_OBJECTS_ROOM, + compression_type="GZIP", buffer_size=2*(2**20)) + ds = ds.map(_decode_objects_room, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.map(_preprocess_minimal, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + class ObjectsRoomBuilder: + """Builder for objects room dataset.""" + + def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): + """Simple function to conform to the builder api.""" + if split == "train": + # 1M - 100 training examples. + ds = ds.skip(100) + return tf.data.experimental.assert_cardinality(999900)(ds) + elif split == "validation": + # 100 validation examples. + ds = ds.take(100) + return tf.data.experimental.assert_cardinality(100)(ds) + else: + raise ValueError("Invalid split.") + + dataset_builder = ObjectsRoomBuilder() + elif config.data.dataset_name == "clevr_with_masks": + ds = tf.data.TFRecordDataset( + PATH_CLEVR_WITH_MASKS, + compression_type="GZIP", buffer_size=2*(2**20)) + ds = ds.map(_decode_clevr_with_masks, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.map(_preprocess_minimal, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + class CLEVRWithMasksBuilder: + def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs): + if split == "train": + ds = ds.skip(100) + return tf.data.experimental.assert_cardinality(99900)(ds) + elif split == "validation": + ds = ds.take(100) + return tf.data.experimental.assert_cardinality(100)(ds) + else: + raise ValueError("Invalid split.") + + dataset_builder = CLEVRWithMasksBuilder() + elif config.data.dataset_name == "waymo_open": + train_path = os.path.join( + PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*") + eval_path = os.path.join( + PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*") + + train_files = tf.data.Dataset.list_files(train_path) + eval_files = tf.data.Dataset.list_files(eval_path) + + train_data_reader = functools.partial( + tf.data.TFRecordDataset, + compression_type="ZLIB", buffer_size=2*(2**20)) + eval_data_reader = functools.partial( + tf.data.TFRecordDataset, + compression_type="ZLIB", buffer_size=2*(2**20)) + + train_dataset = train_files.interleave( + train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE) + eval_dataset = eval_files.interleave( + eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + train_dataset = train_dataset.map( + _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE) + eval_dataset = eval_dataset.map( + _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + # We need to set the dataset cardinality. We assume we have + # the full dataset. + train_dataset = train_dataset.apply( + tf.data.experimental.assert_cardinality(158081)) + + class WaymoOpenBuilder: + def as_dataset(self, split, *unused_args, **unused_kwargs): + if split == "train": + return train_dataset + elif split == "validation": + return eval_dataset + else: + raise ValueError("Invalid split.") + + dataset_builder = WaymoOpenBuilder() + elif config.data.dataset_name == "multishapenet_easy": + dataset_builder = sunds.builder( + name=config.get("tfds_name", "msn_easy"), + data_dir=config.get( + "data_dir", "gs://kubric-public/tfds"), + try_gcs=True) + dataset_builder.as_dataset = functools.partial( + dataset_builder.as_dataset, task=_sunds_create_task()) + elif config.data.dataset_name == "tfds": + dataset_builder = tfds.builder( + config.data.tfds_name, data_dir=config.data.data_dir) + else: + raise ValueError("Please specify a valid dataset name.") + + batch_dims = get_batch_dims(config.batch_size) + + train_preprocess_fn = functools.partial( + preprocess_example, preprocess_strs=config.preproc_train) + eval_preprocess_fn = functools.partial( + preprocess_example, preprocess_strs=config.preproc_eval) + + train_split_name = config.get("train_split", "train") + eval_split_name = config.get("validation_split", "validation") + + train_ds = deterministic_data.create_dataset( + dataset_builder, + split=train_split_name, + rng=data_rng, + preprocess_fn=train_preprocess_fn, + cache=False, + shuffle_buffer_size=config.data.shuffle_buffer_size, + batch_dims=batch_dims, + num_epochs=None, + shuffle=True) + + if config.data.dataset_name == "waymo_open": + # We filter Waymo Open for empty segmentation masks. + def filter_fn(features): + unique_instances = tf.unique( + tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0] + n_instances = tf.size(unique_instances, tf.int32) + # n_instances == 1 means we only have the background. + return 2 <= n_instances + else: + filter_fn = None + + eval_ds = deterministic_data.create_dataset( + dataset_builder, + split=eval_split_name, + rng=None, + preprocess_fn=eval_preprocess_fn, + filter_fn=filter_fn, + cache=False, + batch_dims=batch_dims, + num_epochs=1, + shuffle=False, + pad_up_to_batches=None) + + if config.data.dataset_name == "waymo_open": + # We filter Waymo Open for empty segmentation masks after preprocessing. + # For the full dataset, we know how many we will end up with. + eval_batch_size = batch_dims[0] * batch_dims[1] + # We don't pad the last batch => floor. + eval_num_batches = int( + jnp.floor(1872 / eval_batch_size / jax.host_count())) + eval_ds = eval_ds.apply( + tf.data.experimental.assert_cardinality( + eval_num_batches)) + + return train_ds, eval_ds diff --git a/invariant_slot_attention/lib/losses.py b/invariant_slot_attention/lib/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d90eeee848ddf810112e1f88e8ccddcea3178bf9 --- /dev/null +++ b/invariant_slot_attention/lib/losses.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss functions.""" + +import functools +import inspect +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +import ml_collections + +_LOSS_FUNCTIONS = {} + +Array = Any # jnp.ndarray somehow doesn't work anymore for pytype. +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ArrayDict = Dict[str, Array] +DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet +PRNGKey = Array +LossFn = Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]], + Tuple[Array, ArrayTree]] +ConfigAttr = Any +MetricSpec = Dict[str, str] + + +def standardize_loss_config( + loss_config +): + """Standardize loss configs into a common ConfigDict format. + + Args: + loss_config: List of strings or ConfigDict specifying loss configuration. + Valid input formats are: - Option 1 (list of strings), for example, + `loss_config = ["box", "presence"]` - Option 2 (losses with weights + only), for example, + `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 + (losses with weights and other parameters), for example, + `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"}, + "presence": {"weight": 2}})` + + Returns: + Standardized ConfigDict containing the loss configuration. + + Raises: + ValueError: If loss_config is a list that contains non-string entries. + """ + + if isinstance(loss_config, Sequence): # Option 1 + if not all(isinstance(loss_type, str) for loss_type in loss_config): + raise ValueError(f"Loss types all need to be str but got {loss_config}") + return ml_collections.FrozenConfigDict({k: {} for k in loss_config}) + + # Convert all option-2-style weights to option-3-style dictionaries. + loss_config = { + k: { + "weight": v + } if isinstance(v, (float, int)) else v for k, v in loss_config.items() + } + return ml_collections.FrozenConfigDict(loss_config) + + +def update_loss_aux(loss_aux, update): + existing_keys = set(update.keys()).intersection(loss_aux.keys()) + if existing_keys: + raise KeyError( + f"Can't overwrite existing keys in loss_aux: {existing_keys}") + loss_aux.update(update) + + +def compute_full_loss( + preds, targets, + loss_config +): + """Loss function that parses and combines weighted loss terms. + + Args: + preds: Dictionary of tensors containing model predictions. + targets: Dictionary of tensors containing prediction targets. + loss_config: List of strings or ConfigDict specifying loss configuration. + See @register_loss decorated functions below for valid loss names. + Valid losses formats are: - Option 1 (list of strings), for example, + `loss_config = ["box", "presence"]` - Option 2 (losses with weights + only), for example, + `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 (losses + with weights and other parameters), for example, + `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"}, + "presence": {"weight": 2}})` - Option 4 (like + 3 but decoupling name and loss_type), for + example, + `loss_config = ConfigDict({"recon_flow": {"loss_type": "recon", + "key": "flow"}, + "recon_video": {"loss_type": "recon", + "key": "video"}})` + + Returns: + A 2-tuple of the sum of all individual loss terms and a dictionary of + auxiliary losses and metrics. + """ + + loss = jnp.zeros([], jnp.float32) + loss_aux = {} + loss_config = standardize_loss_config(loss_config) + for loss_name, cfg in loss_config.items(): + context_kwargs = {"preds": preds, "targets": targets} + weight, loss_term, loss_aux_update = compute_loss_term( + loss_name=loss_name, context_kwargs=context_kwargs, config_kwargs=cfg) + + unweighted_loss = jnp.mean(loss_term) + loss += weight * unweighted_loss + loss_aux_update[loss_name + "_value"] = unweighted_loss + loss_aux_update[loss_name + "_weight"] = jnp.ones_like(unweighted_loss) + update_loss_aux(loss_aux, loss_aux_update) + return loss, loss_aux + + +def register_loss(func=None, + *, + name = None, + check_unused_kwargs = True): + """Decorator for registering a loss function. + + Can be used without arguments: + ``` + @register_loss + def my_loss(**_): + return 0 + ``` + or with keyword arguments: + ``` + @register_loss(name="my_renamed_loss") + def my_loss(**_): + return 0 + ``` + + Loss functions may accept + - context kwargs: `preds` and `targets` + - config kwargs: any argument specified in the config + - the special `config_kwargs` parameter that contains the entire loss config + Loss functions also _need_ to accept a **kwarg argument to support extending + the interface. + They should return either: + - just the computed loss (pre-reduction) + - or a tuple of the computed loss and a loss_aux_updates dict + + Args: + func: the decorated function + name (str): Optional name to be used for this loss in the config. Defaults + to the name of the function. + check_unused_kwargs (bool): By default compute_loss_term raises an error if + there are any unused config kwargs. If this flag is set to False that step + is skipped. This is useful if the config_kwargs should be passed onward to + another function. + + Returns: + The decorated function (or a partial of the decorator) + """ + # If this decorator has been called with parameters but no function, then we + # return the decorator again (but with partially filled parameters). + # This allows using both @register_loss and @register_loss(name="foo") + if func is None: + return functools.partial( + register_loss, name=name, check_unused_kwargs=check_unused_kwargs) + + # No (further) arguments: this is the actual decorator + # ensure that the loss function includes a **kwargs argument + loss_name = name if name is not None else func.__name__ + if not any(v.kind == inspect.Parameter.VAR_KEYWORD + for k, v in inspect.signature(func).parameters.items()): + raise TypeError( + f"Loss function '{loss_name}' needs to include a **kwargs argument") + func.name = loss_name + func.check_unused_kwargs = check_unused_kwargs + _LOSS_FUNCTIONS[loss_name] = func + return func + + +def compute_loss_term( + loss_name, context_kwargs, + config_kwargs): + """Compute a loss function given its config and context parameters. + + Takes care of: + - finding the correct loss function based on "loss_type" or name + - the optional "weight" parameter + - checking for typos and collisions in config parameters + - adding the optional loss_aux_updates if omitted by the loss_fn + + Args: + loss_name: Name of the loss, i.e. its key in the config.losses dict. + context_kwargs: Dictionary of context variables (`preds` and `targets`) + config_kwargs: The config dict for this loss. + + Returns: + 1. the loss weight (float) + 2. loss term (Array) + 3. loss aux updates (Dict[str, Array]) + + Raises: + KeyError: + Unknown loss_type + KeyError: + Unused config entries, i.e. not used by the loss function. + Not raised if using @register_loss(check_unused_kwargs=False) + KeyError: Config entry with a name that conflicts with a context_kwarg + ValueError: Non-numerical weight in config_kwargs + + """ + + # Make a dict copy of config_kwargs + kwargs = {k: v for k, v in config_kwargs.items()} + + # Get the loss function + loss_type = kwargs.pop("loss_type", loss_name) + if loss_type not in _LOSS_FUNCTIONS: + raise KeyError(f"Unknown loss_type '{loss_type}'.") + loss_fn = _LOSS_FUNCTIONS[loss_type] + + # Take care of "weight" term + weight = kwargs.pop("weight", 1.0) + if not isinstance(weight, (int, float)): + raise ValueError(f"Weight for loss {loss_name} should be a number, " + f"but was {weight}.") + + # Check for unused config entries (to prevent typos etc.) + config_keys = set(kwargs) + if loss_fn.check_unused_kwargs: + param_names = set(inspect.signature(loss_fn).parameters) + unused_config_keys = config_keys - param_names + if unused_config_keys: + raise KeyError(f"Unrecognized config entries {unused_config_keys} " + f"for loss {loss_name}.") + + # Check for key collisions between context and config + conflicting_config_keys = config_keys.intersection(context_kwargs) + if conflicting_config_keys: + raise KeyError(f"The config keys {conflicting_config_keys} conflict " + f"with the context parameters ({context_kwargs.keys()}) " + f"for loss {loss_name}.") + + # Construct the arguments for the loss function + kwargs.update(context_kwargs) + kwargs["config_kwargs"] = config_kwargs + + # Call loss + results = loss_fn(**kwargs) + + # Add empty loss_aux_updates if necessary + if isinstance(results, Tuple): + loss, loss_aux_update = results + else: + loss, loss_aux_update = results, {} + + return weight, loss, loss_aux_update + + +# -------- Loss functions -------- +@register_loss +def recon(preds, + targets, + key = "video", + reduction_type = "sum", + **_): + """Reconstruction loss (MSE).""" + squared_l2_norm_fn = jax.vmap(functools.partial( + squared_l2_norm, reduction_type=reduction_type)) + targets = targets[key] + loss = squared_l2_norm_fn(preds["outputs"][key], targets) + if reduction_type == "mean": + # This rescaling reflects taking the sum over feature axis & + # mean over space/time axes. + loss *= targets.shape[-1] # pytype: disable=attribute-error # allow-recursive-types + return jnp.mean(loss) + + +def squared_l2_norm(preds, targets, + reduction_type = "sum"): + if reduction_type == "sum": + return jnp.sum(jnp.square(preds - targets)) + elif reduction_type == "mean": + return jnp.mean(jnp.square(preds - targets)) + else: + raise ValueError(f"Unsupported reduction_type: {reduction_type}") diff --git a/invariant_slot_attention/lib/metrics.py b/invariant_slot_attention/lib/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0a0d4c159cd2e4a2c78845fcfa3c378ac8a7d2 --- /dev/null +++ b/invariant_slot_attention/lib/metrics.py @@ -0,0 +1,263 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Clustering metrics.""" + +from typing import Optional, Sequence, Union + +from clu import metrics +import flax +import jax +import jax.numpy as jnp +import numpy as np + +Ndarray = Union[np.ndarray, jnp.ndarray] + + +def check_shape(x, expected_shape, name): + """Check whether shape x is as expected. + + Args: + x: Any data type with `shape` attribute. If `shape` attribute is not present + it is assumed to be a scalar with shape (). + expected_shape: The shape that is expected of x. For example, + [None, None, 3] can be the `expected_shape` for a color image, + [4, None, None, 3] if we know that batch size is 4. + name: Name of `x` to provide informative error messages. + + Raises: ValueError if x's shape does not match expected_shape. Also raises + ValueError if expected_shape is not a list or tuple. + """ + if not isinstance(expected_shape, (list, tuple)): + raise ValueError( + "expected_shape should be a list or tuple of ints but got " + f"{expected_shape}.") + + # Scalars have shape () by definition. + shape = getattr(x, "shape", ()) + + if (len(shape) != len(expected_shape) or + any(j is not None and i != j for i, j in zip(shape, expected_shape))): + raise ValueError( + f"Input {name} had shape {shape} but {expected_shape} was expected.") + + +def _validate_inputs(predicted_segmentations, + ground_truth_segmentations, + padding_mask, + mask = None): + """Checks that all inputs have the expected shapes. + + Args: + predicted_segmentations: An array of integers of shape [bs, seq_len, H, W] + containing model segmentation predictions. + ground_truth_segmentations: An array of integers of shape [bs, seq_len, H, + W] containing ground truth segmentations. + padding_mask: An array of integers of shape [bs, seq_len, H, W] defining + regions where the ground truth is meaningless, for example because this + corresponds to regions which were padded during data augmentation. Value 0 + corresponds to padded regions, 1 corresponds to valid regions to be used + for metric calculation. + mask: An optional array of boolean mask values of shape [bs]. `True` + corresponds to actual batch examples whereas `False` corresponds to + padding. + + Raises: + ValueError if the inputs are not valid. + """ + + check_shape( + predicted_segmentations, [None, None, None, None], + "predicted_segmentations [bs, seq_len, h, w]") + check_shape( + ground_truth_segmentations, [None, None, None, None], + "ground_truth_segmentations [bs, seq_len, h, w]") + check_shape( + predicted_segmentations, ground_truth_segmentations.shape, + "predicted_segmentations [should match ground_truth_segmentations]") + check_shape( + padding_mask, ground_truth_segmentations.shape, + "padding_mask [should match ground_truth_segmentations]") + + if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer): + raise ValueError("predicted_segmentations has to be integer-valued. " + "Got {}".format(predicted_segmentations.dtype)) + + if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer): + raise ValueError("ground_truth_segmentations has to be integer-valued. " + "Got {}".format(ground_truth_segmentations.dtype)) + + if not jnp.issubdtype(padding_mask.dtype, jnp.integer): + raise ValueError("padding_mask has to be integer-valued. " + "Got {}".format(padding_mask.dtype)) + + if mask is not None: + check_shape(mask, [None], "mask [bs]") + if not jnp.issubdtype(mask.dtype, jnp.bool_): + raise ValueError("mask has to be boolean. Got {}".format(mask.dtype)) + + +def adjusted_rand_index(true_ids, pred_ids, + num_instances_true, num_instances_pred, + padding_mask = None, + ignore_background = False): + """Computes the adjusted Rand index (ARI), a clustering similarity score. + + Args: + true_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The true cluster assignment encoded + as integer ids. + pred_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The predicted cluster assignment + encoded as integer ids. + num_instances_true: An integer, the number of instances in true_ids + (i.e. max(true_ids) + 1). + num_instances_pred: An integer, the number of instances in true_ids + (i.e. max(pred_ids) + 1). + padding_mask: An array of integers of shape [batch_size, seq_len, H, W] + defining regions where the ground truth is meaningless, for example + because this corresponds to regions which were padded during data + augmentation. Value 0 corresponds to padded regions, 1 corresponds to + valid regions to be used for metric calculation. + ignore_background: Boolean, if True, then ignore all pixels where + true_ids == 0 (default: False). + + Returns: + ARI scores as a float32 array of shape [batch_size]. + + References: + Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions" + https://link.springer.com/article/10.1007/BF01908075 + Wikipedia + https://en.wikipedia.org/wiki/Rand_index + Scikit Learn + http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html + """ + # pylint: disable=invalid-name + true_oh = jax.nn.one_hot(true_ids, num_instances_true) + pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred) + if padding_mask is not None: + true_oh = true_oh * padding_mask[Ellipsis, None] + # pred_oh = pred_oh * padding_mask[..., None] # <-- not needed + + if ignore_background: + true_oh = true_oh[Ellipsis, 1:] # Remove the background row. + + N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh) + A = jnp.sum(N, axis=-1) # row-sum (batch_size, c) + B = jnp.sum(N, axis=-2) # col-sum (batch_size, k) + num_points = jnp.sum(A, axis=1) + + rindex = jnp.sum(N * (N - 1), axis=[1, 2]) + aindex = jnp.sum(A * (A - 1), axis=1) + bindex = jnp.sum(B * (B - 1), axis=1) + expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1) + max_rindex = (aindex + bindex) / 2 + denominator = max_rindex - expected_rindex + ari = (rindex - expected_rindex) / denominator + + # There are two cases for which the denominator can be zero: + # 1. If both label_pred and label_true assign all pixels to a single cluster. + # (max_rindex == expected_rindex == rindex == num_points * (num_points-1)) + # 2. If both label_pred and label_true assign max 1 point to each cluster. + # (max_rindex == expected_rindex == rindex == 0) + # In both cases, we want the ARI score to be 1.0: + return jnp.where(denominator, ari, 1.0) + + +@flax.struct.dataclass +class Ari(metrics.Average): + """Adjusted Rand Index (ARI) computed from predictions and labels. + + ARI is a similarity score to compare two clusterings. ARI returns values in + the range [-1, 1], where 1 corresponds to two identical clusterings (up to + permutation), i.e. a perfect match between the predicted clustering and the + ground-truth clustering. A value of (close to) 0 corresponds to chance. + Negative values corresponds to cases where the agreement between the + clusterings is less than expected from a random assignment. + + In this implementation, we use ARI to compare predicted instance segmentation + masks (including background prediction) with ground-truth segmentation + annotations. + """ + + @classmethod + def from_model_output(cls, + predicted_segmentations, + ground_truth_segmentations, + padding_mask, + ground_truth_max_num_instances, + predicted_max_num_instances, + ignore_background = False, + mask = None, + **_): + """Computation of the ARI clustering metric. + + NOTE: This implementation does not currently support padding masks. + + Args: + predicted_segmentations: An array of integers of shape + [bs, seq_len, H, W] containing model segmentation predictions. + ground_truth_segmentations: An array of integers of shape + [bs, seq_len, H, W] containing ground truth segmentations. + padding_mask: An array of integers of shape [bs, seq_len, H, W] + defining regions where the ground truth is meaningless, for example + because this corresponds to regions which were padded during data + augmentation. Value 0 corresponds to padded regions, 1 corresponds to + valid regions to be used for metric calculation. + ground_truth_max_num_instances: Maximum number of instances (incl. + background, which counts as the 0-th instance) possible in the dataset. + predicted_max_num_instances: Maximum number of predicted instances (incl. + background). + ignore_background: If True, then ignore all pixels where + ground_truth_segmentations == 0 (default: False). + mask: An optional array of boolean mask values of shape [bs]. `True` + corresponds to actual batch examples whereas `False` corresponds to + padding. + + Returns: + Object of Ari with computed intermediate values. + """ + _validate_inputs( + predicted_segmentations=predicted_segmentations, + ground_truth_segmentations=ground_truth_segmentations, + padding_mask=padding_mask, + mask=mask) + + batch_size = predicted_segmentations.shape[0] + + if mask is None: + mask = jnp.ones(batch_size, dtype=padding_mask.dtype) + else: + mask = jnp.asarray(mask, dtype=padding_mask.dtype) + + ari_batch = adjusted_rand_index( + pred_ids=predicted_segmentations, + true_ids=ground_truth_segmentations, + num_instances_true=ground_truth_max_num_instances, + num_instances_pred=predicted_max_num_instances, + padding_mask=padding_mask, + ignore_background=ignore_background) + return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask)) # pylint: disable=unexpected-keyword-arg + + +@flax.struct.dataclass +class AriNoBg(Ari): + """Adjusted Rand Index (ARI), ignoring the ground-truth background label.""" + + @classmethod + def from_model_output(cls, **kwargs): + """See `Ari` docstring for allowed keyword arguments.""" + return super().from_model_output(**kwargs, ignore_background=True) diff --git a/invariant_slot_attention/lib/preprocessing.py b/invariant_slot_attention/lib/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..51f40efbba4090f9c8154d852c72acf529c43ff7 --- /dev/null +++ b/invariant_slot_attention/lib/preprocessing.py @@ -0,0 +1,1236 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video preprocessing ops.""" + +import abc +import dataclasses +import functools +from typing import Optional, Sequence, Tuple, Union + +from absl import logging +from clu import preprocess_spec + +import numpy as np +import tensorflow as tf + +from invariant_slot_attention.lib import transforms + +Features = preprocess_spec.Features +all_ops = lambda: preprocess_spec.get_all_ops(__name__) +SEED_KEY = preprocess_spec.SEED_KEY +NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding. +NOTRACK_LABEL = -1 + +IMAGE = "image" +VIDEO = "video" +SEGMENTATIONS = "segmentations" +RAGGED_SEGMENTATIONS = "ragged_segmentations" +SPARSE_SEGMENTATIONS = "sparse_segmentations" +SHAPE = "shape" +PADDING_MASK = "padding_mask" +RAGGED_BOXES = "ragged_boxes" +BOXES = "boxes" +FRAMES = "frames" +FLOW = "flow" +DEPTH = "depth" +ORIGINAL_SIZE = "original_size" +INSTANCE_LABELS = "instance_labels" +INSTANCE_MULTI_LABELS = "instance_multi_labels" +BOXES_VIDEO = "boxes_video" +IMAGE_PADDING_MASK = "image_padding_mask" +VIDEO_PADDING_MASK = "video_padding_mask" + + +def convert_uint16_to_float(array, min_val, max_val): + return tf.cast(array, tf.float32) / 65535. * (max_val - min_val) + min_val + + +def get_resize_small_shape(original_size, + small_size): + h, w = original_size + ratio = ( + tf.cast(small_size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32)) + h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32) + w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32) + return h, w + + +def adjust_small_size(original_size, + small_size, max_size): + """Computes the adjusted small size to ensure large side < max_size.""" + h, w = original_size + min_original_size = tf.cast(tf.minimum(w, h), tf.float32) + max_original_size = tf.cast(tf.maximum(w, h), tf.float32) + if max_original_size / min_original_size * small_size > max_size: + small_size = tf.cast(tf.floor( + max_size * min_original_size / max_original_size), tf.int32) + return small_size + + +def crop_or_pad_boxes(boxes, top, left, height, + width, h_orig, w_orig): + """Transforms the relative box coordinates according to the frame crop. + + Note that, if height/width are larger than h_orig/w_orig, this function + implements the equivalent of padding. + + Args: + boxes: Tensor of bounding boxes with shape (..., 4). + top: Top of crop box in absolute pixel coordinates. + left: Left of crop box in absolute pixel coordinates. + height: Height of crop box in absolute pixel coordinates. + width: Width of crop box in absolute pixel coordinates. + h_orig: Original image height in absolute pixel coordinates. + w_orig: Original image width in absolute pixel coordinates. + Returns: + Boxes tensor with same shape as input boxes but updated values. + """ + # Video track bound boxes: [num_instances, num_tracks, 4] + # Image bounding boxes: [num_instances, 4] + assert boxes.shape[-1] == 4 + seq_len = tf.shape(boxes)[0] + has_tracks = len(boxes.shape) == 3 + if has_tracks: + num_tracks = boxes.shape[1] + else: + assert len(boxes.shape) == 2 + num_tracks = 1 + + # Transform the box coordinates. + a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32) + b = tf.cast(tf.stack([top, left]), tf.float32) + c = tf.cast(tf.stack([height, width]), tf.float32) + boxes = tf.reshape( + (tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c, + (seq_len, num_tracks, len(NOTRACK_BOX))) + + # Filter the valid boxes. + boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0) + if has_tracks: + cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1) + boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX) + else: + boxes = tf.reshape(boxes, (seq_len, 4)) + + return boxes + + +def flow_tensor_to_rgb_tensor(motion_image, flow_scaling_factor=50.): + """Visualizes flow motion image as an RGB image. + + Similar as the flow_to_rgb function, but with tensors. + + Args: + motion_image: A tensor either of shape [batch_sz, height, width, 2] or of + shape [height, width, 2]. motion_image[..., 0] is flow in x and + motion_image[..., 1] is flow in y. + flow_scaling_factor: How much to scale flow for visualization. + + Returns: + A visualization tensor with same shape as motion_image, except with three + channels. The dtype of the output is tf.uint8. + """ + + hypot = lambda a, b: (a ** 2.0 + b ** 2.0) ** 0.5 # sqrt(a^2 + b^2) + + height, width = motion_image.get_shape().as_list()[-3:-1] # pytype: disable=attribute-error # allow-recursive-types + scaling = flow_scaling_factor / hypot(height, width) + x, y = motion_image[Ellipsis, 0], motion_image[Ellipsis, 1] + motion_angle = tf.atan2(y, x) + motion_angle = (motion_angle / np.math.pi + 1.0) / 2.0 + motion_magnitude = hypot(y, x) + motion_magnitude = tf.clip_by_value(motion_magnitude * scaling, 0.0, 1.0) + value_channel = tf.ones_like(motion_angle) + flow_hsv = tf.stack([motion_angle, motion_magnitude, value_channel], axis=-1) + flow_rgb = tf.image.convert_image_dtype( + tf.image.hsv_to_rgb(flow_hsv), tf.uint8) + return flow_rgb + + +def get_paddings(image_shape, + size, + pre_spatial_dim = None, + allow_crop = True): + """Returns paddings tensors for tf.pad operation. + + Args: + image_shape: The shape of the Tensor to be padded. The shape can be + [..., N, H, W, C] or [..., H, W, C]. The paddings are computed for H, W + and optionally N dimensions. + size: The total size for the H and W dimensions to pad to. + pre_spatial_dim: Optional, additional padding dimension before the spatial + dimensions. It is only used if given and if len(shape) > 3. + allow_crop: If size is bigger than requested max size, padding will be + negative. If allow_crop is true, negative padding values will be set to 0. + + Returns: + Paddings the given tensor shape. + """ + assert image_shape.shape.rank == 1 + if isinstance(size, int): + size = (size, size) + h, w = image_shape[-3], image_shape[-2] + # Spatial padding. + paddings = [ + tf.stack([0, size[0] - h]), + tf.stack([0, size[1] - w]), + tf.stack([0, 0]) + ] + ndims = len(image_shape) # pytype: disable=wrong-arg-types + # Prepend padding for temporal dimension or number of instances. + if pre_spatial_dim is not None and ndims > 3: + paddings = [[0, pre_spatial_dim - image_shape[-4]]] + paddings + # Prepend with non-padded dimensions if available. + if ndims > len(paddings): + paddings = [[0, 0]] * (ndims - len(paddings)) + paddings + if allow_crop: + paddings = tf.maximum(paddings, 0) + return tf.stack(paddings) + + +@dataclasses.dataclass +class VideoFromTfds: + """Standardize features coming from TFDS video datasets.""" + + video_key: str = VIDEO + segmentations_key: str = SEGMENTATIONS + ragged_segmentations_key: str = RAGGED_SEGMENTATIONS + shape_key: str = SHAPE + padding_mask_key: str = PADDING_MASK + ragged_boxes_key: str = RAGGED_BOXES + boxes_key: str = BOXES + frames_key: str = FRAMES + instance_multi_labels_key: str = INSTANCE_MULTI_LABELS + flow_key: str = FLOW + depth_key: str = DEPTH + + def __call__(self, features): + + features_new = {} + + if "rng" in features: + features_new[SEED_KEY] = features.pop("rng") + + if "instances" in features: + features_new[self.ragged_boxes_key] = features["instances"]["bboxes"] + features_new[self.frames_key] = features["instances"]["bbox_frames"] + if "segmentations" in features["instances"]: + features_new[self.ragged_segmentations_key] = tf.cast( + features["instances"]["segmentations"][Ellipsis, 0], tf.int32) + + # Special handling of CLEVR (https://arxiv.org/abs/1612.06890) objects. + if ("color" in features["instances"] and + "shape" in features["instances"] and + "material" in features["instances"]): + color = tf.cast(features["instances"]["color"], tf.int32) + shape = tf.cast(features["instances"]["shape"], tf.int32) + material = tf.cast(features["instances"]["material"], tf.int32) + features_new[self.instance_multi_labels_key] = tf.stack( + (color, shape, material), axis=-1) + + if "segmentations" in features: + features_new[self.segmentations_key] = tf.cast( + features["segmentations"][Ellipsis, 0], tf.int32) + + if "depth" in features: + # Undo float to uint16 scaling + if "metadata" in features and "depth_range" in features["metadata"]: + depth_range = features["metadata"]["depth_range"] + features_new[self.depth_key] = convert_uint16_to_float( + features["depth"], depth_range[0], depth_range[1]) + + if "flows" in features: + # Some datasets use "flows" instead of "flow" for optical flow. + features["flow"] = features["flows"] + if "backward_flow" in features: + # By default, use "backward_flow" if available. + features["flow"] = features["backward_flow"] + features["metadata"]["flow_range"] = features["metadata"][ + "backward_flow_range"] + if "flow" in features: + # Undo float to uint16 scaling + flow_range = features["metadata"].get("flow_range", (-255, 255)) + features_new[self.flow_key] = convert_uint16_to_float( + features["flow"], flow_range[0], flow_range[1]) + + # Convert video to float and normalize. + video = features["video"] + assert video.dtype == tf.uint8 # pytype: disable=attribute-error # allow-recursive-types + video = tf.image.convert_image_dtype(video, tf.float32) + features_new[self.video_key] = video + + # Store original video shape (e.g. for correct evaluation metrics). + features_new[self.shape_key] = tf.shape(video) + + # Store padding mask + features_new[self.padding_mask_key] = tf.cast( + tf.ones_like(video)[Ellipsis, 0], tf.uint8) + + return features_new + + +@dataclasses.dataclass +class AddTemporalAxis: + """Lift images to videos by adding a temporal axis at the beginning. + + We need to distinguish two cases because `image_ops.py` uses + ORIGINAL_SIZE = [H,W] and `video_ops.py` uses SHAPE = [T,H,W,C]: + a) The features are fed from image ops: ORIGINAL_SIZE is converted + to SHAPE ([H,W] -> [1,H,W,C]) and removed from the features. + Typical use case: Evaluation of GV image tasks in a video setting. This op + is added after the image preprocessing in order not to change the standard + image preprocessing. + b) The features are fed from video ops: The image SHAPE is lifted to a video + SHAPE ([H,W,C] -> [1,H,W,C]). + Typical use case: Training using images in a video setting. This op is added + before the video preprocessing in order not to change the standard video + preprocessing. + """ + + image_key: str = IMAGE + video_key: str = VIDEO + boxes_key: str = BOXES + padding_mask_key: str = PADDING_MASK + segmentations_key: str = SEGMENTATIONS + sparse_segmentations_key: str = SPARSE_SEGMENTATIONS + shape_key: str = SHAPE + original_size_key: str = ORIGINAL_SIZE + + def __call__(self, features): + assert self.image_key in features + + features_new = {} + for k, v in features.items(): + if k == self.image_key: + features_new[self.video_key] = v[tf.newaxis] + elif k in (self.padding_mask_key, self.boxes_key, self.segmentations_key, + self.sparse_segmentations_key): + features_new[k] = v[tf.newaxis] + elif k == self.original_size_key: + pass # See comment in the docstring of the class. + else: + features_new[k] = v + + if self.original_size_key in features: + # The features come from an image preprocessing pipeline. + shape = tf.concat([[1], features[self.original_size_key], + [features[self.image_key].shape[-1]]], # pytype: disable=attribute-error # allow-recursive-types + axis=0) + elif self.shape_key in features: + # The features come from a video preprocessing pipeline. + shape = tf.concat([[1], features[self.shape_key]], axis=0) + else: + shape = tf.shape(features_new[self.video_key]) + features_new[self.shape_key] = shape + + if self.padding_mask_key not in features_new: + features_new[self.padding_mask_key] = tf.cast( + tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8) + + return features_new + + +@dataclasses.dataclass +class SparseToDenseAnnotation: + """Converts the sparse to a dense representation.""" + + max_instances: int = 10 + segmentations_key: str = SEGMENTATIONS + + def __call__(self, features): + + features_new = {} + + for k, v in features.items(): + + if k == self.segmentations_key: + # Dense segmentations are available for this dataset. It may be that + # max_instances < max(features_new[self.segmentations_key]). + # We prune out extra objects here. + segmentations = v + segmentations = tf.where( + tf.less_equal(segmentations, self.max_instances), segmentations, 0) + features_new[self.segmentations_key] = segmentations + else: + features_new[k] = v + + return features_new + + +class VideoPreprocessOp(abc.ABC): + """Base class for all video preprocess ops.""" + + video_key: str = VIDEO + segmentations_key: str = SEGMENTATIONS + padding_mask_key: str = PADDING_MASK + boxes_key: str = BOXES + flow_key: str = FLOW + depth_key: str = DEPTH + sparse_segmentations_key: str = SPARSE_SEGMENTATIONS + + def __call__(self, features): + # Get current video shape. + video_shape = tf.shape(features[self.video_key]) + # Assemble all feature keys that the op should be applied on. + all_keys = [ + self.video_key, self.segmentations_key, self.padding_mask_key, + self.flow_key, self.depth_key, self.sparse_segmentations_key, + self.boxes_key + ] + # Apply the op to all features. + for key in all_keys: + if key in features: + features[key] = self.apply(features[key], key, video_shape) + return features + + @abc.abstractmethod + def apply(self, tensor, key, + video_shape): + """Returns the transformed tensor. + + Args: + tensor: Any of a set of different video modalites, e.g video, flow, + bounding boxes, etc. + key: a string that indicates what feature the tensor represents so that + the apply function can take that into account. + video_shape: The shape of the video (which is necessary for some + transformations). + """ + + +class RandomVideoPreprocessOp(VideoPreprocessOp): + """Base class for all random video preprocess ops.""" + + def __call__(self, features): + if features.get(SEED_KEY) is None: + logging.warning( + "Using random operation without seed. To avoid this " + "please provide a seed in feature %s.", SEED_KEY) + op_seed = tf.random.uniform(shape=(2,), maxval=2**32, dtype=tf.int64) + else: + features[SEED_KEY], op_seed = tf.unstack( + tf.random.experimental.stateless_split(features[SEED_KEY])) + # Get current video shape. + video_shape = tf.shape(features[self.video_key]) + # Assemble all feature keys that the op should be applied on. + all_keys = [ + self.video_key, self.segmentations_key, self.padding_mask_key, + self.flow_key, self.depth_key, self.sparse_segmentations_key, + self.boxes_key + ] + # Apply the op to all features. + for key in all_keys: + if key in features: + features[key] = self.apply(features[key], op_seed, key, video_shape) + return features + + @abc.abstractmethod + def apply(self, tensor, seed, key, + video_shape): + """Returns the transformed tensor. + + Args: + tensor: Any of a set of different video modalites, e.g video, flow, + bounding boxes, etc. + seed: A random seed. + key: a string that indicates what feature the tensor represents so that + the apply function can take that into account. + video_shape: The shape of the video (which is necessary for some + transformations). + """ + + +@dataclasses.dataclass +class ResizeSmall(VideoPreprocessOp): + """Resizes the smaller (spatial) side to `size` keeping aspect ratio. + + Attr: + size: An integer representing the new size of the smaller side of the input. + max_size: If set, an integer representing the maximum size in terms of the + largest side of the input. + """ + + size: int + max_size: Optional[int] = None + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + + # Boxes are defined in normalized image coordinates and are not affected. + if key == self.boxes_key: + return tensor + + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + elif key == self.sparse_segmentations_key: + tensor = tf.reshape(tensor, + (-1, tf.shape(tensor)[2], tf.shape(tensor)[3], 1)) + + h, w = tf.shape(tensor)[1], tf.shape(tensor)[2] + + # Determine resize method based on dtype (e.g. segmentations are int). + if tensor.dtype.is_integer: + resize_method = "nearest" + else: + resize_method = "bilinear" + + # Clip size to max_size if needed. + small_size = self.size + if self.max_size is not None: + small_size = adjust_small_size( + original_size=(h, w), small_size=small_size, max_size=self.max_size) + new_h, new_w = get_resize_small_shape( + original_size=(h, w), small_size=small_size) + tensor = tf.image.resize(tensor, [new_h, new_w], method=resize_method) + + # Flow needs to be rescaled according to the new size to stay valid. + if key == self.flow_key: + scale_h = tf.cast(new_h, tf.float32) / tf.cast(h, tf.float32) + scale_w = tf.cast(new_w, tf.float32) / tf.cast(w, tf.float32) + scale = tf.reshape(tf.stack([scale_h, scale_w], axis=0), (1, 2)) + # Optionally repeat scale in case both forward and backward flow are + # stacked in the last dimension. + scale = tf.repeat(scale, tf.shape(tensor)[-1] // 2, axis=0) + scale = tf.reshape(scale, (1, 1, 1, tf.shape(tensor)[-1])) + tensor *= scale + + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + elif key == self.sparse_segmentations_key: + tensor = tf.reshape(tensor, (video_shape[0], -1, new_h, new_w)) + + return tensor + + +@dataclasses.dataclass +class CentralCrop(VideoPreprocessOp): + """Makes central (spatial) crop of a given size. + + Attr: + height: An integer representing the height of the crop. + width: An (optional) integer representing the width of the crop. Make square + crop if width is not provided. + """ + + height: int + width: Optional[int] = None + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + if key == self.boxes_key: + width = self.width or self.height + h_orig, w_orig = video_shape[1], video_shape[2] + top = (h_orig - self.height) // 2 + left = (w_orig - width) // 2 + tensor = crop_or_pad_boxes(tensor, top, left, self.height, + width, h_orig, w_orig) + return tensor + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2] + width = self.width or self.height + crop_size = (seq_len, self.height, width, n_channels) + top = (h_orig - self.height) // 2 + left = (w_orig - width) // 2 + tensor = tf.image.crop_to_bounding_box(tensor, top, left, self.height, + width) + tensor = tf.ensure_shape(tensor, crop_size) + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + return tensor + + +@dataclasses.dataclass +class CropOrPad(VideoPreprocessOp): + """Spatially crops or pads a video to a specified size. + + Attr: + height: An integer representing the new height of the video. + width: An integer representing the new width of the video. + allow_crop: A boolean indicating if cropping is allowed. + """ + + height: int + width: int + allow_crop: bool = True + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + if key == self.boxes_key: + # Pad and crop the spatial dimensions. + h_orig, w_orig = video_shape[1], video_shape[2] + if self.allow_crop: + # After cropping, the frame shape is always [self.height, self.width]. + height, width = self.height, self.width + else: + # If only padding is performed, the frame size is at least + # [self.height, self.width]. + height = tf.maximum(h_orig, self.height) + width = tf.maximum(w_orig, self.width) + tensor = crop_or_pad_boxes( + tensor, + top=0, + left=0, + height=height, + width=width, + h_orig=h_orig, + w_orig=w_orig) + return tensor + elif key == self.sparse_segmentations_key: + seq_len = tensor.get_shape()[0] + paddings = get_paddings( + tf.shape(tensor[Ellipsis, tf.newaxis]), (self.height, self.width), + allow_crop=self.allow_crop)[:-1] + tensor = tf.pad(tensor, paddings, constant_values=0) + if self.allow_crop: + tensor = tensor[Ellipsis, :self.height, :self.width] + tensor = tf.ensure_shape( + tensor, (seq_len, None, self.height, self.width)) + return tensor + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + paddings = get_paddings( + tf.shape(tensor), (self.height, self.width), + allow_crop=self.allow_crop) + tensor = tf.pad(tensor, paddings, constant_values=0) + if self.allow_crop: + tensor = tensor[:, :self.height, :self.width, :] + tensor = tf.ensure_shape(tensor, + (seq_len, self.height, self.width, n_channels)) + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + return tensor + + +@dataclasses.dataclass +class RandomCrop(RandomVideoPreprocessOp): + """Gets a random (width, height) crop of input video. + + Assumption: Height and width are the same for all video-like modalities. + + Attr: + height: An integer representing the height of the crop. + width: An integer representing the width of the crop. + """ + + height: int + width: int + + def apply(self, tensor, seed, key=None, video_shape=None): + """See base class.""" + if key == self.boxes_key: + # We copy the random generation part from tf.image.stateless_random_crop + # to generate exactly the same offset as for the video. + crop_size = (video_shape[0], self.height, self.width, video_shape[-1]) + size = tf.convert_to_tensor(crop_size, tf.int32) + limit = video_shape - size + 1 + offset = tf.random.stateless_uniform( + tf.shape(video_shape), dtype=tf.int32, maxval=tf.int32.max, + seed=seed) % limit + tensor = crop_or_pad_boxes(tensor, offset[1], offset[2], self.height, + self.width, video_shape[1], video_shape[2]) + return tensor + elif key == self.sparse_segmentations_key: + raise NotImplementedError("Sparse segmentations aren't supported yet") + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + crop_size = (seq_len, self.height, self.width, n_channels) + tensor = tf.image.stateless_random_crop(tensor, size=crop_size, seed=seed) + tensor = tf.ensure_shape(tensor, crop_size) + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + return tensor + + +@dataclasses.dataclass +class DropFrames(VideoPreprocessOp): + """Subsamples a video by skipping frames. + + Attr: + frame_skip: An integer representing the subsampling frequency of the video, + where 1 means no frames are skipped, 2 means every other frame is skipped, + and so forth. + """ + + frame_skip: int + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + del key + del video_shape + tensor = tensor[::self.frame_skip] + new_length = tensor.get_shape()[0] + tensor = tf.ensure_shape(tensor, [new_length] + tensor.get_shape()[1:]) + return tensor + + +@dataclasses.dataclass +class TemporalCropOrPad(VideoPreprocessOp): + """Crops or pads a video in time to a specified length. + + Attr: + length: An integer representing the new length of the video. + allow_crop: A boolean, specifying whether temporal cropping is allowed. If + False, will throw an error if length of the video is more than "length" + """ + + length: int + allow_crop: bool = True + + def _apply(self, tensor, constant_values): + frames_to_pad = self.length - tf.shape(tensor)[0] + if self.allow_crop: + frames_to_pad = tf.maximum(frames_to_pad, 0) + tensor = tf.pad( + tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1), + constant_values=constant_values) + tensor = tensor[:self.length] + tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:]) + return tensor + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + del video_shape + if key == self.boxes_key: + constant_values = NOTRACK_BOX[0] + else: + constant_values = 0 + return self._apply(tensor, constant_values=constant_values) + + +@dataclasses.dataclass +class TemporalRandomWindow(RandomVideoPreprocessOp): + """Gets a random slice (window) along 0-th axis of input tensor. + + Pads the video if the video length is shorter than the provided length. + + Assumption: The number of frames is the same for all video-like modalities. + + Attr: + length: An integer representing the new length of the video. + """ + + length: int + + def _apply(self, tensor, seed, constant_values): + length = tf.minimum(self.length, tf.shape(tensor)[0]) + frames_to_pad = tf.maximum(self.length - tf.shape(tensor)[0], 0) + window_size = tf.concat(([length], tf.shape(tensor)[1:]), axis=0) + tensor = tf.image.stateless_random_crop(tensor, size=window_size, seed=seed) + tensor = tf.pad( + tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1), + constant_values=constant_values) + tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:]) + return tensor + + def apply(self, tensor, seed, key=None, video_shape=None): + """See base class.""" + del video_shape + if key == self.boxes_key: + constant_values = NOTRACK_BOX[0] + else: + constant_values = 0 + return self._apply(tensor, seed, constant_values=constant_values) + + +@dataclasses.dataclass +class TemporalRandomStridedWindow(RandomVideoPreprocessOp): + """Gets a random strided slice (window) along 0-th axis of input tensor. + + This op is like TemporalRandomWindow but it samples from one of a set of + strides of the video, whereas TemporalRandomWindow will densely sample from + all possible slices of `length` frames from the video. + + For the following video and `length=3`: [1, 2, 3, 4, 5, 6, 7, 8, 9] + + This op will return one of [1, 2, 3], [4, 5, 6], or [7, 8, 9] + + This pads the video if the video length is shorter than the provided length. + + Assumption: The number of frames is the same for all video-like modalities. + + Attr: + length: An integer representing the new length of the video and the sampling + stride width. + """ + + length: int + + def _apply(self, tensor, seed, + constant_values): + """Applies the strided crop operation to the video tensor.""" + num_frames = tf.shape(tensor)[0] + num_crop_points = tf.cast(tf.math.ceil(num_frames / self.length), tf.int32) + crop_point = tf.random.stateless_uniform( + shape=(), minval=0, maxval=num_crop_points, dtype=tf.int32, seed=seed) + crop_point *= self.length + frames_sample = tensor[crop_point:crop_point + self.length] + frames_to_pad = tf.maximum(self.length - tf.shape(frames_sample)[0], 0) + frames_sample = tf.pad( + frames_sample, + ((0, frames_to_pad),) + ((0, 0),) * (len(frames_sample.shape) - 1), + constant_values=constant_values) + frames_sample = tf.ensure_shape(frames_sample, [self.length] + + frames_sample.get_shape()[1:]) + return frames_sample + + def apply(self, tensor, seed, key=None, video_shape=None): + """See base class.""" + del video_shape + if key == self.boxes_key: + constant_values = NOTRACK_BOX[0] + else: + constant_values = 0 + return self._apply(tensor, seed, constant_values=constant_values) + + +@dataclasses.dataclass +class FlowToRgb: + """Converts flow to an RGB image. + + NOTE: This operation requires a statically known shape for the input flow, + i.e. it is best to place it as final operation into the preprocessing + pipeline after all shapes are statically known (e.g. after cropping / + padding). + """ + flow_key: str = FLOW + + def __call__(self, features): + if self.flow_key in features: + flow_rgb = flow_tensor_to_rgb_tensor(features[self.flow_key]) + assert flow_rgb.dtype == tf.uint8 + features[self.flow_key] = tf.image.convert_image_dtype( + flow_rgb, tf.float32) + return features + + +@dataclasses.dataclass +class TransformDepth: + """Applies one of several possible transformations to depth features.""" + transform: str + depth_key: str = DEPTH + + def __call__(self, features): + if self.depth_key in features: + if self.transform == "log": + depth_norm = tf.math.log(features[self.depth_key]) + elif self.transform == "log_plus": + depth_norm = tf.math.log(1. + features[self.depth_key]) + elif self.transform == "invert_plus": + depth_norm = 1. / (1. + features[self.depth_key]) + else: + raise ValueError(f"Unknown depth transformation {self.transform}") + + features[self.depth_key] = depth_norm + return features + + +@dataclasses.dataclass +class RandomResizedCrop(RandomVideoPreprocessOp): + """Random-resized crop for each of the two views. + + Assumption: Height and width are the same for all video-like modalities. + + We randomly crop the input and record the transformation this crop corresponds + to as a new feature. Croped images are resized to (height, width). Boxes are + corrected adjusted and boxes outside the crop are discarded. Flow is rescaled + so as to be pixel accurate after the operation. lidar_points_2d are + transformed using the computed transformation. These points may lie outside + the image after the operation. + + Attr: + height: An integer representing the height to resize to. + width: An integer representing the width to resize to. + min_object_covered, aspect_ratio_range, area_range, max_attempts: See + docstring of `stateless_sample_distorted_bounding_box`. Aspect ratio range + has not been scaled by target aspect ratio. This differs from other + implementations of this data augmentation. + relative_box_area_threshold: If ratio of areas before and after cropping are + lower than this threshold, then the box is discarded (set to NOTRACK_BOX). + """ + # Target size. + height: int + width: int + + # Crop sampling attributes. + min_object_covered: float = 0.1 + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.) + area_range: Tuple[float, float] = (0.08, 1.0) + max_attempts: int = 100 + + # Box retention attributes + relative_box_area_threshold: float = 0.0 + + def apply(self, tensor, seed, key, + video_shape): + """Applies the crop operation on tensor.""" + param = self.sample_augmentation_params(video_shape, seed) + si, sj = param[0], param[1] + crop_h, crop_w = param[2], param[3] + + to_float32 = lambda x: tf.cast(x, tf.float32) + + if key == self.boxes_key: + # First crop the boxes. + cropped_boxes = crop_or_pad_boxes( + tensor, si, sj, + crop_h, crop_w, + video_shape[1], video_shape[2]) + # We do not need to scale the boxes because they are in normalized coords. + resized_boxes = cropped_boxes + # Lastly detects NOTRACK_BOX boxes and avoid manipulating those. + no_track_boxes = tf.convert_to_tensor(NOTRACK_BOX) + no_track_boxes = tf.reshape(no_track_boxes, [1, 4]) + resized_boxes = tf.where( + tf.reduce_all(tensor == no_track_boxes, axis=-1, keepdims=True), + tensor, resized_boxes) + + if self.relative_box_area_threshold > 0: + # Thresholds boxes that have been cropped too much, as in their area is + # lower, in relative terms, than `relative_box_area_threshold`. + area_before_crop = tf.reduce_prod(tensor[Ellipsis, 2:] - tensor[Ellipsis, :2], + axis=-1) + # Sets minimum area_before_crop to 1e-8 we avoid divisions by 0. + area_before_crop = tf.maximum(area_before_crop, + tf.zeros_like(area_before_crop) + 1e-8) + area_after_crop = tf.reduce_prod( + resized_boxes[Ellipsis, 2:] - resized_boxes[Ellipsis, :2], axis=-1) + # As the boxes have normalized coordinates, they need to be rescaled to + # be compared against the original uncropped boxes. + scale_x = to_float32(crop_w) / to_float32(self.width) + scale_y = to_float32(crop_h) / to_float32(self.height) + area_after_crop *= scale_x * scale_y + + ratio = area_after_crop / area_before_crop + return tf.where( + tf.expand_dims(ratio > self.relative_box_area_threshold, -1), + resized_boxes, no_track_boxes) + + else: + return resized_boxes + + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + + # Crop. + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + crop_size = (seq_len, crop_h, crop_w, n_channels) + tensor = tf.slice(tensor, tf.stack([0, si, sj, 0]), crop_size) + + # Resize. + resize_method = tf.image.ResizeMethod.BILINEAR + if (tensor.dtype == tf.int32 or tensor.dtype == tf.int64 or + tensor.dtype == tf.uint8): + resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR + tensor = tf.image.resize(tensor, [self.height, self.width], + method=resize_method) + out_size = (seq_len, self.height, self.width, n_channels) + tensor = tf.ensure_shape(tensor, out_size) + + if key == self.flow_key: + # Rescales optical flow. + scale_x = to_float32(self.width) / to_float32(crop_w) + scale_y = to_float32(self.height) / to_float32(crop_h) + tensor = tf.stack( + [tensor[Ellipsis, 0] * scale_y, tensor[Ellipsis, 1] * scale_x], axis=-1) + + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + return tensor + + def sample_augmentation_params(self, video_shape, rng): + """Sample a random bounding box for the crop.""" + sample_bbox = tf.image.stateless_sample_distorted_bounding_box( + video_shape[1:], + bounding_boxes=tf.constant([0.0, 0.0, 1.0, 1.0], + dtype=tf.float32, shape=[1, 1, 4]), + seed=rng, + min_object_covered=self.min_object_covered, + aspect_ratio_range=self.aspect_ratio_range, + area_range=self.area_range, + max_attempts=self.max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_bbox + + # The specified bounding box provides crop coordinates. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + + return tf.stack([offset_y, offset_x, target_height, target_width]) + + def estimate_transformation(self, param, video_shape + ): + """Computes the affine transformation for crop params. + + Args: + param: Crop parameters in the [y, x, h, w] format of shape [4,]. + video_shape: Unused. + + Returns: + Affine transformation of shape [3, 3] corresponding to cropping the image + at [y, x] of size [h, w] and resizing it into [self.height, self.width]. + """ + del video_shape + crop = tf.cast(param, tf.float32) + si, sj = crop[0], crop[1] + crop_h, crop_w = crop[2], crop[3] + ei, ej = si + crop_h - 1.0, sj + crop_w - 1.0 + h, w = float(self.height), float(self.width) + + a1 = (ei - si + 1.)/h + a2 = 0. + a3 = si - 0.5 + a1 / 2. + a4 = 0. + a5 = (ej - sj + 1.)/w + a6 = sj - 0.5 + a5 / 2. + affine = tf.stack([a1, a2, a3, a4, a5, a6, 0., 0., 1.]) + return tf.reshape(affine, [3, 3]) + + +@dataclasses.dataclass +class TfdsImageToTfdsVideo: + """Lift TFDS image format to TFDS video format by adding a temporal axis. + + This op is intended to be called directly before VideoFromTfds. + """ + + TFDS_SEGMENTATIONS_KEY = "segmentations" + TFDS_INSTANCES_KEY = "instances" + TFDS_BOXES_KEY = "bboxes" + TFDS_BOXES_FRAMES_KEY = "bbox_frames" + + image_key: str = IMAGE + video_key: str = VIDEO + boxes_image_key: str = BOXES + boxes_key: str = BOXES_VIDEO + image_padding_mask_key: str = IMAGE_PADDING_MASK + video_padding_mask_key: str = VIDEO_PADDING_MASK + depth_key: str = DEPTH + depth_mask_key: str = "depth_mask" + force_overwrite: bool = False + + def __call__(self, features): + if self.video_key in features and not self.force_overwrite: + return features + + features_new = {} + for k, v in features.items(): + if k == self.image_key: + features_new[self.video_key] = v[tf.newaxis] + elif k == self.image_padding_mask_key: + features_new[self.video_padding_mask_key] = v[tf.newaxis] + elif k == self.boxes_image_key: + features_new[self.boxes_key] = v[tf.newaxis] + elif k == self.TFDS_SEGMENTATIONS_KEY: + features_new[self.TFDS_SEGMENTATIONS_KEY] = v[tf.newaxis] + elif k == self.TFDS_INSTANCES_KEY and self.TFDS_BOXES_KEY in v: + # Add sequence dimension to boxes and create boxes frames for indexing. + features_new[k] = v + + # Create dummy ragged tensor (1, None) and broadcast + dummy = tf.ragged.constant([[0]], dtype=tf.int32) + boxes_frames_value = tf.zeros_like( + v[self.TFDS_BOXES_KEY][Ellipsis, 0], dtype=tf.int32)[Ellipsis, tf.newaxis] + features_new[k][self.TFDS_BOXES_FRAMES_KEY] = boxes_frames_value + dummy + # Create dummy ragged tensor (1, None, 1) and broadcast + dummy = tf.ragged.constant([[0]], dtype=tf.float32)[Ellipsis, tf.newaxis] + boxes_value = v[self.TFDS_BOXES_KEY][Ellipsis, tf.newaxis, :] + features_new[k][self.TFDS_BOXES_KEY] = boxes_value + dummy + elif k == self.depth_key: + features_new[self.depth_key] = v[tf.newaxis] + elif k == self.depth_mask_key: + features_new[self.depth_mask_key] = v[tf.newaxis] + else: + features_new[k] = v + + if self.video_padding_mask_key not in features_new: + logging.warning("Adding default video_padding_mask") + features_new[self.video_padding_mask_key] = tf.cast( + tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8) + + return features_new + + +@dataclasses.dataclass +class TopLeftCrop(VideoPreprocessOp): + """Makes an arbitrary crop in all video frames. + + Attr: + top: An integer representing the horizontal coordinate of the crop start. + left: An integer representing the vertical coordinate of the crop start. + height: An integer representing the height of the crop. + width: An (optional) integer representing the width of the crop. Make square + crop if width is not provided. + """ + + top: int + left: int + height: int + width: Optional[int] = None + + def apply(self, tensor, key=None, video_shape=None): + """See base class.""" + if key in (self.boxes_key,): + width = self.width or self.height + h_orig, w_orig = video_shape[1], video_shape[2] + tensor = transforms.crop_or_pad_boxes( + tensor, self.top, self.left, self.height, width, h_orig, w_orig) + return tensor + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, tf.newaxis] + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2] + width = self.width or self.height + crop_size = (seq_len, self.height, width, n_channels) + tensor = tf.image.crop_to_bounding_box( + tensor, self.top, self.left, self.height, width) + tensor = tf.ensure_shape(tensor, crop_size) + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[Ellipsis, 0] + return tensor + + +@dataclasses.dataclass +class DeleteSmallMasks: + """Delete masks smaller than a selected fraction of pixels.""" + threshold: float = 0.05 + max_instances: int = 50 + max_instances_after: int = 11 + + def __call__(self, features): + + features_new = {} + + for key in features.keys(): + + if key == SEGMENTATIONS: + seg = features[key] + size = tf.shape(seg) + + assert_op = tf.Assert( + tf.equal(size[0], 1), ["Implemented only for a single frame."]) + + with tf.control_dependencies([assert_op]): + # Delete time dimension. + seg = seg[0] + + # Get the minimum number of pixels a masks needs to have. + max_pixels = size[1] * size[2] + threshold_pixels = tf.cast( + tf.cast(max_pixels, tf.float32) * self.threshold, tf.int32) + + # Decompose the segmentation map as a single image for each instance. + dec_seg = tf.stack( + tf.map_fn(functools.partial(self._decompose, seg=seg), + tf.range(self.max_instances)), axis=0) + + # Count the pixels and find segmentation masks that are big enough. + sums = tf.reduce_sum(dec_seg, axis=(1, 2)) + # We want the background to always be slot zero. + # We can accomplish that be pretending it has the maximum + # number of pixels. + sums = tf.concat( + [tf.ones_like(sums[0: 1]) * max_pixels, sums[1:]], + axis=0) + + sort = tf.argsort(sums, axis=0, direction="DESCENDING") + sums_s = tf.gather(sums, sort, axis=0) + mask_s = tf.cast(tf.greater_equal(sums_s, threshold_pixels), tf.int32) + + dec_seg_plus = tf.stack( + tf.map_fn(functools.partial( + self._compose_sort, seg=seg, sort=sort, mask_s=mask_s), + tf.range(self.max_instances_after)), axis=0) + new_seg = tf.reduce_sum(dec_seg_plus, axis=0) + + features_new[key] = tf.cast(new_seg[None], tf.int32) + + else: + # keep all other features + features_new[key] = features[key] + + return features_new + + @classmethod + def _decompose(cls, i, seg): + return tf.cast(tf.equal(seg, i), tf.int32) + + @classmethod + def _compose_sort(cls, i, seg, sort, mask_s): + return tf.cast(tf.equal(seg, sort[i]), tf.int32) * i * mask_s[i] + + +@dataclasses.dataclass +class SundsToTfdsVideo: + """Lift Sunds format to TFDS video format. + + Renames fields and adds a temporal axis. + This op is intended to be called directly before VideoFromTfds. + """ + + SUNDS_IMAGE_KEY = "color_image" + SUNDS_SEGMENTATIONS_KEY = "instance_image" + SUNDS_DEPTH_KEY = "depth_image" + + image_key: str = SUNDS_IMAGE_KEY + image_segmentations_key = SUNDS_SEGMENTATIONS_KEY + video_key: str = VIDEO + video_segmentations_key = SEGMENTATIONS + image_depths_key: str = SUNDS_DEPTH_KEY + depths_key = DEPTH + video_padding_mask_key: str = VIDEO_PADDING_MASK + force_overwrite: bool = False + + def __call__(self, features): + if self.video_key in features and not self.force_overwrite: + return features + + features_new = {} + for k, v in features.items(): + if k == self.image_key: + features_new[self.video_key] = v[tf.newaxis] + elif k == self.image_segmentations_key: + features_new[self.video_segmentations_key] = v[tf.newaxis] + elif k == self.image_depths_key: + features_new[self.depths_key] = v[tf.newaxis] + else: + features_new[k] = v + + if self.video_padding_mask_key not in features_new: + logging.warning("Adding default video_padding_mask") + features_new[self.video_padding_mask_key] = tf.cast( + tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8) + + return features_new + + +@dataclasses.dataclass +class SubtractOneFromSegmentations: + """Subtract one from segmentation masks. Used for MultiShapeNet-Easy.""" + + segmentations_key: str = SEGMENTATIONS + + def __call__(self, features): + features[self.segmentations_key] = features[self.segmentations_key] - 1 + return features diff --git a/invariant_slot_attention/lib/trainer.py b/invariant_slot_attention/lib/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e70a33c9bf193ba8bd46433695e02dd7f07c7473 --- /dev/null +++ b/invariant_slot_attention/lib/trainer.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The main model training loop.""" + +import functools +import os +import time +from typing import Dict, Iterable, Mapping, Optional, Tuple, Type, Union + +from absl import logging +from clu import checkpoint +from clu import metric_writers +from clu import metrics +from clu import parameter_overview +from clu import periodic_actions +import flax +from flax import linen as nn + +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax + +from scenic.train_lib import lr_schedules +from scenic.train_lib import optimizers + +import tensorflow as tf + +from invariant_slot_attention.lib import evaluator +from invariant_slot_attention.lib import input_pipeline +from invariant_slot_attention.lib import losses +from invariant_slot_attention.lib import utils + +Array = jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +PRNGKey = Array + + +def train_step( + model, + tx, + rng, + step, + state_vars, + opt_state, + params, + batch, + loss_fn, + train_metrics_cls, + predicted_max_num_instances, + ground_truth_max_num_instances, + conditioning_key = None, + ): + """Perform a single training step. + + Args: + model: Model used in training step. + tx: The optimizer to use to minimize loss_fn. + rng: Random number key + step: Which training step we are on. + state_vars: Accessory variables. + opt_state: The state of the optimizer. + params: The current parameters to be updated. + batch: Training inputs for this step. + loss_fn: Loss function that takes model predictions and a batch of data. + train_metrics_cls: The metrics collection for computing training metrics. + predicted_max_num_instances: Maximum number of instances in prediction. + ground_truth_max_num_instances: Maximum number of instances in ground truth, + including background (which counts as a separate instance). + conditioning_key: Optional string. If provided, defines the batch key to be + used as conditioning signal for the model. Otherwise this is inferred from + the available keys in the batch. + + Returns: + Tuple of the updated opt, state_vars, new random number key, + metrics update, and step + 1. Note that some of this info is stored in + TrainState, but here it is unpacked. + """ + + # Split PRNGKey and bind to host / device. + new_rng, rng = jax.random.split(rng) + rng = jax.random.fold_in(rng, jax.host_id()) + rng = jax.random.fold_in(rng, jax.lax.axis_index("batch")) + init_rng, dropout_rng = jax.random.split(rng, 2) + + mutable_var_keys = list(state_vars.keys()) + ["intermediates"] + + conditioning = batch[conditioning_key] if conditioning_key else None + + def train_loss_fn(params, state_vars): + preds, mutable_vars = model.apply( + {"params": params, **state_vars}, video=batch["video"], + conditioning=conditioning, mutable=mutable_var_keys, + rngs={"state_init": init_rng, "dropout": dropout_rng}, train=True, + padding_mask=batch.get("padding_mask")) + # Filter intermediates, as we do not want to store them in the TrainState. + state_vars = utils.filter_key_from_frozen_dict( + mutable_vars, key="intermediates") + loss, loss_aux = loss_fn(preds, batch) + return loss, (state_vars, preds, loss_aux) + + grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) + (loss, (state_vars, preds, loss_aux)), grad = grad_fn(params, state_vars) + + # Compute average gradient across multiple workers. + grad = jax.lax.pmean(grad, axis_name="batch") + + updates, new_opt_state = tx.update(grad, opt_state, params) + new_params = optax.apply_updates(params, updates) + + # Compute metrics. + metrics_update = train_metrics_cls.gather_from_model_output( + loss=loss, + **loss_aux, + predicted_segmentations=utils.remove_singleton_dim( + preds["outputs"].get("segmentations")), # pytype: disable=attribute-error + ground_truth_segmentations=batch.get("segmentations"), + predicted_max_num_instances=predicted_max_num_instances, + ground_truth_max_num_instances=ground_truth_max_num_instances, + padding_mask=batch.get("padding_mask"), + mask=batch.get("mask")) + return ( + new_opt_state, new_params, state_vars, new_rng, metrics_update, step + 1) + + +def train_and_evaluate(config, + workdir): + """Runs a training and evaluation loop. + + Args: + config: Configuration to use. + workdir: Working directory for checkpoints and TF summaries. If this + contains checkpoint training will be resumed from the latest checkpoint. + """ + rng = jax.random.PRNGKey(config.seed) + + tf.io.gfile.makedirs(workdir) + + # Input pipeline. + rng, data_rng = jax.random.split(rng) + # Make sure each host uses a different RNG for the training data. + if config.get("seed_data", True): # Default to seeding data if not specified. + data_rng = jax.random.fold_in(data_rng, jax.host_id()) + else: + data_rng = None + train_ds, eval_ds = input_pipeline.create_datasets(config, data_rng) + train_iter = iter(train_ds) # pytype: disable=wrong-arg-types + + # Initialize model + model = utils.build_model_from_config(config.model) + + # Construct TrainMetrics and EvalMetrics, metrics collections. + train_metrics_cls = utils.make_metrics_collection("TrainMetrics", + config.train_metrics_spec) + eval_metrics_cls = utils.make_metrics_collection("EvalMetrics", + config.eval_metrics_spec) + + def init_model(rng): + rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4) + + init_conditioning = None + if config.get("conditioning_key"): + init_conditioning = jnp.ones( + [1] + list(train_ds.element_spec[config.conditioning_key].shape)[2:], + jnp.int32) + init_inputs = jnp.ones( + [1] + list(train_ds.element_spec["video"].shape)[2:], + jnp.float32) + initial_vars = model.init( + {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng}, + video=init_inputs, conditioning=init_conditioning, + padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32)) + + # Split into state variables (e.g. for batchnorm stats) and model params. + # Note that `pop()` on a FrozenDict performs a deep copy. + state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error + + # Filter out intermediates (we don't want to store these in the TrainState). + state_vars = utils.filter_key_from_frozen_dict( + state_vars, key="intermediates") + return state_vars, initial_params + + state_vars, initial_params = init_model(rng) + parameter_overview.log_parameter_overview(initial_params) # pytype: disable=wrong-arg-types + + learning_rate_fn = lr_schedules.get_learning_rate_fn(config) + tx = optimizers.get_optimizer( + config.optimizer_configs, learning_rate_fn, params=initial_params) + + opt_state = tx.init(initial_params) + + state = utils.TrainState( + step=1, opt_state=opt_state, params=initial_params, rng=rng, + variables=state_vars) + + loss_fn = functools.partial( + losses.compute_full_loss, loss_config=config.losses) + + checkpoint_dir = os.path.join(workdir, "checkpoints") + ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) + state = ckpt.restore_or_initialize(state) + initial_step = int(state.step) + + # Replicate our parameters. + state = flax.jax_utils.replicate(state, devices=jax.local_devices()) + del rng # rng is stored in the state. + + # Only write metrics on host 0, write to logs on all other hosts. + writer = metric_writers.create_default_writer( + workdir, just_logging=jax.host_id() > 0) + writer.write_hparams(utils.prepare_dict_for_logging(config.to_dict())) + + logging.info("Starting training loop at step %d.", initial_step) + report_progress = periodic_actions.ReportProgress( + num_train_steps=config.num_train_steps, writer=writer) + if jax.process_index() == 0: + profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) + p_train_step = jax.pmap( + train_step, + axis_name="batch", + donate_argnums=(2, 3, 4, 5, 6, 7), + static_broadcasted_argnums=(0, 1, 8, 9, 10, 11, 12)) + + train_metrics = None + with metric_writers.ensure_flushes(writer): + if config.num_train_steps == 0: + with report_progress.timed("eval"): + evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls, config, + writer, step=0) + with report_progress.timed("checkpoint"): + ckpt.save(flax.jax_utils.unreplicate(state)) + return + + for step in range(initial_step, config.num_train_steps + 1): + # `step` is a Python integer. `state.step` is JAX integer on GPU/TPU. + is_last_step = step == config.num_train_steps + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + batch = jax.tree_map(np.asarray, next(train_iter)) + (opt_state, params, state_vars, rng, metrics_update, p_step + ) = p_train_step( + model, tx, state.rng, state.step, state.variables, + state.opt_state, state.params, batch, loss_fn, + train_metrics_cls, + config.num_slots, + config.max_instances + 1, # Incl. background. + config.get("conditioning_key")) + + state = state.replace( # pytype: disable=attribute-error + opt_state=opt_state, + params=params, + step=p_step, + variables=state_vars, + rng=rng, + ) + + metric_update = flax.jax_utils.unreplicate(metrics_update) + train_metrics = ( + metric_update + if train_metrics is None else train_metrics.merge(metric_update)) + + # Quick indication that training is happening. + logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) + report_progress(step, time.time()) + + if jax.process_index() == 0: + profiler(step) + + if step % config.log_loss_every_steps == 0 or is_last_step: + metrics_res = train_metrics.compute() + writer.write_scalars(step, jax.tree_map(np.array, metrics_res)) + train_metrics = None + + if step % config.eval_every_steps == 0 or is_last_step: + with report_progress.timed("eval"): + evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls, + config, writer, step=step) + + if step % config.checkpoint_every_steps == 0 or is_last_step: + with report_progress.timed("checkpoint"): + ckpt.save(flax.jax_utils.unreplicate(state)) + + +def evaluate(model, state, eval_ds, loss_fn_eval, eval_metrics_cls, config, + writer, step): + """Evaluate the model.""" + eval_metrics, eval_batch, eval_preds = evaluator.evaluate( + model, + state, + eval_ds, + loss_fn_eval, + eval_metrics_cls, + predicted_max_num_instances=config.num_slots, + ground_truth_max_num_instances=config.max_instances + 1, # Incl. bg. + slice_size=config.get("eval_slice_size"), + slice_keys=config.get("eval_slice_keys"), + conditioning_key=config.get("conditioning_key"), + remove_from_predictions=config.get("remove_from_predictions"), + metrics_on_cpu=config.get("metrics_on_cpu", False)) + + metrics_res = eval_metrics.compute() + writer.write_scalars( + step, jax.tree_map(np.array, utils.flatten_named_dicttree(metrics_res))) + writer.write_images( + step, + jax.tree_map( + np.array, + utils.prepare_images_for_logging( + config, + eval_batch, + eval_preds, + n_samples=config.get("n_samples", 5), + n_frames=config.get("n_frames", 1), + min_n_colors=config.get("logging_min_n_colors", 1)))) diff --git a/invariant_slot_attention/lib/transforms.py b/invariant_slot_attention/lib/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..db5f2f5b750f8a9a78df26396fca443959b7a781 --- /dev/null +++ b/invariant_slot_attention/lib/transforms.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transform functions for preprocessing.""" +from typing import Any, Optional, Tuple + +import tensorflow as tf + + +SizeTuple = Tuple[tf.Tensor, tf.Tensor] # (height, width). +Self = Any + +PADDING_VALUE = -1 +PADDING_VALUE_STR = b"" + +NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding. +NOTRACK_RBOX = (0., 0., 0., 0., 0.) # No-track bounding rbox for padding. + + +def crop_or_pad_boxes(boxes, top, left, height, + width, h_orig, w_orig, + min_cropped_area = None): + """Transforms the relative box coordinates according to the frame crop. + + Note that, if height/width are larger than h_orig/w_orig, this function + implements the equivalent of padding. + + Args: + boxes: Tensor of bounding boxes with shape (..., 4). + top: Top of crop box in absolute pixel coordinates. + left: Left of crop box in absolute pixel coordinates. + height: Height of crop box in absolute pixel coordinates. + width: Width of crop box in absolute pixel coordinates. + h_orig: Original image height in absolute pixel coordinates. + w_orig: Original image width in absolute pixel coordinates. + min_cropped_area: If set, remove cropped boxes whose area relative to the + original box is less than min_cropped_area or that covers the entire + image. + + Returns: + Boxes tensor with same shape as input boxes but updated values. + """ + # Video track bound boxes: [num_instances, num_tracks, 4] + # Image bounding boxes: [num_instances, 4] + assert boxes.shape[-1] == 4 + seq_len = tf.shape(boxes)[0] + not_padding = tf.reduce_any(tf.not_equal(boxes, PADDING_VALUE), axis=-1) + has_tracks = len(boxes.shape) == 3 + if has_tracks: + num_tracks = tf.shape(boxes)[1] + else: + assert len(boxes.shape) == 2 + num_tracks = 1 + + # Transform the box coordinates. + a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32) + b = tf.cast(tf.stack([top, left]), tf.float32) + c = tf.cast(tf.stack([height, width]), tf.float32) + boxes = tf.reshape( + (tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c, + (seq_len, num_tracks, len(NOTRACK_BOX)), + ) + + # Filter the valid boxes. + areas_uncropped = tf.reduce_prod( + tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1 + ) + boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0) + if has_tracks: + cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1) + boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX) + if min_cropped_area is not None: + areas_cropped = tf.reduce_prod( + tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1 + ) + boxes = tf.where( + tf.logical_and( + tf.reduce_max(areas_cropped, axis=0, keepdims=True) + > min_cropped_area * areas_uncropped, + tf.reduce_min(areas_cropped, axis=0, keepdims=True) < 1, + )[Ellipsis, tf.newaxis], + boxes, + tf.constant(NOTRACK_BOX)[tf.newaxis, tf.newaxis], + ) + else: + boxes = tf.reshape(boxes, (seq_len, 4)) + # Image ops use `-1``, whereas video ops above use `NOTRACK_BOX`. + boxes = tf.where(not_padding[Ellipsis, tf.newaxis], boxes, PADDING_VALUE) + + return boxes + + +def cxcywha_to_corners(cxcywha): + """Convert [cx, cy, w, h, a] to four corners of [x, y]. + + TF version of cxcywha_to_corners in + third_party/py/scenic/model_lib/base_models/box_utils.py. + + Args: + cxcywha: [..., 5]-tf.Tensor of [center-x, center-y, width, height, angle] + representation of rotated boxes. Angle is in radians and center of rotation + is defined by [center-x, center-y] point. + + Returns: + [..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y] points. + """ + assert cxcywha.shape[-1] == 5, "Expected [..., [cx, cy, w, h, a] input." + bs = cxcywha.shape[:-1] + cx, cy, w, h, a = tf.split(cxcywha, num_or_size_splits=5, axis=-1) + xs = tf.constant([.5, .5, -.5, -.5]) * w + ys = tf.constant([-.5, .5, .5, -.5]) * h + pts = tf.stack([xs, ys], axis=-1) + sin = tf.sin(a) + cos = tf.cos(a) + rot = tf.reshape(tf.concat([cos, -sin, sin, cos], axis=-1), (*bs, 2, 2)) + offset = tf.reshape(tf.concat([cx, cy], -1), (*bs, 1, 2)) + corners = pts @ rot + offset + return corners + + +def corners_to_cxcywha(corners): + """Convert four corners of [x, y] to [cx, cy, w, h, a]. + + Args: + corners: [..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y] + points. + + Returns: + [..., 5]-tf.Tensor of [center-x, center-y, width, height, angle] + representation of rotated boxes. Angle is in radians and center of rotation + is defined by [center-x, center-y] point. + """ + assert corners.shape[-2] == 4 and corners.shape[-1] == 2, ( + "Expected [..., [cx, cy, w, h, a] input.") + + cornersx, cornersy = tf.unstack(corners, axis=-1) + cx = tf.reduce_mean(cornersx, axis=-1) + cy = tf.reduce_mean(cornersy, axis=-1) + wcornersx = ( + cornersx[Ellipsis, 0] + cornersx[Ellipsis, 1] - cornersx[Ellipsis, 2] - cornersx[Ellipsis, 3]) + wcornersy = ( + cornersy[Ellipsis, 0] + cornersy[Ellipsis, 1] - cornersy[Ellipsis, 2] - cornersy[Ellipsis, 3]) + hcornersy = (-cornersy[Ellipsis, 0,] + cornersy[Ellipsis, 1] + cornersy[Ellipsis, 2] - + cornersy[Ellipsis, 3]) + a = -tf.atan2(wcornersy, wcornersx) + cos = tf.cos(a) + w = wcornersx / (2 * cos) + h = hcornersy / (2 * cos) + cxcywha = tf.stack([cx, cy, w, h, a], axis=-1) + + return cxcywha diff --git a/invariant_slot_attention/lib/utils.py b/invariant_slot_attention/lib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29797f86146e2c997047ea9d324c34e02b895d30 --- /dev/null +++ b/invariant_slot_attention/lib/utils.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utils.""" + +import functools +import importlib +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union + +from absl import logging +from clu import metrics as base_metrics + +import flax +from flax import linen as nn +from flax import traverse_util + +import jax +import jax.numpy as jnp +import jax.ops + +import matplotlib +import matplotlib.pyplot as plt +import ml_collections +import numpy as np +import optax + +import skimage.transform +import tensorflow as tf + +from invariant_slot_attention.lib import metrics + + +Array = Any # Union[np.ndarray, jnp.ndarray] +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet +PRNGKey = Array +ConfigAttr = Any +MetricSpec = Dict[str, str] + + +@flax.struct.dataclass +class TrainState: + """Data structure for checkpointing the model.""" + step: int + opt_state: optax.OptState + params: ArrayTree + variables: flax.core.FrozenDict + rng: PRNGKey + + +METRIC_TYPE_TO_CLS = { + "loss": base_metrics.Average.from_output(name="loss"), + "ari": metrics.Ari, + "ari_nobg": metrics.AriNoBg, +} + + +def make_metrics_collection( + class_name, + metrics_spec): + """Create class inhering from metrics.Collection based on spec.""" + metrics_dict = {} + if metrics_spec: + for m_name, m_type in metrics_spec.items(): + metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type] + + return flax.struct.dataclass( + type(class_name, + (base_metrics.Collection,), + {"__annotations__": metrics_dict})) + + +def flatten_named_dicttree(metrics_res, sep = "/"): + """Flatten dictionary.""" + metrics_res_flat = {} + for k, v in traverse_util.flatten_dict(metrics_res).items(): + metrics_res_flat[(sep.join(k)).strip(sep)] = v + return metrics_res_flat + + +def spatial_broadcast(x, resolution): + """Broadcast flat inputs to a 2D grid of a given resolution.""" + # x.shape = (batch_size, features). + x = x[:, jnp.newaxis, jnp.newaxis, :] + return jnp.tile(x, [1, resolution[0], resolution[1], 1]) + + +def time_distributed(cls, in_axes=1, axis=1): + """Wrapper for time-distributed (vmapped) application of a module.""" + return nn.vmap( + cls, in_axes=in_axes, out_axes=axis, axis_name="time", + # Stack debug vars along sequence dim and broadcast params. + variable_axes={ + "params": None, "intermediates": axis, "batch_stats": None}, + split_rngs={"params": False, "dropout": True, "state_init": True}) + + +def broadcast_across_batch(inputs, batch_size): + """Broadcasts inputs across a batch of examples (creates new axis).""" + return jnp.broadcast_to( + array=jnp.expand_dims(inputs, axis=0), + shape=(batch_size,) + inputs.shape) + + +def create_gradient_grid( + samples_per_dim, value_range = (-1.0, 1.0) + ): + """Creates a tensor with equidistant entries from -1 to +1 in each dim. + + Args: + samples_per_dim: Number of points to have along each dimension. + value_range: In each dimension, points will go from range[0] to range[1] + + Returns: + A tensor of shape [samples_per_dim] + [len(samples_per_dim)]. + """ + s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim] + pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1) + return jnp.array(pe) + + +def convert_to_fourier_features(inputs, basis_degree): + """Convert inputs to Fourier features, e.g. for positional encoding.""" + + # inputs.shape = (..., n_dims). + # inputs should be in range [-pi, pi] or [0, 2pi]. + n_dims = inputs.shape[-1] + + # Generate frequency basis. + freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree) + [2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1) + + # x.shape = (..., n_dims * basis_degree) + x = inputs @ freq_basis # Project inputs onto frequency basis. + + # Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)]. + return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1)) + + +def prepare_images_for_logging( + config, + batch = None, + preds = None, + n_samples = 5, + n_frames = 5, + min_n_colors = 1, + epsilon = 1e-6, + first_replica_only = False): + """Prepare images from batch and/or model predictions for logging.""" + + images = dict() + # Converts all tensors to numpy arrays to run everything on CPU as JAX + # eager mode is inefficient and because memory usage from these ops may + # lead to OOM errors. + batch = jax.tree_map(np.array, batch) + preds = jax.tree_map(np.array, preds) + + if n_samples <= 0: + return images + + if not first_replica_only: + # Move the two leading batch dimensions into a single dimension. We do this + # to plot the same number of examples regardless of the data parallelism. + batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch) + preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds) + else: + batch = jax.tree_map(lambda x: x[0], batch) + preds = jax.tree_map(lambda x: x[0], preds) + + # Limit the tensors to n_samples and n_frames. + batch = jax.tree_map( + lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], + batch) + preds = jax.tree_map( + lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], + preds) + + # Log input data. + if batch is not None: + images["video"] = video_to_image_grid(batch["video"]) + if "segmentations" in batch: + images["mask"] = video_to_image_grid(convert_categories_to_color( + batch["segmentations"], min_n_colors=min_n_colors)) + if "flow" in batch: + images["flow"] = video_to_image_grid(batch["flow"]) + if "boxes" in batch: + images["boxes"] = draw_bounding_boxes( + batch["video"], + batch["boxes"], + min_n_colors=min_n_colors) + + # Log model predictions. + if preds is not None and preds.get("outputs") is not None: + if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error + images["segmentations"] = video_to_image_grid( + convert_categories_to_color( + preds["outputs"]["segmentations"], min_n_colors=min_n_colors)) + + def shape_fn(x): + if isinstance(x, (np.ndarray, jnp.ndarray)): + return x.shape + + # Log intermediate variables. + if preds is not None and "intermediates" in preds: + + logging.info("intermediates: %s", + jax.tree_map(shape_fn, preds["intermediates"])) + + for key, path in config.debug_var_video_paths.items(): + log_vars = retrieve_from_collection(preds["intermediates"], path) + if log_vars is not None: + if not isinstance(log_vars, Sequence): + log_vars = [log_vars] + for i, log_var in enumerate(log_vars): + log_var = np.array(log_var) # Moves log_var to CPU. + images[key + "_" + str(i)] = video_to_image_grid(log_var) + else: + logging.warning("%s not found in intermediates", path) + + # Log attention weights. + for key, path in config.debug_var_attn_paths.items(): + log_vars = retrieve_from_collection(preds["intermediates"], path) + if log_vars is not None: + if not isinstance(log_vars, Sequence): + log_vars = [log_vars] + for i, log_var in enumerate(log_vars): + log_var = np.array(log_var) # Moves log_var to CPU. + images.update( + prepare_attention_maps_for_logging( + attn_maps=log_var, + key=key + "_" + str(i), + map_width=config.debug_var_attn_widths.get(key), + video=batch["video"], + epsilon=epsilon, + n_samples=n_samples, + n_frames=n_frames)) + else: + logging.warning("%s not found in intermediates", path) + + # Crop each image to a maximum of 3 channels for RGB visualization. + for key, image in images.items(): + if image.shape[-1] > 3: + logging.warning("Truncating channels of %s for visualization.", key) + images[key] = image[Ellipsis, :3] + + return images + + +def prepare_attention_maps_for_logging(attn_maps, key, + map_width, epsilon, + n_samples, n_frames, + video): + """Visualize (overlayed) attention maps as an image grid.""" + images = {} # Results dictionary. + attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width) + + num_heads = attn_maps.shape[2] + for head_idx in range(num_heads): + attn = attn_maps[:n_samples, :n_frames, head_idx] + attn /= attn.max() + epsilon # Standardizes scale for visualization. + # attn.shape: [bs, seq_len, 11, h', w', 1] + + bs, seq_len, _, h_attn, w_attn, _ = attn.shape + images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn) + + # Attention maps are interpretable when they align with object boundaries. + # However, if they are overly smooth then the following visualization which + # overlays attention maps on video is helpful. + video = video[:n_samples, :n_frames] + # video.shape: [bs, seq_len, h, w, 3] + video_resized = [] + for i in range(n_samples): + for j in range(n_frames): + video_resized.append( + skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1)) + video_resized = np.array(video_resized).reshape( + (bs, seq_len, h_attn, w_attn, 3)) + attn_overlayed = attn * np.expand_dims(video_resized, 2) + images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid( + attn_overlayed) + + return images + + +def convert_categories_to_color( + inputs, min_n_colors = 1, include_black = True): + """Converts int-valued categories to color in last axis of input tensor. + + Args: + inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the + categories. + min_n_colors: Minimum number of colors (excl. black) to encode categories. + include_black: Include black as 0-th entry in the color palette. Increases + `min_n_colors` by 1 if True. + + Returns: + `np.ndarray` with RGB colors in last axis. + """ + if inputs.shape[-1] == 1: # Strip category axis. + inputs = np.squeeze(inputs, axis=-1) + inputs = np.array(inputs, dtype=np.int32) # Convert to int. + + # Infer number of colors from inputs. + n_colors = int(inputs.max()) + 1 # One color per category incl. 0. + if include_black: + n_colors -= 1 # If we include black, we need one color less. + + if min_n_colors > n_colors: # Use more colors in color palette if requested. + n_colors = min_n_colors + + rgb_colors = get_uniform_colors(n_colors) + + if include_black: # Add black as color for zero-th index. + rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0) + return rgb_colors[inputs] + + +def get_uniform_colors(n_colors): + """Get n_colors with uniformly spaced hues.""" + hues = np.linspace(0, 1, n_colors, endpoint=False) + hsv_colors = np.concatenate( + (np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1) + rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors) + return rgb_colors # rgb_colors.shape = (n_colors, 3) + + +def unflatten_image(image, width = None): + """Unflatten image array of shape [batch_dims..., height*width, channels].""" + n_channels = image.shape[-1] + # If width is not provided, we assume that the image is square. + if width is None: + width = int(np.floor(np.sqrt(image.shape[-2]))) + height = width + assert width * height == image.shape[-2], "Image is not square." + else: + height = image.shape[-2] // width + return image.reshape(image.shape[:-2] + (height, width, n_channels)) + + +def video_to_image_grid(video): + """Transform video to image grid by folding sequence dim along width.""" + if len(video.shape) == 5: + n_samples, n_frames, height, width, n_channels = video.shape + video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height. + image_grid = np.reshape( + video, (n_samples, height, n_frames * width, n_channels)) + elif len(video.shape) == 6: + n_samples, n_frames, n_slots, height, width, n_channels = video.shape + # Put n_frames next to width. + video = np.transpose(video, (0, 2, 3, 1, 4, 5)) + image_grid = np.reshape( + video, (n_samples, n_slots * height, n_frames * width, n_channels)) + else: + raise ValueError("Unsupported video shape for visualization.") + return image_grid + + +def draw_bounding_boxes(video, + boxes, + min_n_colors = 1, + include_black = True): + """Draw bounding boxes in videos.""" + colors = get_uniform_colors(min_n_colors - include_black) + + b, t, h, w, c = video.shape + n = boxes.shape[2] + image_grid = tf.image.draw_bounding_boxes( + np.reshape(video, (b * t, h, w, c)), + np.reshape(boxes, (b * t, n, 4)), + colors).numpy() + image_grid = np.reshape( + np.transpose(np.reshape(image_grid, (b, t, h, w, c)), + (0, 2, 1, 3, 4)), + (b, h, t * w, c)) + return image_grid + + +def plot_image(ax, image): + """Add an image visualization to a provided `plt.Axes` instance.""" + num_channels = image.shape[-1] + if num_channels == 1: + image = image.reshape(image.shape[:2]) + ax.imshow(image, cmap="viridis") + ax.grid(False) + plt.axis("off") + + +def visualize_image_dict(images, plot_scale = 10): + """Visualize a dictionary of images in colab using maptlotlib.""" + + for key in images.keys(): + logging.info("Visualizing key: %s", key) + n_images = len(images[key]) + fig = plt.figure(figsize=(n_images * plot_scale, plot_scale)) + for idx, image in enumerate(images[key]): + ax = fig.add_subplot(1, n_images, idx+1) + plot_image(ax, image) + plt.show() + + +def filter_key_from_frozen_dict( + frozen_dict, key): + """Filters (removes) an item by key from a flax.core.FrozenDict.""" + if key in frozen_dict: + frozen_dict, _ = frozen_dict.pop(key) + return frozen_dict + + +def prepare_dict_for_logging(nested_dict, parent_key = "", + sep = "_"): + """Prepare a nested dictionary for logging with `clu.metric_writers`. + + Args: + nested_dict: A nested dictionary, e.g. obtained from a + `ml_collections.ConfigDict` via `.to_dict()`. + parent_key: String used in recursion. + sep: String used to separate parent and child keys. + + Returns: + Flattened dict. + """ + items = [] + for k, v in nested_dict.items(): + # Flatten keys of nested elements. + new_key = parent_key + sep + k if parent_key else k + + # Convert None values, lists and tuples to strings. + if v is None: + v = "None" + if isinstance(v, list) or isinstance(v, tuple): + v = str(v) + + # Recursively flatten the dict. + if isinstance(v, dict): + items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def retrieve_from_collection( + variable_collection, path): + """Finds variables by their path by recursively searching the collection. + + Args: + variable_collection: Nested dict containing the variables (or tuples/lists + of variables). + path: Path to variable in module tree, similar to Unix file names (e.g. + '/module/dense/0/bias'). + + Returns: + The requested variable, variable collection or None (in case the variable + could not be found). + """ + key, _, rpath = path.strip("/").partition("/") + + # In case the variable is not found, we return None. + if (key.isdigit() and not isinstance(variable_collection, Sequence)) or ( + key.isdigit() and int(key) >= len(variable_collection)) or ( + not key.isdigit() and key not in variable_collection): + return None + + if key.isdigit(): + key = int(key) + + if not rpath: + return variable_collection[key] + else: + return retrieve_from_collection(variable_collection[key], rpath) + + +def build_model_from_config(config): + """Build a Flax model from a (nested) ConfigDict.""" + model_constructor = _parse_config(config) + if callable(model_constructor): + return model_constructor() + else: + raise ValueError("Provided config does not contain module constructors.") + + +def _parse_config(config + ): + """Recursively parses a nested ConfigDict and resolves module constructors.""" + + if isinstance(config, list): + return [_parse_config(c) for c in config] + elif isinstance(config, tuple): + return tuple([_parse_config(c) for c in config]) + elif not isinstance(config, ml_collections.ConfigDict): + return config + elif "module" in config: + module_constructor = _resolve_module_constructor(config.module) + kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"} + return functools.partial(module_constructor, **kwargs) + else: + return {k: _parse_config(v) for k, v in config.items()} + + +def _resolve_module_constructor( + constructor_str): + import_str, _, module_name = constructor_str.rpartition(".") + py_module = importlib.import_module(import_str) + return getattr(py_module, module_name) + + +def get_slices_along_axis( + inputs, + slice_keys, + start_idx = 0, + end_idx = -1, + axis = 2, + pad_value = 0): + """Extracts slices from a dictionary of tensors along the specified axis. + + The slice operation is only applied to `slice_keys` dictionary keys. If + `end_idx` is larger than the actual size of the specified axis, padding is + added (with values provided in `pad_value`). + + Args: + inputs: Dictionary of tensors. + slice_keys: Iterable of strings, the keys for the inputs dictionary for + which to apply the slice operation. + start_idx: Integer, defining the first index to be part of the slice. + end_idx: Integer, defining the end of the slice interval (exclusive). If set + to `-1`, the end index is set to the size of the axis. If a value is + provided that is larger than the size of the axis, zero-padding is added + for the remaining elements. + axis: Integer, the axis along which to slice. + pad_value: Integer, value to be used in padding. + + Returns: + Dictionary of tensors where elements described in `slice_keys` are sliced, + and all other elements are returned as original. + """ + + max_size = None + pad_size = 0 + + # Check shapes and get maximum size of requested axis. + for key in slice_keys: + curr_size = inputs[key].shape[axis] + if max_size is None: + max_size = curr_size + elif max_size != curr_size: + raise ValueError( + "For specified tensors the requested axis needs to be of equal size.") + + # Infer end index if not provided. + if end_idx == -1: + end_idx = max_size + + # Set padding size if end index is larger than maximum size of requested axis. + elif end_idx > max_size: + pad_size = end_idx - max_size + end_idx = max_size + + outputs = {} + for key in slice_keys: + outputs[key] = np.take( + inputs[key], indices=np.arange(start_idx, end_idx), axis=axis) + + # Add padding if necessary. + if pad_size > 0: + pad_shape = np.array(outputs[key].shape) + np.put(pad_shape, axis, pad_size) # In-place op. + padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype) + outputs[key] = np.concatenate((outputs[key], padding), axis=axis) + + return outputs + + +def get_element_by_str( + dictionary, multilevel_key, separator = "/" + ): + """Gets element in a dictionary with multilevel key (e.g., "key1/key2").""" + keys = multilevel_key.split(separator) + if len(keys) == 1: + return dictionary[keys[0]] + return get_element_by_str( + dictionary[keys[0]], separator.join(keys[1:]), separator=separator) + + +def set_element_by_str( + dictionary, multilevel_key, new_value, + separator = "/"): + """Sets element in a dictionary with multilevel key (e.g., "key1/key2").""" + keys = multilevel_key.split(separator) + if len(keys) == 1: + if keys[0] not in dictionary: + key_error = ( + "Pretrained {key} was not found in trained model. " + "Make sure you are loading the correct pretrained model " + "or consider adding {key} to exceptions.") + raise KeyError(key_error.format(type="parameter", key=keys[0])) + dictionary[keys[0]] = new_value + else: + set_element_by_str( + dictionary[keys[0]], + separator.join(keys[1:]), + new_value, + separator=separator) + + +def remove_singleton_dim(inputs): + """Removes the final dimension if it is singleton (i.e. of size 1).""" + if inputs is None: + return None + if inputs.shape[-1] != 1: + logging.warning("Expected final dimension of inputs to be 1, " + "received inputs of shape %s: ", str(inputs.shape)) + return inputs + return inputs[Ellipsis, 0] + diff --git a/invariant_slot_attention/modules/__init__.py b/invariant_slot_attention/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4de3565ef241189eac33dea9d2f901f5305be01f --- /dev/null +++ b/invariant_slot_attention/modules/__init__.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module library.""" +# pylint: disable=g-multiple-import +# pylint: disable=g-bad-import-order +# Re-export commonly used modules and functions + +from .attention import (GeneralizedDotProductAttention, + InvertedDotProductAttention, SlotAttention, + TransformerBlock, Transformer) +from .convolution import (SimpleCNN, CNN) +from .decoders import (SpatialBroadcastDecoder, SiameseSpatialBroadcastDecoder) +from .initializers import (GaussianStateInit, ParamStateInit, + SegmentationEncoderStateInit, + CoordinateEncoderStateInit) +from .misc import (Dense, GRU, Identity, MLP, PositionEmbedding, Readout, + RelativePositionEmbedding) +from .video import (CorrectorPredictorTuple, FrameEncoder, Processor, SAVi) +from .resnet import (ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, + ResNet200) +from .invariant_attention import (InvertedDotProductAttentionKeyPerQuery, + SlotAttentionExplicitStats, + SlotAttentionPosKeysValues, + SlotAttentionTranslEquiv, + SlotAttentionTranslScaleEquiv, + SlotAttentionTranslRotScaleEquiv) +from .invariant_initializers import ( + ParamStateInitRandomPositions, + ParamStateInitRandomPositionsScales, + ParamStateInitRandomPositionsRotationsScales, + ParamStateInitLearnablePositions, + ParamStateInitLearnablePositionsScales, + ParamStateInitLearnablePositionsRotationsScales) + + +# pylint: enable=g-multiple-import diff --git a/invariant_slot_attention/modules/attention.py b/invariant_slot_attention/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..88cf1bc2a1f5cde5b49b81c9606b8439c5e70a35 --- /dev/null +++ b/invariant_slot_attention/modules/attention.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attention module library.""" + +import functools +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + +from flax import linen as nn +import jax +import jax.numpy as jnp +from invariant_slot_attention.modules import misc + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class SlotAttention(nn.Module): + """Slot Attention module. + + Note: This module uses pre-normalization by default. + """ + + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + num_heads: int = 1 + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention module forward pass.""" + del padding_mask, train # Unused. + + qkv_size = self.qkv_size or slots.shape[-1] + head_dim = qkv_size // self.num_heads + dense = functools.partial(nn.DenseGeneral, + axis=-1, features=(self.num_heads, head_dim), + use_bias=False) + + # Shared modules. + dense_q = dense(name="general_dense_q_0") + layernorm_q = nn.LayerNorm() + inverted_attention = InvertedDotProductAttention( + norm_type="mean", multi_head=self.num_heads > 1) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + # inputs.shape = (..., n_inputs, inputs_size). + inputs = nn.LayerNorm()(inputs) + # k.shape = (..., n_inputs, slot_size). + k = dense(name="general_dense_k_0")(inputs) + # v.shape = (..., n_inputs, slot_size). + v = dense(name="general_dense_v_0")(inputs) + + # Multiple rounds of attention. + for _ in range(self.num_iterations): + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size). + updates = inverted_attention(query=q, key=k, value=v) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + return slots + + +class InvertedDotProductAttention(nn.Module): + """Inverted version of dot-product attention (softmax over query axis).""" + + norm_type: Optional[str] = "mean" # mean, layernorm, or None + multi_head: bool = False + epsilon: float = 1e-8 + dtype: DType = jnp.float32 + precision: Optional[jax.lax.Precision] = None + return_attn_weights: bool = False + + @nn.compact + def __call__(self, query, key, value, + train = False): + """Computes inverted dot-product attention. + + Args: + query: Queries with shape of `[batch..., q_num, qk_features]`. + key: Keys with shape of `[batch..., kv_num, qk_features]`. + value: Values with shape of `[batch..., kv_num, v_features]`. + train: Indicating whether we're training or evaluating. + + Returns: + Output of shape `[batch_size..., n_queries, v_features]` + """ + del train # Unused. + + attn = GeneralizedDotProductAttention( + inverted_attn=True, + renormalize_keys=True if self.norm_type == "mean" else False, + epsilon=self.epsilon, + dtype=self.dtype, + precision=self.precision, + return_attn_weights=True) + + # Apply attention mechanism. + output, attn = attn(query=query, key=key, value=value) + + if self.multi_head: + # Multi-head aggregation. Equivalent to concat + dense layer. + output = nn.DenseGeneral(features=output.shape[-1], axis=(-2, -1))(output) + else: + # Remove head dimension. + output = jnp.squeeze(output, axis=-2) + attn = jnp.squeeze(attn, axis=-3) + + if self.norm_type == "layernorm": + output = nn.LayerNorm()(output) + + if self.return_attn_weights: + return output, attn + + return output + + +class GeneralizedDotProductAttention(nn.Module): + """Multi-head dot-product attention with customizable normalization axis. + + This module supports logging of attention weights in a variable collection. + """ + + dtype: DType = jnp.float32 + precision: Optional[jax.lax.Precision] = None + epsilon: float = 1e-8 + inverted_attn: bool = False + renormalize_keys: bool = False + attn_weights_only: bool = False + return_attn_weights: bool = False + + @nn.compact + def __call__(self, query, key, value, + train = False, **kwargs + ): + """Computes multi-head dot-product attention given query, key, and value. + + Args: + query: Queries with shape of `[batch..., q_num, num_heads, qk_features]`. + key: Keys with shape of `[batch..., kv_num, num_heads, qk_features]`. + value: Values with shape of `[batch..., kv_num, num_heads, v_features]`. + train: Indicating whether we're training or evaluating. + **kwargs: Additional keyword arguments are required when used as attention + function in nn.MultiHeadDotProductAttention, but they will be ignored + here. + + Returns: + Output of shape `[batch..., q_num, num_heads, v_features]`. + """ + + assert query.ndim == key.ndim == value.ndim, ( + "Queries, keys, and values must have the same rank.") + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + "Query, key, and value batch dimensions must match.") + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + "Query, key, and value num_heads dimensions must match.") + assert key.shape[-3] == value.shape[-3], ( + "Key and value cardinality dimensions must match.") + assert query.shape[-1] == key.shape[-1], ( + "Query and key feature dimensions must match.") + + if kwargs.get("bias") is not None: + raise NotImplementedError( + "Support for masked attention is not yet implemented.") + + if "dropout_rate" in kwargs: + if kwargs["dropout_rate"] > 0.: + raise NotImplementedError("Support for dropout is not yet implemented.") + + # Temperature normalization. + qk_features = query.shape[-1] + query = query / jnp.sqrt(qk_features).astype(self.dtype) + + # attn.shape = (batch..., num_heads, q_num, kv_num) + attn = jnp.einsum("...qhd,...khd->...hqk", query, key, + precision=self.precision) + + if self.inverted_attn: + attention_axis = -2 # Query axis. + else: + attention_axis = -1 # Key axis. + + # Softmax normalization (by default over key axis). + attn = jax.nn.softmax(attn, axis=attention_axis).astype(self.dtype) + + # Defines intermediate for logging. + if not train: + self.sow("intermediates", "attn", attn) + + if self.renormalize_keys: + # Corresponds to value aggregation via weighted mean (as opposed to sum). + normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon + attn = attn / normalizer + + if self.attn_weights_only: + return attn + + # Aggregate values using a weighted sum with weights provided by `attn`. + output = jnp.einsum( + "...hqk,...khd->...qhd", attn, value, precision=self.precision) + + if self.return_attn_weights: + return output, attn + + return output + + +class Transformer(nn.Module): + """Transformer with multiple blocks.""" + + num_heads: int + qkv_size: int + mlp_size: int + num_layers: int + pre_norm: bool = False + + @nn.compact + def __call__(self, queries, inputs = None, + padding_mask = None, + train = False): + x = queries + for lyr in range(self.num_layers): + x = TransformerBlock( + num_heads=self.num_heads, qkv_size=self.qkv_size, + mlp_size=self.mlp_size, pre_norm=self.pre_norm, + name=f"TransformerBlock{lyr}")( # pytype: disable=wrong-arg-types + x, inputs, padding_mask, train) + return x + + +class TransformerBlock(nn.Module): + """Transformer decoder block.""" + + num_heads: int + qkv_size: int + mlp_size: int + pre_norm: bool = False + + @nn.compact + def __call__(self, queries, inputs = None, + padding_mask = None, + train = False): + del padding_mask # Unused. + assert queries.ndim == 3 + + attention_fn = GeneralizedDotProductAttention() + + attn = functools.partial( + nn.MultiHeadDotProductAttention, + num_heads=self.num_heads, + qkv_features=self.qkv_size, + attention_fn=attention_fn) + + mlp = misc.MLP(hidden_size=self.mlp_size) # type: ignore + + if self.pre_norm: + # Self-attention on queries. + x = nn.LayerNorm()(queries) + x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train) + x = x + queries + + # Cross-attention on inputs. + if inputs is not None: + assert inputs.ndim == 3 + y = nn.LayerNorm()(x) + y = attn()(inputs_q=y, inputs_kv=inputs, deterministic=not train) + y = y + x + else: + y = x + + # MLP + z = nn.LayerNorm()(y) + z = mlp(z, train) + z = z + y + else: + # Self-attention on queries. + x = queries + x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train) + x = x + queries + x = nn.LayerNorm()(x) + + # Cross-attention on inputs. + if inputs is not None: + assert inputs.ndim == 3 + y = attn()(inputs_q=x, inputs_kv=inputs, deterministic=not train) + y = y + x + y = nn.LayerNorm()(y) + else: + y = x + + # MLP. + z = mlp(y, train) + z = z + y + z = nn.LayerNorm()(z) + return z diff --git a/invariant_slot_attention/modules/convolution.py b/invariant_slot_attention/modules/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8448e659617a75715429a2e311c7e30e3710ede6 --- /dev/null +++ b/invariant_slot_attention/modules/convolution.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convolutional module library.""" + +import functools +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from flax import linen as nn +import jax + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class SimpleCNN(nn.Module): + """Simple CNN encoder with multiple Conv+ReLU layers.""" + + features: Sequence[int] + kernel_size: Sequence[Tuple[int, int]] + strides: Sequence[Tuple[int, int]] + transpose: bool = False + use_batch_norm: bool = False + axis_name: Optional[str] = None # Over which axis to aggregate batch stats. + padding: Union[str, Iterable[Tuple[int, int]]] = "SAME" + resize_output: Optional[Iterable[int]] = None + + @nn.compact + def __call__(self, inputs, train = False): + num_layers = len(self.features) + assert len(self.kernel_size) == num_layers, ( + "len(kernel_size) and len(features) must match.") + assert len(self.strides) == num_layers, ( + "len(strides) and len(features) must match.") + assert num_layers >= 1, "Need to have at least one layer." + + if self.transpose: + conv_module = nn.ConvTranspose + else: + conv_module = nn.Conv + + x = conv_module( + name="conv_simple_0", + features=self.features[0], + kernel_size=self.kernel_size[0], + strides=self.strides[0], + use_bias=False if self.use_batch_norm else True, + padding=self.padding)(inputs) + + for i in range(1, num_layers): + if self.use_batch_norm: + x = nn.BatchNorm( + momentum=0.9, use_running_average=not train, + axis_name=self.axis_name, name=f"bn_simple_{i-1}")(x) + + x = nn.relu(x) + x = conv_module( + name=f"conv_simple_{i}", + features=self.features[i], + kernel_size=self.kernel_size[i], + strides=self.strides[i], + use_bias=False if ( + self.use_batch_norm and i < (num_layers-1)) else True, + padding=self.padding)(x) + + if self.resize_output: + x = jax.image.resize( + x, list(x.shape[:-3]) + list(self.resize_output) + [x.shape[-1]], + method=jax.image.ResizeMethod.LINEAR) + return x + + +class CNN(nn.Module): + """Flexible CNN model with Conv/Normalization/Pooling layers.""" + + features: Sequence[int] + kernel_size: Sequence[Tuple[int, int]] + strides: Sequence[Tuple[int, int]] + max_pool_strides: Sequence[Tuple[int, int]] + layer_transpose: Sequence[bool] + activation_fn: Callable[[Array], Array] = nn.relu + norm_type: Optional[str] = None + axis_name: Optional[str] = None # Over which axis to aggregate batch stats. + output_size: Optional[int] = None + + @nn.compact + def __call__(self, inputs, train = False): + num_layers = len(self.features) + + assert num_layers >= 1, "Need to have at least one layer." + assert len(self.kernel_size) == num_layers, ( + "len(kernel_size) and len(features) must match.") + assert len(self.strides) == num_layers, ( + "len(strides) and len(features) must match.") + assert len(self.max_pool_strides) == num_layers, ( + "len(max_pool_strides) and len(features) must match.") + assert len(self.layer_transpose) == num_layers, ( + "len(layer_transpose) and len(features) must match.") + + if self.norm_type: + assert self.norm_type in {"batch", "group", "instance", "layer"}, ( + f"{self.norm_type} is unrecognizaed normalization") + + # Whether transpose conv or regular conv. + conv_module = {False: nn.Conv, True: nn.ConvTranspose} + + if self.norm_type == "batch": + norm_module = functools.partial( + nn.BatchNorm, momentum=0.9, use_running_average=not train, + axis_name=self.axis_name) + elif self.norm_type == "group": + norm_module = functools.partial( + nn.GroupNorm, num_groups=32) + elif self.norm_type == "layer": + norm_module = nn.LayerNorm + + x = inputs + for i in range(num_layers): + x = conv_module[self.layer_transpose[i]]( + name=f"conv_{i}", + features=self.features[i], + kernel_size=self.kernel_size[i], + strides=self.strides[i], + use_bias=False if self.norm_type else True)(x) + + # Normalization layer. + if self.norm_type: + if self.norm_type == "instance": + x = nn.GroupNorm( + num_groups=self.features[i], + name=f"{self.norm_type}_norm_{i}")(x) + else: + norm_module(name=f"{self.norm_type}_norm_{i}")(x) + + # Activation layer. + x = self.activation_fn(x) + + # Max pooling layer. + x = x if self.max_pool_strides[i] == (1, 1) else nn.max_pool( + x, self.max_pool_strides[i], strides=self.max_pool_strides[i], + padding="SAME") + + # Final dense layer. + if self.output_size: + x = nn.Dense(self.output_size, name="output_layer", use_bias=True)(x) + return x diff --git a/invariant_slot_attention/modules/decoders.py b/invariant_slot_attention/modules/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b48080b2cf39f915f2d1633382f1078867445f --- /dev/null +++ b/invariant_slot_attention/modules/decoders.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoder module library.""" +import functools +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from flax import linen as nn + +import jax.numpy as jnp + +from invariant_slot_attention.lib import utils +from invariant_slot_attention.modules import misc + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class SpatialBroadcastDecoder(nn.Module): + """Spatial broadcast decoder for a set of slots (per frame).""" + + resolution: Sequence[int] + backbone: Callable[[], nn.Module] + pos_emb: Callable[[], nn.Module] + early_fusion: bool = False # Fuse slot features before constructing targets. + target_readout: Optional[Callable[[], nn.Module]] = None + + # Vmapped application of module, consumes time axis (axis=1). + @functools.partial(utils.time_distributed, in_axes=(1, None)) + @nn.compact + def __call__(self, slots, train = False): + + batch_size, n_slots, n_features = slots.shape + + # Fold slot dim into batch dim. + x = jnp.reshape(slots, (batch_size * n_slots, n_features)) + + # Spatial broadcast with position embedding. + x = utils.spatial_broadcast(x, self.resolution) + x = self.pos_emb()(x) + + # bb_features.shape = (batch_size * n_slots, h, w, c) + bb_features = self.backbone()(x, train=train) + spatial_dims = bb_features.shape[-3:-1] + + alpha_logits = nn.Dense( + features=1, use_bias=True, name="alpha_logits")(bb_features) + alpha_logits = jnp.reshape( + alpha_logits, (batch_size, n_slots) + spatial_dims + (-1,)) + + alphas = nn.softmax(alpha_logits, axis=1) + if not train: + # Define intermediates for logging / visualization. + self.sow("intermediates", "alphas", alphas) + + if self.early_fusion: + # To save memory, fuse the slot features before predicting targets. + # The final target output should be equivalent to the late fusion when + # using linear prediction. + bb_features = jnp.reshape( + bb_features, (batch_size, n_slots) + spatial_dims + (-1,)) + # Combine backbone features by alpha masks. + bb_features = jnp.sum(bb_features * alphas, axis=1) + + targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable + + preds_dict = dict() + for target_key, channels in targets_dict.items(): + if self.early_fusion: + # decoded_target.shape = (batch_size, h, w, c) after next line. + decoded_target = channels + else: + # channels.shape = (batch_size, n_slots, h, w, c) + channels = jnp.reshape( + channels, (batch_size, n_slots) + (spatial_dims) + (-1,)) + + # masked_channels.shape = (batch_size, n_slots, h, w, c) + masked_channels = channels * alphas + + # decoded_target.shape = (batch_size, h, w, c) + decoded_target = jnp.sum(masked_channels, axis=1) # Combine target. + preds_dict[target_key] = decoded_target + + if not train: + # Define intermediates for logging / visualization. + self.sow("intermediates", f"{target_key}_slots", channels) + if not self.early_fusion: + self.sow("intermediates", f"{target_key}_masked", masked_channels) + self.sow("intermediates", f"{target_key}_combined", decoded_target) + + preds_dict["segmentations"] = jnp.argmax(alpha_logits, axis=1) + + return preds_dict + + +class SiameseSpatialBroadcastDecoder(nn.Module): + """Siamese spatial broadcast decoder for a set of slots (per frame). + + Similar to the decoders used in IODINE: https://arxiv.org/abs/1903.00450 + and in Slot Attention: https://arxiv.org/abs/2006.15055. + """ + + resolution: Sequence[int] + backbone: Callable[[], nn.Module] + pos_emb: Callable[[], nn.Module] + pass_intermediates: bool = False + alpha_only: bool = False # Predict only alpha masks. + concat_attn: bool = False + # Readout after backbone. + target_readout_from_slots: bool = False + target_readout: Optional[Callable[[], nn.Module]] = None + early_fusion: bool = False # Fuse slot features before constructing targets. + # Readout on slots. + attribute_readout: Optional[Callable[[], nn.Module]] = None + remove_background_attribute: bool = False + attn_key: Optional[str] = None + attn_width: Optional[int] = None + # If True, expects slot embeddings to contain slot positions. + relative_positions: bool = False + # Slot positions and scales. + relative_positions_and_scales: bool = False + relative_positions_rotations_and_scales: bool = False + + # Vmapped application of module, consumes time axis (axis=1). + @functools.partial(utils.time_distributed, in_axes=(1, None)) + @nn.compact + def __call__(self, + slots, + train = False): + + if self.remove_background_attribute and self.attribute_readout is None: + raise NotImplementedError( + "Background removal is only supported for attribute readout.") + + if self.relative_positions: + # Assume slot positions were concatenated to slot embeddings. + # E.g. an output of SlotAttentionTranslEquiv. + slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:] + # Reshape positions to [B * num_slots, 2] + positions = positions.reshape( + (positions.shape[0] * positions.shape[1], positions.shape[2])) + elif self.relative_positions_and_scales: + # Assume slot positions and scales were concatenated to slot embeddings. + # E.g. an output of SlotAttentionTranslScaleEquiv. + slots, positions, scales = (slots[Ellipsis, :-4], + slots[Ellipsis, -4: -2], + slots[Ellipsis, -2:]) + positions = positions.reshape( + (positions.shape[0] * positions.shape[1], positions.shape[2])) + scales = scales.reshape( + (scales.shape[0] * scales.shape[1], scales.shape[2])) + elif self.relative_positions_rotations_and_scales: + slots, positions, scales, rotm = (slots[Ellipsis, :-8], + slots[Ellipsis, -8: -6], + slots[Ellipsis, -6: -4], + slots[Ellipsis, -4:]) + positions = positions.reshape( + (positions.shape[0] * positions.shape[1], positions.shape[2])) + scales = scales.reshape( + (scales.shape[0] * scales.shape[1], scales.shape[2])) + rotm = rotm.reshape( + rotm.shape[0] * rotm.shape[1], 2, 2) + + batch_size, n_slots, n_features = slots.shape + + preds_dict = {} + # Fold slot dim into batch dim. + x = jnp.reshape(slots, (batch_size * n_slots, n_features)) + + # Attribute readout. + if self.attribute_readout is not None: + if self.remove_background_attribute: + slots = slots[:, 1:] + attributes_dict = self.attribute_readout()(slots, train) # pylint: disable=not-callable + preds_dict.update(attributes_dict) + + # Spatial broadcast with position embedding. + # See https://arxiv.org/abs/1901.07017. + x = utils.spatial_broadcast(x, self.resolution) + + if self.relative_positions: + x = self.pos_emb()(inputs=x, slot_positions=positions) + elif self.relative_positions_and_scales: + x = self.pos_emb()(inputs=x, slot_positions=positions, slot_scales=scales) + elif self.relative_positions_rotations_and_scales: + x = self.pos_emb()( + inputs=x, slot_positions=positions, slot_scales=scales, + slot_rotm=rotm) + else: + x = self.pos_emb()(x) + + # bb_features.shape = (batch_size*n_slots, h, w, c) + bb_features = self.backbone()(x, train=train) + spatial_dims = bb_features.shape[-3:-1] + alphas = nn.Dense(features=1, use_bias=True, name="alphas")(bb_features) + alphas = jnp.reshape( + alphas, (batch_size, n_slots) + spatial_dims + (-1,)) + alphas_softmaxed = nn.softmax(alphas, axis=1) + preds_dict["segmentation_logits"] = alphas + preds_dict["segmentations"] = jnp.argmax(alphas, axis=1) + # Define intermediates for logging. + _ = misc.Identity(name="alphas_softmaxed")(alphas_softmaxed) + if self.alpha_only or self.target_readout is None: + assert alphas.shape[-1] == 1, "Alpha masks need to be one-dimensional." + return preds_dict, {"segmentation_logits": alphas} + + if self.early_fusion: + # To save memory, fuse the slot features before predicting targets. + # The final target output should be equivalent to the late fusion when + # using linear prediction. + bb_features = jnp.reshape( + bb_features, (batch_size, n_slots) + spatial_dims + (-1,)) + # Combine backbone features by alpha masks. + bb_features = jnp.sum(bb_features * alphas_softmaxed, axis=1) + + if self.target_readout_from_slots: + targets_dict = self.target_readout()(slots, train) # pylint: disable=not-callable + else: + targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable + + targets_dict_new = dict() + targets_dict_new["targets_masks"] = alphas_softmaxed + targets_dict_new["targets_logits_masks"] = alphas + + for target_key, channels in targets_dict.items(): + if self.early_fusion: + # decoded_target.shape = (batch_size, h, w, c) after next line. + decoded_target = channels + else: + # channels.shape = (batch_size, n_slots, h, w, c) after next line. + channels = jnp.reshape( + channels, (batch_size, n_slots) + + (spatial_dims if not self.target_readout_from_slots else + (1, 1)) + (-1,)) + # masked_channels.shape = (batch_size, n_slots, h, w, c) at next line. + masked_channels = channels * alphas_softmaxed + # decoded_target.shape = (batch_size, h, w, c) after next line. + decoded_target = jnp.sum(masked_channels, axis=1) # Combine target. + targets_dict_new[target_key + "_channels"] = channels + # Define intermediates for logging. + _ = misc.Identity(name=f"{target_key}_channels")(channels) + _ = misc.Identity(name=f"{target_key}_masked_channels")(masked_channels) + + targets_dict_new[target_key] = decoded_target + # Define intermediates for logging. + _ = misc.Identity(name=f"decoded_{target_key}")(decoded_target) + + preds_dict.update(targets_dict_new) + return preds_dict diff --git a/invariant_slot_attention/modules/initializers.py b/invariant_slot_attention/modules/initializers.py new file mode 100644 index 0000000000000000000000000000000000000000..faa45a6826ece7b4a3fbd19c3ca4c2c89e98bc0f --- /dev/null +++ b/invariant_slot_attention/modules/initializers.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Initializers module library.""" + +import functools +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from flax import linen as nn + +import jax +import jax.numpy as jnp + +from invariant_slot_attention.lib import utils +from invariant_slot_attention.modules import misc +from invariant_slot_attention.modules import video + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class ParamStateInit(nn.Module): + """Fixed, learnable state initalization. + + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + param = self.param("state_init", init_fn(), self.shape) + return utils.broadcast_across_batch(param, batch_size=batch_size) + + +class GaussianStateInit(nn.Module): + """Random state initialization with zero-mean, unit-variance Gaussian. + + Note: This module does not contain any trainable parameters and requires + providing a jax.PRNGKey both at training and at test time. Note: This module + also ignores any conditional input (by design). + """ + + shape: Sequence[int] + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + rng = self.make_rng("state_init") + return jax.random.normal(rng, shape=[batch_size] + list(self.shape)) + + +class SegmentationEncoderStateInit(nn.Module): + """State init that encodes segmentation masks as conditional input.""" + + max_num_slots: int + backbone: Callable[[], nn.Module] + pos_emb: Callable[[], nn.Module] = misc.Identity + reduction: Optional[str] = "all_flatten" # Reduce spatial dim by default. + output_transform: Callable[[], nn.Module] = misc.Identity + zero_background: bool = False + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del batch_size # Unused. + + # inputs.shape = (batch_size, seq_len, height, width) + inputs = inputs[:, 0] # Only condition on first time step. + + # Convert mask index to one-hot. + inputs_oh = jax.nn.one_hot(inputs, self.max_num_slots) + # inputs_oh.shape = (batch_size, height, width, n_slots) + # NOTE: 0th entry inputs_oh[..., 0] will typically correspond to background. + + # Set background slot to all-zeros. + if self.zero_background: + inputs_oh = inputs_oh.at[:, :, :, 0].set(0) + + # Switch one-hot axis into 1st position (i.e. sequence axis). + inputs_oh = jnp.transpose(inputs_oh, (0, 3, 1, 2)) + # inputs_oh.shape = (batch_size, max_num_slots, height, width) + + # Append dummy feature axis. + inputs_oh = jnp.expand_dims(inputs_oh, axis=-1) + + # Vmapped encoder over seq. axis (i.e. we process each slot independently). + encoder = video.FrameEncoder( + backbone=self.backbone, + pos_emb=self.pos_emb, + reduction=self.reduction, + output_transform=self.output_transform) # type: ignore + + # encoder(inputs_oh).shape = (batch_size, n_slots, n_features) + slots = encoder(inputs_oh, None, train) + + return slots + + +class CoordinateEncoderStateInit(nn.Module): + """State init that encodes bounding box coordinates as conditional input. + + Attributes: + embedding_transform: A nn.Module that is applied on inputs (bounding boxes). + prepend_background: Boolean flag; whether to prepend a special, zero-valued + background bounding box to the input. Default: false. + center_of_mass: Boolean flag; whether to convert bounding boxes to center + of mass coordinates. Default: false. + background_value: Default value to fill in the background. + """ + + embedding_transform: Callable[[], nn.Module] + prepend_background: bool = False + center_of_mass: bool = False + background_value: float = 0. + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del batch_size # Unused. + + # inputs.shape = (batch_size, seq_len, bboxes, 4) + inputs = inputs[:, 0] # Only condition on first time step. + # inputs.shape = (batch_size, bboxes, 4) + + if self.prepend_background: + # Adds a fake background box [0, 0, 0, 0] at the beginning. + batch_size = inputs.shape[0] + + # Encode the background as specified by background_value. + background = jnp.full( + (batch_size, 1, 4), self.background_value, dtype=inputs.dtype) + + inputs = jnp.concatenate((background, inputs), axis=1) + + if self.center_of_mass: + y_pos = (inputs[:, :, 0] + inputs[:, :, 2]) / 2 + x_pos = (inputs[:, :, 1] + inputs[:, :, 3]) / 2 + inputs = jnp.stack((y_pos, x_pos), axis=-1) + + slots = self.embedding_transform()(inputs, train=train) # pytype: disable=not-callable + + return slots diff --git a/invariant_slot_attention/modules/invariant_attention.py b/invariant_slot_attention/modules/invariant_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd69db49feb771ec770e654be4d854910e1e872 --- /dev/null +++ b/invariant_slot_attention/modules/invariant_attention.py @@ -0,0 +1,963 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Equivariant attention module library.""" +import functools +from typing import Any, Optional, Tuple + +from flax import linen as nn +import jax +import jax.numpy as jnp +from invariant_slot_attention.modules import attention +from invariant_slot_attention.modules import misc + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +PRNGKey = Array + + +class InvertedDotProductAttentionKeyPerQuery(nn.Module): + """Inverted dot-product attention with a different set of keys per query. + + Used in SlotAttentionTranslEquiv, where each slot has a position. + The positions are used to create relative coordinate grids, + which result in a different set of inputs (keys) for each slot. + """ + + dtype: DType = jnp.float32 + precision: Optional[jax.lax.Precision] = None + epsilon: float = 1e-8 + renormalize_keys: bool = False + attn_weights_only: bool = False + softmax_temperature: float = 1.0 + value_per_query: bool = False + + @nn.compact + def __call__(self, query, key, value, train): + """Computes inverted dot-product attention with key per query. + + Args: + query: Queries with shape of `[batch..., q_num, qk_features]`. + key: Keys with shape of `[batch..., q_num, kv_num, qk_features]`. + value: Values with shape of `[batch..., kv_num, v_features]`. + train: Indicating whether we're training or evaluating. + + Returns: + Tuple of two elements: (1) output of shape + `[batch_size..., q_num, v_features]` and (2) attention mask of shape + `[batch_size..., q_num, kv_num]`. + """ + qk_features = query.shape[-1] + query = query / jnp.sqrt(qk_features).astype(self.dtype) + + # Each query is multiplied with its own set of keys. + attn = jnp.einsum( + "...qd,...qkd->...qk", query, key, precision=self.precision + ) + + # axis=-2 for a softmax over query axis (inverted attention). + attn = jax.nn.softmax( + attn / self.softmax_temperature, axis=-2 + ).astype(self.dtype) + + # We expand dims because the logger expect a #heads dimension. + self.sow("intermediates", "attn", jnp.expand_dims(attn, -3)) + + if self.renormalize_keys: + normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon + attn = attn / normalizer + + if self.attn_weights_only: + return attn + + output = jnp.einsum( + "...qk,...qkd->...qd" if self.value_per_query else "...qk,...kd->...qd", + attn, + value, + precision=self.precision + ) + + return output, attn + + +class SlotAttentionExplicitStats(nn.Module): + """Slot Attention module with explicit slot statistics. + + Slot statistics, such as position and scale, are appended to the + output slot representations. + + Note: This module expects a 2D coordinate grid to be appended + at the end of inputs. + + Note: This module uses pre-normalization by default. + """ + grid_encoder: nn.Module + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + softmax_temperature: float = 1.0 + gumbel_softmax: bool = False + gumbel_softmax_straight_through: bool = False + num_heads: int = 1 + min_scale: float = 0.01 + max_scale: float = 5. + return_slot_positions: bool = True + return_slot_scales: bool = True + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention with explicit slot statistics module forward pass.""" + del padding_mask # Unused. + # Slot scales require slot positions. + assert self.return_slot_positions or not self.return_slot_scales + + # Separate a concatenated linear coordinate grid from the inputs. + inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:] + + # Hack so that the input and output slot dimensions are the same. + to_remove = 0 + if self.return_slot_positions: + to_remove += 2 + if self.return_slot_scales: + to_remove += 2 + if to_remove > 0: + slots = slots[Ellipsis, :-to_remove] + + # Add position encodings to inputs + n_features = inputs.shape[-1] + grid_projector = nn.Dense(n_features, name="dense_pe_0") + inputs = self.grid_encoder()(inputs + grid_projector(grid)) + + qkv_size = self.qkv_size or slots.shape[-1] + head_dim = qkv_size // self.num_heads + dense = functools.partial(nn.DenseGeneral, + axis=-1, features=(self.num_heads, head_dim), + use_bias=False) + + # Shared modules. + dense_q = dense(name="general_dense_q_0") + layernorm_q = nn.LayerNorm() + inverted_attention = attention.InvertedDotProductAttention( + norm_type="mean", + multi_head=self.num_heads > 1, + return_attn_weights=True) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + # inputs.shape = (..., n_inputs, inputs_size). + inputs = nn.LayerNorm()(inputs) + # k.shape = (..., n_inputs, slot_size). + k = dense(name="general_dense_k_0")(inputs) + # v.shape = (..., n_inputs, slot_size). + v = dense(name="general_dense_v_0")(inputs) + + # Multiple rounds of attention. + for _ in range(self.num_iterations): + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size). + updates, attn = inverted_attention(query=q, key=k, value=v, train=train) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + if self.return_slot_positions: + # Compute the center of mass of each slot attention mask. + positions = jnp.einsum("...qk,...kd->...qd", attn, grid) + slots = jnp.concatenate([slots, positions], axis=-1) + + if self.return_slot_scales: + # Compute slot scales. Take the square root to make the operation + # analogous to normalizing data drawn from a Gaussian. + spread = jnp.square( + jnp.expand_dims(grid, axis=-3) - jnp.expand_dims(positions, axis=-2)) + scales = jnp.sqrt( + jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread)) + scales = jnp.clip(scales, self.min_scale, self.max_scale) + slots = jnp.concatenate([slots, scales], axis=-1) + + return slots + + +class SlotAttentionPosKeysValues(nn.Module): + """Slot Attention module with positional encodings in keys and values. + + Feature position encodings are added to keys and values instead + of the inputs. + + Note: This module expects a 2D coordinate grid to be appended + at the end of inputs. + + Note: This module uses pre-normalization by default. + """ + grid_encoder: nn.Module + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + softmax_temperature: float = 1.0 + gumbel_softmax: bool = False + gumbel_softmax_straight_through: bool = False + num_heads: int = 1 + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention with explicit slot statistics module forward pass.""" + del padding_mask # Unused. + + # Separate a concatenated linear coordinate grid from the inputs. + inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:] + + qkv_size = self.qkv_size or slots.shape[-1] + head_dim = qkv_size // self.num_heads + dense = functools.partial(nn.DenseGeneral, + axis=-1, features=(self.num_heads, head_dim), + use_bias=False) + + # Shared modules. + dense_q = dense(name="general_dense_q_0") + layernorm_q = nn.LayerNorm() + inverted_attention = attention.InvertedDotProductAttention( + norm_type="mean", + multi_head=self.num_heads > 1) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + # inputs.shape = (..., n_inputs, inputs_size). + inputs = nn.LayerNorm()(inputs) + # k.shape = (..., n_inputs, slot_size). + k = dense(name="general_dense_k_0")(inputs) + # v.shape = (..., n_inputs, slot_size). + v = dense(name="general_dense_v_0")(inputs) + + # Add position encodings to keys and values. + grid_projector = dense(name="general_dense_p_0") + grid_encoder = self.grid_encoder() + k = grid_encoder(k + grid_projector(grid)) + v = grid_encoder(v + grid_projector(grid)) + + # Multiple rounds of attention. + for _ in range(self.num_iterations): + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size). + updates = inverted_attention(query=q, key=k, value=v, train=train) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + return slots + + +class SlotAttentionTranslEquiv(nn.Module): + """Slot Attention module with slot positions. + + A position is computed for each slot. Slot positions are used to create + relative coordinate grids, which are used as position embeddings reapplied + in each iteration of slot attention. The last two channels in inputs + must contain the flattened position grid. + + Note: This module uses pre-normalization by default. + """ + + grid_encoder: nn.Module + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + softmax_temperature: float = 1.0 + gumbel_softmax: bool = False + gumbel_softmax_straight_through: bool = False + num_heads: int = 1 + zero_position_init: bool = True + ablate_non_equivariant: bool = False + stop_grad_positions: bool = False + mix_slots: bool = False + add_rel_pos_to_values: bool = False + append_statistics: bool = False + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention translation equiv. module forward pass.""" + del padding_mask # Unused. + + if self.num_heads > 1: + raise NotImplementedError("This prototype only uses one attn. head.") + + # Separate a concatenated linear coordinate grid from the inputs. + inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:] + + # Separate position (x,y) from slot embeddings. + slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:] + qkv_size = self.qkv_size or slots.shape[-1] + num_slots = slots.shape[-2] + + # Prepare initial slot positions. + if self.zero_position_init: + # All slots start in the middle of the image. + positions *= 0. + + # Learnable initial positions might deviate from the allowed range. + positions = jnp.clip(positions, -1., 1.) + + # Pre-normalization. + inputs = nn.LayerNorm()(inputs) + + grid_per_slot = jnp.repeat( + jnp.expand_dims(grid, axis=-3), num_slots, axis=-3) + + # Shared modules. + dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0") + dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0") + dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0") + grid_proj = nn.Dense(qkv_size, name="dense_gp_0") + grid_enc = self.grid_encoder() + layernorm_q = nn.LayerNorm() + inverted_attention = InvertedDotProductAttentionKeyPerQuery( + epsilon=self.epsilon, + renormalize_keys=True, + softmax_temperature=self.softmax_temperature, + value_per_query=self.add_rel_pos_to_values + ) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + if self.append_statistics: + embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0") + + # k.shape and v.shape = (..., n_inputs, slot_size). + v = dense_v(inputs) + k = dense_k(inputs) + k_expand = jnp.expand_dims(k, axis=-3) + v_expand = jnp.expand_dims(v, axis=-3) + + # Multiple rounds of attention. Last iteration updates positions only. + for attn_round in range(self.num_iterations + 1): + + if self.ablate_non_equivariant: + # Add an encoded coordinate grid with absolute positions. + grid_emb_per_slot = grid_proj(grid_per_slot) + k_rel_pos = grid_enc(k_expand + grid_emb_per_slot) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + grid_emb_per_slot) + else: + # Relativize positions, encode them and add them to the keys + # and optionally to values. + relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2) + grid_emb_per_slot = grid_proj(relative_grid) + k_rel_pos = grid_enc(k_expand + grid_emb_per_slot) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + grid_emb_per_slot) + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size). + updates, attn = inverted_attention( + query=q, + key=k_rel_pos, + value=v_rel_pos if self.add_rel_pos_to_values else v, + train=train) + + # Compute the center of mass of each slot attention mask. + # Guaranteed to be in [-1, 1]. + positions = jnp.einsum("...qk,...kd->...qd", attn, grid) + + if self.stop_grad_positions: + # Do not backprop through positions and scales. + positions = jax.lax.stop_gradient(positions) + + if attn_round < self.num_iterations: + if self.append_statistics: + # Projects and add 2D slot positions into slot latents. + tmp = jnp.concatenate([slots, positions], axis=-1) + slots = embed_statistics(tmp) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + # Concatenate position information to slots. + output = jnp.concatenate([slots, positions], axis=-1) + + if self.mix_slots: + output = misc.MLP(hidden_size=128, layernorm="pre")(output) + + return output + + +class SlotAttentionTranslScaleEquiv(nn.Module): + """Slot Attention module with slot positions and scales. + + A position and scale is computed for each slot. Slot positions and scales + are used to create relative coordinate grids, which are used as position + embeddings reapplied in each iteration of slot attention. The last two + channels in input must contain the flattened position grid. + + Note: This module uses pre-normalization by default. + """ + + grid_encoder: nn.Module + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + softmax_temperature: float = 1.0 + gumbel_softmax: bool = False + gumbel_softmax_straight_through: bool = False + num_heads: int = 1 + zero_position_init: bool = True + # Scale of 0.1 corresponds to fairly small objects. + init_with_fixed_scale: Optional[float] = 0.1 + ablate_non_equivariant: bool = False + stop_grad_positions_and_scales: bool = False + mix_slots: bool = False + add_rel_pos_to_values: bool = False + scales_factor: float = 1. + # Slot scales cannot be negative and should not be too close to zero + # or too large. + min_scale: float = 0.001 + max_scale: float = 2. + append_statistics: bool = False + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention translation and scale equiv. module forward pass.""" + del padding_mask # Unused. + + if self.num_heads > 1: + raise NotImplementedError("This prototype only uses one attn. head.") + + # Separate a concatenated linear coordinate grid from the inputs. + inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:] + + # Separate position (x,y) and scale from slot embeddings. + slots, positions, scales = (slots[Ellipsis, :-4], + slots[Ellipsis, -4: -2], + slots[Ellipsis, -2:]) + qkv_size = self.qkv_size or slots.shape[-1] + num_slots = slots.shape[-2] + + # Prepare initial slot positions. + if self.zero_position_init: + # All slots start in the middle of the image. + positions *= 0. + + if self.init_with_fixed_scale is not None: + scales = scales * 0. + self.init_with_fixed_scale + + # Learnable initial positions and scales could have arbitrary values. + positions = jnp.clip(positions, -1., 1.) + scales = jnp.clip(scales, self.min_scale, self.max_scale) + + # Pre-normalization. + inputs = nn.LayerNorm()(inputs) + + grid_per_slot = jnp.repeat( + jnp.expand_dims(grid, axis=-3), num_slots, axis=-3) + + # Shared modules. + dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0") + dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0") + dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0") + grid_proj = nn.Dense(qkv_size, name="dense_gp_0") + grid_enc = self.grid_encoder() + layernorm_q = nn.LayerNorm() + inverted_attention = InvertedDotProductAttentionKeyPerQuery( + epsilon=self.epsilon, + renormalize_keys=True, + softmax_temperature=self.softmax_temperature, + value_per_query=self.add_rel_pos_to_values + ) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + if self.append_statistics: + embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0") + + # k.shape and v.shape = (..., n_inputs, slot_size). + v = dense_v(inputs) + k = dense_k(inputs) + k_expand = jnp.expand_dims(k, axis=-3) + v_expand = jnp.expand_dims(v, axis=-3) + + # Multiple rounds of attention. + # Last iteration updates positions and scales only. + for attn_round in range(self.num_iterations + 1): + + if self.ablate_non_equivariant: + # Add an encoded coordinate grid with absolute positions. + tmp_grid = grid_proj(grid_per_slot) + k_rel_pos = grid_enc(k_expand + tmp_grid) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + tmp_grid) + else: + # Relativize and scale positions, encode them and add them to inputs. + relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2) + # Scales are usually small so the grid might get too large. + relative_grid = relative_grid / self.scales_factor + relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2) + tmp_grid = grid_proj(relative_grid) + k_rel_pos = grid_enc(k_expand + tmp_grid) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + tmp_grid) + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size). + updates, attn = inverted_attention( + query=q, + key=k_rel_pos, + value=v_rel_pos if self.add_rel_pos_to_values else v, + train=train) + + # Compute the center of mass of each slot attention mask. + positions = jnp.einsum("...qk,...kd->...qd", attn, grid) + + # Compute slot scales. Take the square root to make the operation + # analogous to normalizing data drawn from a Gaussian. + spread = jnp.square(grid_per_slot - jnp.expand_dims(positions, axis=-2)) + scales = jnp.sqrt( + jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread)) + + # Computed positions are guaranteed to be in [-1, 1]. + # Scales are unbounded. + scales = jnp.clip(scales, self.min_scale, self.max_scale) + + if self.stop_grad_positions_and_scales: + # Do not backprop through positions and scales. + positions = jax.lax.stop_gradient(positions) + scales = jax.lax.stop_gradient(scales) + + if attn_round < self.num_iterations: + if self.append_statistics: + # Project and add 2D slot positions and scales into slot latents. + tmp = jnp.concatenate([slots, positions, scales], axis=-1) + slots = embed_statistics(tmp) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + # Concatenate position and scale information to slots. + output = jnp.concatenate([slots, positions, scales], axis=-1) + + if self.mix_slots: + output = misc.MLP(hidden_size=128, layernorm="pre")(output) + + return output + + +class SlotAttentionTranslRotScaleEquiv(nn.Module): + """Slot Attention module with slot positions, rotations and scales. + + A position, rotation and scale is computed for each slot. + Slot positions, rotations and scales are used to create relative + coordinate grids, which are used as position embeddings reapplied in each + iteration of slot attention. The last two channels in input must contain + the flattened position grid. + + Note: This module uses pre-normalization by default. + """ + + grid_encoder: nn.Module + num_iterations: int = 1 + qkv_size: Optional[int] = None + mlp_size: Optional[int] = None + epsilon: float = 1e-8 + softmax_temperature: float = 1.0 + gumbel_softmax: bool = False + gumbel_softmax_straight_through: bool = False + num_heads: int = 1 + zero_position_init: bool = True + # Scale of 0.1 corresponds to fairly small objects. + init_with_fixed_scale: Optional[float] = 0.1 + ablate_non_equivariant: bool = False + stop_grad_positions: bool = False + stop_grad_scales: bool = False + stop_grad_rotations: bool = False + mix_slots: bool = False + add_rel_pos_to_values: bool = False + scales_factor: float = 1. + # Slot scales cannot be negative and should not be too close to zero + # or too large. + min_scale: float = 0.001 + max_scale: float = 2. + limit_rot_to_45_deg: bool = True + append_statistics: bool = False + + @nn.compact + def __call__(self, slots, inputs, + padding_mask = None, + train = False): + """Slot Attention translation and scale equiv. module forward pass.""" + del padding_mask # Unused. + + if self.num_heads > 1: + raise NotImplementedError("This prototype only uses one attn. head.") + + # Separate a concatenated linear coordinate grid from the inputs. + inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:] + + # Separate position (x,y) and scale from slot embeddings. + slots, positions, scales, rotm = (slots[Ellipsis, :-8], + slots[Ellipsis, -8: -6], + slots[Ellipsis, -6: -4], + slots[Ellipsis, -4:]) + rotm = jnp.reshape(rotm, (*rotm.shape[:-1], 2, 2)) + qkv_size = self.qkv_size or slots.shape[-1] + num_slots = slots.shape[-2] + + # Prepare initial slot positions. + if self.zero_position_init: + # All slots start in the middle of the image. + positions *= 0. + + if self.init_with_fixed_scale is not None: + scales = scales * 0. + self.init_with_fixed_scale + + # Learnable initial positions and scales could have arbitrary values. + positions = jnp.clip(positions, -1., 1.) + scales = jnp.clip(scales, self.min_scale, self.max_scale) + + # Pre-normalization. + inputs = nn.LayerNorm()(inputs) + + grid_per_slot = jnp.repeat( + jnp.expand_dims(grid, axis=-3), num_slots, axis=-3) + + # Shared modules. + dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0") + dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0") + dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0") + grid_proj = nn.Dense(qkv_size, name="dense_gp_0") + grid_enc = self.grid_encoder() + layernorm_q = nn.LayerNorm() + inverted_attention = InvertedDotProductAttentionKeyPerQuery( + epsilon=self.epsilon, + renormalize_keys=True, + softmax_temperature=self.softmax_temperature, + value_per_query=self.add_rel_pos_to_values + ) + gru = misc.GRU() + + if self.mlp_size is not None: + mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore + + if self.append_statistics: + embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0") + + # k.shape and v.shape = (..., n_inputs, slot_size). + v = dense_v(inputs) + k = dense_k(inputs) + k_expand = jnp.expand_dims(k, axis=-3) + v_expand = jnp.expand_dims(v, axis=-3) + + # Multiple rounds of attention. + # Last iteration updates positions and scales only. + for attn_round in range(self.num_iterations + 1): + + if self.ablate_non_equivariant: + # Add an encoded coordinate grid with absolute positions. + tmp_grid = grid_proj(grid_per_slot) + k_rel_pos = grid_enc(k_expand + tmp_grid) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + tmp_grid) + else: + # Relativize and scale positions, encode them and add them to inputs. + relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2) + + # Rotation. + relative_grid = self.transform(rotm, relative_grid) + + # Scales are usually small so the grid might get too large. + relative_grid = relative_grid / self.scales_factor + relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2) + tmp_grid = grid_proj(relative_grid) + k_rel_pos = grid_enc(k_expand + tmp_grid) + if self.add_rel_pos_to_values: + v_rel_pos = grid_enc(v_expand + tmp_grid) + + # Inverted dot-product attention. + slots_n = layernorm_q(slots) + q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size). + updates, attn = inverted_attention( + query=q, + key=k_rel_pos, + value=v_rel_pos if self.add_rel_pos_to_values else v, + train=train) + + # Compute the center of mass of each slot attention mask. + positions = jnp.einsum("...qk,...kd->...qd", attn, grid) + + # Find the axis with the highest spread. + relp = grid_per_slot - jnp.expand_dims(positions, axis=-2) + if self.limit_rot_to_45_deg: + rotm = self.compute_rotation_matrix_45_deg(relp, attn) + else: + rotm = self.compute_rotation_matrix_90_deg(relp, attn) + + # Compute slot scales. Take the square root to make the operation + # analogous to normalizing data drawn from a Gaussian. + relp = self.transform(rotm, relp) + + spread = jnp.square(relp) + scales = jnp.sqrt( + jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread)) + + # Computed positions are guaranteed to be in [-1, 1]. + # Scales are unbounded. + scales = jnp.clip(scales, self.min_scale, self.max_scale) + + if self.stop_grad_positions: + positions = jax.lax.stop_gradient(positions) + if self.stop_grad_scales: + scales = jax.lax.stop_gradient(scales) + if self.stop_grad_rotations: + rotm = jax.lax.stop_gradient(rotm) + + if attn_round < self.num_iterations: + if self.append_statistics: + # For the slot rotations, we append both the 2D rotation matrix + # and the angle by which we rotate. + # We can compute the angle using atan2(R[0, 0], R[1, 0]). + tmp = jnp.concatenate( + [slots, positions, scales, + rotm.reshape(*rotm.shape[:-2], 4), + jnp.arctan2(rotm[Ellipsis, 0, 0], rotm[Ellipsis, 1, 0])[Ellipsis, None]], + axis=-1) + slots = embed_statistics(tmp) + + # Recurrent update. + slots = gru(slots, updates) + + # Feedforward block with pre-normalization. + if self.mlp_size is not None: + slots = mlp(slots) + + # Concatenate position and scale information to slots. + output = jnp.concatenate( + [slots, positions, scales, rotm.reshape(*rotm.shape[:-2], 4)], axis=-1) + + if self.mix_slots: + output = misc.MLP(hidden_size=128, layernorm="pre")(output) + + return output + + @classmethod + def compute_weighted_covariance(cls, x, w): + # The coordinate grid is (y, x), we want (x, y). + x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1) + + # Pixel coordinates weighted by attention mask. + cov = x * w[Ellipsis, None] + cov = jnp.einsum( + "...ji,...jk->...ik", cov, x, precision=jax.lax.Precision.HIGHEST) + + return cov + + @classmethod + def compute_reference_frame_45_deg(cls, x, w): + cov = cls.compute_weighted_covariance(x, w) + + # Compute eigenvalues. + pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) + + jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16) + + eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2. + eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2. + + # Compute eigenvectors, note that both have a positive y-axis. + # This means we have eliminated half of the possible rotations. + div = cov[Ellipsis, 0, 1] + 1e-16 + + v1 = (eig1 - cov[Ellipsis, 1, 1]) / div + v2 = (eig2 - cov[Ellipsis, 1, 1]) / div + + v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1) + v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1) + + # RULE 1: + # We catch two failure modes here. + # 1. If all attention weights are zero the covariance is also zero. + # Then the above computation is meaningless. + # 2. If the attention pattern is exactly aligned with the axes + # (e.g. a horizontal/vertical bar), the off-diagonal covariance + # values are going to be very low. If we use float32, we get + # basis vectors that are not orthogonal. + # Solution: use the default reference frame if the off-diagonal + # covariance value is too low. + default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1) + default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1) + + mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None] + v1 = (1. - mask) * v1 + mask * default_1 + v2 = (1. - mask) * v2 + mask * default_2 + + # Turn eigenvectors into unit vectors, so that we can construct + # a basis of a new reference frame. + norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True)) + norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True)) + + v1 = v1 / norm1 + v2 = v2 / norm2 + + # RULE 2: + # If the first basis vector is "pointing up" we assume the object + # is vertical (e.g. we say a door is vertical, whereas a car is horizontal). + # In the case of vertical objects, we swap the two basis vectors. + # This limits the possible rotations to +- 45deg instead of +- 90deg. + # We define "pointing up" as the first coordinate of the first basis vector + # being between +- sin(pi/4). The second coordinate is always positive. + mask = (jnp.logical_and(v1[Ellipsis, 0] < 0.707, v1[Ellipsis, 0] > -0.707) + ).astype(jnp.float32)[Ellipsis, None] + v1_ = (1. - mask) * v1 + mask * v2 + v2_ = (1. - mask) * v2 + mask * v1 + v1 = v1_ + v2 = v2_ + + # RULE 3: + # Mirror the first basis vector if the first coordinate is negative. + # Here, we ensure that our coordinate system is always left-handed. + # Otherwise, we would sometimes unintentionally mirror the grid. + mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None] + v1 = (1. - mask) * v1 - mask * v1 + + return v1, v2 + + @classmethod + def compute_reference_frame_90_deg(cls, x, w): + cov = cls.compute_weighted_covariance(x, w) + + # Compute eigenvalues. + pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) + + jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16) + + eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2. + eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2. + + # Compute eigenvectors, note that both have a positive y-axis. + # This means we have eliminated half of the possible rotations. + div = cov[Ellipsis, 0, 1] + 1e-16 + + v1 = (eig1 - cov[Ellipsis, 1, 1]) / div + v2 = (eig2 - cov[Ellipsis, 1, 1]) / div + + v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1) + v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1) + + # RULE 1: + # We catch two failure modes here. + # 1. If all attention weights are zero the covariance is also zero. + # Then the above computation is meaningless. + # 2. If the attention pattern is exactly aligned with the axes + # (e.g. a horizontal/vertical bar), the off-diagonal covariance + # values are going to be very low. If we use float32, we get + # basis vectors that are not orthogonal. + # Solution: use the default reference frame if the off-diagonal + # covariance value is too low. + default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1) + default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1) + + # RULE 1.5: + # RULE 1 is activated if we see a vertical or a horizontal bar. + # We make sure that the coordinate grid for a horizontal bar is not rotated, + # whereas the coordinate grid for a vertical bar is rotated by 90deg. + # If cov[0, 0] > cov[1, 1], the bar is vertical. + mask = (cov[Ellipsis, 0, 0] <= cov[Ellipsis, 1, 1]).astype(jnp.float32)[Ellipsis, None] + # Furthermore, we have to mirror one of the basis vectors (if mask==1) + # so that we always have a left-handed coordinate grid. + default_v1 = (1. - mask) * default_1 - mask * default_2 + default_v2 = (1. - mask) * default_2 + mask * default_1 + + # Continuation of RULE 1. + mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None] + v1 = mask * default_v1 + (1. - mask) * v1 + v2 = mask * default_v2 + (1. - mask) * v2 + + # Turn eigenvectors into unit vectors, so that we can construct + # a basis of a new reference frame. + norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True)) + norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True)) + + v1 = v1 / norm1 + v2 = v2 / norm2 + + # RULE 2: + # Mirror the first basis vector if the first coordinate is negative. + # Here, we ensure that the our coordinate system is always left-handed. + # Otherwise, we would sometimes unintentionally mirror the grid. + mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None] + v1 = (1. - mask) * v1 - mask * v1 + + return v1, v2 + + @classmethod + def compute_rotation_matrix_45_deg(cls, x, w): + v1, v2 = cls.compute_reference_frame_45_deg(x, w) + return jnp.stack([v1, v2], axis=-1) + + @classmethod + def compute_rotation_matrix_90_deg(cls, x, w): + v1, v2 = cls.compute_reference_frame_90_deg(x, w) + return jnp.stack([v1, v2], axis=-1) + + @classmethod + def transform(cls, rotm, x): + # The coordinate grid x is in the (y, x) format, so we need to swap + # the coordinates on the input and output. + x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1) + # Equivalent to inv(R) * x^T = R^T * x^T = (x * R)^T. + # We are multiplying by the inverse of the rotation matrix because + # we are rotating the coordinate grid *against* the rotation of the object. + # y = jnp.matmul(x, R) + y = jnp.einsum("...ij,...jk->...ik", x, rotm) + # Swap coordinates again. + y = jnp.stack([y[Ellipsis, 1], y[Ellipsis, 0]], axis=-1) + return y diff --git a/invariant_slot_attention/modules/invariant_initializers.py b/invariant_slot_attention/modules/invariant_initializers.py new file mode 100644 index 0000000000000000000000000000000000000000..1715e21cf3c7f3b0b243a65dc3c70762cc7ae1c0 --- /dev/null +++ b/invariant_slot_attention/modules/invariant_initializers.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Initializers module library for equivariant slot attention.""" + +import functools +from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from flax import linen as nn +import jax +import jax.numpy as jnp +from invariant_slot_attention.lib import utils + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +def get_uniform_initializer(vmin, vmax): + """Get an uniform initializer with an arbitrary range.""" + init = nn.initializers.uniform(scale=vmax - vmin) + + def fn(*args, **kwargs): + return init(*args, **kwargs) + vmin + + return fn + + +def get_normal_initializer(mean, sd): + """Get a normal initializer with an arbitrary mean.""" + init = nn.initializers.normal(stddev=sd) + + def fn(*args, **kwargs): + return init(*args, **kwargs) + mean + + return fn + + +class ParamStateInitRandomPositions(nn.Module): + """Fixed, learnable state initalization with random positions. + + Random slot positions sampled from U[-1, 1] are concatenated + as the last two dimensions. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + param = self.param("state_init", init_fn(), self.shape) + + out = utils.broadcast_across_batch(param, batch_size=batch_size) + shape = out.shape[:-1] + rng = self.make_rng("state_init") + slot_positions = jax.random.uniform( + rng, shape=[*shape, 2], minval=self.slot_positions_min, + maxval=self.slot_positions_max) + out = jnp.concatenate((out, slot_positions), axis=-1) + return out + + +class ParamStateInitLearnablePositions(nn.Module): + """Fixed, learnable state initalization with learnable positions. + + Learnable initial positions are concatenated at the end of slots. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn_state = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn_state = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + init_fn_state = init_fn_state() + init_fn_pos = get_uniform_initializer( + self.slot_positions_min, self.slot_positions_max) + + param_state = self.param("state_init", init_fn_state, self.shape) + param_pos = self.param( + "state_init_position", init_fn_pos, (*self.shape[:-1], 2)) + + param = jnp.concatenate((param_state, param_pos), axis=-1) + + return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray + + +class ParamStateInitRandomPositionsScales(nn.Module): + """Fixed, learnable state initalization with random positions and scales. + + Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1) + are concatenated as the last four dimensions. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + slot_scales_mean: float = 0.1 + slot_scales_sd: float = 0.1 + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + param = self.param("state_init", init_fn(), self.shape) + + out = utils.broadcast_across_batch(param, batch_size=batch_size) + shape = out.shape[:-1] + rng = self.make_rng("state_init") + slot_positions = jax.random.uniform( + rng, shape=[*shape, 2], minval=self.slot_positions_min, + maxval=self.slot_positions_max) + slot_scales = jax.random.normal(rng, shape=[*shape, 2]) + slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales + out = jnp.concatenate((out, slot_positions, slot_scales), axis=-1) + return out + + +class ParamStateInitLearnablePositionsScales(nn.Module): + """Fixed, learnable state initalization with random positions and scales. + + Lernable initial positions and scales are concatenated at the end of slots. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + slot_scales_mean: float = 0.1 + slot_scales_sd: float = 0.01 + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn_state = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn_state = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + init_fn_state = init_fn_state() + init_fn_pos = get_uniform_initializer( + self.slot_positions_min, self.slot_positions_max) + init_fn_scales = get_normal_initializer( + self.slot_scales_mean, self.slot_scales_sd) + + param_state = self.param("state_init", init_fn_state, self.shape) + param_pos = self.param( + "state_init_position", init_fn_pos, (*self.shape[:-1], 2)) + param_scales = self.param( + "state_init_scale", init_fn_scales, (*self.shape[:-1], 2)) + + param = jnp.concatenate((param_state, param_pos, param_scales), axis=-1) + + return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray + + +class ParamStateInitLearnablePositionsRotationsScales(nn.Module): + """Fixed, learnable state initalization. + + Learnable initial positions, rotations and scales are concatenated + at the end of slots. The rotation matrix is flattened. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + slot_scales_mean: float = 0.1 + slot_scales_sd: float = 0.01 + slot_angles_mean: float = 0. + slot_angles_sd: float = 0.1 + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn_state = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn_state = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + init_fn_state = init_fn_state() + init_fn_pos = get_uniform_initializer( + self.slot_positions_min, self.slot_positions_max) + init_fn_scales = get_normal_initializer( + self.slot_scales_mean, self.slot_scales_sd) + init_fn_angles = get_normal_initializer( + self.slot_angles_mean, self.slot_angles_sd) + + param_state = self.param("state_init", init_fn_state, self.shape) + param_pos = self.param( + "state_init_position", init_fn_pos, (*self.shape[:-1], 2)) + param_scales = self.param( + "state_init_scale", init_fn_scales, (*self.shape[:-1], 2)) + param_angles = self.param( + "state_init_angles", init_fn_angles, (*self.shape[:-1], 1)) + + # Initial angles in the range of (-pi / 4, pi / 4) <=> (-45, 45) degrees. + angles = jnp.tanh(param_angles) * (jnp.pi / 4) + rotm = jnp.concatenate( + [jnp.cos(angles), jnp.sin(angles), + -jnp.sin(angles), jnp.cos(angles)], axis=-1) + + param = jnp.concatenate( + (param_state, param_pos, param_scales, rotm), axis=-1) + + return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray + + +class ParamStateInitRandomPositionsRotationsScales(nn.Module): + """Fixed, learnable state initialization with random pos., rot. and scales. + + Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1) + are concatenated as the last four dimensions. Rotations are sampled + from +- 45 degrees. + Note: This module ignores any conditional input (by design). + """ + + shape: Sequence[int] + init_fn: str = "normal" # Default init with unit variance. + conditioning_key: Optional[str] = None + slot_positions_min: float = -1. + slot_positions_max: float = 1. + slot_scales_mean: float = 0.1 + slot_scales_sd: float = 0.1 + slot_angles_min: float = -jnp.pi / 4. + slot_angles_max: float = jnp.pi / 4. + + @nn.compact + def __call__(self, inputs, batch_size, + train = False): + del inputs, train # Unused. + + if self.init_fn == "normal": + init_fn = functools.partial(nn.initializers.normal, stddev=1.) + elif self.init_fn == "zeros": + init_fn = lambda: nn.initializers.zeros + else: + raise ValueError("Unknown init_fn: {}.".format(self.init_fn)) + + param = self.param("state_init", init_fn(), self.shape) + + out = utils.broadcast_across_batch(param, batch_size=batch_size) + shape = out.shape[:-1] + rng = self.make_rng("state_init") + slot_positions = jax.random.uniform( + rng, shape=[*shape, 2], minval=self.slot_positions_min, + maxval=self.slot_positions_max) + rng = self.make_rng("state_init") + slot_scales = jax.random.normal(rng, shape=[*shape, 2]) + slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales + rng = self.make_rng("state_init") + slot_angles = jax.random.uniform(rng, shape=[*shape, 1]) + slot_angles = (slot_angles * (self.slot_angles_max - self.slot_angles_min) + ) + self.slot_angles_min + slot_rotm = jnp.concatenate( + [jnp.cos(slot_angles), jnp.sin(slot_angles), + -jnp.sin(slot_angles), jnp.cos(slot_angles)], axis=-1) + out = jnp.concatenate( + (out, slot_positions, slot_scales, slot_rotm), axis=-1) + return out diff --git a/invariant_slot_attention/modules/misc.py b/invariant_slot_attention/modules/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e296f881160c0f4cff86e4d86d3cb729978e206c --- /dev/null +++ b/invariant_slot_attention/modules/misc.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Miscellaneous modules.""" + +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from flax import linen as nn +import jax +import jax.numpy as jnp + +from invariant_slot_attention.lib import utils + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class Identity(nn.Module): + """Module that applies the identity function, ignoring any additional args.""" + + @nn.compact + def __call__(self, inputs, **args): + return inputs + + +class Readout(nn.Module): + """Module for reading out multiple targets from an embedding.""" + + keys: Sequence[str] + readout_modules: Sequence[Callable[[], nn.Module]] + stop_gradient: Optional[Sequence[bool]] = None + + @nn.compact + def __call__(self, inputs, train = False): + num_targets = len(self.keys) + assert num_targets >= 1, "Need to have at least one target." + assert len(self.readout_modules) == num_targets, ( + "len(modules) and len(keys) must match.") + if self.stop_gradient is not None: + assert len(self.stop_gradient) == num_targets, ( + "len(stop_gradient) and len(keys) must match.") + outputs = {} + for i in range(num_targets): + if self.stop_gradient is not None and self.stop_gradient[i]: + x = jax.lax.stop_gradient(inputs) + else: + x = inputs + outputs[self.keys[i]] = self.readout_modules[i]()(x, train) # pytype: disable=not-callable + return outputs + + +class MLP(nn.Module): + """Simple MLP with one hidden layer and optional pre-/post-layernorm.""" + + hidden_size: int + output_size: Optional[int] = None + num_hidden_layers: int = 1 + activation_fn: Callable[[Array], Array] = nn.relu + layernorm: Optional[str] = None + activate_output: bool = False + residual: bool = False + + @nn.compact + def __call__(self, inputs, train = False): + del train # Unused. + + output_size = self.output_size or inputs.shape[-1] + + x = inputs + + if self.layernorm == "pre": + x = nn.LayerNorm()(x) + + for i in range(self.num_hidden_layers): + x = nn.Dense(self.hidden_size, name=f"dense_mlp_{i}")(x) + x = self.activation_fn(x) + x = nn.Dense(output_size, name=f"dense_mlp_{self.num_hidden_layers}")(x) + + if self.activate_output: + x = self.activation_fn(x) + + if self.residual: + x = x + inputs + + if self.layernorm == "post": + x = nn.LayerNorm()(x) + + return x + + +class GRU(nn.Module): + """GRU cell as nn.Module.""" + + @nn.compact + def __call__(self, carry, inputs, + train = False): + del train # Unused. + carry, _ = nn.GRUCell()(carry, inputs) + return carry + + +class Dense(nn.Module): + """Dense layer as nn.Module accepting "train" flag.""" + + features: int + use_bias: bool = True + + @nn.compact + def __call__(self, inputs, train = False): + del train # Unused. + return nn.Dense(features=self.features, use_bias=self.use_bias)(inputs) + + +class PositionEmbedding(nn.Module): + """A module for applying N-dimensional position embedding. + + Attr: + embedding_type: A string defining the type of position embedding to use. One + of ["linear", "discrete_1d", "fourier", "gaussian_fourier"]. + update_type: A string defining how the input is updated with the position + embedding. One of ["proj_add", "concat"]. + num_fourier_bases: The number of Fourier bases to use. For embedding_type == + "fourier", the embedding dimensionality is 2 x number of position + dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier", + the embedding dimensionality is 2 x num_fourier_bases. For embedding_type + == "linear", this parameter is ignored. + gaussian_sigma: Standard deviation of sampled Gaussians. + pos_transform: Optional transform for the embedding. + output_transform: Optional transform for the combined input and embedding. + trainable_pos_embedding: Boolean flag for allowing gradients to flow into + the position embedding, so that the optimizer can update it. + """ + + embedding_type: str + update_type: str + num_fourier_bases: int = 0 + gaussian_sigma: float = 1.0 + pos_transform: Callable[[], nn.Module] = Identity + output_transform: Callable[[], nn.Module] = Identity + trainable_pos_embedding: bool = False + + def _make_pos_embedding_tensor(self, rng, input_shape): + if self.embedding_type == "discrete_1d": + # An integer tensor in [0, input_shape[-2]-1] reflecting + # 1D discrete position encoding (encode the second-to-last axis). + pos_embedding = jnp.broadcast_to( + jnp.arange(input_shape[-2]), input_shape[1:-1]) + else: + # A tensor grid in [-1, +1] for each input dimension. + pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0]) + + if self.embedding_type == "linear": + pass + elif self.embedding_type == "discrete_1d": + pos_embedding = jax.nn.one_hot(pos_embedding, input_shape[-2]) + elif self.embedding_type == "fourier": + # NeRF-style Fourier/sinusoidal position encoding. + pos_embedding = utils.convert_to_fourier_features( + pos_embedding * jnp.pi, basis_degree=self.num_fourier_bases) + elif self.embedding_type == "gaussian_fourier": + # Gaussian Fourier features. Reference: https://arxiv.org/abs/2006.10739 + num_dims = pos_embedding.shape[-1] + projection = jax.random.normal( + rng, [num_dims, self.num_fourier_bases]) * self.gaussian_sigma + pos_embedding = jnp.pi * pos_embedding.dot(projection) + # A slightly faster implementation of sin and cos. + pos_embedding = jnp.sin( + jnp.concatenate([pos_embedding, pos_embedding + 0.5 * jnp.pi], + axis=-1)) + else: + raise ValueError("Invalid embedding type provided.") + + # Add batch dimension. + pos_embedding = jnp.expand_dims(pos_embedding, axis=0) + + return pos_embedding + + @nn.compact + def __call__(self, inputs): + + # Compute the position embedding only in the initial call use the same rng + # as is used for initializing learnable parameters. + pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor, + inputs.shape) + + if not self.trainable_pos_embedding: + pos_embedding = jax.lax.stop_gradient(pos_embedding) + + # Apply optional transformation on the position embedding. + pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable + + # Apply position encoding to inputs. + if self.update_type == "project_add": + # Here, we project the position encodings to the same dimensionality as + # the inputs and add them to the inputs (broadcast along batch dimension). + # This is roughly equivalent to concatenation of position encodings to the + # inputs (if followed by a Dense layer), but is slightly more efficient. + n_features = inputs.shape[-1] + x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding) + elif self.update_type == "concat": + # Repeat the position embedding along the first (batch) dimension. + pos_embedding = jnp.broadcast_to( + pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:]) + # concatenate along the channel dimension. + x = jnp.concatenate((inputs, pos_embedding), axis=-1) + else: + raise ValueError("Invalid update type provided.") + + # Apply optional output transformation. + x = self.output_transform()(x) # pytype: disable=not-callable + return x + + +class RelativePositionEmbedding(nn.Module): + """A module for applying embedding of input position relative to slots. + + Attr + update_type: A string defining how the input is updated with the position + embedding. One of ["proj_add", "concat"]. + embedding_type: A string defining the type of position embedding to use. + Currently only "linear" is supported. + num_fourier_bases: The number of Fourier bases to use. For embedding_type == + "fourier", the embedding dimensionality is 2 x number of position + dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier", + the embedding dimensionality is 2 x num_fourier_bases. For embedding_type + == "linear", this parameter is ignored. + gaussian_sigma: Standard deviation of sampled Gaussians. + pos_transform: Optional transform for the embedding. + output_transform: Optional transform for the combined input and embedding. + trainable_pos_embedding: Boolean flag for allowing gradients to flow into + the position embedding, so that the optimizer can update it. + """ + + update_type: str + embedding_type: str = "linear" + num_fourier_bases: int = 0 + gaussian_sigma: float = 1.0 + pos_transform: Callable[[], nn.Module] = Identity + output_transform: Callable[[], nn.Module] = Identity + trainable_pos_embedding: bool = False + scales_factor: float = 1.0 + + def _make_pos_embedding_tensor(self, rng, input_shape): + + # A tensor grid in [-1, +1] for each input dimension. + pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0]) + + # Add batch dimension. + pos_embedding = jnp.expand_dims(pos_embedding, axis=0) + + return pos_embedding + + @nn.compact + def __call__(self, inputs, slot_positions, + slot_scales = None, + slot_rotm = None): + + # Compute the position embedding only in the initial call use the same rng + # as is used for initializing learnable parameters. + pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor, + inputs.shape) + + if not self.trainable_pos_embedding: + pos_embedding = jax.lax.stop_gradient(pos_embedding) + + # Relativize pos_embedding with respect to slot positions + # and optionally slot scales. + slot_positions = jnp.expand_dims( + jnp.expand_dims(slot_positions, axis=-2), axis=-2) + if slot_scales is not None: + slot_scales = jnp.expand_dims( + jnp.expand_dims(slot_scales, axis=-2), axis=-2) + + if self.embedding_type == "linear": + pos_embedding = pos_embedding - slot_positions + if slot_rotm is not None: + pos_embedding = self.transform(slot_rotm, pos_embedding) + if slot_scales is not None: + # Scales are usually small so the grid might get too large. + pos_embedding = pos_embedding / self.scales_factor + pos_embedding = pos_embedding / slot_scales + else: + raise ValueError("Invalid embedding type provided.") + + # Apply optional transformation on the position embedding. + pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable + + # Define intermediate for logging. + pos_embedding = Identity(name="pos_emb")(pos_embedding) + + # Apply position encoding to inputs. + if self.update_type == "project_add": + # Here, we project the position encodings to the same dimensionality as + # the inputs and add them to the inputs (broadcast along batch dimension). + # This is roughly equivalent to concatenation of position encodings to the + # inputs (if followed by a Dense layer), but is slightly more efficient. + n_features = inputs.shape[-1] + x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding) + elif self.update_type == "concat": + # Repeat the position embedding along the first (batch) dimension. + pos_embedding = jnp.broadcast_to( + pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:]) + # concatenate along the channel dimension. + x = jnp.concatenate((inputs, pos_embedding), axis=-1) + else: + raise ValueError("Invalid update type provided.") + + # Apply optional output transformation. + x = self.output_transform()(x) # pytype: disable=not-callable + return x + + @classmethod + def transform(cls, rot, coords): + # The coordinate grid coords is in the (y, x) format, so we need to swap + # the coordinates on the input and output. + coords = jnp.stack([coords[Ellipsis, 1], coords[Ellipsis, 0]], axis=-1) + # Equivalent to inv(R) * coords^T = R^T * coords^T = (coords * R)^T. + # We are multiplying by the inverse of the rotation matrix because + # we are rotating the coordinate grid *against* the rotation of the object. + new_coords = jnp.einsum("...hij,...jk->...hik", coords, rot) + # Swap coordinates again. + return jnp.stack([new_coords[Ellipsis, 1], new_coords[Ellipsis, 0]], axis=-1) diff --git a/invariant_slot_attention/modules/resnet.py b/invariant_slot_attention/modules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..08088c59c5c5eb3a63b67bc5ee782dddd36813c6 --- /dev/null +++ b/invariant_slot_attention/modules/resnet.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of ResNet V1 in Flax. + +"Deep Residual Learning for Image Recognition" +He et al., 2015, [https://arxiv.org/abs/1512.03385] +""" + +import functools + +from typing import Any, Tuple, Type, List, Optional, Callable, Sequence +import flax.linen as nn +import jax.numpy as jnp + + +Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False) +Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False) + + +class ResNetBlock(nn.Module): + """ResNet block without bottleneck used in ResNet-18 and ResNet-34.""" + + filters: int + norm: Any + kernel_dilation: Tuple[int, int] = (1, 1) + strides: Tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + + x = Conv3x3( + self.filters, + strides=self.strides, + kernel_dilation=self.kernel_dilation, + name="conv1")(x) + x = self.norm(name="bn1")(x) + x = nn.relu(x) + x = Conv3x3(self.filters, name="conv2")(x) + # Initializing the scale to 0 has been common practice since "Fixup + # Initialization: Residual Learning Without Normalization" Tengyu et al, + # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. + x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x) + + if residual.shape != x.shape: + residual = Conv1x1( + self.filters, strides=self.strides, name="proj_conv")( + residual) + residual = self.norm(name="proj_bn")(residual) + + x = nn.relu(residual + x) + return x + + +class BottleneckResNetBlock(ResNetBlock): + """Bottleneck ResNet block used in ResNet-50 and larger.""" + + @nn.compact + def __call__(self, x): + residual = x + + x = Conv1x1(self.filters, name="conv1")(x) + x = self.norm(name="bn1")(x) + x = nn.relu(x) + x = Conv3x3( + self.filters, + strides=self.strides, + kernel_dilation=self.kernel_dilation, + name="conv2")(x) + x = self.norm(name="bn2")(x) + x = nn.relu(x) + x = Conv1x1(4 * self.filters, name="conv3")(x) + # Initializing the scale to 0 has been common practice since "Fixup + # Initialization: Residual Learning Without Normalization" Tengyu et al, + # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. + x = self.norm(name="bn3")(x) + + if residual.shape != x.shape: + residual = Conv1x1( + 4 * self.filters, strides=self.strides, name="proj_conv")( + residual) + residual = self.norm(name="proj_bn")(residual) + + x = nn.relu(residual + x) + return x + + +class ResNetStage(nn.Module): + """ResNet stage consistent of multiple ResNet blocks.""" + + stage_size: int + filters: int + block_cls: Type[ResNetBlock] + norm: Any + first_block_strides: Tuple[int, int] + + @nn.compact + def __call__(self, x): + for i in range(self.stage_size): + x = self.block_cls( + filters=self.filters, + norm=self.norm, + strides=self.first_block_strides if i == 0 else (1, 1), + name=f"block{i + 1}")( + x) + return x + + +class ResNet(nn.Module): + """Construct ResNet V1 with `num_classes` outputs. + + Attributes: + num_classes: Number of nodes in the final layer. + block_cls: Class for the blocks. ResNet-50 and larger use + `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and + ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions). + stage_sizes: List with the number of ResNet blocks in each stage. Number of + stages can be varied. + norm_type: Which type of normalization layer to apply. Options are: + "batch": BatchNorm, "group": GroupNorm, "layer": LayerNorm. Defaults to + BatchNorm. + width_factor: Factor applied to the number of filters. The 64 * width_factor + is the number of filters in the first stage, every consecutive stage + doubles the number of filters. + small_inputs: Bool, if True, ignore strides and skip max pooling in the root + block and use smaller filter size. + stage_strides: Stride per stage. This overrides all other arguments. + include_top: Whether to include the fully-connected layer at the top + of the network. + axis_name: Axis name over which to aggregate batchnorm statistics. + """ + num_classes: int + block_cls: Type[ResNetBlock] + stage_sizes: List[int] + norm_type: str = "batch" + width_factor: int = 1 + small_inputs: bool = False + stage_strides: Optional[List[Tuple[int, int]]] = None + include_top: bool = False + axis_name: Optional[str] = None + output_initializer: Callable[[Any, Sequence[int], Any], Any] = ( + nn.initializers.zeros) + + @nn.compact + def __call__(self, x, *, train): + """Apply the ResNet to the inputs `x`. + + Args: + x: Inputs. + train: Whether to use BatchNorm in training or inference mode. + + Returns: + The output head with `num_classes` entries. + """ + width = 64 * self.width_factor + + if self.norm_type == "batch": + norm = functools.partial( + nn.BatchNorm, use_running_average=not train, momentum=0.9, + axis_name=self.axis_name) + elif self.norm_type == "layer": + norm = nn.LayerNorm + elif self.norm_type == "group": + norm = nn.GroupNorm + else: + raise ValueError(f"Invalid norm_type: {self.norm_type}") + + # Root block. + x = nn.Conv( + features=width, + kernel_size=(7, 7) if not self.small_inputs else (3, 3), + strides=(2, 2) if not self.small_inputs else (1, 1), + use_bias=False, + name="init_conv")( + x) + x = norm(name="init_bn")(x) + + if not self.small_inputs: + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + + # Stages. + for i, stage_size in enumerate(self.stage_sizes): + if i == 0: + first_block_strides = ( + 1, 1) if self.stage_strides is None else self.stage_strides[i] + else: + first_block_strides = ( + 2, 2) if self.stage_strides is None else self.stage_strides[i] + + x = ResNetStage( + stage_size, + filters=width * 2**i, + block_cls=self.block_cls, + norm=norm, + first_block_strides=first_block_strides, + name=f"stage{i + 1}")(x) + + # Head. + if self.include_top: + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense( + self.num_classes, kernel_init=self.output_initializer, name="head")(x) + return x + + +ResNetWithBasicBlk = functools.partial(ResNet, block_cls=ResNetBlock) +ResNetWithBottleneckBlk = functools.partial(ResNet, + block_cls=BottleneckResNetBlock) + +ResNet18 = functools.partial(ResNetWithBasicBlk, stage_sizes=[2, 2, 2, 2]) +ResNet34 = functools.partial(ResNetWithBasicBlk, stage_sizes=[3, 4, 6, 3]) +ResNet50 = functools.partial(ResNetWithBottleneckBlk, stage_sizes=[3, 4, 6, 3]) +ResNet101 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 4, 23, 3]) +ResNet152 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 8, 36, 3]) +ResNet200 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 24, 36, 3]) diff --git a/invariant_slot_attention/modules/video.py b/invariant_slot_attention/modules/video.py new file mode 100644 index 0000000000000000000000000000000000000000..94c4aba90b02c7cebbb44e7646232cb341de0f56 --- /dev/null +++ b/invariant_slot_attention/modules/video.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video module library.""" + +import functools +from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Optional, Tuple, Union + +from flax import linen as nn +import jax.numpy as jnp +from invariant_slot_attention.lib import utils +from invariant_slot_attention.modules import misc + +Shape = Tuple[int] + +DType = Any +Array = Any # jnp.ndarray +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet +ProcessorState = ArrayTree +PRNGKey = Array +NestedDict = Dict[str, Any] + + +class CorrectorPredictorTuple(NamedTuple): + corrected: ProcessorState + predicted: ProcessorState + + +class Processor(nn.Module): + """Recurrent processor module. + + This module is scanned (applied recurrently) over the sequence dimension of + the input and applies a corrector and a predictor module. The corrector is + only applied if new inputs (such as a new image/frame) are received and uses + the new input to correct its internal state. + + The predictor is equivalent to a latent transition model and produces a + prediction for the state at the next time step, given the current (corrected) + state. + """ + corrector: Callable[[ProcessorState, Array], ProcessorState] + predictor: Callable[[ProcessorState], ProcessorState] + + @functools.partial( + nn.scan, # Scan (recurrently apply) over time axis. + in_axes=(1, 1, nn.broadcast), # (inputs, padding_mask, train). + out_axes=1, + variable_axes={"intermediates": 1}, # Stack intermediates along seq. dim. + variable_broadcast="params", + split_rngs={"params": False, "dropout": True}) + @nn.compact + def __call__(self, state, inputs, + padding_mask, + train): + + # Only apply corrector if we receive new inputs. + if inputs is not None: + corrected_state = self.corrector(state, inputs, padding_mask, train=train) + # Otherwise simply use previous state as input for predictor. + else: + corrected_state = state + + # Always apply predictor (i.e. transition model). + predicted_state = self.predictor(corrected_state, train=train) + + # Prepare outputs in a format compatible with nn.scan. + new_state = predicted_state + outputs = CorrectorPredictorTuple( + corrected=corrected_state, predicted=predicted_state) + return new_state, outputs + + +class SAVi(nn.Module): + """Video model consisting of encoder, recurrent processor, and decoder.""" + + encoder: Callable[[], nn.Module] + decoder: Callable[[], nn.Module] + corrector: Callable[[], nn.Module] + predictor: Callable[[], nn.Module] + initializer: Callable[[], nn.Module] + decode_corrected: bool = True + decode_predicted: bool = True + + @nn.compact + def __call__(self, video, conditioning = None, + continue_from_previous_state = False, + padding_mask = None, + train = False): + """Performs a forward pass on a video. + + Args: + video: Video of shape `[batch_size, n_frames, height, width, n_channels]`. + conditioning: Optional jnp.ndarray used for conditioning the initial state + of the recurrent processor. + continue_from_previous_state: Boolean, whether to continue from a previous + state or not. If True, the conditioning variable is used directly as + initial state. + padding_mask: Binary mask for padding video inputs (e.g. for videos of + different sizes/lengths). Zero corresponds to padding. + train: Indicating whether we're training or evaluating. + + Returns: + A dictionary of model predictions. + """ + processor = Processor( + corrector=self.corrector(), predictor=self.predictor()) # pytype: disable=wrong-arg-types + + if padding_mask is None: + padding_mask = jnp.ones(video.shape[:-1], jnp.int32) + + # video.shape = (batch_size, n_frames, height, width, n_channels) + # Vmapped over sequence dim. + encoded_inputs = self.encoder()(video, padding_mask, train) # pytype: disable=not-callable + if continue_from_previous_state: + assert conditioning is not None, ( + "When continuing from a previous state, the state has to be passed " + "via the `conditioning` variable, which cannot be `None`.") + init_state = conditioning[:, -1] # We currently only use last state. + else: + # Same as above but without encoded inputs. + init_state = self.initializer()( + conditioning, batch_size=video.shape[0], train=train) # pytype: disable=not-callable + + # Scan recurrent processor over encoded inputs along sequence dimension. + _, states = processor(init_state, encoded_inputs, padding_mask, train) + # type(states) = CorrectorPredictorTuple. + # states.corrected.shape = (batch_size, n_frames, ..., n_features). + # states.predicted.shape = (batch_size, n_frames, ..., n_features). + + # Decode latent states. + decoder = self.decoder() # Vmapped over sequence dim. + outputs = decoder(states.corrected, + train) if self.decode_corrected else None # pytype: disable=not-callable + outputs_pred = decoder(states.predicted, + train) if self.decode_predicted else None # pytype: disable=not-callable + + return { + "states": states.corrected, + "states_pred": states.predicted, + "outputs": outputs, + "outputs_pred": outputs_pred, + } + + +class FrameEncoder(nn.Module): + """Encoder for single video frame, vmapped over time axis.""" + + backbone: Callable[[], nn.Module] + pos_emb: Callable[[], nn.Module] = misc.Identity + reduction: Optional[str] = None + output_transform: Callable[[], nn.Module] = misc.Identity + + # Vmapped application of module, consumes time axis (axis=1). + @functools.partial(utils.time_distributed, in_axes=(1, 1, None)) + @nn.compact + def __call__(self, inputs, padding_mask = None, + train = False): + del padding_mask # Unused. + + # inputs.shape = (batch_size, height, width, n_channels) + x = self.backbone()(inputs, train=train) + + x = self.pos_emb()(x) + + if self.reduction == "spatial_flatten": + batch_size, height, width, n_features = x.shape + x = jnp.reshape(x, (batch_size, height * width, n_features)) + elif self.reduction == "spatial_average": + x = jnp.mean(x, axis=(1, 2)) + elif self.reduction == "all_flatten": + batch_size, height, width, n_features = x.shape + x = jnp.reshape(x, (batch_size, height * width * n_features)) + elif self.reduction is not None: + raise ValueError("Unknown reduction type: {}.".format(self.reduction)) + + output_block = self.output_transform() + + if hasattr(output_block, "qkv_size"): + # Project to qkv_size if used transformer. + x = nn.relu(nn.Dense(output_block.qkv_size)(x)) + + x = output_block(x, train=train) + return x diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cd9d94bcc1229b0ae2efccf4ac4000fbfb03a6 --- /dev/null +++ b/main.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main file for running the model trainer.""" + +from absl import app +from absl import flags +from absl import logging + +from clu import platform +import jax +from ml_collections import config_flags + +import tensorflow as tf + + +from invariant_slot_attention.lib import trainer + +FLAGS = flags.FLAGS + +config_flags.DEFINE_config_file( + "config", None, "Config file.") +flags.DEFINE_string("workdir", None, "Work unit directory.") +flags.DEFINE_string("jax_backend_target", None, "JAX backend target to use.") +flags.mark_flags_as_required(["config", "workdir"]) + + +def main(argv): + del argv + + # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], "GPU") + + if FLAGS.jax_backend_target: + logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) + jax.config.update("jax_xla_backend", "tpu_driver") + jax.config.update("jax_backend_target", FLAGS.jax_backend_target) + + logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count()) + logging.info("JAX devices: %r", jax.devices()) + + # Add a note so that we can tell which task is which JAX host. + platform.work_unit().set_task_status( + f"host_id: {jax.host_id()}, host_count: {jax.host_count()}") + platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, + FLAGS.workdir, "workdir") + + trainer.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == "__main__": + app.run(main) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2e18b3ed7f077ce62d631fcb134ff98a91aa7a3e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +absl-py>=0.12.0 +numpy>=1.21.5 +tensorflow-datasets>=4.4.0 +matplotlib>=3.5.0 +clu>=0.0.3 +flax>=0.3.5 +chex>=0.0.7 +optax>=0.1.0 +ml-collections>=0.1.0 +scikit-image +sunds \ No newline at end of file