ondrejbiza commited on
Commit
a560c26
1 Parent(s): db5cc89

Working on isa demo.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. __init__.py +15 -0
  3. app.py +66 -0
  4. invariant_slot_attention/configs/__init__.py +15 -0
  5. invariant_slot_attention/configs/clevr_with_masks/baseline.py +194 -0
  6. invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py +202 -0
  7. invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py +203 -0
  8. invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py +203 -0
  9. invariant_slot_attention/configs/clevrtex/resnet/baseline.py +198 -0
  10. invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py +206 -0
  11. invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py +213 -0
  12. invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py +213 -0
  13. invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py +197 -0
  14. invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py +205 -0
  15. invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py +207 -0
  16. invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py +207 -0
  17. invariant_slot_attention/configs/multishapenet_easy/baseline.py +195 -0
  18. invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py +203 -0
  19. invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py +205 -0
  20. invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py +205 -0
  21. invariant_slot_attention/configs/objects_room/baseline.py +192 -0
  22. invariant_slot_attention/configs/objects_room/equiv_transl.py +200 -0
  23. invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py +202 -0
  24. invariant_slot_attention/configs/objects_room/equiv_transl_scale.py +202 -0
  25. invariant_slot_attention/configs/tetrominoes/baseline.py +191 -0
  26. invariant_slot_attention/configs/tetrominoes/equiv_transl.py +199 -0
  27. invariant_slot_attention/configs/waymo_open/baseline.py +191 -0
  28. invariant_slot_attention/configs/waymo_open/equiv_transl.py +199 -0
  29. invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py +206 -0
  30. invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py +206 -0
  31. invariant_slot_attention/lib/__init__.py +15 -0
  32. invariant_slot_attention/lib/evaluator.py +326 -0
  33. invariant_slot_attention/lib/input_pipeline.py +390 -0
  34. invariant_slot_attention/lib/losses.py +295 -0
  35. invariant_slot_attention/lib/metrics.py +263 -0
  36. invariant_slot_attention/lib/preprocessing.py +1236 -0
  37. invariant_slot_attention/lib/trainer.py +328 -0
  38. invariant_slot_attention/lib/transforms.py +163 -0
  39. invariant_slot_attention/lib/utils.py +625 -0
  40. invariant_slot_attention/modules/__init__.py +49 -0
  41. invariant_slot_attention/modules/attention.py +327 -0
  42. invariant_slot_attention/modules/convolution.py +164 -0
  43. invariant_slot_attention/modules/decoders.py +267 -0
  44. invariant_slot_attention/modules/initializers.py +173 -0
  45. invariant_slot_attention/modules/invariant_attention.py +963 -0
  46. invariant_slot_attention/modules/invariant_initializers.py +327 -0
  47. invariant_slot_attention/modules/misc.py +340 -0
  48. invariant_slot_attention/modules/resnet.py +231 -0
  49. invariant_slot_attention/modules/video.py +195 -0
  50. main.py +65 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /venv
2
+ /flagged
3
+ /clevr_isa_ts
4
+ *.pyc
__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+
4
+ from absl import flags
5
+ import gradio as gr
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+ from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
10
+ from invariant_slot_attention.lib import input_pipeline
11
+ from invariant_slot_attention.lib import utils
12
+
13
+
14
+ def load_model(config):
15
+ rng, data_rng = jax.random.split(rng)
16
+
17
+ # Initialize model
18
+ model = utils.build_model_from_config(config.model)
19
+
20
+ def init_model(rng):
21
+ rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
22
+
23
+ init_conditioning = None
24
+ init_inputs = jnp.ones(
25
+ [1] + list(train_ds.element_spec["video"].shape)[2:],
26
+ jnp.float32)
27
+ initial_vars = model.init(
28
+ {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
29
+ video=init_inputs, conditioning=init_conditioning,
30
+ padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
31
+
32
+ # Split into state variables (e.g. for batchnorm stats) and model params.
33
+ # Note that `pop()` on a FrozenDict performs a deep copy.
34
+ state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
35
+
36
+ # Filter out intermediates (we don't want to store these in the TrainState).
37
+ state_vars = utils.filter_key_from_frozen_dict(
38
+ state_vars, key="intermediates")
39
+ return state_vars, initial_params
40
+
41
+ state_vars, initial_params = init_model(rng)
42
+
43
+ learning_rate_fn = lr_schedules.get_learning_rate_fn(config)
44
+ tx = optimizers.get_optimizer(
45
+ config.optimizer_configs, learning_rate_fn, params=initial_params)
46
+
47
+ opt_state = tx.init(initial_params)
48
+
49
+ state = utils.TrainState(
50
+ step=1, opt_state=opt_state, params=initial_params, rng=rng,
51
+ variables=state_vars)
52
+
53
+ loss_fn = functools.partial(
54
+ losses.compute_full_loss, loss_config=config.losses)
55
+
56
+ checkpoint_dir = os.path.join(workdir, "checkpoints")
57
+ ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
58
+ state = ckpt.restore_or_initialize(state)
59
+
60
+
61
+ def greet(name):
62
+ return "Hello " + name + "!"
63
+
64
+
65
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
66
+ demo.launch()
invariant_slot_attention/configs/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
invariant_slot_attention/configs/clevr_with_masks/baseline.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVR."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "clevr_with_masks",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ "top_left_crop(top=29, left=64, height=192)",
79
+ "resize_small({size})".format(size=min(*config.data.resolution))
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "top_left_crop(top=29, left=64, height=192)",
86
+ "resize_small({size})".format(size=min(*config.data.resolution))
87
+ ]
88
+
89
+ config.eval_slice_size = 1
90
+ config.eval_slice_keys = ["video", "segmentations_video"]
91
+
92
+ # Dictionary of targets and corresponding channels. Losses need to match.
93
+ targets = {"video": 3}
94
+ config.losses = {"recon": {"targets": list(targets)}}
95
+ config.losses = ml_collections.ConfigDict({
96
+ f"recon_{target}": {"loss_type": "recon", "key": target}
97
+ for target in targets})
98
+
99
+ config.model = ml_collections.ConfigDict({
100
+ "module": "invariant_slot_attention.modules.SAVi",
101
+
102
+ # Encoder.
103
+ "encoder": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.FrameEncoder",
105
+ "reduction": "spatial_flatten",
106
+ "backbone": ml_collections.ConfigDict({
107
+ "module": "invariant_slot_attention.modules.SimpleCNN",
108
+ "features": [64, 64, 64, 64],
109
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
110
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
111
+ }),
112
+ "pos_emb": ml_collections.ConfigDict({
113
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
114
+ "embedding_type": "linear",
115
+ "update_type": "project_add",
116
+ "output_transform": ml_collections.ConfigDict({
117
+ "module": "invariant_slot_attention.modules.MLP",
118
+ "hidden_size": 128,
119
+ "layernorm": "pre"
120
+ }),
121
+ }),
122
+ }),
123
+
124
+ # Corrector.
125
+ "corrector": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.SlotAttention",
127
+ "num_iterations": 3,
128
+ "qkv_size": 64,
129
+ "mlp_size": 128,
130
+ }),
131
+
132
+ # Predictor.
133
+ # Removed since we are running a single frame.
134
+ "predictor": ml_collections.ConfigDict({
135
+ "module": "invariant_slot_attention.modules.Identity"
136
+ }),
137
+
138
+ # Initializer.
139
+ "initializer": ml_collections.ConfigDict({
140
+ "module": "invariant_slot_attention.modules.ParamStateInit",
141
+ "shape": (11, 64), # (num_slots, slot_size)
142
+ }),
143
+
144
+ # Decoder.
145
+ "decoder": ml_collections.ConfigDict({
146
+ "module":
147
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
148
+ "resolution": (16, 16), # Update if data resolution or strides change
149
+ "backbone": ml_collections.ConfigDict({
150
+ "module": "invariant_slot_attention.modules.CNN",
151
+ "features": [64, 64, 64, 64, 64],
152
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
153
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
154
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
155
+ "layer_transpose": [True, True, True, False, False]
156
+ }),
157
+ "target_readout": ml_collections.ConfigDict({
158
+ "module": "invariant_slot_attention.modules.Readout",
159
+ "keys": list(targets),
160
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
161
+ "module": "invariant_slot_attention.modules.MLP",
162
+ "num_hidden_layers": 0,
163
+ "hidden_size": 0,
164
+ "output_size": targets[k]}) for k in targets
165
+ ],
166
+ }),
167
+ "pos_emb": ml_collections.ConfigDict({
168
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
169
+ "embedding_type": "linear",
170
+ "update_type": "project_add"
171
+ }),
172
+ }),
173
+ "decode_corrected": True,
174
+ "decode_predicted": False,
175
+ })
176
+
177
+ # Which video-shaped variables to visualize.
178
+ config.debug_var_video_paths = {
179
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
180
+ }
181
+
182
+ # Define which attention matrices to log/visualize.
183
+ config.debug_var_attn_paths = {
184
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
185
+ }
186
+
187
+ # Widths of attention matrices (for reshaping to image grid).
188
+ config.debug_var_attn_widths = {
189
+ "corrector_attn": 16,
190
+ }
191
+
192
+ return config
193
+
194
+
invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVR."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "clevr_with_masks",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ "top_left_crop(top=29, left=64, height=192)",
79
+ "resize_small({size})".format(size=min(*config.data.resolution))
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "top_left_crop(top=29, left=64, height=192)",
86
+ "resize_small({size})".format(size=min(*config.data.resolution))
87
+ ]
88
+
89
+ config.eval_slice_size = 1
90
+ config.eval_slice_keys = ["video", "segmentations_video"]
91
+
92
+ # Dictionary of targets and corresponding channels. Losses need to match.
93
+ targets = {"video": 3}
94
+ config.losses = {"recon": {"targets": list(targets)}}
95
+ config.losses = ml_collections.ConfigDict({
96
+ f"recon_{target}": {"loss_type": "recon", "key": target}
97
+ for target in targets})
98
+
99
+ config.model = ml_collections.ConfigDict({
100
+ "module": "invariant_slot_attention.modules.SAVi",
101
+
102
+ # Encoder.
103
+ "encoder": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.FrameEncoder",
105
+ "reduction": "spatial_flatten",
106
+ "backbone": ml_collections.ConfigDict({
107
+ "module": "invariant_slot_attention.modules.SimpleCNN",
108
+ "features": [64, 64, 64, 64],
109
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
110
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
111
+ }),
112
+ "pos_emb": ml_collections.ConfigDict({
113
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
114
+ "embedding_type": "linear",
115
+ "update_type": "concat"
116
+ }),
117
+ }),
118
+
119
+ # Corrector.
120
+ "corrector": ml_collections.ConfigDict({
121
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
122
+ "num_iterations": 3,
123
+ "qkv_size": 64,
124
+ "mlp_size": 128,
125
+ "grid_encoder": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.MLP",
127
+ "hidden_size": 128,
128
+ "layernorm": "pre"
129
+ }),
130
+ "add_rel_pos_to_values": True, # V3
131
+ "zero_position_init": False, # Random positions.
132
+ }),
133
+
134
+ # Predictor.
135
+ # Removed since we are running a single frame.
136
+ "predictor": ml_collections.ConfigDict({
137
+ "module": "invariant_slot_attention.modules.Identity"
138
+ }),
139
+
140
+ # Initializer.
141
+ "initializer": ml_collections.ConfigDict({
142
+ "module":
143
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
144
+ "shape":
145
+ (11, 64), # (num_slots, slot_size)
146
+ }),
147
+
148
+ # Decoder.
149
+ "decoder": ml_collections.ConfigDict({
150
+ "module":
151
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
152
+ "resolution": (16, 16), # Update if data resolution or strides change
153
+ "backbone": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.CNN",
155
+ "features": [64, 64, 64, 64, 64],
156
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
157
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
158
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
159
+ "layer_transpose": [True, True, True, False, False]
160
+ }),
161
+ "target_readout": ml_collections.ConfigDict({
162
+ "module": "invariant_slot_attention.modules.Readout",
163
+ "keys": list(targets),
164
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
165
+ "module": "invariant_slot_attention.modules.MLP",
166
+ "num_hidden_layers": 0,
167
+ "hidden_size": 0,
168
+ "output_size": targets[k]}) for k in targets
169
+ ],
170
+ }),
171
+ "relative_positions": True,
172
+ "pos_emb": ml_collections.ConfigDict({
173
+ "module":
174
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
175
+ "embedding_type":
176
+ "linear",
177
+ "update_type":
178
+ "project_add",
179
+ }),
180
+ }),
181
+ "decode_corrected": True,
182
+ "decode_predicted": False,
183
+ })
184
+
185
+ # Which video-shaped variables to visualize.
186
+ config.debug_var_video_paths = {
187
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
188
+ }
189
+
190
+ # Define which attention matrices to log/visualize.
191
+ config.debug_var_attn_paths = {
192
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
193
+ }
194
+
195
+ # Widths of attention matrices (for reshaping to image grid).
196
+ config.debug_var_attn_widths = {
197
+ "corrector_attn": 16,
198
+ }
199
+
200
+ return config
201
+
202
+
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVR."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "clevr_with_masks",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ "top_left_crop(top=29, left=64, height=192)",
79
+ "resize_small({size})".format(size=min(*config.data.resolution))
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "top_left_crop(top=29, left=64, height=192)",
86
+ "resize_small({size})".format(size=min(*config.data.resolution))
87
+ ]
88
+
89
+ config.eval_slice_size = 1
90
+ config.eval_slice_keys = ["video", "segmentations_video"]
91
+
92
+ # Dictionary of targets and corresponding channels. Losses need to match.
93
+ targets = {"video": 3}
94
+ config.losses = {"recon": {"targets": list(targets)}}
95
+ config.losses = ml_collections.ConfigDict({
96
+ f"recon_{target}": {"loss_type": "recon", "key": target}
97
+ for target in targets})
98
+
99
+ config.model = ml_collections.ConfigDict({
100
+ "module": "invariant_slot_attention.modules.SAVi",
101
+
102
+ # Encoder.
103
+ "encoder": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.FrameEncoder",
105
+ "reduction": "spatial_flatten",
106
+ "backbone": ml_collections.ConfigDict({
107
+ "module": "invariant_slot_attention.modules.SimpleCNN",
108
+ "features": [64, 64, 64, 64],
109
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
110
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
111
+ }),
112
+ "pos_emb": ml_collections.ConfigDict({
113
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
114
+ "embedding_type": "linear",
115
+ "update_type": "concat"
116
+ }),
117
+ }),
118
+
119
+ # Corrector.
120
+ "corrector": ml_collections.ConfigDict({
121
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
122
+ "num_iterations": 3,
123
+ "qkv_size": 64,
124
+ "mlp_size": 128,
125
+ "grid_encoder": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.MLP",
127
+ "hidden_size": 128,
128
+ "layernorm": "pre"
129
+ }),
130
+ "add_rel_pos_to_values": True, # V3
131
+ "zero_position_init": False, # Random positions.
132
+ "init_with_fixed_scale": None, # Random scales.
133
+ "scales_factor": 5.0,
134
+ }),
135
+
136
+ # Predictor.
137
+ # Removed since we are running a single frame.
138
+ "predictor": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.Identity"
140
+ }),
141
+
142
+ # Initializer.
143
+ "initializer": ml_collections.ConfigDict({
144
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
145
+ "shape": (11, 64), # (num_slots, slot_size)
146
+ }),
147
+
148
+ # Decoder.
149
+ "decoder": ml_collections.ConfigDict({
150
+ "module":
151
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
152
+ "resolution": (16, 16), # Update if data resolution or strides change
153
+ "backbone": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.CNN",
155
+ "features": [64, 64, 64, 64, 64],
156
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
157
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
158
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
159
+ "layer_transpose": [True, True, True, False, False]
160
+ }),
161
+ "target_readout": ml_collections.ConfigDict({
162
+ "module": "invariant_slot_attention.modules.Readout",
163
+ "keys": list(targets),
164
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
165
+ "module": "invariant_slot_attention.modules.MLP",
166
+ "num_hidden_layers": 0,
167
+ "hidden_size": 0,
168
+ "output_size": targets[k]}) for k in targets],
169
+ }),
170
+ "relative_positions_rotations_and_scales": True,
171
+ "pos_emb": ml_collections.ConfigDict({
172
+ "module":
173
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
174
+ "embedding_type":
175
+ "linear",
176
+ "update_type":
177
+ "project_add",
178
+ "scales_factor":
179
+ 5.0,
180
+ }),
181
+ }),
182
+ "decode_corrected": True,
183
+ "decode_predicted": False,
184
+ })
185
+
186
+ # Which video-shaped variables to visualize.
187
+ config.debug_var_video_paths = {
188
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
189
+ }
190
+
191
+ # Define which attention matrices to log/visualize.
192
+ config.debug_var_attn_paths = {
193
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
194
+ }
195
+
196
+ # Widths of attention matrices (for reshaping to image grid).
197
+ config.debug_var_attn_widths = {
198
+ "corrector_attn": 16,
199
+ }
200
+
201
+ return config
202
+
203
+
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVR."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "clevr_with_masks",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ "top_left_crop(top=29, left=64, height=192)",
79
+ "resize_small({size})".format(size=min(*config.data.resolution))
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "top_left_crop(top=29, left=64, height=192)",
86
+ "resize_small({size})".format(size=min(*config.data.resolution))
87
+ ]
88
+
89
+ config.eval_slice_size = 1
90
+ config.eval_slice_keys = ["video", "segmentations_video"]
91
+
92
+ # Dictionary of targets and corresponding channels. Losses need to match.
93
+ targets = {"video": 3}
94
+ config.losses = {"recon": {"targets": list(targets)}}
95
+ config.losses = ml_collections.ConfigDict({
96
+ f"recon_{target}": {"loss_type": "recon", "key": target}
97
+ for target in targets})
98
+
99
+ config.model = ml_collections.ConfigDict({
100
+ "module": "invariant_slot_attention.modules.SAVi",
101
+
102
+ # Encoder.
103
+ "encoder": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.FrameEncoder",
105
+ "reduction": "spatial_flatten",
106
+ "backbone": ml_collections.ConfigDict({
107
+ "module": "invariant_slot_attention.modules.SimpleCNN",
108
+ "features": [64, 64, 64, 64],
109
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
110
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
111
+ }),
112
+ "pos_emb": ml_collections.ConfigDict({
113
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
114
+ "embedding_type": "linear",
115
+ "update_type": "concat"
116
+ }),
117
+ }),
118
+
119
+ # Corrector.
120
+ "corrector": ml_collections.ConfigDict({
121
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
122
+ "num_iterations": 3,
123
+ "qkv_size": 64,
124
+ "mlp_size": 128,
125
+ "grid_encoder": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.MLP",
127
+ "hidden_size": 128,
128
+ "layernorm": "pre"
129
+ }),
130
+ "add_rel_pos_to_values": True, # V3
131
+ "zero_position_init": False, # Random positions.
132
+ "init_with_fixed_scale": None, # Random scales.
133
+ "scales_factor": 5.0,
134
+ }),
135
+
136
+ # Predictor.
137
+ # Removed since we are running a single frame.
138
+ "predictor": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.Identity"
140
+ }),
141
+
142
+ # Initializer.
143
+ "initializer": ml_collections.ConfigDict({
144
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
145
+ "shape": (11, 64), # (num_slots, slot_size)
146
+ }),
147
+
148
+ # Decoder.
149
+ "decoder": ml_collections.ConfigDict({
150
+ "module":
151
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
152
+ "resolution": (16, 16), # Update if data resolution or strides change
153
+ "backbone": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.CNN",
155
+ "features": [64, 64, 64, 64, 64],
156
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
157
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
158
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
159
+ "layer_transpose": [True, True, True, False, False]
160
+ }),
161
+ "target_readout": ml_collections.ConfigDict({
162
+ "module": "invariant_slot_attention.modules.Readout",
163
+ "keys": list(targets),
164
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
165
+ "module": "invariant_slot_attention.modules.MLP",
166
+ "num_hidden_layers": 0,
167
+ "hidden_size": 0,
168
+ "output_size": targets[k]}) for k in targets],
169
+ }),
170
+ "relative_positions_and_scales": True,
171
+ "pos_emb": ml_collections.ConfigDict({
172
+ "module":
173
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
174
+ "embedding_type":
175
+ "linear",
176
+ "update_type":
177
+ "project_add",
178
+ "scales_factor":
179
+ 5.0,
180
+ }),
181
+ }),
182
+ "decode_corrected": True,
183
+ "decode_predicted": False,
184
+ })
185
+
186
+ # Which video-shaped variables to visualize.
187
+ config.debug_var_video_paths = {
188
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
189
+ }
190
+
191
+ # Define which attention matrices to log/visualize.
192
+ config.debug_var_attn_paths = {
193
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
194
+ }
195
+
196
+ # Widths of attention matrices (for reshaping to image grid).
197
+ config.debug_var_attn_widths = {
198
+ "corrector_attn": 16,
199
+ }
200
+
201
+ return config
202
+
203
+
invariant_slot_attention/configs/clevrtex/resnet/baseline.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.ResNet34",
112
+ "num_classes": None,
113
+ "axis_name": "time",
114
+ "norm_type": "group",
115
+ "small_inputs": True
116
+ }),
117
+ "pos_emb": ml_collections.ConfigDict({
118
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
119
+ "embedding_type": "linear",
120
+ "update_type": "project_add",
121
+ "output_transform": ml_collections.ConfigDict({
122
+ "module": "invariant_slot_attention.modules.MLP",
123
+ "hidden_size": 128,
124
+ "layernorm": "pre"
125
+ }),
126
+ }),
127
+ }),
128
+
129
+ # Corrector.
130
+ "corrector": ml_collections.ConfigDict({
131
+ "module": "invariant_slot_attention.modules.SlotAttention",
132
+ "num_iterations": 3,
133
+ "qkv_size": 64,
134
+ "mlp_size": 128,
135
+ }),
136
+
137
+ # Predictor.
138
+ # Removed since we are running a single frame.
139
+ "predictor": ml_collections.ConfigDict({
140
+ "module": "invariant_slot_attention.modules.Identity"
141
+ }),
142
+
143
+ # Initializer.
144
+ "initializer": ml_collections.ConfigDict({
145
+ "module": "invariant_slot_attention.modules.ParamStateInit",
146
+ "shape": (11, 64), # (num_slots, slot_size)
147
+ }),
148
+
149
+ # Decoder.
150
+ "decoder": ml_collections.ConfigDict({
151
+ "module":
152
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
153
+ "resolution": (16, 16), # Update if data resolution or strides change
154
+ "backbone": ml_collections.ConfigDict({
155
+ "module": "invariant_slot_attention.modules.CNN",
156
+ "features": [64, 64, 64, 64, 64],
157
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
158
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
159
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
160
+ "layer_transpose": [True, True, True, False, False]
161
+ }),
162
+ "target_readout": ml_collections.ConfigDict({
163
+ "module": "invariant_slot_attention.modules.Readout",
164
+ "keys": list(targets),
165
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
166
+ "module": "invariant_slot_attention.modules.MLP",
167
+ "num_hidden_layers": 0,
168
+ "hidden_size": 0,
169
+ "output_size": targets[k]}) for k in targets],
170
+ }),
171
+ "pos_emb": ml_collections.ConfigDict({
172
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
173
+ "embedding_type": "linear",
174
+ "update_type": "project_add"
175
+ }),
176
+ }),
177
+ "decode_corrected": True,
178
+ "decode_predicted": False,
179
+ })
180
+
181
+ # Which video-shaped variables to visualize.
182
+ config.debug_var_video_paths = {
183
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
184
+ }
185
+
186
+ # Define which attention matrices to log/visualize.
187
+ config.debug_var_attn_paths = {
188
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
189
+ }
190
+
191
+ # Widths of attention matrices (for reshaping to image grid).
192
+ config.debug_var_attn_widths = {
193
+ "corrector_attn": 16,
194
+ }
195
+
196
+ return config
197
+
198
+
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.ResNet34",
112
+ "num_classes": None,
113
+ "axis_name": "time",
114
+ "norm_type": "group",
115
+ "small_inputs": True
116
+ }),
117
+ "pos_emb": ml_collections.ConfigDict({
118
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
119
+ "embedding_type": "linear",
120
+ "update_type": "concat"
121
+ }),
122
+ }),
123
+
124
+ # Corrector.
125
+ "corrector": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
127
+ "num_iterations": 3,
128
+ "qkv_size": 64,
129
+ "mlp_size": 128,
130
+ "grid_encoder": ml_collections.ConfigDict({
131
+ "module": "invariant_slot_attention.modules.MLP",
132
+ "hidden_size": 128,
133
+ "layernorm": "pre"
134
+ }),
135
+ "add_rel_pos_to_values": True, # V3
136
+ "zero_position_init": False, # Random positions.
137
+ }),
138
+
139
+ # Predictor.
140
+ # Removed since we are running a single frame.
141
+ "predictor": ml_collections.ConfigDict({
142
+ "module": "invariant_slot_attention.modules.Identity"
143
+ }),
144
+
145
+ # Initializer.
146
+ "initializer": ml_collections.ConfigDict({
147
+ "module":
148
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
149
+ "shape":
150
+ (11, 64), # (num_slots, slot_size)
151
+ }),
152
+
153
+ # Decoder.
154
+ "decoder": ml_collections.ConfigDict({
155
+ "module":
156
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
157
+ "resolution": (16, 16), # Update if data resolution or strides change
158
+ "backbone": ml_collections.ConfigDict({
159
+ "module": "invariant_slot_attention.modules.CNN",
160
+ "features": [64, 64, 64, 64, 64],
161
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
162
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
163
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
164
+ "layer_transpose": [True, True, True, False, False]
165
+ }),
166
+ "target_readout": ml_collections.ConfigDict({
167
+ "module": "invariant_slot_attention.modules.Readout",
168
+ "keys": list(targets),
169
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
170
+ "module": "invariant_slot_attention.modules.MLP",
171
+ "num_hidden_layers": 0,
172
+ "hidden_size": 0,
173
+ "output_size": targets[k]}) for k in targets],
174
+ }),
175
+ "relative_positions": True,
176
+ "pos_emb": ml_collections.ConfigDict({
177
+ "module":
178
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
179
+ "embedding_type":
180
+ "linear",
181
+ "update_type":
182
+ "project_add",
183
+ }),
184
+ }),
185
+ "decode_corrected": True,
186
+ "decode_predicted": False,
187
+ })
188
+
189
+ # Which video-shaped variables to visualize.
190
+ config.debug_var_video_paths = {
191
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
192
+ }
193
+
194
+ # Define which attention matrices to log/visualize.
195
+ config.debug_var_attn_paths = {
196
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
197
+ }
198
+
199
+ # Widths of attention matrices (for reshaping to image grid).
200
+ config.debug_var_attn_widths = {
201
+ "corrector_attn": 16,
202
+ }
203
+
204
+ return config
205
+
206
+
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.ResNet34",
112
+ "num_classes": None,
113
+ "axis_name": "time",
114
+ "norm_type": "group",
115
+ "small_inputs": True
116
+ }),
117
+ "pos_emb": ml_collections.ConfigDict({
118
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
119
+ "embedding_type": "linear",
120
+ "update_type": "project_add",
121
+ "output_transform": ml_collections.ConfigDict({
122
+ "module": "invariant_slot_attention.modules.MLP",
123
+ "hidden_size": 128,
124
+ "layernorm": "pre"
125
+ }),
126
+ }),
127
+ }),
128
+
129
+ # Corrector.
130
+ "corrector": ml_collections.ConfigDict({
131
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
132
+ "num_iterations": 3,
133
+ "qkv_size": 64,
134
+ "mlp_size": 128,
135
+ "grid_encoder": ml_collections.ConfigDict({
136
+ "module": "invariant_slot_attention.modules.MLP",
137
+ "hidden_size": 128,
138
+ "layernorm": "pre"
139
+ }),
140
+ "add_rel_pos_to_values": True, # V3
141
+ "zero_position_init": False, # Random positions.
142
+ "init_with_fixed_scale": None, # Random scales.
143
+ "scales_factor": 5.0,
144
+ }),
145
+
146
+ # Predictor.
147
+ # Removed since we are running a single frame.
148
+ "predictor": ml_collections.ConfigDict({
149
+ "module": "invariant_slot_attention.modules.Identity"
150
+ }),
151
+
152
+ # Initializer.
153
+ "initializer": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
155
+ "shape": (11, 64), # (num_slots, slot_size)
156
+ }),
157
+
158
+ # Decoder.
159
+ "decoder": ml_collections.ConfigDict({
160
+ "module":
161
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
162
+ "resolution": (16, 16), # Update if data resolution or strides change
163
+ "backbone": ml_collections.ConfigDict({
164
+ "module": "invariant_slot_attention.modules.CNN",
165
+ "features": [64, 64, 64, 64, 64],
166
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
167
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
168
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
169
+ "layer_transpose": [True, True, True, False, False]
170
+ }),
171
+ "target_readout": ml_collections.ConfigDict({
172
+ "module": "invariant_slot_attention.modules.Readout",
173
+ "keys": list(targets),
174
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
175
+ "module": "invariant_slot_attention.modules.MLP",
176
+ "num_hidden_layers": 0,
177
+ "hidden_size": 0,
178
+ "output_size": targets[k]}) for k in targets],
179
+ }),
180
+ "relative_positions_rotations_and_scales": True,
181
+ "pos_emb": ml_collections.ConfigDict({
182
+ "module":
183
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
184
+ "embedding_type":
185
+ "linear",
186
+ "update_type":
187
+ "project_add",
188
+ "scales_factor":
189
+ 5.0,
190
+ }),
191
+ }),
192
+ "decode_corrected": True,
193
+ "decode_predicted": False,
194
+ })
195
+
196
+ # Which video-shaped variables to visualize.
197
+ config.debug_var_video_paths = {
198
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
199
+ }
200
+
201
+ # Define which attention matrices to log/visualize.
202
+ config.debug_var_attn_paths = {
203
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
204
+ }
205
+
206
+ # Widths of attention matrices (for reshaping to image grid).
207
+ config.debug_var_attn_widths = {
208
+ "corrector_attn": 16,
209
+ }
210
+
211
+ return config
212
+
213
+
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.ResNet34",
112
+ "num_classes": None,
113
+ "axis_name": "time",
114
+ "norm_type": "group",
115
+ "small_inputs": True
116
+ }),
117
+ "pos_emb": ml_collections.ConfigDict({
118
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
119
+ "embedding_type": "linear",
120
+ "update_type": "project_add",
121
+ "output_transform": ml_collections.ConfigDict({
122
+ "module": "invariant_slot_attention.modules.MLP",
123
+ "hidden_size": 128,
124
+ "layernorm": "pre"
125
+ }),
126
+ }),
127
+ }),
128
+
129
+ # Corrector.
130
+ "corrector": ml_collections.ConfigDict({
131
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
132
+ "num_iterations": 3,
133
+ "qkv_size": 64,
134
+ "mlp_size": 128,
135
+ "grid_encoder": ml_collections.ConfigDict({
136
+ "module": "invariant_slot_attention.modules.MLP",
137
+ "hidden_size": 128,
138
+ "layernorm": "pre"
139
+ }),
140
+ "add_rel_pos_to_values": True, # V3
141
+ "zero_position_init": False, # Random positions.
142
+ "init_with_fixed_scale": None, # Random scales.
143
+ "scales_factor": 5.0,
144
+ }),
145
+
146
+ # Predictor.
147
+ # Removed since we are running a single frame.
148
+ "predictor": ml_collections.ConfigDict({
149
+ "module": "invariant_slot_attention.modules.Identity"
150
+ }),
151
+
152
+ # Initializer.
153
+ "initializer": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
155
+ "shape": (11, 64), # (num_slots, slot_size)
156
+ }),
157
+
158
+ # Decoder.
159
+ "decoder": ml_collections.ConfigDict({
160
+ "module":
161
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
162
+ "resolution": (16, 16), # Update if data resolution or strides change
163
+ "backbone": ml_collections.ConfigDict({
164
+ "module": "invariant_slot_attention.modules.CNN",
165
+ "features": [64, 64, 64, 64, 64],
166
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
167
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
168
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
169
+ "layer_transpose": [True, True, True, False, False]
170
+ }),
171
+ "target_readout": ml_collections.ConfigDict({
172
+ "module": "invariant_slot_attention.modules.Readout",
173
+ "keys": list(targets),
174
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
175
+ "module": "invariant_slot_attention.modules.MLP",
176
+ "num_hidden_layers": 0,
177
+ "hidden_size": 0,
178
+ "output_size": targets[k]}) for k in targets],
179
+ }),
180
+ "relative_positions_and_scales": True,
181
+ "pos_emb": ml_collections.ConfigDict({
182
+ "module":
183
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
184
+ "embedding_type":
185
+ "linear",
186
+ "update_type":
187
+ "project_add",
188
+ "scales_factor":
189
+ 5.0,
190
+ }),
191
+ }),
192
+ "decode_corrected": True,
193
+ "decode_predicted": False,
194
+ })
195
+
196
+ # Which video-shaped variables to visualize.
197
+ config.debug_var_video_paths = {
198
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
199
+ }
200
+
201
+ # Define which attention matrices to log/visualize.
202
+ config.debug_var_attn_paths = {
203
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
204
+ }
205
+
206
+ # Widths of attention matrices (for reshaping to image grid).
207
+ config.debug_var_attn_widths = {
208
+ "corrector_attn": 16,
209
+ }
210
+
211
+ return config
212
+
213
+
invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.SimpleCNN",
112
+ "features": [64, 64, 64, 64],
113
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
114
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
115
+ }),
116
+ "pos_emb": ml_collections.ConfigDict({
117
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
118
+ "embedding_type": "linear",
119
+ "update_type": "project_add",
120
+ "output_transform": ml_collections.ConfigDict({
121
+ "module": "invariant_slot_attention.modules.MLP",
122
+ "hidden_size": 128,
123
+ "layernorm": "pre"
124
+ }),
125
+ }),
126
+ }),
127
+
128
+ # Corrector.
129
+ "corrector": ml_collections.ConfigDict({
130
+ "module": "invariant_slot_attention.modules.SlotAttention",
131
+ "num_iterations": 3,
132
+ "qkv_size": 64,
133
+ "mlp_size": 128,
134
+ }),
135
+
136
+ # Predictor.
137
+ # Removed since we are running a single frame.
138
+ "predictor": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.Identity"
140
+ }),
141
+
142
+ # Initializer.
143
+ "initializer": ml_collections.ConfigDict({
144
+ "module": "invariant_slot_attention.modules.ParamStateInit",
145
+ "shape": (11, 64), # (num_slots, slot_size)
146
+ }),
147
+
148
+ # Decoder.
149
+ "decoder": ml_collections.ConfigDict({
150
+ "module":
151
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
152
+ "resolution": (16, 16), # Update if data resolution or strides change
153
+ "backbone": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.CNN",
155
+ "features": [64, 64, 64, 64, 64],
156
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
157
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
158
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
159
+ "layer_transpose": [True, True, True, False, False]
160
+ }),
161
+ "target_readout": ml_collections.ConfigDict({
162
+ "module": "invariant_slot_attention.modules.Readout",
163
+ "keys": list(targets),
164
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
165
+ "module": "invariant_slot_attention.modules.MLP",
166
+ "num_hidden_layers": 0,
167
+ "hidden_size": 0,
168
+ "output_size": targets[k]}) for k in targets],
169
+ }),
170
+ "pos_emb": ml_collections.ConfigDict({
171
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
172
+ "embedding_type": "linear",
173
+ "update_type": "project_add"
174
+ }),
175
+ }),
176
+ "decode_corrected": True,
177
+ "decode_predicted": False,
178
+ })
179
+
180
+ # Which video-shaped variables to visualize.
181
+ config.debug_var_video_paths = {
182
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
183
+ }
184
+
185
+ # Define which attention matrices to log/visualize.
186
+ config.debug_var_attn_paths = {
187
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
188
+ }
189
+
190
+ # Widths of attention matrices (for reshaping to image grid).
191
+ config.debug_var_attn_widths = {
192
+ "corrector_attn": 16,
193
+ }
194
+
195
+ return config
196
+
197
+
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.SimpleCNN",
112
+ "features": [64, 64, 64, 64],
113
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
114
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
115
+ }),
116
+ "pos_emb": ml_collections.ConfigDict({
117
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
118
+ "embedding_type": "linear",
119
+ "update_type": "concat"
120
+ }),
121
+ }),
122
+
123
+ # Corrector.
124
+ "corrector": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
126
+ "num_iterations": 3,
127
+ "qkv_size": 64,
128
+ "mlp_size": 128,
129
+ "grid_encoder": ml_collections.ConfigDict({
130
+ "module": "invariant_slot_attention.modules.MLP",
131
+ "hidden_size": 128,
132
+ "layernorm": "pre"
133
+ }),
134
+ "add_rel_pos_to_values": True, # V3
135
+ "zero_position_init": False, # Random positions.
136
+ }),
137
+
138
+ # Predictor.
139
+ # Removed since we are running a single frame.
140
+ "predictor": ml_collections.ConfigDict({
141
+ "module": "invariant_slot_attention.modules.Identity"
142
+ }),
143
+
144
+ # Initializer.
145
+ "initializer": ml_collections.ConfigDict({
146
+ "module":
147
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
148
+ "shape":
149
+ (11, 64), # (num_slots, slot_size)
150
+ }),
151
+
152
+ # Decoder.
153
+ "decoder": ml_collections.ConfigDict({
154
+ "module":
155
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
156
+ "resolution": (16, 16), # Update if data resolution or strides change
157
+ "backbone": ml_collections.ConfigDict({
158
+ "module": "invariant_slot_attention.modules.CNN",
159
+ "features": [64, 64, 64, 64, 64],
160
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
161
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
162
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
163
+ "layer_transpose": [True, True, True, False, False]
164
+ }),
165
+ "target_readout": ml_collections.ConfigDict({
166
+ "module": "invariant_slot_attention.modules.Readout",
167
+ "keys": list(targets),
168
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
169
+ "module": "invariant_slot_attention.modules.MLP",
170
+ "num_hidden_layers": 0,
171
+ "hidden_size": 0,
172
+ "output_size": targets[k]}) for k in targets],
173
+ }),
174
+ "relative_positions": True,
175
+ "pos_emb": ml_collections.ConfigDict({
176
+ "module":
177
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
178
+ "embedding_type":
179
+ "linear",
180
+ "update_type":
181
+ "project_add",
182
+ }),
183
+ }),
184
+ "decode_corrected": True,
185
+ "decode_predicted": False,
186
+ })
187
+
188
+ # Which video-shaped variables to visualize.
189
+ config.debug_var_video_paths = {
190
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
191
+ }
192
+
193
+ # Define which attention matrices to log/visualize.
194
+ config.debug_var_attn_paths = {
195
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
196
+ }
197
+
198
+ # Widths of attention matrices (for reshaping to image grid).
199
+ config.debug_var_attn_widths = {
200
+ "corrector_attn": 16,
201
+ }
202
+
203
+ return config
204
+
205
+
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.SimpleCNN",
112
+ "features": [64, 64, 64, 64],
113
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
114
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
115
+ }),
116
+ "pos_emb": ml_collections.ConfigDict({
117
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
118
+ "embedding_type": "linear",
119
+ "update_type": "concat"
120
+ }),
121
+ }),
122
+
123
+ # Corrector.
124
+ "corrector": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
126
+ "num_iterations": 3,
127
+ "qkv_size": 64,
128
+ "mlp_size": 128,
129
+ "grid_encoder": ml_collections.ConfigDict({
130
+ "module": "invariant_slot_attention.modules.MLP",
131
+ "hidden_size": 128,
132
+ "layernorm": "pre"
133
+ }),
134
+ "add_rel_pos_to_values": True, # V3
135
+ "zero_position_init": False, # Random positions.
136
+ "init_with_fixed_scale": None, # Random scales.
137
+ "scales_factor": 5.0,
138
+ }),
139
+
140
+ # Predictor.
141
+ # Removed since we are running a single frame.
142
+ "predictor": ml_collections.ConfigDict({
143
+ "module": "invariant_slot_attention.modules.Identity"
144
+ }),
145
+
146
+ # Initializer.
147
+ "initializer": ml_collections.ConfigDict({
148
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
149
+ "shape": (11, 64), # (num_slots, slot_size)
150
+ }),
151
+
152
+ # Decoder.
153
+ "decoder": ml_collections.ConfigDict({
154
+ "module":
155
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
156
+ "resolution": (16, 16), # Update if data resolution or strides change
157
+ "backbone": ml_collections.ConfigDict({
158
+ "module": "invariant_slot_attention.modules.CNN",
159
+ "features": [64, 64, 64, 64, 64],
160
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
161
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
162
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
163
+ "layer_transpose": [True, True, True, False, False]
164
+ }),
165
+ "target_readout": ml_collections.ConfigDict({
166
+ "module": "invariant_slot_attention.modules.Readout",
167
+ "keys": list(targets),
168
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
169
+ "module": "invariant_slot_attention.modules.MLP",
170
+ "num_hidden_layers": 0,
171
+ "hidden_size": 0,
172
+ "output_size": targets[k]}) for k in targets],
173
+ }),
174
+ "relative_positions_rotations_and_scales": True,
175
+ "pos_emb": ml_collections.ConfigDict({
176
+ "module":
177
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
178
+ "embedding_type":
179
+ "linear",
180
+ "update_type":
181
+ "project_add",
182
+ "scales_factor":
183
+ 5.0,
184
+ }),
185
+ }),
186
+ "decode_corrected": True,
187
+ "decode_predicted": False,
188
+ })
189
+
190
+ # Which video-shaped variables to visualize.
191
+ config.debug_var_video_paths = {
192
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
193
+ }
194
+
195
+ # Define which attention matrices to log/visualize.
196
+ config.debug_var_attn_paths = {
197
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
198
+ }
199
+
200
+ # Widths of attention matrices (for reshaping to image grid).
201
+ config.debug_var_attn_widths = {
202
+ "corrector_attn": 16,
203
+ }
204
+
205
+ return config
206
+
207
+
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVRTex."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "tfds",
67
+ # The TFDS dataset will be created in the directory below
68
+ # if you follow the README in datasets/clevrtex.
69
+ "data_dir": "~/tensorflow_datasets",
70
+ "tfds_name": "clevr_tex",
71
+ "shuffle_buffer_size": config.batch_size * 8,
72
+ "resolution": (128, 128)
73
+ })
74
+
75
+ config.max_instances = 11
76
+ config.num_slots = config.max_instances # Only used for metrics.
77
+ config.logging_min_n_colors = config.max_instances
78
+
79
+ config.preproc_train = [
80
+ "tfds_image_to_tfds_video",
81
+ "video_from_tfds",
82
+ "central_crop(height=192,width=192)",
83
+ "resize_small({size})".format(size=min(*config.data.resolution))
84
+ ]
85
+
86
+ config.preproc_eval = [
87
+ "tfds_image_to_tfds_video",
88
+ "video_from_tfds",
89
+ "central_crop(height=192,width=192)",
90
+ "resize_small({size})".format(size=min(*config.data.resolution))
91
+ ]
92
+
93
+ config.eval_slice_size = 1
94
+ config.eval_slice_keys = ["video", "segmentations_video"]
95
+
96
+ # Dictionary of targets and corresponding channels. Losses need to match.
97
+ targets = {"video": 3}
98
+ config.losses = {"recon": {"targets": list(targets)}}
99
+ config.losses = ml_collections.ConfigDict({
100
+ f"recon_{target}": {"loss_type": "recon", "key": target}
101
+ for target in targets})
102
+
103
+ config.model = ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.SAVi",
105
+
106
+ # Encoder.
107
+ "encoder": ml_collections.ConfigDict({
108
+ "module": "invariant_slot_attention.modules.FrameEncoder",
109
+ "reduction": "spatial_flatten",
110
+ "backbone": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.SimpleCNN",
112
+ "features": [64, 64, 64, 64],
113
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
114
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
115
+ }),
116
+ "pos_emb": ml_collections.ConfigDict({
117
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
118
+ "embedding_type": "linear",
119
+ "update_type": "concat"
120
+ }),
121
+ }),
122
+
123
+ # Corrector.
124
+ "corrector": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
126
+ "num_iterations": 3,
127
+ "qkv_size": 64,
128
+ "mlp_size": 128,
129
+ "grid_encoder": ml_collections.ConfigDict({
130
+ "module": "invariant_slot_attention.modules.MLP",
131
+ "hidden_size": 128,
132
+ "layernorm": "pre"
133
+ }),
134
+ "add_rel_pos_to_values": True, # V3
135
+ "zero_position_init": False, # Random positions.
136
+ "init_with_fixed_scale": None, # Random scales.
137
+ "scales_factor": 5.0,
138
+ }),
139
+
140
+ # Predictor.
141
+ # Removed since we are running a single frame.
142
+ "predictor": ml_collections.ConfigDict({
143
+ "module": "invariant_slot_attention.modules.Identity"
144
+ }),
145
+
146
+ # Initializer.
147
+ "initializer": ml_collections.ConfigDict({
148
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
149
+ "shape": (11, 64), # (num_slots, slot_size)
150
+ }),
151
+
152
+ # Decoder.
153
+ "decoder": ml_collections.ConfigDict({
154
+ "module":
155
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
156
+ "resolution": (16, 16), # Update if data resolution or strides change
157
+ "backbone": ml_collections.ConfigDict({
158
+ "module": "invariant_slot_attention.modules.CNN",
159
+ "features": [64, 64, 64, 64, 64],
160
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
161
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
162
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
163
+ "layer_transpose": [True, True, True, False, False]
164
+ }),
165
+ "target_readout": ml_collections.ConfigDict({
166
+ "module": "invariant_slot_attention.modules.Readout",
167
+ "keys": list(targets),
168
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
169
+ "module": "invariant_slot_attention.modules.MLP",
170
+ "num_hidden_layers": 0,
171
+ "hidden_size": 0,
172
+ "output_size": targets[k]}) for k in targets],
173
+ }),
174
+ "relative_positions_and_scales": True,
175
+ "pos_emb": ml_collections.ConfigDict({
176
+ "module":
177
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
178
+ "embedding_type":
179
+ "linear",
180
+ "update_type":
181
+ "project_add",
182
+ "scales_factor":
183
+ 5.0,
184
+ }),
185
+ }),
186
+ "decode_corrected": True,
187
+ "decode_predicted": False,
188
+ })
189
+
190
+ # Which video-shaped variables to visualize.
191
+ config.debug_var_video_paths = {
192
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
193
+ }
194
+
195
+ # Define which attention matrices to log/visualize.
196
+ config.debug_var_attn_paths = {
197
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
198
+ }
199
+
200
+ # Widths of attention matrices (for reshaping to image grid).
201
+ config.debug_var_attn_widths = {
202
+ "corrector_attn": 16,
203
+ }
204
+
205
+ return config
206
+
207
+
invariant_slot_attention/configs/multishapenet_easy/baseline.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on MultiShapeNet-Easy."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "multishapenet_easy",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "sunds_to_tfds_video",
77
+ "video_from_tfds",
78
+ "subtract_one_from_segmentations",
79
+ "central_crop(height=240, width=240)",
80
+ "resize_small({size})".format(size=min(*config.data.resolution))
81
+ ]
82
+
83
+ config.preproc_eval = [
84
+ "sunds_to_tfds_video",
85
+ "video_from_tfds",
86
+ "subtract_one_from_segmentations",
87
+ "central_crop(height=240, width=240)",
88
+ "resize_small({size})".format(size=min(*config.data.resolution))
89
+ ]
90
+
91
+ config.eval_slice_size = 1
92
+ config.eval_slice_keys = ["video", "segmentations_video"]
93
+
94
+ # Dictionary of targets and corresponding channels. Losses need to match.
95
+ targets = {"video": 3}
96
+ config.losses = {"recon": {"targets": list(targets)}}
97
+ config.losses = ml_collections.ConfigDict({
98
+ f"recon_{target}": {"loss_type": "recon", "key": target}
99
+ for target in targets})
100
+
101
+ config.model = ml_collections.ConfigDict({
102
+ "module": "invariant_slot_attention.modules.SAVi",
103
+
104
+ # Encoder.
105
+ "encoder": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.FrameEncoder",
107
+ "reduction": "spatial_flatten",
108
+ "backbone": ml_collections.ConfigDict({
109
+ "module": "invariant_slot_attention.modules.SimpleCNN",
110
+ "features": [64, 64, 64, 64],
111
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
112
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
113
+ }),
114
+ "pos_emb": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
116
+ "embedding_type": "linear",
117
+ "update_type": "project_add",
118
+ "output_transform": ml_collections.ConfigDict({
119
+ "module": "invariant_slot_attention.modules.MLP",
120
+ "hidden_size": 128,
121
+ "layernorm": "pre"
122
+ }),
123
+ }),
124
+ }),
125
+
126
+ # Corrector.
127
+ "corrector": ml_collections.ConfigDict({
128
+ "module": "invariant_slot_attention.modules.SlotAttention",
129
+ "num_iterations": 3,
130
+ "qkv_size": 64,
131
+ "mlp_size": 128,
132
+ }),
133
+
134
+ # Predictor.
135
+ # Removed since we are running a single frame.
136
+ "predictor": ml_collections.ConfigDict({
137
+ "module": "invariant_slot_attention.modules.Identity"
138
+ }),
139
+
140
+ # Initializer.
141
+ "initializer": ml_collections.ConfigDict({
142
+ "module": "invariant_slot_attention.modules.ParamStateInit",
143
+ "shape": (11, 64), # (num_slots, slot_size)
144
+ }),
145
+
146
+ # Decoder.
147
+ "decoder": ml_collections.ConfigDict({
148
+ "module":
149
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
150
+ "resolution": (16, 16), # Update if data resolution or strides change
151
+ "backbone": ml_collections.ConfigDict({
152
+ "module": "invariant_slot_attention.modules.CNN",
153
+ "features": [64, 64, 64, 64, 64],
154
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
155
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
156
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
157
+ "layer_transpose": [True, True, True, False, False]
158
+ }),
159
+ "target_readout": ml_collections.ConfigDict({
160
+ "module": "invariant_slot_attention.modules.Readout",
161
+ "keys": list(targets),
162
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
163
+ "module": "invariant_slot_attention.modules.MLP",
164
+ "num_hidden_layers": 0,
165
+ "hidden_size": 0,
166
+ "output_size": targets[k]}) for k in targets],
167
+ }),
168
+ "pos_emb": ml_collections.ConfigDict({
169
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
170
+ "embedding_type": "linear",
171
+ "update_type": "project_add"
172
+ }),
173
+ }),
174
+ "decode_corrected": True,
175
+ "decode_predicted": False,
176
+ })
177
+
178
+ # Which video-shaped variables to visualize.
179
+ config.debug_var_video_paths = {
180
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
181
+ }
182
+
183
+ # Define which attention matrices to log/visualize.
184
+ config.debug_var_attn_paths = {
185
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
186
+ }
187
+
188
+ # Widths of attention matrices (for reshaping to image grid).
189
+ config.debug_var_attn_widths = {
190
+ "corrector_attn": 16,
191
+ }
192
+
193
+ return config
194
+
195
+
invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on MultiShapeNet-Easy."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "multishapenet_easy",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "sunds_to_tfds_video",
77
+ "video_from_tfds",
78
+ "subtract_one_from_segmentations",
79
+ "central_crop(height=240, width=240)",
80
+ "resize_small({size})".format(size=min(*config.data.resolution))
81
+ ]
82
+
83
+ config.preproc_eval = [
84
+ "sunds_to_tfds_video",
85
+ "video_from_tfds",
86
+ "subtract_one_from_segmentations",
87
+ "central_crop(height=240, width=240)",
88
+ "resize_small({size})".format(size=min(*config.data.resolution))
89
+ ]
90
+
91
+ config.eval_slice_size = 1
92
+ config.eval_slice_keys = ["video", "segmentations_video"]
93
+
94
+ # Dictionary of targets and corresponding channels. Losses need to match.
95
+ targets = {"video": 3}
96
+ config.losses = {"recon": {"targets": list(targets)}}
97
+ config.losses = ml_collections.ConfigDict({
98
+ f"recon_{target}": {"loss_type": "recon", "key": target}
99
+ for target in targets})
100
+
101
+ config.model = ml_collections.ConfigDict({
102
+ "module": "invariant_slot_attention.modules.SAVi",
103
+
104
+ # Encoder.
105
+ "encoder": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.FrameEncoder",
107
+ "reduction": "spatial_flatten",
108
+ "backbone": ml_collections.ConfigDict({
109
+ "module": "invariant_slot_attention.modules.SimpleCNN",
110
+ "features": [64, 64, 64, 64],
111
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
112
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
113
+ }),
114
+ "pos_emb": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
116
+ "embedding_type": "linear",
117
+ "update_type": "concat"
118
+ }),
119
+ }),
120
+
121
+ # Corrector.
122
+ "corrector": ml_collections.ConfigDict({
123
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
124
+ "num_iterations": 3,
125
+ "qkv_size": 64,
126
+ "mlp_size": 128,
127
+ "grid_encoder": ml_collections.ConfigDict({
128
+ "module": "invariant_slot_attention.modules.MLP",
129
+ "hidden_size": 128,
130
+ "layernorm": "pre"
131
+ }),
132
+ "add_rel_pos_to_values": True, # V3
133
+ "zero_position_init": False, # Random positions.
134
+ }),
135
+
136
+ # Predictor.
137
+ # Removed since we are running a single frame.
138
+ "predictor": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.Identity"
140
+ }),
141
+
142
+ # Initializer.
143
+ "initializer": ml_collections.ConfigDict({
144
+ "module":
145
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
146
+ "shape":
147
+ (11, 64), # (num_slots, slot_size)
148
+ }),
149
+
150
+ # Decoder.
151
+ "decoder": ml_collections.ConfigDict({
152
+ "module":
153
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
154
+ "resolution": (16, 16), # Update if data resolution or strides change
155
+ "backbone": ml_collections.ConfigDict({
156
+ "module": "invariant_slot_attention.modules.CNN",
157
+ "features": [64, 64, 64, 64, 64],
158
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
159
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
160
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
161
+ "layer_transpose": [True, True, True, False, False]
162
+ }),
163
+ "target_readout": ml_collections.ConfigDict({
164
+ "module": "invariant_slot_attention.modules.Readout",
165
+ "keys": list(targets),
166
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
167
+ "module": "invariant_slot_attention.modules.MLP",
168
+ "num_hidden_layers": 0,
169
+ "hidden_size": 0,
170
+ "output_size": targets[k]}) for k in targets],
171
+ }),
172
+ "relative_positions": True,
173
+ "pos_emb": ml_collections.ConfigDict({
174
+ "module":
175
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
176
+ "embedding_type":
177
+ "linear",
178
+ "update_type":
179
+ "project_add",
180
+ }),
181
+ }),
182
+ "decode_corrected": True,
183
+ "decode_predicted": False,
184
+ })
185
+
186
+ # Which video-shaped variables to visualize.
187
+ config.debug_var_video_paths = {
188
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
189
+ }
190
+
191
+ # Define which attention matrices to log/visualize.
192
+ config.debug_var_attn_paths = {
193
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
194
+ }
195
+
196
+ # Widths of attention matrices (for reshaping to image grid).
197
+ config.debug_var_attn_widths = {
198
+ "corrector_attn": 16,
199
+ }
200
+
201
+ return config
202
+
203
+
invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on MultiShapeNet-Easy."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "multishapenet_easy",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "sunds_to_tfds_video",
77
+ "video_from_tfds",
78
+ "subtract_one_from_segmentations",
79
+ "central_crop(height=240, width=240)",
80
+ "resize_small({size})".format(size=min(*config.data.resolution))
81
+ ]
82
+
83
+ config.preproc_eval = [
84
+ "sunds_to_tfds_video",
85
+ "video_from_tfds",
86
+ "subtract_one_from_segmentations",
87
+ "central_crop(height=240, width=240)",
88
+ "resize_small({size})".format(size=min(*config.data.resolution))
89
+ ]
90
+
91
+ config.eval_slice_size = 1
92
+ config.eval_slice_keys = ["video", "segmentations_video"]
93
+
94
+ # Dictionary of targets and corresponding channels. Losses need to match.
95
+ targets = {"video": 3}
96
+ config.losses = {"recon": {"targets": list(targets)}}
97
+ config.losses = ml_collections.ConfigDict({
98
+ f"recon_{target}": {"loss_type": "recon", "key": target}
99
+ for target in targets})
100
+
101
+ config.model = ml_collections.ConfigDict({
102
+ "module": "invariant_slot_attention.modules.SAVi",
103
+
104
+ # Encoder.
105
+ "encoder": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.FrameEncoder",
107
+ "reduction": "spatial_flatten",
108
+ "backbone": ml_collections.ConfigDict({
109
+ "module": "invariant_slot_attention.modules.SimpleCNN",
110
+ "features": [64, 64, 64, 64],
111
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
112
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
113
+ }),
114
+ "pos_emb": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
116
+ "embedding_type": "linear",
117
+ "update_type": "concat"
118
+ }),
119
+ }),
120
+
121
+ # Corrector.
122
+ "corrector": ml_collections.ConfigDict({
123
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
124
+ "num_iterations": 3,
125
+ "qkv_size": 64,
126
+ "mlp_size": 128,
127
+ "grid_encoder": ml_collections.ConfigDict({
128
+ "module": "invariant_slot_attention.modules.MLP",
129
+ "hidden_size": 128,
130
+ "layernorm": "pre"
131
+ }),
132
+ "add_rel_pos_to_values": True, # V3
133
+ "zero_position_init": False, # Random positions.
134
+ "init_with_fixed_scale": None, # Random scales.
135
+ "scales_factor": 5.0,
136
+ }),
137
+
138
+ # Predictor.
139
+ # Removed since we are running a single frame.
140
+ "predictor": ml_collections.ConfigDict({
141
+ "module": "invariant_slot_attention.modules.Identity"
142
+ }),
143
+
144
+ # Initializer.
145
+ "initializer": ml_collections.ConfigDict({
146
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
147
+ "shape": (11, 64), # (num_slots, slot_size)
148
+ }),
149
+
150
+ # Decoder.
151
+ "decoder": ml_collections.ConfigDict({
152
+ "module":
153
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
154
+ "resolution": (16, 16), # Update if data resolution or strides change
155
+ "backbone": ml_collections.ConfigDict({
156
+ "module": "invariant_slot_attention.modules.CNN",
157
+ "features": [64, 64, 64, 64, 64],
158
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
159
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
160
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
161
+ "layer_transpose": [True, True, True, False, False]
162
+ }),
163
+ "target_readout": ml_collections.ConfigDict({
164
+ "module": "invariant_slot_attention.modules.Readout",
165
+ "keys": list(targets),
166
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
167
+ "module": "invariant_slot_attention.modules.MLP",
168
+ "num_hidden_layers": 0,
169
+ "hidden_size": 0,
170
+ "output_size": targets[k]}) for k in targets],
171
+ }),
172
+ "relative_positions_rotations_and_scales": True,
173
+ "pos_emb": ml_collections.ConfigDict({
174
+ "module":
175
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
176
+ "embedding_type":
177
+ "linear",
178
+ "update_type":
179
+ "project_add",
180
+ "scales_factor":
181
+ 5.0,
182
+ }),
183
+ }),
184
+ "decode_corrected": True,
185
+ "decode_predicted": False,
186
+ })
187
+
188
+ # Which video-shaped variables to visualize.
189
+ config.debug_var_video_paths = {
190
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
191
+ }
192
+
193
+ # Define which attention matrices to log/visualize.
194
+ config.debug_var_attn_paths = {
195
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
196
+ }
197
+
198
+ # Widths of attention matrices (for reshaping to image grid).
199
+ config.debug_var_attn_widths = {
200
+ "corrector_attn": 16,
201
+ }
202
+
203
+ return config
204
+
205
+
invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on MultiShapeNet-Easy."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "multishapenet_easy",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "sunds_to_tfds_video",
77
+ "video_from_tfds",
78
+ "subtract_one_from_segmentations",
79
+ "central_crop(height=240, width=240)",
80
+ "resize_small({size})".format(size=min(*config.data.resolution))
81
+ ]
82
+
83
+ config.preproc_eval = [
84
+ "sunds_to_tfds_video",
85
+ "video_from_tfds",
86
+ "subtract_one_from_segmentations",
87
+ "central_crop(height=240, width=240)",
88
+ "resize_small({size})".format(size=min(*config.data.resolution))
89
+ ]
90
+
91
+ config.eval_slice_size = 1
92
+ config.eval_slice_keys = ["video", "segmentations_video"]
93
+
94
+ # Dictionary of targets and corresponding channels. Losses need to match.
95
+ targets = {"video": 3}
96
+ config.losses = {"recon": {"targets": list(targets)}}
97
+ config.losses = ml_collections.ConfigDict({
98
+ f"recon_{target}": {"loss_type": "recon", "key": target}
99
+ for target in targets})
100
+
101
+ config.model = ml_collections.ConfigDict({
102
+ "module": "invariant_slot_attention.modules.SAVi",
103
+
104
+ # Encoder.
105
+ "encoder": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.FrameEncoder",
107
+ "reduction": "spatial_flatten",
108
+ "backbone": ml_collections.ConfigDict({
109
+ "module": "invariant_slot_attention.modules.SimpleCNN",
110
+ "features": [64, 64, 64, 64],
111
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
112
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
113
+ }),
114
+ "pos_emb": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
116
+ "embedding_type": "linear",
117
+ "update_type": "concat"
118
+ }),
119
+ }),
120
+
121
+ # Corrector.
122
+ "corrector": ml_collections.ConfigDict({
123
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
124
+ "num_iterations": 3,
125
+ "qkv_size": 64,
126
+ "mlp_size": 128,
127
+ "grid_encoder": ml_collections.ConfigDict({
128
+ "module": "invariant_slot_attention.modules.MLP",
129
+ "hidden_size": 128,
130
+ "layernorm": "pre"
131
+ }),
132
+ "add_rel_pos_to_values": True, # V3
133
+ "zero_position_init": False, # Random positions.
134
+ "init_with_fixed_scale": None, # Random scales.
135
+ "scales_factor": 5.0,
136
+ }),
137
+
138
+ # Predictor.
139
+ # Removed since we are running a single frame.
140
+ "predictor": ml_collections.ConfigDict({
141
+ "module": "invariant_slot_attention.modules.Identity"
142
+ }),
143
+
144
+ # Initializer.
145
+ "initializer": ml_collections.ConfigDict({
146
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
147
+ "shape": (11, 64), # (num_slots, slot_size)
148
+ }),
149
+
150
+ # Decoder.
151
+ "decoder": ml_collections.ConfigDict({
152
+ "module":
153
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
154
+ "resolution": (16, 16), # Update if data resolution or strides change
155
+ "backbone": ml_collections.ConfigDict({
156
+ "module": "invariant_slot_attention.modules.CNN",
157
+ "features": [64, 64, 64, 64, 64],
158
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
159
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
160
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
161
+ "layer_transpose": [True, True, True, False, False]
162
+ }),
163
+ "target_readout": ml_collections.ConfigDict({
164
+ "module": "invariant_slot_attention.modules.Readout",
165
+ "keys": list(targets),
166
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
167
+ "module": "invariant_slot_attention.modules.MLP",
168
+ "num_hidden_layers": 0,
169
+ "hidden_size": 0,
170
+ "output_size": targets[k]}) for k in targets],
171
+ }),
172
+ "relative_positions_and_scales": True,
173
+ "pos_emb": ml_collections.ConfigDict({
174
+ "module":
175
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
176
+ "embedding_type":
177
+ "linear",
178
+ "update_type":
179
+ "project_add",
180
+ "scales_factor":
181
+ 5.0,
182
+ }),
183
+ }),
184
+ "decode_corrected": True,
185
+ "decode_predicted": False,
186
+ })
187
+
188
+ # Which video-shaped variables to visualize.
189
+ config.debug_var_video_paths = {
190
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
191
+ }
192
+
193
+ # Define which attention matrices to log/visualize.
194
+ config.debug_var_attn_paths = {
195
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
196
+ }
197
+
198
+ # Widths of attention matrices (for reshaping to image grid).
199
+ config.debug_var_attn_widths = {
200
+ "corrector_attn": 16,
201
+ }
202
+
203
+ return config
204
+
205
+
invariant_slot_attention/configs/objects_room/baseline.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on objects_room."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 5000
53
+ config.checkpoint_every_steps = 5000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "objects_room",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (64, 64)
70
+ })
71
+
72
+ config.max_instances = 11
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=10)",
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=10)",
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "project_add",
115
+ "output_transform": ml_collections.ConfigDict({
116
+ "module": "invariant_slot_attention.modules.MLP",
117
+ "hidden_size": 128,
118
+ "layernorm": "pre"
119
+ }),
120
+ }),
121
+ }),
122
+
123
+ # Corrector.
124
+ "corrector": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.SlotAttention",
126
+ "num_iterations": 3,
127
+ "qkv_size": 64,
128
+ "mlp_size": 128,
129
+ }),
130
+
131
+ # Predictor.
132
+ # Removed since we are running a single frame.
133
+ "predictor": ml_collections.ConfigDict({
134
+ "module": "invariant_slot_attention.modules.Identity"
135
+ }),
136
+
137
+ # Initializer.
138
+ "initializer": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.ParamStateInit",
140
+ "shape": (11, 64), # (num_slots, slot_size)
141
+ }),
142
+
143
+ # Decoder.
144
+ "decoder": ml_collections.ConfigDict({
145
+ "module":
146
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
147
+ "resolution": (16, 16), # Update if data resolution or strides change
148
+ "backbone": ml_collections.ConfigDict({
149
+ "module": "invariant_slot_attention.modules.CNN",
150
+ "features": [64, 64, 64, 64, 64],
151
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
152
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
153
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
154
+ "layer_transpose": [True, True, False, False, False]
155
+ }),
156
+ "target_readout": ml_collections.ConfigDict({
157
+ "module": "invariant_slot_attention.modules.Readout",
158
+ "keys": list(targets),
159
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
160
+ "module": "invariant_slot_attention.modules.MLP",
161
+ "num_hidden_layers": 0,
162
+ "hidden_size": 0,
163
+ "output_size": targets[k]}) for k in targets],
164
+ }),
165
+ "pos_emb": ml_collections.ConfigDict({
166
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
167
+ "embedding_type": "linear",
168
+ "update_type": "project_add"
169
+ }),
170
+ }),
171
+ "decode_corrected": True,
172
+ "decode_predicted": False,
173
+ })
174
+
175
+ # Which video-shaped variables to visualize.
176
+ config.debug_var_video_paths = {
177
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
178
+ }
179
+
180
+ # Define which attention matrices to log/visualize.
181
+ config.debug_var_attn_paths = {
182
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
183
+ }
184
+
185
+ # Widths of attention matrices (for reshaping to image grid).
186
+ config.debug_var_attn_widths = {
187
+ "corrector_attn": 16,
188
+ }
189
+
190
+ return config
191
+
192
+
invariant_slot_attention/configs/objects_room/equiv_transl.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on objects_room."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 5000
53
+ config.checkpoint_every_steps = 5000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "objects_room",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (64, 64)
70
+ })
71
+
72
+ config.max_instances = 11
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=10)",
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=10)",
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "concat"
115
+ }),
116
+ }),
117
+
118
+ # Corrector.
119
+ "corrector": ml_collections.ConfigDict({
120
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
121
+ "num_iterations": 3,
122
+ "qkv_size": 64,
123
+ "mlp_size": 128,
124
+ "grid_encoder": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.MLP",
126
+ "hidden_size": 128,
127
+ "layernorm": "pre"
128
+ }),
129
+ "add_rel_pos_to_values": True, # V3
130
+ "zero_position_init": False, # Random positions.
131
+ }),
132
+
133
+ # Predictor.
134
+ # Removed since we are running a single frame.
135
+ "predictor": ml_collections.ConfigDict({
136
+ "module": "invariant_slot_attention.modules.Identity"
137
+ }),
138
+
139
+ # Initializer.
140
+ "initializer": ml_collections.ConfigDict({
141
+ "module":
142
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
143
+ "shape":
144
+ (11, 64), # (num_slots, slot_size)
145
+ }),
146
+
147
+ # Decoder.
148
+ "decoder": ml_collections.ConfigDict({
149
+ "module":
150
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
151
+ "resolution": (16, 16), # Update if data resolution or strides change
152
+ "backbone": ml_collections.ConfigDict({
153
+ "module": "invariant_slot_attention.modules.CNN",
154
+ "features": [64, 64, 64, 64, 64],
155
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
156
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
157
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
158
+ "layer_transpose": [True, True, False, False, False]
159
+ }),
160
+ "target_readout": ml_collections.ConfigDict({
161
+ "module": "invariant_slot_attention.modules.Readout",
162
+ "keys": list(targets),
163
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
164
+ "module": "invariant_slot_attention.modules.MLP",
165
+ "num_hidden_layers": 0,
166
+ "hidden_size": 0,
167
+ "output_size": targets[k]}) for k in targets],
168
+ }),
169
+ "relative_positions": True,
170
+ "pos_emb": ml_collections.ConfigDict({
171
+ "module":
172
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
173
+ "embedding_type":
174
+ "linear",
175
+ "update_type":
176
+ "project_add",
177
+ }),
178
+ }),
179
+ "decode_corrected": True,
180
+ "decode_predicted": False,
181
+ })
182
+
183
+ # Which video-shaped variables to visualize.
184
+ config.debug_var_video_paths = {
185
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
186
+ }
187
+
188
+ # Define which attention matrices to log/visualize.
189
+ config.debug_var_attn_paths = {
190
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
191
+ }
192
+
193
+ # Widths of attention matrices (for reshaping to image grid).
194
+ config.debug_var_attn_widths = {
195
+ "corrector_attn": 16,
196
+ }
197
+
198
+ return config
199
+
200
+
invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on objects_room."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 5000
53
+ config.checkpoint_every_steps = 5000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "objects_room",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (64, 64)
70
+ })
71
+
72
+ config.max_instances = 11
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=10)",
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=10)",
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "concat"
115
+ }),
116
+ }),
117
+
118
+ # Corrector.
119
+ "corrector": ml_collections.ConfigDict({
120
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
121
+ "num_iterations": 3,
122
+ "qkv_size": 64,
123
+ "mlp_size": 128,
124
+ "grid_encoder": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.MLP",
126
+ "hidden_size": 128,
127
+ "layernorm": "pre"
128
+ }),
129
+ "add_rel_pos_to_values": True, # V3
130
+ "zero_position_init": False, # Random positions.
131
+ "init_with_fixed_scale": None, # Random scales.
132
+ "scales_factor": 5.0,
133
+ }),
134
+
135
+ # Predictor.
136
+ # Removed since we are running a single frame.
137
+ "predictor": ml_collections.ConfigDict({
138
+ "module": "invariant_slot_attention.modules.Identity"
139
+ }),
140
+
141
+ # Initializer.
142
+ "initializer": ml_collections.ConfigDict({
143
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
144
+ "shape": (11, 64), # (num_slots, slot_size)
145
+ }),
146
+
147
+ # Decoder.
148
+ "decoder": ml_collections.ConfigDict({
149
+ "module":
150
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
151
+ "resolution": (16, 16), # Update if data resolution or strides change
152
+ "backbone": ml_collections.ConfigDict({
153
+ "module": "invariant_slot_attention.modules.CNN",
154
+ "features": [64, 64, 64, 64, 64],
155
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
156
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
157
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
158
+ "layer_transpose": [True, True, False, False, False]
159
+ }),
160
+ "target_readout": ml_collections.ConfigDict({
161
+ "module": "invariant_slot_attention.modules.Readout",
162
+ "keys": list(targets),
163
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
164
+ "module": "invariant_slot_attention.modules.MLP",
165
+ "num_hidden_layers": 0,
166
+ "hidden_size": 0,
167
+ "output_size": targets[k]}) for k in targets],
168
+ }),
169
+ "relative_positions_rotations_and_scales": True,
170
+ "pos_emb": ml_collections.ConfigDict({
171
+ "module":
172
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
173
+ "embedding_type":
174
+ "linear",
175
+ "update_type":
176
+ "project_add",
177
+ "scales_factor":
178
+ 5.0,
179
+ }),
180
+ }),
181
+ "decode_corrected": True,
182
+ "decode_predicted": False,
183
+ })
184
+
185
+ # Which video-shaped variables to visualize.
186
+ config.debug_var_video_paths = {
187
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
188
+ }
189
+
190
+ # Define which attention matrices to log/visualize.
191
+ config.debug_var_attn_paths = {
192
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
193
+ }
194
+
195
+ # Widths of attention matrices (for reshaping to image grid).
196
+ config.debug_var_attn_widths = {
197
+ "corrector_attn": 16,
198
+ }
199
+
200
+ return config
201
+
202
+
invariant_slot_attention/configs/objects_room/equiv_transl_scale.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on objects_room."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 5000
53
+ config.checkpoint_every_steps = 5000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "objects_room",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (64, 64)
70
+ })
71
+
72
+ config.max_instances = 11
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=10)",
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=10)",
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "concat"
115
+ }),
116
+ }),
117
+
118
+ # Corrector.
119
+ "corrector": ml_collections.ConfigDict({
120
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
121
+ "num_iterations": 3,
122
+ "qkv_size": 64,
123
+ "mlp_size": 128,
124
+ "grid_encoder": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.MLP",
126
+ "hidden_size": 128,
127
+ "layernorm": "pre"
128
+ }),
129
+ "add_rel_pos_to_values": True, # V3
130
+ "zero_position_init": False, # Random positions.
131
+ "init_with_fixed_scale": None, # Random scales.
132
+ "scales_factor": 5.0,
133
+ }),
134
+
135
+ # Predictor.
136
+ # Removed since we are running a single frame.
137
+ "predictor": ml_collections.ConfigDict({
138
+ "module": "invariant_slot_attention.modules.Identity"
139
+ }),
140
+
141
+ # Initializer.
142
+ "initializer": ml_collections.ConfigDict({
143
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
144
+ "shape": (11, 64), # (num_slots, slot_size)
145
+ }),
146
+
147
+ # Decoder.
148
+ "decoder": ml_collections.ConfigDict({
149
+ "module":
150
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
151
+ "resolution": (16, 16), # Update if data resolution or strides change
152
+ "backbone": ml_collections.ConfigDict({
153
+ "module": "invariant_slot_attention.modules.CNN",
154
+ "features": [64, 64, 64, 64, 64],
155
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
156
+ "strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
157
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
158
+ "layer_transpose": [True, True, False, False, False]
159
+ }),
160
+ "target_readout": ml_collections.ConfigDict({
161
+ "module": "invariant_slot_attention.modules.Readout",
162
+ "keys": list(targets),
163
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
164
+ "module": "invariant_slot_attention.modules.MLP",
165
+ "num_hidden_layers": 0,
166
+ "hidden_size": 0,
167
+ "output_size": targets[k]}) for k in targets],
168
+ }),
169
+ "relative_positions_and_scales": True,
170
+ "pos_emb": ml_collections.ConfigDict({
171
+ "module":
172
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
173
+ "embedding_type":
174
+ "linear",
175
+ "update_type":
176
+ "project_add",
177
+ "scales_factor":
178
+ 5.0,
179
+ }),
180
+ }),
181
+ "decode_corrected": True,
182
+ "decode_predicted": False,
183
+ })
184
+
185
+ # Which video-shaped variables to visualize.
186
+ config.debug_var_video_paths = {
187
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
188
+ }
189
+
190
+ # Define which attention matrices to log/visualize.
191
+ config.debug_var_attn_paths = {
192
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
193
+ }
194
+
195
+ # Widths of attention matrices (for reshaping to image grid).
196
+ config.debug_var_attn_widths = {
197
+ "corrector_attn": 16,
198
+ }
199
+
200
+ return config
201
+
202
+
invariant_slot_attention/configs/tetrominoes/baseline.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Tetrominoes with 512 train samples."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 20000
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 5000
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 1000
53
+ config.checkpoint_every_steps = 1000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "tetrominoes",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (35, 35)
70
+ })
71
+
72
+ config.max_instances = 4
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=3)"
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=3)"
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(1, 1), (1, 1), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "project_add",
115
+ "output_transform": ml_collections.ConfigDict({
116
+ "module": "invariant_slot_attention.modules.MLP",
117
+ "hidden_size": 128,
118
+ "layernorm": "pre"
119
+ }),
120
+ }),
121
+ }),
122
+
123
+ # Corrector.
124
+ "corrector": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.SlotAttention",
126
+ "num_iterations": 3,
127
+ "qkv_size": 64,
128
+ "mlp_size": 128,
129
+ }),
130
+
131
+ # Predictor.
132
+ # Removed since we are running a single frame.
133
+ "predictor": ml_collections.ConfigDict({
134
+ "module": "invariant_slot_attention.modules.Identity"
135
+ }),
136
+
137
+ # Initializer.
138
+ "initializer": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.ParamStateInit",
140
+ "shape": (4, 64), # (num_slots, slot_size)
141
+ }),
142
+
143
+ # Decoder.
144
+ "decoder": ml_collections.ConfigDict({
145
+ "module":
146
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
147
+ "resolution": (35, 35), # Update if data resolution or strides change
148
+ "backbone": ml_collections.ConfigDict({
149
+ "module": "invariant_slot_attention.modules.MLP",
150
+ "hidden_size": 256,
151
+ "output_size": 256,
152
+ "num_hidden_layers": 5,
153
+ "activate_output": True
154
+ }),
155
+ "target_readout": ml_collections.ConfigDict({
156
+ "module": "invariant_slot_attention.modules.Readout",
157
+ "keys": list(targets),
158
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
159
+ "module": "invariant_slot_attention.modules.MLP",
160
+ "num_hidden_layers": 0,
161
+ "hidden_size": 0,
162
+ "output_size": targets[k]}) for k in targets],
163
+ }),
164
+ "pos_emb": ml_collections.ConfigDict({
165
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
166
+ "embedding_type": "linear",
167
+ "update_type": "project_add"
168
+ }),
169
+ }),
170
+ "decode_corrected": True,
171
+ "decode_predicted": False,
172
+ })
173
+
174
+ # Which video-shaped variables to visualize.
175
+ config.debug_var_video_paths = {
176
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
177
+ }
178
+
179
+ # Define which attention matrices to log/visualize.
180
+ config.debug_var_attn_paths = {
181
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
182
+ }
183
+
184
+ # Widths of attention matrices (for reshaping to image grid).
185
+ config.debug_var_attn_widths = {
186
+ "corrector_attn": 35,
187
+ }
188
+
189
+ return config
190
+
191
+
invariant_slot_attention/configs/tetrominoes/equiv_transl.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Tetrominoes with 512 train samples."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 20000
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 5000
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ # TODO(obvis): Implement masked evaluation.
50
+ config.eval_pad_last_batch = False # True
51
+ config.log_loss_every_steps = 50
52
+ config.eval_every_steps = 1000
53
+ config.checkpoint_every_steps = 1000
54
+
55
+ config.train_metrics_spec = {
56
+ "loss": "loss",
57
+ "ari": "ari",
58
+ "ari_nobg": "ari_nobg",
59
+ }
60
+ config.eval_metrics_spec = {
61
+ "eval_loss": "loss",
62
+ "eval_ari": "ari",
63
+ "eval_ari_nobg": "ari_nobg",
64
+ }
65
+
66
+ config.data = ml_collections.ConfigDict({
67
+ "dataset_name": "tetrominoes",
68
+ "shuffle_buffer_size": config.batch_size * 8,
69
+ "resolution": (35, 35)
70
+ })
71
+
72
+ config.max_instances = 4
73
+ config.num_slots = config.max_instances # Only used for metrics.
74
+ config.logging_min_n_colors = config.max_instances
75
+
76
+ config.preproc_train = [
77
+ "tfds_image_to_tfds_video",
78
+ "video_from_tfds",
79
+ "sparse_to_dense_annotation(max_instances=3)"
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "sparse_to_dense_annotation(max_instances=3)"
86
+ ]
87
+
88
+ config.eval_slice_size = 1
89
+ config.eval_slice_keys = ["video", "segmentations_video"]
90
+
91
+ # Dictionary of targets and corresponding channels. Losses need to match.
92
+ targets = {"video": 3}
93
+ config.losses = {"recon": {"targets": list(targets)}}
94
+ config.losses = ml_collections.ConfigDict({
95
+ f"recon_{target}": {"loss_type": "recon", "key": target}
96
+ for target in targets})
97
+
98
+ config.model = ml_collections.ConfigDict({
99
+ "module": "invariant_slot_attention.modules.SAVi",
100
+
101
+ # Encoder.
102
+ "encoder": ml_collections.ConfigDict({
103
+ "module": "invariant_slot_attention.modules.FrameEncoder",
104
+ "reduction": "spatial_flatten",
105
+ "backbone": ml_collections.ConfigDict({
106
+ "module": "invariant_slot_attention.modules.SimpleCNN",
107
+ "features": [64, 64, 64, 64],
108
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
109
+ "strides": [(1, 1), (1, 1), (1, 1), (1, 1)]
110
+ }),
111
+ "pos_emb": ml_collections.ConfigDict({
112
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
113
+ "embedding_type": "linear",
114
+ "update_type": "concat"
115
+ }),
116
+ }),
117
+
118
+ # Corrector.
119
+ "corrector": ml_collections.ConfigDict({
120
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
121
+ "num_iterations": 3,
122
+ "qkv_size": 64,
123
+ "mlp_size": 128,
124
+ "grid_encoder": ml_collections.ConfigDict({
125
+ "module": "invariant_slot_attention.modules.MLP",
126
+ "hidden_size": 128,
127
+ "layernorm": "pre"
128
+ }),
129
+ "add_rel_pos_to_values": True, # V3
130
+ "zero_position_init": False, # Random positions.
131
+ }),
132
+
133
+ # Predictor.
134
+ # Removed since we are running a single frame.
135
+ "predictor": ml_collections.ConfigDict({
136
+ "module": "invariant_slot_attention.modules.Identity"
137
+ }),
138
+
139
+ # Initializer.
140
+ "initializer": ml_collections.ConfigDict({
141
+ "module":
142
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
143
+ "shape":
144
+ (4, 64), # (num_slots, slot_size)
145
+ }),
146
+
147
+ # Decoder.
148
+ "decoder": ml_collections.ConfigDict({
149
+ "module":
150
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
151
+ "resolution": (35, 35), # Update if data resolution or strides change
152
+ "backbone": ml_collections.ConfigDict({
153
+ "module": "invariant_slot_attention.modules.MLP",
154
+ "hidden_size": 256,
155
+ "output_size": 256,
156
+ "num_hidden_layers": 5,
157
+ "activate_output": True
158
+ }),
159
+ "target_readout": ml_collections.ConfigDict({
160
+ "module": "invariant_slot_attention.modules.Readout",
161
+ "keys": list(targets),
162
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
163
+ "module": "invariant_slot_attention.modules.MLP",
164
+ "num_hidden_layers": 0,
165
+ "hidden_size": 0,
166
+ "output_size": targets[k]}) for k in targets],
167
+ }),
168
+ "relative_positions": True,
169
+ "pos_emb": ml_collections.ConfigDict({
170
+ "module":
171
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
172
+ "embedding_type":
173
+ "linear",
174
+ "update_type":
175
+ "project_add",
176
+ }),
177
+ }),
178
+ "decode_corrected": True,
179
+ "decode_predicted": False,
180
+ })
181
+
182
+ # Which video-shaped variables to visualize.
183
+ config.debug_var_video_paths = {
184
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
185
+ }
186
+
187
+ # Define which attention matrices to log/visualize.
188
+ config.debug_var_attn_paths = {
189
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
190
+ }
191
+
192
+ # Widths of attention matrices (for reshaping to image grid).
193
+ config.debug_var_attn_widths = {
194
+ "corrector_attn": 35,
195
+ }
196
+
197
+ return config
198
+
199
+
invariant_slot_attention/configs/waymo_open/baseline.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Waymo Open."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "waymo_open",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 192)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ ]
79
+
80
+ config.preproc_eval = [
81
+ "tfds_image_to_tfds_video",
82
+ "video_from_tfds",
83
+ "delete_small_masks(threshold=0.01, max_instances_after=11)",
84
+ ]
85
+
86
+ config.eval_slice_size = 1
87
+ config.eval_slice_keys = ["video", "segmentations_video"]
88
+
89
+ # Dictionary of targets and corresponding channels. Losses need to match.
90
+ targets = {"video": 3}
91
+ config.losses = {"recon": {"targets": list(targets)}}
92
+ config.losses = ml_collections.ConfigDict({
93
+ f"recon_{target}": {"loss_type": "recon", "key": target}
94
+ for target in targets})
95
+
96
+ config.model = ml_collections.ConfigDict({
97
+ "module": "invariant_slot_attention.modules.SAVi",
98
+
99
+ # Encoder.
100
+ "encoder": ml_collections.ConfigDict({
101
+ "module": "invariant_slot_attention.modules.FrameEncoder",
102
+ "reduction": "spatial_flatten",
103
+ "backbone": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.ResNet34",
105
+ "num_classes": None,
106
+ "axis_name": "time",
107
+ "norm_type": "group",
108
+ "small_inputs": True
109
+ }),
110
+ "pos_emb": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
112
+ "embedding_type": "linear",
113
+ "update_type": "project_add",
114
+ "output_transform": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.MLP",
116
+ "hidden_size": 128,
117
+ "layernorm": "pre"
118
+ }),
119
+ }),
120
+ }),
121
+
122
+ # Corrector.
123
+ "corrector": ml_collections.ConfigDict({
124
+ "module": "invariant_slot_attention.modules.SlotAttention",
125
+ "num_iterations": 3,
126
+ "qkv_size": 64,
127
+ "mlp_size": 128,
128
+ }),
129
+
130
+ # Predictor.
131
+ # Removed since we are running a single frame.
132
+ "predictor": ml_collections.ConfigDict({
133
+ "module": "invariant_slot_attention.modules.Identity"
134
+ }),
135
+
136
+ # Initializer.
137
+ "initializer": ml_collections.ConfigDict({
138
+ "module": "invariant_slot_attention.modules.ParamStateInit",
139
+ "shape": (11, 64), # (num_slots, slot_size)
140
+ }),
141
+
142
+ # Decoder.
143
+ "decoder": ml_collections.ConfigDict({
144
+ "module":
145
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
146
+ "resolution": (16, 24), # Update if data resolution or strides change
147
+ "backbone": ml_collections.ConfigDict({
148
+ "module": "invariant_slot_attention.modules.CNN",
149
+ "features": [64, 64, 64, 64, 64],
150
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
151
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
152
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
153
+ "layer_transpose": [True, True, True, False, False]
154
+ }),
155
+ "target_readout": ml_collections.ConfigDict({
156
+ "module": "invariant_slot_attention.modules.Readout",
157
+ "keys": list(targets),
158
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
159
+ "module": "invariant_slot_attention.modules.MLP",
160
+ "num_hidden_layers": 0,
161
+ "hidden_size": 0,
162
+ "output_size": targets[k]}) for k in targets],
163
+ }),
164
+ "pos_emb": ml_collections.ConfigDict({
165
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
166
+ "embedding_type": "linear",
167
+ "update_type": "project_add"
168
+ }),
169
+ }),
170
+ "decode_corrected": True,
171
+ "decode_predicted": False,
172
+ })
173
+
174
+ # Which video-shaped variables to visualize.
175
+ config.debug_var_video_paths = {
176
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
177
+ }
178
+
179
+ # Define which attention matrices to log/visualize.
180
+ config.debug_var_attn_paths = {
181
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
182
+ }
183
+
184
+ # Widths of attention matrices (for reshaping to image grid).
185
+ config.debug_var_attn_widths = {
186
+ "corrector_attn": 16,
187
+ }
188
+
189
+ return config
190
+
191
+
invariant_slot_attention/configs/waymo_open/equiv_transl.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Waymo Open."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "waymo_open",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 192)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ ]
79
+
80
+ config.preproc_eval = [
81
+ "tfds_image_to_tfds_video",
82
+ "video_from_tfds",
83
+ "delete_small_masks(threshold=0.01, max_instances_after=11)",
84
+ ]
85
+
86
+ config.eval_slice_size = 1
87
+ config.eval_slice_keys = ["video", "segmentations_video"]
88
+
89
+ # Dictionary of targets and corresponding channels. Losses need to match.
90
+ targets = {"video": 3}
91
+ config.losses = {"recon": {"targets": list(targets)}}
92
+ config.losses = ml_collections.ConfigDict({
93
+ f"recon_{target}": {"loss_type": "recon", "key": target}
94
+ for target in targets})
95
+
96
+ config.model = ml_collections.ConfigDict({
97
+ "module": "invariant_slot_attention.modules.SAVi",
98
+
99
+ # Encoder.
100
+ "encoder": ml_collections.ConfigDict({
101
+ "module": "invariant_slot_attention.modules.FrameEncoder",
102
+ "reduction": "spatial_flatten",
103
+ "backbone": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.ResNet34",
105
+ "num_classes": None,
106
+ "axis_name": "time",
107
+ "norm_type": "group",
108
+ "small_inputs": True
109
+ }),
110
+ "pos_emb": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
112
+ "embedding_type": "linear",
113
+ "update_type": "concat"
114
+ }),
115
+ }),
116
+
117
+ # Corrector.
118
+ "corrector": ml_collections.ConfigDict({
119
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
120
+ "num_iterations": 3,
121
+ "qkv_size": 64,
122
+ "mlp_size": 128,
123
+ "grid_encoder": ml_collections.ConfigDict({
124
+ "module": "invariant_slot_attention.modules.MLP",
125
+ "hidden_size": 128,
126
+ "layernorm": "pre"
127
+ }),
128
+ "add_rel_pos_to_values": True, # V3
129
+ "zero_position_init": False, # Random positions.
130
+ }),
131
+
132
+ # Predictor.
133
+ # Removed since we are running a single frame.
134
+ "predictor": ml_collections.ConfigDict({
135
+ "module": "invariant_slot_attention.modules.Identity"
136
+ }),
137
+
138
+ # Initializer.
139
+ "initializer": ml_collections.ConfigDict({
140
+ "module":
141
+ "invariant_slot_attention.modules.ParamStateInitRandomPositions",
142
+ "shape":
143
+ (11, 64), # (num_slots, slot_size)
144
+ }),
145
+
146
+ # Decoder.
147
+ "decoder": ml_collections.ConfigDict({
148
+ "module":
149
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
150
+ "resolution": (16, 24), # Update if data resolution or strides change
151
+ "backbone": ml_collections.ConfigDict({
152
+ "module": "invariant_slot_attention.modules.CNN",
153
+ "features": [64, 64, 64, 64, 64],
154
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
155
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
156
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
157
+ "layer_transpose": [True, True, True, False, False]
158
+ }),
159
+ "target_readout": ml_collections.ConfigDict({
160
+ "module": "invariant_slot_attention.modules.Readout",
161
+ "keys": list(targets),
162
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
163
+ "module": "invariant_slot_attention.modules.MLP",
164
+ "num_hidden_layers": 0,
165
+ "hidden_size": 0,
166
+ "output_size": targets[k]}) for k in targets],
167
+ }),
168
+ "relative_positions": True,
169
+ "pos_emb": ml_collections.ConfigDict({
170
+ "module":
171
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
172
+ "embedding_type":
173
+ "linear",
174
+ "update_type":
175
+ "project_add",
176
+ }),
177
+ }),
178
+ "decode_corrected": True,
179
+ "decode_predicted": False,
180
+ })
181
+
182
+ # Which video-shaped variables to visualize.
183
+ config.debug_var_video_paths = {
184
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
185
+ }
186
+
187
+ # Define which attention matrices to log/visualize.
188
+ config.debug_var_attn_paths = {
189
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
190
+ }
191
+
192
+ # Widths of attention matrices (for reshaping to image grid).
193
+ config.debug_var_attn_widths = {
194
+ "corrector_attn": 16,
195
+ }
196
+
197
+ return config
198
+
199
+
invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Waymo Open."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "waymo_open",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 192)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ ]
79
+
80
+ config.preproc_eval = [
81
+ "tfds_image_to_tfds_video",
82
+ "video_from_tfds",
83
+ "delete_small_masks(threshold=0.01, max_instances_after=11)",
84
+ ]
85
+
86
+ config.eval_slice_size = 1
87
+ config.eval_slice_keys = ["video", "segmentations_video"]
88
+
89
+ # Dictionary of targets and corresponding channels. Losses need to match.
90
+ targets = {"video": 3}
91
+ config.losses = {"recon": {"targets": list(targets)}}
92
+ config.losses = ml_collections.ConfigDict({
93
+ f"recon_{target}": {"loss_type": "recon", "key": target}
94
+ for target in targets})
95
+
96
+ config.model = ml_collections.ConfigDict({
97
+ "module": "invariant_slot_attention.modules.SAVi",
98
+
99
+ # Encoder.
100
+ "encoder": ml_collections.ConfigDict({
101
+ "module": "invariant_slot_attention.modules.FrameEncoder",
102
+ "reduction": "spatial_flatten",
103
+ "backbone": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.ResNet34",
105
+ "num_classes": None,
106
+ "axis_name": "time",
107
+ "norm_type": "group",
108
+ "small_inputs": True
109
+ }),
110
+ "pos_emb": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
112
+ "embedding_type": "linear",
113
+ "update_type": "project_add",
114
+ "output_transform": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.MLP",
116
+ "hidden_size": 128,
117
+ "layernorm": "pre"
118
+ }),
119
+ }),
120
+ }),
121
+
122
+ # Corrector.
123
+ "corrector": ml_collections.ConfigDict({
124
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
125
+ "num_iterations": 3,
126
+ "qkv_size": 64,
127
+ "mlp_size": 128,
128
+ "grid_encoder": ml_collections.ConfigDict({
129
+ "module": "invariant_slot_attention.modules.MLP",
130
+ "hidden_size": 128,
131
+ "layernorm": "pre"
132
+ }),
133
+ "add_rel_pos_to_values": True, # V3
134
+ "zero_position_init": False, # Random positions.
135
+ "init_with_fixed_scale": None, # Random scales.
136
+ "scales_factor": 5.0,
137
+ }),
138
+
139
+ # Predictor.
140
+ # Removed since we are running a single frame.
141
+ "predictor": ml_collections.ConfigDict({
142
+ "module": "invariant_slot_attention.modules.Identity"
143
+ }),
144
+
145
+ # Initializer.
146
+ "initializer": ml_collections.ConfigDict({
147
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
148
+ "shape": (11, 64), # (num_slots, slot_size)
149
+ }),
150
+
151
+ # Decoder.
152
+ "decoder": ml_collections.ConfigDict({
153
+ "module":
154
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
155
+ "resolution": (16, 24), # Update if data resolution or strides change
156
+ "backbone": ml_collections.ConfigDict({
157
+ "module": "invariant_slot_attention.modules.CNN",
158
+ "features": [64, 64, 64, 64, 64],
159
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
160
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
161
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
162
+ "layer_transpose": [True, True, True, False, False]
163
+ }),
164
+ "target_readout": ml_collections.ConfigDict({
165
+ "module": "invariant_slot_attention.modules.Readout",
166
+ "keys": list(targets),
167
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
168
+ "module": "invariant_slot_attention.modules.MLP",
169
+ "num_hidden_layers": 0,
170
+ "hidden_size": 0,
171
+ "output_size": targets[k]}) for k in targets],
172
+ }),
173
+ "relative_positions_rotations_and_scales": True,
174
+ "pos_emb": ml_collections.ConfigDict({
175
+ "module":
176
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
177
+ "embedding_type":
178
+ "linear",
179
+ "update_type":
180
+ "project_add",
181
+ "scales_factor":
182
+ 5.0,
183
+ }),
184
+ }),
185
+ "decode_corrected": True,
186
+ "decode_predicted": False,
187
+ })
188
+
189
+ # Which video-shaped variables to visualize.
190
+ config.debug_var_video_paths = {
191
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
192
+ }
193
+
194
+ # Define which attention matrices to log/visualize.
195
+ config.debug_var_attn_paths = {
196
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
197
+ }
198
+
199
+ # Widths of attention matrices (for reshaping to image grid).
200
+ config.debug_var_attn_widths = {
201
+ "corrector_attn": 16,
202
+ }
203
+
204
+ return config
205
+
206
+
invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on Waymo Open."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "waymo_open",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 192)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ ]
79
+
80
+ config.preproc_eval = [
81
+ "tfds_image_to_tfds_video",
82
+ "video_from_tfds",
83
+ "delete_small_masks(threshold=0.01, max_instances_after=11)",
84
+ ]
85
+
86
+ config.eval_slice_size = 1
87
+ config.eval_slice_keys = ["video", "segmentations_video"]
88
+
89
+ # Dictionary of targets and corresponding channels. Losses need to match.
90
+ targets = {"video": 3}
91
+ config.losses = {"recon": {"targets": list(targets)}}
92
+ config.losses = ml_collections.ConfigDict({
93
+ f"recon_{target}": {"loss_type": "recon", "key": target}
94
+ for target in targets})
95
+
96
+ config.model = ml_collections.ConfigDict({
97
+ "module": "invariant_slot_attention.modules.SAVi",
98
+
99
+ # Encoder.
100
+ "encoder": ml_collections.ConfigDict({
101
+ "module": "invariant_slot_attention.modules.FrameEncoder",
102
+ "reduction": "spatial_flatten",
103
+ "backbone": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.ResNet34",
105
+ "num_classes": None,
106
+ "axis_name": "time",
107
+ "norm_type": "group",
108
+ "small_inputs": True
109
+ }),
110
+ "pos_emb": ml_collections.ConfigDict({
111
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
112
+ "embedding_type": "linear",
113
+ "update_type": "project_add",
114
+ "output_transform": ml_collections.ConfigDict({
115
+ "module": "invariant_slot_attention.modules.MLP",
116
+ "hidden_size": 128,
117
+ "layernorm": "pre"
118
+ }),
119
+ }),
120
+ }),
121
+
122
+ # Corrector.
123
+ "corrector": ml_collections.ConfigDict({
124
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
125
+ "num_iterations": 3,
126
+ "qkv_size": 64,
127
+ "mlp_size": 128,
128
+ "grid_encoder": ml_collections.ConfigDict({
129
+ "module": "invariant_slot_attention.modules.MLP",
130
+ "hidden_size": 128,
131
+ "layernorm": "pre"
132
+ }),
133
+ "add_rel_pos_to_values": True, # V3
134
+ "zero_position_init": False, # Random positions.
135
+ "init_with_fixed_scale": None, # Random scales.
136
+ "scales_factor": 5.0,
137
+ }),
138
+
139
+ # Predictor.
140
+ # Removed since we are running a single frame.
141
+ "predictor": ml_collections.ConfigDict({
142
+ "module": "invariant_slot_attention.modules.Identity"
143
+ }),
144
+
145
+ # Initializer.
146
+ "initializer": ml_collections.ConfigDict({
147
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
148
+ "shape": (11, 64), # (num_slots, slot_size)
149
+ }),
150
+
151
+ # Decoder.
152
+ "decoder": ml_collections.ConfigDict({
153
+ "module":
154
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
155
+ "resolution": (16, 24), # Update if data resolution or strides change
156
+ "backbone": ml_collections.ConfigDict({
157
+ "module": "invariant_slot_attention.modules.CNN",
158
+ "features": [64, 64, 64, 64, 64],
159
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
160
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
161
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
162
+ "layer_transpose": [True, True, True, False, False]
163
+ }),
164
+ "target_readout": ml_collections.ConfigDict({
165
+ "module": "invariant_slot_attention.modules.Readout",
166
+ "keys": list(targets),
167
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
168
+ "module": "invariant_slot_attention.modules.MLP",
169
+ "num_hidden_layers": 0,
170
+ "hidden_size": 0,
171
+ "output_size": targets[k]}) for k in targets],
172
+ }),
173
+ "relative_positions_and_scales": True,
174
+ "pos_emb": ml_collections.ConfigDict({
175
+ "module":
176
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
177
+ "embedding_type":
178
+ "linear",
179
+ "update_type":
180
+ "project_add",
181
+ "scales_factor":
182
+ 5.0,
183
+ }),
184
+ }),
185
+ "decode_corrected": True,
186
+ "decode_predicted": False,
187
+ })
188
+
189
+ # Which video-shaped variables to visualize.
190
+ config.debug_var_video_paths = {
191
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
192
+ }
193
+
194
+ # Define which attention matrices to log/visualize.
195
+ config.debug_var_attn_paths = {
196
+ "corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
197
+ }
198
+
199
+ # Widths of attention matrices (for reshaping to image grid).
200
+ config.debug_var_attn_widths = {
201
+ "corrector_attn": 16,
202
+ }
203
+
204
+ return config
205
+
206
+
invariant_slot_attention/lib/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
invariant_slot_attention/lib/evaluator.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Model evaluation."""
17
+
18
+ import functools
19
+ from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
20
+
21
+ from absl import logging
22
+ from clu import metrics
23
+ import flax
24
+ from flax import linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import numpy as np
28
+ import tensorflow as tf
29
+
30
+ from invariant_slot_attention.lib import losses
31
+ from invariant_slot_attention.lib import utils
32
+
33
+
34
+ Array = jnp.ndarray
35
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
36
+ PRNGKey = Array
37
+
38
+
39
+ def get_eval_metrics(
40
+ preds,
41
+ batch,
42
+ loss_fn,
43
+ eval_metrics_cls,
44
+ predicted_max_num_instances,
45
+ ground_truth_max_num_instances,
46
+ ):
47
+ """Compute the metrics for the model predictions in inference mode.
48
+
49
+ The metrics are averaged across *all* devices (of all hosts).
50
+
51
+ Args:
52
+ preds: Model predictions.
53
+ batch: Inputs that should be evaluated.
54
+ loss_fn: Loss function that takes model predictions and a batch of data.
55
+ eval_metrics_cls: Evaluation metrics collection.
56
+ predicted_max_num_instances: Maximum number of instances in prediction.
57
+ ground_truth_max_num_instances: Maximum number of instances in ground truth,
58
+ including background (which counts as a separate instance).
59
+
60
+ Returns:
61
+ The evaluation metrics.
62
+ """
63
+ loss, loss_aux = loss_fn(preds, batch)
64
+ metrics_update = eval_metrics_cls.gather_from_model_output(
65
+ loss=loss,
66
+ **loss_aux,
67
+ predicted_segmentations=utils.remove_singleton_dim(
68
+ preds["outputs"].get("segmentations")), # pytype: disable=attribute-error
69
+ ground_truth_segmentations=batch.get("segmentations"),
70
+ predicted_max_num_instances=predicted_max_num_instances,
71
+ ground_truth_max_num_instances=ground_truth_max_num_instances,
72
+ padding_mask=batch.get("padding_mask"),
73
+ mask=batch.get("mask"))
74
+ return metrics_update
75
+
76
+
77
+ def eval_first_step(
78
+ model,
79
+ state_variables,
80
+ params,
81
+ batch,
82
+ rng,
83
+ conditioning_key = None
84
+ ):
85
+ """Get the model predictions with a freshly initialized recurrent state.
86
+
87
+ The model is applied to the inputs using all devices on the host.
88
+
89
+ Args:
90
+ model: Model used in eval step.
91
+ state_variables: State variables for the model.
92
+ params: Params for the model.
93
+ batch: Inputs that should be evaluated.
94
+ rng: PRNGKey for model forward pass.
95
+ conditioning_key: Optional string. If provided, defines the batch key to be
96
+ used as conditioning signal for the model. Otherwise this is inferred from
97
+ the available keys in the batch.
98
+ Returns:
99
+ The model's predictions.
100
+ """
101
+ logging.info("eval_first_step(batch=%s)", batch)
102
+
103
+ conditioning = None
104
+ if conditioning_key:
105
+ conditioning = batch[conditioning_key]
106
+ preds, mutable_vars = model.apply(
107
+ {"params": params, **state_variables}, video=batch["video"],
108
+ conditioning=conditioning, mutable="intermediates",
109
+ rngs={"state_init": rng}, train=False,
110
+ padding_mask=batch.get("padding_mask"))
111
+
112
+ if "intermediates" in mutable_vars:
113
+ preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"])
114
+
115
+ return preds
116
+
117
+
118
+ def eval_continued_step(
119
+ model,
120
+ state_variables,
121
+ params,
122
+ batch,
123
+ rng,
124
+ recurrent_states
125
+ ):
126
+ """Get the model predictions, continuing from a provided recurrent state.
127
+
128
+ The model is applied to the inputs using all devices on the host.
129
+
130
+ Args:
131
+ model: Model used in eval step.
132
+ state_variables: State variables for the model.
133
+ params: The model parameters.
134
+ batch: Inputs that should be evaluated.
135
+ rng: PRNGKey for model forward pass.
136
+ recurrent_states: Recurrent internal model state from which to continue.
137
+ Returns:
138
+ The model's predictions.
139
+ """
140
+ logging.info("eval_continued_step(batch=%s, recurrent_states=%s)", batch,
141
+ recurrent_states)
142
+
143
+ preds, mutable_vars = model.apply(
144
+ {"params": params, **state_variables}, video=batch["video"],
145
+ conditioning=recurrent_states, continue_from_previous_state=True,
146
+ mutable="intermediates", rngs={"state_init": rng}, train=False,
147
+ padding_mask=batch.get("padding_mask"))
148
+
149
+ if "intermediates" in mutable_vars:
150
+ preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"])
151
+
152
+ return preds
153
+
154
+
155
+ def eval_step(
156
+ model,
157
+ state,
158
+ batch,
159
+ rng,
160
+ p_eval_first_step,
161
+ p_eval_continued_step,
162
+ slice_size = None,
163
+ slice_keys = None,
164
+ conditioning_key = None,
165
+ remove_from_predictions = None
166
+ ):
167
+ """Compute the metrics for the given model in inference mode.
168
+
169
+ The model is applied to the inputs using all devices on the host. Afterwards
170
+ metrics are averaged across *all* devices (of all hosts).
171
+
172
+ Args:
173
+ model: Model used in eval step.
174
+ state: Replicated model state.
175
+ batch: Inputs that should be evaluated.
176
+ rng: PRNGKey for model forward pass.
177
+ p_eval_first_step: A parallel version of the function eval_first_step.
178
+ p_eval_continued_step: A parallel version of the function
179
+ eval_continued_step.
180
+ slice_size: Optional integer, if provided, evaluate the model on temporal
181
+ slices of this size instead of on the full sequence length at once.
182
+ slice_keys: Optional list of strings, the keys of the tensors which will be
183
+ sliced if slice_size is provided.
184
+ conditioning_key: Optional string. If provided, defines the batch key to be
185
+ used as conditioning signal for the model. Otherwise this is inferred from
186
+ the available keys in the batch.
187
+ remove_from_predictions: Remove the provided keys. The default None removes
188
+ "states" and "states_pred" from model output to save memory. Disable this
189
+ if either of these are required in the loss function or for visualization.
190
+ Returns:
191
+ Model predictions.
192
+ """
193
+ if remove_from_predictions is None:
194
+ remove_from_predictions = ["states", "states_pred"]
195
+
196
+ seq_len = batch["video"].shape[2]
197
+ # Sliced evaluation (i.e. on smaller temporal slices of the video).
198
+ if slice_size is not None and slice_size < seq_len:
199
+ num_slices = int(np.ceil(seq_len / slice_size))
200
+
201
+ assert slice_keys is not None, (
202
+ "Slice keys need to be provided for sliced evaluation.")
203
+
204
+ preds_per_slice = []
205
+ # Get predictions for first slice (with fresh recurrent state).
206
+ batch_slice = utils.get_slices_along_axis(
207
+ batch, slice_keys=slice_keys, start_idx=0, end_idx=slice_size)
208
+ preds_slice = p_eval_first_step(model, state.variables,
209
+ state.params, batch_slice, rng,
210
+ conditioning_key)
211
+ preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU.
212
+ preds_per_slice.append(preds_slice)
213
+
214
+ # Iterate over remaining slices (re-using the previous recurrent state).
215
+ for slice_idx in range(1, num_slices):
216
+ recurrent_states = preds_per_slice[-1]["states_pred"]
217
+ batch_slice = utils.get_slices_along_axis(
218
+ batch, slice_keys=slice_keys, start_idx=slice_idx * slice_size,
219
+ end_idx=(slice_idx + 1) * slice_size)
220
+ preds_slice = p_eval_continued_step(
221
+ model, state.variables, state.params,
222
+ batch_slice, rng, recurrent_states)
223
+ preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU.
224
+ preds_per_slice.append(preds_slice)
225
+
226
+ # Remove states from predictions before concat to save memory.
227
+ for k in remove_from_predictions:
228
+ for i in range(num_slices):
229
+ _ = preds_per_slice[i].pop(k, None)
230
+
231
+ # Join predictions along sequence dimension.
232
+ concat_fn = lambda _, *x: functools.partial(np.concatenate, axis=2)([*x])
233
+ preds = jax.tree_map(concat_fn, preds_per_slice[0], *preds_per_slice)
234
+
235
+ # Truncate to original sequence length.
236
+ # NOTE: This op assumes that all predictions have a (complete) time axis.
237
+ preds = jax.tree_map(lambda x: x[:, :, :seq_len], preds)
238
+
239
+ # Evaluate on full sequence if no (or too large) slice size is provided.
240
+ else:
241
+ preds = p_eval_first_step(model, state.variables,
242
+ state.params, batch, rng,
243
+ conditioning_key)
244
+ for k in remove_from_predictions:
245
+ _ = preds.pop(k, None)
246
+
247
+ return preds
248
+
249
+
250
+ def evaluate(
251
+ model,
252
+ state,
253
+ eval_ds,
254
+ loss_fn,
255
+ eval_metrics_cls,
256
+ predicted_max_num_instances,
257
+ ground_truth_max_num_instances,
258
+ slice_size = None,
259
+ slice_keys = None,
260
+ conditioning_key = None,
261
+ remove_from_predictions = None,
262
+ metrics_on_cpu = False,
263
+ ):
264
+ """Evaluate the model on the given dataset."""
265
+ eval_metrics = None
266
+ batch = None
267
+ preds = None
268
+ rng = state.rng[0] # Get training state PRNGKey from first replica.
269
+
270
+ if metrics_on_cpu and jax.process_count() > 1:
271
+ raise NotImplementedError(
272
+ "metrics_on_cpu feature cannot be used in a multi-host setup."
273
+ " This experiment is using {} hosts.".format(jax.process_count()))
274
+ metric_devices = jax.devices("cpu") if metrics_on_cpu else jax.devices()
275
+
276
+ p_eval_first_step = jax.pmap(
277
+ eval_first_step,
278
+ axis_name="batch",
279
+ static_broadcasted_argnums=(0, 5),
280
+ devices=jax.devices())
281
+ p_eval_continued_step = jax.pmap(
282
+ eval_continued_step,
283
+ axis_name="batch",
284
+ static_broadcasted_argnums=(0),
285
+ devices=jax.devices())
286
+ p_get_eval_metrics = jax.pmap(
287
+ get_eval_metrics,
288
+ axis_name="batch",
289
+ static_broadcasted_argnums=(2, 3, 4, 5),
290
+ devices=metric_devices,
291
+ backend="cpu" if metrics_on_cpu else None)
292
+
293
+ def reshape_fn(x):
294
+ """Function to reshape preds and batch before calling p_get_eval_metrics."""
295
+ return np.reshape(x, [len(metric_devices), -1] + list(x.shape[2:]))
296
+
297
+ for batch in eval_ds:
298
+ rng, eval_rng = jax.random.split(rng)
299
+ eval_rng = jax.random.fold_in(eval_rng, jax.host_id()) # Bind to host.
300
+ eval_rngs = jax.random.split(eval_rng, jax.local_device_count())
301
+ batch = jax.tree_map(np.asarray, batch)
302
+ preds = eval_step(
303
+ model=model,
304
+ state=state,
305
+ batch=batch,
306
+ rng=eval_rngs,
307
+ p_eval_first_step=p_eval_first_step,
308
+ p_eval_continued_step=p_eval_continued_step,
309
+ slice_size=slice_size,
310
+ slice_keys=slice_keys,
311
+ conditioning_key=conditioning_key,
312
+ remove_from_predictions=remove_from_predictions)
313
+
314
+ if metrics_on_cpu:
315
+ # Reshape replica dim and batch-dims to work with metric_devices.
316
+ preds = jax.tree_map(reshape_fn, preds)
317
+ batch = jax.tree_map(reshape_fn, batch)
318
+ # Get metric updates.
319
+ update = p_get_eval_metrics(preds, batch, loss_fn, eval_metrics_cls,
320
+ predicted_max_num_instances,
321
+ ground_truth_max_num_instances)
322
+ update = flax.jax_utils.unreplicate(update)
323
+ eval_metrics = (
324
+ update if eval_metrics is None else eval_metrics.merge(update))
325
+ assert eval_metrics is not None
326
+ return eval_metrics, batch, preds
invariant_slot_attention/lib/input_pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Input pipeline for TFDS datasets."""
17
+
18
+ import functools
19
+ import os
20
+ from typing import Dict, List, Tuple
21
+
22
+ from clu import deterministic_data
23
+ from clu import preprocess_spec
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import ml_collections
28
+
29
+ import sunds
30
+ import tensorflow as tf
31
+ import tensorflow_datasets as tfds
32
+
33
+ from invariant_slot_attention.lib import preprocessing
34
+
35
+ Array = jnp.ndarray
36
+ PRNGKey = Array
37
+
38
+
39
+ PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords"
40
+ FEATURES_CLEVR_WITH_MASKS = {
41
+ "image": tf.io.FixedLenFeature([240, 320, 3], tf.string),
42
+ "mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string),
43
+ "x": tf.io.FixedLenFeature([11], tf.float32),
44
+ "y": tf.io.FixedLenFeature([11], tf.float32),
45
+ "z": tf.io.FixedLenFeature([11], tf.float32),
46
+ "pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32),
47
+ "rotation": tf.io.FixedLenFeature([11], tf.float32),
48
+ "size": tf.io.FixedLenFeature([11], tf.string),
49
+ "material": tf.io.FixedLenFeature([11], tf.string),
50
+ "shape": tf.io.FixedLenFeature([11], tf.string),
51
+ "color": tf.io.FixedLenFeature([11], tf.string),
52
+ "visibility": tf.io.FixedLenFeature([11], tf.float32),
53
+ }
54
+
55
+ PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords"
56
+ FEATURES_TETROMINOES = {
57
+ "image": tf.io.FixedLenFeature([35, 35, 3], tf.string),
58
+ "mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string),
59
+ "x": tf.io.FixedLenFeature([4], tf.float32),
60
+ "y": tf.io.FixedLenFeature([4], tf.float32),
61
+ "shape": tf.io.FixedLenFeature([4], tf.float32),
62
+ "color": tf.io.FixedLenFeature([4, 3], tf.float32),
63
+ "visibility": tf.io.FixedLenFeature([4], tf.float32),
64
+ }
65
+
66
+ PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords"
67
+ FEATURES_OBJECTS_ROOM = {
68
+ "image": tf.io.FixedLenFeature([64, 64, 3], tf.string),
69
+ "mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string),
70
+ }
71
+
72
+ PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords"
73
+
74
+ FEATURES_WAYMO_OPEN = {
75
+ "image": tf.io.FixedLenFeature([128, 192, 3], tf.string),
76
+ "segmentations": tf.io.FixedLenFeature([128, 192], tf.string),
77
+ "depth": tf.io.FixedLenFeature([128, 192], tf.float32),
78
+ "num_objects": tf.io.FixedLenFeature([1], tf.int64),
79
+ "has_mask": tf.io.FixedLenFeature([1], tf.int64),
80
+ "camera": tf.io.FixedLenFeature([1], tf.int64),
81
+ }
82
+
83
+
84
+ def _decode_tetrominoes(example_proto):
85
+ single_example = tf.io.parse_single_example(
86
+ example_proto, FEATURES_TETROMINOES)
87
+ for k in ["mask", "image"]:
88
+ single_example[k] = tf.squeeze(
89
+ tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
90
+ return single_example
91
+
92
+
93
+ def _decode_objects_room(example_proto):
94
+ single_example = tf.io.parse_single_example(
95
+ example_proto, FEATURES_OBJECTS_ROOM)
96
+ for k in ["mask", "image"]:
97
+ single_example[k] = tf.squeeze(
98
+ tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
99
+ return single_example
100
+
101
+
102
+ def _decode_clevr_with_masks(example_proto):
103
+ single_example = tf.io.parse_single_example(
104
+ example_proto, FEATURES_CLEVR_WITH_MASKS)
105
+ for k in ["mask", "image", "color", "material", "shape", "size"]:
106
+ single_example[k] = tf.squeeze(
107
+ tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
108
+ return single_example
109
+
110
+
111
+ def _decode_waymo_open(example_proto):
112
+ """Unserializes a serialized tf.train.Example sample."""
113
+ single_example = tf.io.parse_single_example(
114
+ example_proto, FEATURES_WAYMO_OPEN)
115
+ for k in ["image", "segmentations"]:
116
+ single_example[k] = tf.squeeze(
117
+ tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
118
+ single_example["segmentations"] = tf.expand_dims(
119
+ single_example["segmentations"], axis=-1)
120
+ single_example["depth"] = tf.expand_dims(
121
+ single_example["depth"], axis=-1)
122
+ return single_example
123
+
124
+
125
+ def _preprocess_minimal(example):
126
+ return {
127
+ "image": example["image"],
128
+ "segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8),
129
+ }
130
+
131
+
132
+ def _sunds_create_task():
133
+ """Create a sunds task to return images and instance segmentation."""
134
+ return sunds.tasks.Nerf(
135
+ yield_mode=sunds.tasks.YieldMode.IMAGE,
136
+ additional_camera_specs={
137
+ "depth_image": False, # Not available in the dataset.
138
+ "category_image": False, # Not available in the dataset.
139
+ "instance_image": True,
140
+ "extrinsics": True,
141
+ },
142
+ additional_frame_specs={"pose": True},
143
+ add_name=True
144
+ )
145
+
146
+
147
+ def preprocess_example(features,
148
+ preprocess_strs):
149
+ """Processes a single data example.
150
+
151
+ Args:
152
+ features: A dictionary containing the tensors of a single data example.
153
+ preprocess_strs: List of strings, describing one preprocessing operation
154
+ each, in clu.preprocess_spec format.
155
+
156
+ Returns:
157
+ Dictionary containing the preprocessed tensors of a single data example.
158
+ """
159
+ all_ops = preprocessing.all_ops()
160
+ preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops)
161
+ return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types
162
+
163
+
164
+ def get_batch_dims(global_batch_size):
165
+ """Gets the first two axis sizes for data batches.
166
+
167
+ Args:
168
+ global_batch_size: Integer, the global batch size (across all devices).
169
+
170
+ Returns:
171
+ List of batch dimensions
172
+
173
+ Raises:
174
+ ValueError if the requested dimensions don't make sense with the
175
+ number of devices.
176
+ """
177
+ num_local_devices = jax.local_device_count()
178
+ if global_batch_size % jax.host_count() != 0:
179
+ raise ValueError(f"Global batch size {global_batch_size} not evenly "
180
+ f"divisble with {jax.host_count()}.")
181
+ per_host_batch_size = global_batch_size // jax.host_count()
182
+ if per_host_batch_size % num_local_devices != 0:
183
+ raise ValueError(f"Global batch size {global_batch_size} not evenly "
184
+ f"divisible with {jax.host_count()} hosts with a per host "
185
+ f"batch size of {per_host_batch_size} and "
186
+ f"{num_local_devices} local devices. ")
187
+ return [num_local_devices, per_host_batch_size // num_local_devices]
188
+
189
+
190
+ def create_datasets(
191
+ config,
192
+ data_rng):
193
+ """Create datasets for training and evaluation.
194
+
195
+ For the same data_rng and config this will return the same datasets. The
196
+ datasets only contain stateless operations.
197
+
198
+ Args:
199
+ config: Configuration to use.
200
+ data_rng: JAX PRNGKey for dataset pipeline.
201
+
202
+ Returns:
203
+ A tuple with the training dataset and the evaluation dataset.
204
+ """
205
+
206
+ if config.data.dataset_name == "tetrominoes":
207
+ ds = tf.data.TFRecordDataset(
208
+ PATH_TETROMINOES,
209
+ compression_type="GZIP", buffer_size=2*(2**20))
210
+ ds = ds.map(_decode_tetrominoes,
211
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
212
+ ds = ds.map(_preprocess_minimal,
213
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
214
+
215
+ class TetrominoesBuilder:
216
+ """Builder for tentrominoes dataset."""
217
+
218
+ def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
219
+ """Simple function to conform to the builder api."""
220
+ if split == "train":
221
+ # We use 512 training examples.
222
+ ds = ds.skip(100)
223
+ ds = ds.take(512)
224
+ return tf.data.experimental.assert_cardinality(512)(ds)
225
+ elif split == "validation":
226
+ # 100 validation examples.
227
+ ds = ds.take(100)
228
+ return tf.data.experimental.assert_cardinality(100)(ds)
229
+ else:
230
+ raise ValueError("Invalid split.")
231
+
232
+ dataset_builder = TetrominoesBuilder()
233
+ elif config.data.dataset_name == "objects_room":
234
+ ds = tf.data.TFRecordDataset(
235
+ PATH_OBJECTS_ROOM,
236
+ compression_type="GZIP", buffer_size=2*(2**20))
237
+ ds = ds.map(_decode_objects_room,
238
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
239
+ ds = ds.map(_preprocess_minimal,
240
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
241
+
242
+ class ObjectsRoomBuilder:
243
+ """Builder for objects room dataset."""
244
+
245
+ def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
246
+ """Simple function to conform to the builder api."""
247
+ if split == "train":
248
+ # 1M - 100 training examples.
249
+ ds = ds.skip(100)
250
+ return tf.data.experimental.assert_cardinality(999900)(ds)
251
+ elif split == "validation":
252
+ # 100 validation examples.
253
+ ds = ds.take(100)
254
+ return tf.data.experimental.assert_cardinality(100)(ds)
255
+ else:
256
+ raise ValueError("Invalid split.")
257
+
258
+ dataset_builder = ObjectsRoomBuilder()
259
+ elif config.data.dataset_name == "clevr_with_masks":
260
+ ds = tf.data.TFRecordDataset(
261
+ PATH_CLEVR_WITH_MASKS,
262
+ compression_type="GZIP", buffer_size=2*(2**20))
263
+ ds = ds.map(_decode_clevr_with_masks,
264
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
265
+ ds = ds.map(_preprocess_minimal,
266
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
267
+
268
+ class CLEVRWithMasksBuilder:
269
+ def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
270
+ if split == "train":
271
+ ds = ds.skip(100)
272
+ return tf.data.experimental.assert_cardinality(99900)(ds)
273
+ elif split == "validation":
274
+ ds = ds.take(100)
275
+ return tf.data.experimental.assert_cardinality(100)(ds)
276
+ else:
277
+ raise ValueError("Invalid split.")
278
+
279
+ dataset_builder = CLEVRWithMasksBuilder()
280
+ elif config.data.dataset_name == "waymo_open":
281
+ train_path = os.path.join(
282
+ PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*")
283
+ eval_path = os.path.join(
284
+ PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*")
285
+
286
+ train_files = tf.data.Dataset.list_files(train_path)
287
+ eval_files = tf.data.Dataset.list_files(eval_path)
288
+
289
+ train_data_reader = functools.partial(
290
+ tf.data.TFRecordDataset,
291
+ compression_type="ZLIB", buffer_size=2*(2**20))
292
+ eval_data_reader = functools.partial(
293
+ tf.data.TFRecordDataset,
294
+ compression_type="ZLIB", buffer_size=2*(2**20))
295
+
296
+ train_dataset = train_files.interleave(
297
+ train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
298
+ eval_dataset = eval_files.interleave(
299
+ eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
300
+
301
+ train_dataset = train_dataset.map(
302
+ _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
303
+ eval_dataset = eval_dataset.map(
304
+ _decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
305
+
306
+ # We need to set the dataset cardinality. We assume we have
307
+ # the full dataset.
308
+ train_dataset = train_dataset.apply(
309
+ tf.data.experimental.assert_cardinality(158081))
310
+
311
+ class WaymoOpenBuilder:
312
+ def as_dataset(self, split, *unused_args, **unused_kwargs):
313
+ if split == "train":
314
+ return train_dataset
315
+ elif split == "validation":
316
+ return eval_dataset
317
+ else:
318
+ raise ValueError("Invalid split.")
319
+
320
+ dataset_builder = WaymoOpenBuilder()
321
+ elif config.data.dataset_name == "multishapenet_easy":
322
+ dataset_builder = sunds.builder(
323
+ name=config.get("tfds_name", "msn_easy"),
324
+ data_dir=config.get(
325
+ "data_dir", "gs://kubric-public/tfds"),
326
+ try_gcs=True)
327
+ dataset_builder.as_dataset = functools.partial(
328
+ dataset_builder.as_dataset, task=_sunds_create_task())
329
+ elif config.data.dataset_name == "tfds":
330
+ dataset_builder = tfds.builder(
331
+ config.data.tfds_name, data_dir=config.data.data_dir)
332
+ else:
333
+ raise ValueError("Please specify a valid dataset name.")
334
+
335
+ batch_dims = get_batch_dims(config.batch_size)
336
+
337
+ train_preprocess_fn = functools.partial(
338
+ preprocess_example, preprocess_strs=config.preproc_train)
339
+ eval_preprocess_fn = functools.partial(
340
+ preprocess_example, preprocess_strs=config.preproc_eval)
341
+
342
+ train_split_name = config.get("train_split", "train")
343
+ eval_split_name = config.get("validation_split", "validation")
344
+
345
+ train_ds = deterministic_data.create_dataset(
346
+ dataset_builder,
347
+ split=train_split_name,
348
+ rng=data_rng,
349
+ preprocess_fn=train_preprocess_fn,
350
+ cache=False,
351
+ shuffle_buffer_size=config.data.shuffle_buffer_size,
352
+ batch_dims=batch_dims,
353
+ num_epochs=None,
354
+ shuffle=True)
355
+
356
+ if config.data.dataset_name == "waymo_open":
357
+ # We filter Waymo Open for empty segmentation masks.
358
+ def filter_fn(features):
359
+ unique_instances = tf.unique(
360
+ tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0]
361
+ n_instances = tf.size(unique_instances, tf.int32)
362
+ # n_instances == 1 means we only have the background.
363
+ return 2 <= n_instances
364
+ else:
365
+ filter_fn = None
366
+
367
+ eval_ds = deterministic_data.create_dataset(
368
+ dataset_builder,
369
+ split=eval_split_name,
370
+ rng=None,
371
+ preprocess_fn=eval_preprocess_fn,
372
+ filter_fn=filter_fn,
373
+ cache=False,
374
+ batch_dims=batch_dims,
375
+ num_epochs=1,
376
+ shuffle=False,
377
+ pad_up_to_batches=None)
378
+
379
+ if config.data.dataset_name == "waymo_open":
380
+ # We filter Waymo Open for empty segmentation masks after preprocessing.
381
+ # For the full dataset, we know how many we will end up with.
382
+ eval_batch_size = batch_dims[0] * batch_dims[1]
383
+ # We don't pad the last batch => floor.
384
+ eval_num_batches = int(
385
+ jnp.floor(1872 / eval_batch_size / jax.host_count()))
386
+ eval_ds = eval_ds.apply(
387
+ tf.data.experimental.assert_cardinality(
388
+ eval_num_batches))
389
+
390
+ return train_ds, eval_ds
invariant_slot_attention/lib/losses.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Loss functions."""
17
+
18
+ import functools
19
+ import inspect
20
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import ml_collections
25
+
26
+ _LOSS_FUNCTIONS = {}
27
+
28
+ Array = Any # jnp.ndarray somehow doesn't work anymore for pytype.
29
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
30
+ ArrayDict = Dict[str, Array]
31
+ DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet
32
+ PRNGKey = Array
33
+ LossFn = Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]],
34
+ Tuple[Array, ArrayTree]]
35
+ ConfigAttr = Any
36
+ MetricSpec = Dict[str, str]
37
+
38
+
39
+ def standardize_loss_config(
40
+ loss_config
41
+ ):
42
+ """Standardize loss configs into a common ConfigDict format.
43
+
44
+ Args:
45
+ loss_config: List of strings or ConfigDict specifying loss configuration.
46
+ Valid input formats are: - Option 1 (list of strings), for example,
47
+ `loss_config = ["box", "presence"]` - Option 2 (losses with weights
48
+ only), for example,
49
+ `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3
50
+ (losses with weights and other parameters), for example,
51
+ `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"},
52
+ "presence": {"weight": 2}})`
53
+
54
+ Returns:
55
+ Standardized ConfigDict containing the loss configuration.
56
+
57
+ Raises:
58
+ ValueError: If loss_config is a list that contains non-string entries.
59
+ """
60
+
61
+ if isinstance(loss_config, Sequence): # Option 1
62
+ if not all(isinstance(loss_type, str) for loss_type in loss_config):
63
+ raise ValueError(f"Loss types all need to be str but got {loss_config}")
64
+ return ml_collections.FrozenConfigDict({k: {} for k in loss_config})
65
+
66
+ # Convert all option-2-style weights to option-3-style dictionaries.
67
+ loss_config = {
68
+ k: {
69
+ "weight": v
70
+ } if isinstance(v, (float, int)) else v for k, v in loss_config.items()
71
+ }
72
+ return ml_collections.FrozenConfigDict(loss_config)
73
+
74
+
75
+ def update_loss_aux(loss_aux, update):
76
+ existing_keys = set(update.keys()).intersection(loss_aux.keys())
77
+ if existing_keys:
78
+ raise KeyError(
79
+ f"Can't overwrite existing keys in loss_aux: {existing_keys}")
80
+ loss_aux.update(update)
81
+
82
+
83
+ def compute_full_loss(
84
+ preds, targets,
85
+ loss_config
86
+ ):
87
+ """Loss function that parses and combines weighted loss terms.
88
+
89
+ Args:
90
+ preds: Dictionary of tensors containing model predictions.
91
+ targets: Dictionary of tensors containing prediction targets.
92
+ loss_config: List of strings or ConfigDict specifying loss configuration.
93
+ See @register_loss decorated functions below for valid loss names.
94
+ Valid losses formats are: - Option 1 (list of strings), for example,
95
+ `loss_config = ["box", "presence"]` - Option 2 (losses with weights
96
+ only), for example,
97
+ `loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 (losses
98
+ with weights and other parameters), for example,
99
+ `loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"},
100
+ "presence": {"weight": 2}})` - Option 4 (like
101
+ 3 but decoupling name and loss_type), for
102
+ example,
103
+ `loss_config = ConfigDict({"recon_flow": {"loss_type": "recon",
104
+ "key": "flow"},
105
+ "recon_video": {"loss_type": "recon",
106
+ "key": "video"}})`
107
+
108
+ Returns:
109
+ A 2-tuple of the sum of all individual loss terms and a dictionary of
110
+ auxiliary losses and metrics.
111
+ """
112
+
113
+ loss = jnp.zeros([], jnp.float32)
114
+ loss_aux = {}
115
+ loss_config = standardize_loss_config(loss_config)
116
+ for loss_name, cfg in loss_config.items():
117
+ context_kwargs = {"preds": preds, "targets": targets}
118
+ weight, loss_term, loss_aux_update = compute_loss_term(
119
+ loss_name=loss_name, context_kwargs=context_kwargs, config_kwargs=cfg)
120
+
121
+ unweighted_loss = jnp.mean(loss_term)
122
+ loss += weight * unweighted_loss
123
+ loss_aux_update[loss_name + "_value"] = unweighted_loss
124
+ loss_aux_update[loss_name + "_weight"] = jnp.ones_like(unweighted_loss)
125
+ update_loss_aux(loss_aux, loss_aux_update)
126
+ return loss, loss_aux
127
+
128
+
129
+ def register_loss(func=None,
130
+ *,
131
+ name = None,
132
+ check_unused_kwargs = True):
133
+ """Decorator for registering a loss function.
134
+
135
+ Can be used without arguments:
136
+ ```
137
+ @register_loss
138
+ def my_loss(**_):
139
+ return 0
140
+ ```
141
+ or with keyword arguments:
142
+ ```
143
+ @register_loss(name="my_renamed_loss")
144
+ def my_loss(**_):
145
+ return 0
146
+ ```
147
+
148
+ Loss functions may accept
149
+ - context kwargs: `preds` and `targets`
150
+ - config kwargs: any argument specified in the config
151
+ - the special `config_kwargs` parameter that contains the entire loss config
152
+ Loss functions also _need_ to accept a **kwarg argument to support extending
153
+ the interface.
154
+ They should return either:
155
+ - just the computed loss (pre-reduction)
156
+ - or a tuple of the computed loss and a loss_aux_updates dict
157
+
158
+ Args:
159
+ func: the decorated function
160
+ name (str): Optional name to be used for this loss in the config. Defaults
161
+ to the name of the function.
162
+ check_unused_kwargs (bool): By default compute_loss_term raises an error if
163
+ there are any unused config kwargs. If this flag is set to False that step
164
+ is skipped. This is useful if the config_kwargs should be passed onward to
165
+ another function.
166
+
167
+ Returns:
168
+ The decorated function (or a partial of the decorator)
169
+ """
170
+ # If this decorator has been called with parameters but no function, then we
171
+ # return the decorator again (but with partially filled parameters).
172
+ # This allows using both @register_loss and @register_loss(name="foo")
173
+ if func is None:
174
+ return functools.partial(
175
+ register_loss, name=name, check_unused_kwargs=check_unused_kwargs)
176
+
177
+ # No (further) arguments: this is the actual decorator
178
+ # ensure that the loss function includes a **kwargs argument
179
+ loss_name = name if name is not None else func.__name__
180
+ if not any(v.kind == inspect.Parameter.VAR_KEYWORD
181
+ for k, v in inspect.signature(func).parameters.items()):
182
+ raise TypeError(
183
+ f"Loss function '{loss_name}' needs to include a **kwargs argument")
184
+ func.name = loss_name
185
+ func.check_unused_kwargs = check_unused_kwargs
186
+ _LOSS_FUNCTIONS[loss_name] = func
187
+ return func
188
+
189
+
190
+ def compute_loss_term(
191
+ loss_name, context_kwargs,
192
+ config_kwargs):
193
+ """Compute a loss function given its config and context parameters.
194
+
195
+ Takes care of:
196
+ - finding the correct loss function based on "loss_type" or name
197
+ - the optional "weight" parameter
198
+ - checking for typos and collisions in config parameters
199
+ - adding the optional loss_aux_updates if omitted by the loss_fn
200
+
201
+ Args:
202
+ loss_name: Name of the loss, i.e. its key in the config.losses dict.
203
+ context_kwargs: Dictionary of context variables (`preds` and `targets`)
204
+ config_kwargs: The config dict for this loss.
205
+
206
+ Returns:
207
+ 1. the loss weight (float)
208
+ 2. loss term (Array)
209
+ 3. loss aux updates (Dict[str, Array])
210
+
211
+ Raises:
212
+ KeyError:
213
+ Unknown loss_type
214
+ KeyError:
215
+ Unused config entries, i.e. not used by the loss function.
216
+ Not raised if using @register_loss(check_unused_kwargs=False)
217
+ KeyError: Config entry with a name that conflicts with a context_kwarg
218
+ ValueError: Non-numerical weight in config_kwargs
219
+
220
+ """
221
+
222
+ # Make a dict copy of config_kwargs
223
+ kwargs = {k: v for k, v in config_kwargs.items()}
224
+
225
+ # Get the loss function
226
+ loss_type = kwargs.pop("loss_type", loss_name)
227
+ if loss_type not in _LOSS_FUNCTIONS:
228
+ raise KeyError(f"Unknown loss_type '{loss_type}'.")
229
+ loss_fn = _LOSS_FUNCTIONS[loss_type]
230
+
231
+ # Take care of "weight" term
232
+ weight = kwargs.pop("weight", 1.0)
233
+ if not isinstance(weight, (int, float)):
234
+ raise ValueError(f"Weight for loss {loss_name} should be a number, "
235
+ f"but was {weight}.")
236
+
237
+ # Check for unused config entries (to prevent typos etc.)
238
+ config_keys = set(kwargs)
239
+ if loss_fn.check_unused_kwargs:
240
+ param_names = set(inspect.signature(loss_fn).parameters)
241
+ unused_config_keys = config_keys - param_names
242
+ if unused_config_keys:
243
+ raise KeyError(f"Unrecognized config entries {unused_config_keys} "
244
+ f"for loss {loss_name}.")
245
+
246
+ # Check for key collisions between context and config
247
+ conflicting_config_keys = config_keys.intersection(context_kwargs)
248
+ if conflicting_config_keys:
249
+ raise KeyError(f"The config keys {conflicting_config_keys} conflict "
250
+ f"with the context parameters ({context_kwargs.keys()}) "
251
+ f"for loss {loss_name}.")
252
+
253
+ # Construct the arguments for the loss function
254
+ kwargs.update(context_kwargs)
255
+ kwargs["config_kwargs"] = config_kwargs
256
+
257
+ # Call loss
258
+ results = loss_fn(**kwargs)
259
+
260
+ # Add empty loss_aux_updates if necessary
261
+ if isinstance(results, Tuple):
262
+ loss, loss_aux_update = results
263
+ else:
264
+ loss, loss_aux_update = results, {}
265
+
266
+ return weight, loss, loss_aux_update
267
+
268
+
269
+ # -------- Loss functions --------
270
+ @register_loss
271
+ def recon(preds,
272
+ targets,
273
+ key = "video",
274
+ reduction_type = "sum",
275
+ **_):
276
+ """Reconstruction loss (MSE)."""
277
+ squared_l2_norm_fn = jax.vmap(functools.partial(
278
+ squared_l2_norm, reduction_type=reduction_type))
279
+ targets = targets[key]
280
+ loss = squared_l2_norm_fn(preds["outputs"][key], targets)
281
+ if reduction_type == "mean":
282
+ # This rescaling reflects taking the sum over feature axis &
283
+ # mean over space/time axes.
284
+ loss *= targets.shape[-1] # pytype: disable=attribute-error # allow-recursive-types
285
+ return jnp.mean(loss)
286
+
287
+
288
+ def squared_l2_norm(preds, targets,
289
+ reduction_type = "sum"):
290
+ if reduction_type == "sum":
291
+ return jnp.sum(jnp.square(preds - targets))
292
+ elif reduction_type == "mean":
293
+ return jnp.mean(jnp.square(preds - targets))
294
+ else:
295
+ raise ValueError(f"Unsupported reduction_type: {reduction_type}")
invariant_slot_attention/lib/metrics.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Clustering metrics."""
17
+
18
+ from typing import Optional, Sequence, Union
19
+
20
+ from clu import metrics
21
+ import flax
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+
26
+ Ndarray = Union[np.ndarray, jnp.ndarray]
27
+
28
+
29
+ def check_shape(x, expected_shape, name):
30
+ """Check whether shape x is as expected.
31
+
32
+ Args:
33
+ x: Any data type with `shape` attribute. If `shape` attribute is not present
34
+ it is assumed to be a scalar with shape ().
35
+ expected_shape: The shape that is expected of x. For example,
36
+ [None, None, 3] can be the `expected_shape` for a color image,
37
+ [4, None, None, 3] if we know that batch size is 4.
38
+ name: Name of `x` to provide informative error messages.
39
+
40
+ Raises: ValueError if x's shape does not match expected_shape. Also raises
41
+ ValueError if expected_shape is not a list or tuple.
42
+ """
43
+ if not isinstance(expected_shape, (list, tuple)):
44
+ raise ValueError(
45
+ "expected_shape should be a list or tuple of ints but got "
46
+ f"{expected_shape}.")
47
+
48
+ # Scalars have shape () by definition.
49
+ shape = getattr(x, "shape", ())
50
+
51
+ if (len(shape) != len(expected_shape) or
52
+ any(j is not None and i != j for i, j in zip(shape, expected_shape))):
53
+ raise ValueError(
54
+ f"Input {name} had shape {shape} but {expected_shape} was expected.")
55
+
56
+
57
+ def _validate_inputs(predicted_segmentations,
58
+ ground_truth_segmentations,
59
+ padding_mask,
60
+ mask = None):
61
+ """Checks that all inputs have the expected shapes.
62
+
63
+ Args:
64
+ predicted_segmentations: An array of integers of shape [bs, seq_len, H, W]
65
+ containing model segmentation predictions.
66
+ ground_truth_segmentations: An array of integers of shape [bs, seq_len, H,
67
+ W] containing ground truth segmentations.
68
+ padding_mask: An array of integers of shape [bs, seq_len, H, W] defining
69
+ regions where the ground truth is meaningless, for example because this
70
+ corresponds to regions which were padded during data augmentation. Value 0
71
+ corresponds to padded regions, 1 corresponds to valid regions to be used
72
+ for metric calculation.
73
+ mask: An optional array of boolean mask values of shape [bs]. `True`
74
+ corresponds to actual batch examples whereas `False` corresponds to
75
+ padding.
76
+
77
+ Raises:
78
+ ValueError if the inputs are not valid.
79
+ """
80
+
81
+ check_shape(
82
+ predicted_segmentations, [None, None, None, None],
83
+ "predicted_segmentations [bs, seq_len, h, w]")
84
+ check_shape(
85
+ ground_truth_segmentations, [None, None, None, None],
86
+ "ground_truth_segmentations [bs, seq_len, h, w]")
87
+ check_shape(
88
+ predicted_segmentations, ground_truth_segmentations.shape,
89
+ "predicted_segmentations [should match ground_truth_segmentations]")
90
+ check_shape(
91
+ padding_mask, ground_truth_segmentations.shape,
92
+ "padding_mask [should match ground_truth_segmentations]")
93
+
94
+ if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer):
95
+ raise ValueError("predicted_segmentations has to be integer-valued. "
96
+ "Got {}".format(predicted_segmentations.dtype))
97
+
98
+ if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer):
99
+ raise ValueError("ground_truth_segmentations has to be integer-valued. "
100
+ "Got {}".format(ground_truth_segmentations.dtype))
101
+
102
+ if not jnp.issubdtype(padding_mask.dtype, jnp.integer):
103
+ raise ValueError("padding_mask has to be integer-valued. "
104
+ "Got {}".format(padding_mask.dtype))
105
+
106
+ if mask is not None:
107
+ check_shape(mask, [None], "mask [bs]")
108
+ if not jnp.issubdtype(mask.dtype, jnp.bool_):
109
+ raise ValueError("mask has to be boolean. Got {}".format(mask.dtype))
110
+
111
+
112
+ def adjusted_rand_index(true_ids, pred_ids,
113
+ num_instances_true, num_instances_pred,
114
+ padding_mask = None,
115
+ ignore_background = False):
116
+ """Computes the adjusted Rand index (ARI), a clustering similarity score.
117
+
118
+ Args:
119
+ true_ids: An integer-valued array of shape
120
+ [batch_size, seq_len, H, W]. The true cluster assignment encoded
121
+ as integer ids.
122
+ pred_ids: An integer-valued array of shape
123
+ [batch_size, seq_len, H, W]. The predicted cluster assignment
124
+ encoded as integer ids.
125
+ num_instances_true: An integer, the number of instances in true_ids
126
+ (i.e. max(true_ids) + 1).
127
+ num_instances_pred: An integer, the number of instances in true_ids
128
+ (i.e. max(pred_ids) + 1).
129
+ padding_mask: An array of integers of shape [batch_size, seq_len, H, W]
130
+ defining regions where the ground truth is meaningless, for example
131
+ because this corresponds to regions which were padded during data
132
+ augmentation. Value 0 corresponds to padded regions, 1 corresponds to
133
+ valid regions to be used for metric calculation.
134
+ ignore_background: Boolean, if True, then ignore all pixels where
135
+ true_ids == 0 (default: False).
136
+
137
+ Returns:
138
+ ARI scores as a float32 array of shape [batch_size].
139
+
140
+ References:
141
+ Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
142
+ https://link.springer.com/article/10.1007/BF01908075
143
+ Wikipedia
144
+ https://en.wikipedia.org/wiki/Rand_index
145
+ Scikit Learn
146
+ http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
147
+ """
148
+ # pylint: disable=invalid-name
149
+ true_oh = jax.nn.one_hot(true_ids, num_instances_true)
150
+ pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred)
151
+ if padding_mask is not None:
152
+ true_oh = true_oh * padding_mask[Ellipsis, None]
153
+ # pred_oh = pred_oh * padding_mask[..., None] # <-- not needed
154
+
155
+ if ignore_background:
156
+ true_oh = true_oh[Ellipsis, 1:] # Remove the background row.
157
+
158
+ N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
159
+ A = jnp.sum(N, axis=-1) # row-sum (batch_size, c)
160
+ B = jnp.sum(N, axis=-2) # col-sum (batch_size, k)
161
+ num_points = jnp.sum(A, axis=1)
162
+
163
+ rindex = jnp.sum(N * (N - 1), axis=[1, 2])
164
+ aindex = jnp.sum(A * (A - 1), axis=1)
165
+ bindex = jnp.sum(B * (B - 1), axis=1)
166
+ expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1)
167
+ max_rindex = (aindex + bindex) / 2
168
+ denominator = max_rindex - expected_rindex
169
+ ari = (rindex - expected_rindex) / denominator
170
+
171
+ # There are two cases for which the denominator can be zero:
172
+ # 1. If both label_pred and label_true assign all pixels to a single cluster.
173
+ # (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
174
+ # 2. If both label_pred and label_true assign max 1 point to each cluster.
175
+ # (max_rindex == expected_rindex == rindex == 0)
176
+ # In both cases, we want the ARI score to be 1.0:
177
+ return jnp.where(denominator, ari, 1.0)
178
+
179
+
180
+ @flax.struct.dataclass
181
+ class Ari(metrics.Average):
182
+ """Adjusted Rand Index (ARI) computed from predictions and labels.
183
+
184
+ ARI is a similarity score to compare two clusterings. ARI returns values in
185
+ the range [-1, 1], where 1 corresponds to two identical clusterings (up to
186
+ permutation), i.e. a perfect match between the predicted clustering and the
187
+ ground-truth clustering. A value of (close to) 0 corresponds to chance.
188
+ Negative values corresponds to cases where the agreement between the
189
+ clusterings is less than expected from a random assignment.
190
+
191
+ In this implementation, we use ARI to compare predicted instance segmentation
192
+ masks (including background prediction) with ground-truth segmentation
193
+ annotations.
194
+ """
195
+
196
+ @classmethod
197
+ def from_model_output(cls,
198
+ predicted_segmentations,
199
+ ground_truth_segmentations,
200
+ padding_mask,
201
+ ground_truth_max_num_instances,
202
+ predicted_max_num_instances,
203
+ ignore_background = False,
204
+ mask = None,
205
+ **_):
206
+ """Computation of the ARI clustering metric.
207
+
208
+ NOTE: This implementation does not currently support padding masks.
209
+
210
+ Args:
211
+ predicted_segmentations: An array of integers of shape
212
+ [bs, seq_len, H, W] containing model segmentation predictions.
213
+ ground_truth_segmentations: An array of integers of shape
214
+ [bs, seq_len, H, W] containing ground truth segmentations.
215
+ padding_mask: An array of integers of shape [bs, seq_len, H, W]
216
+ defining regions where the ground truth is meaningless, for example
217
+ because this corresponds to regions which were padded during data
218
+ augmentation. Value 0 corresponds to padded regions, 1 corresponds to
219
+ valid regions to be used for metric calculation.
220
+ ground_truth_max_num_instances: Maximum number of instances (incl.
221
+ background, which counts as the 0-th instance) possible in the dataset.
222
+ predicted_max_num_instances: Maximum number of predicted instances (incl.
223
+ background).
224
+ ignore_background: If True, then ignore all pixels where
225
+ ground_truth_segmentations == 0 (default: False).
226
+ mask: An optional array of boolean mask values of shape [bs]. `True`
227
+ corresponds to actual batch examples whereas `False` corresponds to
228
+ padding.
229
+
230
+ Returns:
231
+ Object of Ari with computed intermediate values.
232
+ """
233
+ _validate_inputs(
234
+ predicted_segmentations=predicted_segmentations,
235
+ ground_truth_segmentations=ground_truth_segmentations,
236
+ padding_mask=padding_mask,
237
+ mask=mask)
238
+
239
+ batch_size = predicted_segmentations.shape[0]
240
+
241
+ if mask is None:
242
+ mask = jnp.ones(batch_size, dtype=padding_mask.dtype)
243
+ else:
244
+ mask = jnp.asarray(mask, dtype=padding_mask.dtype)
245
+
246
+ ari_batch = adjusted_rand_index(
247
+ pred_ids=predicted_segmentations,
248
+ true_ids=ground_truth_segmentations,
249
+ num_instances_true=ground_truth_max_num_instances,
250
+ num_instances_pred=predicted_max_num_instances,
251
+ padding_mask=padding_mask,
252
+ ignore_background=ignore_background)
253
+ return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask)) # pylint: disable=unexpected-keyword-arg
254
+
255
+
256
+ @flax.struct.dataclass
257
+ class AriNoBg(Ari):
258
+ """Adjusted Rand Index (ARI), ignoring the ground-truth background label."""
259
+
260
+ @classmethod
261
+ def from_model_output(cls, **kwargs):
262
+ """See `Ari` docstring for allowed keyword arguments."""
263
+ return super().from_model_output(**kwargs, ignore_background=True)
invariant_slot_attention/lib/preprocessing.py ADDED
@@ -0,0 +1,1236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Video preprocessing ops."""
17
+
18
+ import abc
19
+ import dataclasses
20
+ import functools
21
+ from typing import Optional, Sequence, Tuple, Union
22
+
23
+ from absl import logging
24
+ from clu import preprocess_spec
25
+
26
+ import numpy as np
27
+ import tensorflow as tf
28
+
29
+ from invariant_slot_attention.lib import transforms
30
+
31
+ Features = preprocess_spec.Features
32
+ all_ops = lambda: preprocess_spec.get_all_ops(__name__)
33
+ SEED_KEY = preprocess_spec.SEED_KEY
34
+ NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding.
35
+ NOTRACK_LABEL = -1
36
+
37
+ IMAGE = "image"
38
+ VIDEO = "video"
39
+ SEGMENTATIONS = "segmentations"
40
+ RAGGED_SEGMENTATIONS = "ragged_segmentations"
41
+ SPARSE_SEGMENTATIONS = "sparse_segmentations"
42
+ SHAPE = "shape"
43
+ PADDING_MASK = "padding_mask"
44
+ RAGGED_BOXES = "ragged_boxes"
45
+ BOXES = "boxes"
46
+ FRAMES = "frames"
47
+ FLOW = "flow"
48
+ DEPTH = "depth"
49
+ ORIGINAL_SIZE = "original_size"
50
+ INSTANCE_LABELS = "instance_labels"
51
+ INSTANCE_MULTI_LABELS = "instance_multi_labels"
52
+ BOXES_VIDEO = "boxes_video"
53
+ IMAGE_PADDING_MASK = "image_padding_mask"
54
+ VIDEO_PADDING_MASK = "video_padding_mask"
55
+
56
+
57
+ def convert_uint16_to_float(array, min_val, max_val):
58
+ return tf.cast(array, tf.float32) / 65535. * (max_val - min_val) + min_val
59
+
60
+
61
+ def get_resize_small_shape(original_size,
62
+ small_size):
63
+ h, w = original_size
64
+ ratio = (
65
+ tf.cast(small_size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32))
66
+ h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
67
+ w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)
68
+ return h, w
69
+
70
+
71
+ def adjust_small_size(original_size,
72
+ small_size, max_size):
73
+ """Computes the adjusted small size to ensure large side < max_size."""
74
+ h, w = original_size
75
+ min_original_size = tf.cast(tf.minimum(w, h), tf.float32)
76
+ max_original_size = tf.cast(tf.maximum(w, h), tf.float32)
77
+ if max_original_size / min_original_size * small_size > max_size:
78
+ small_size = tf.cast(tf.floor(
79
+ max_size * min_original_size / max_original_size), tf.int32)
80
+ return small_size
81
+
82
+
83
+ def crop_or_pad_boxes(boxes, top, left, height,
84
+ width, h_orig, w_orig):
85
+ """Transforms the relative box coordinates according to the frame crop.
86
+
87
+ Note that, if height/width are larger than h_orig/w_orig, this function
88
+ implements the equivalent of padding.
89
+
90
+ Args:
91
+ boxes: Tensor of bounding boxes with shape (..., 4).
92
+ top: Top of crop box in absolute pixel coordinates.
93
+ left: Left of crop box in absolute pixel coordinates.
94
+ height: Height of crop box in absolute pixel coordinates.
95
+ width: Width of crop box in absolute pixel coordinates.
96
+ h_orig: Original image height in absolute pixel coordinates.
97
+ w_orig: Original image width in absolute pixel coordinates.
98
+ Returns:
99
+ Boxes tensor with same shape as input boxes but updated values.
100
+ """
101
+ # Video track bound boxes: [num_instances, num_tracks, 4]
102
+ # Image bounding boxes: [num_instances, 4]
103
+ assert boxes.shape[-1] == 4
104
+ seq_len = tf.shape(boxes)[0]
105
+ has_tracks = len(boxes.shape) == 3
106
+ if has_tracks:
107
+ num_tracks = boxes.shape[1]
108
+ else:
109
+ assert len(boxes.shape) == 2
110
+ num_tracks = 1
111
+
112
+ # Transform the box coordinates.
113
+ a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32)
114
+ b = tf.cast(tf.stack([top, left]), tf.float32)
115
+ c = tf.cast(tf.stack([height, width]), tf.float32)
116
+ boxes = tf.reshape(
117
+ (tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c,
118
+ (seq_len, num_tracks, len(NOTRACK_BOX)))
119
+
120
+ # Filter the valid boxes.
121
+ boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0)
122
+ if has_tracks:
123
+ cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1)
124
+ boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX)
125
+ else:
126
+ boxes = tf.reshape(boxes, (seq_len, 4))
127
+
128
+ return boxes
129
+
130
+
131
+ def flow_tensor_to_rgb_tensor(motion_image, flow_scaling_factor=50.):
132
+ """Visualizes flow motion image as an RGB image.
133
+
134
+ Similar as the flow_to_rgb function, but with tensors.
135
+
136
+ Args:
137
+ motion_image: A tensor either of shape [batch_sz, height, width, 2] or of
138
+ shape [height, width, 2]. motion_image[..., 0] is flow in x and
139
+ motion_image[..., 1] is flow in y.
140
+ flow_scaling_factor: How much to scale flow for visualization.
141
+
142
+ Returns:
143
+ A visualization tensor with same shape as motion_image, except with three
144
+ channels. The dtype of the output is tf.uint8.
145
+ """
146
+
147
+ hypot = lambda a, b: (a ** 2.0 + b ** 2.0) ** 0.5 # sqrt(a^2 + b^2)
148
+
149
+ height, width = motion_image.get_shape().as_list()[-3:-1] # pytype: disable=attribute-error # allow-recursive-types
150
+ scaling = flow_scaling_factor / hypot(height, width)
151
+ x, y = motion_image[Ellipsis, 0], motion_image[Ellipsis, 1]
152
+ motion_angle = tf.atan2(y, x)
153
+ motion_angle = (motion_angle / np.math.pi + 1.0) / 2.0
154
+ motion_magnitude = hypot(y, x)
155
+ motion_magnitude = tf.clip_by_value(motion_magnitude * scaling, 0.0, 1.0)
156
+ value_channel = tf.ones_like(motion_angle)
157
+ flow_hsv = tf.stack([motion_angle, motion_magnitude, value_channel], axis=-1)
158
+ flow_rgb = tf.image.convert_image_dtype(
159
+ tf.image.hsv_to_rgb(flow_hsv), tf.uint8)
160
+ return flow_rgb
161
+
162
+
163
+ def get_paddings(image_shape,
164
+ size,
165
+ pre_spatial_dim = None,
166
+ allow_crop = True):
167
+ """Returns paddings tensors for tf.pad operation.
168
+
169
+ Args:
170
+ image_shape: The shape of the Tensor to be padded. The shape can be
171
+ [..., N, H, W, C] or [..., H, W, C]. The paddings are computed for H, W
172
+ and optionally N dimensions.
173
+ size: The total size for the H and W dimensions to pad to.
174
+ pre_spatial_dim: Optional, additional padding dimension before the spatial
175
+ dimensions. It is only used if given and if len(shape) > 3.
176
+ allow_crop: If size is bigger than requested max size, padding will be
177
+ negative. If allow_crop is true, negative padding values will be set to 0.
178
+
179
+ Returns:
180
+ Paddings the given tensor shape.
181
+ """
182
+ assert image_shape.shape.rank == 1
183
+ if isinstance(size, int):
184
+ size = (size, size)
185
+ h, w = image_shape[-3], image_shape[-2]
186
+ # Spatial padding.
187
+ paddings = [
188
+ tf.stack([0, size[0] - h]),
189
+ tf.stack([0, size[1] - w]),
190
+ tf.stack([0, 0])
191
+ ]
192
+ ndims = len(image_shape) # pytype: disable=wrong-arg-types
193
+ # Prepend padding for temporal dimension or number of instances.
194
+ if pre_spatial_dim is not None and ndims > 3:
195
+ paddings = [[0, pre_spatial_dim - image_shape[-4]]] + paddings
196
+ # Prepend with non-padded dimensions if available.
197
+ if ndims > len(paddings):
198
+ paddings = [[0, 0]] * (ndims - len(paddings)) + paddings
199
+ if allow_crop:
200
+ paddings = tf.maximum(paddings, 0)
201
+ return tf.stack(paddings)
202
+
203
+
204
+ @dataclasses.dataclass
205
+ class VideoFromTfds:
206
+ """Standardize features coming from TFDS video datasets."""
207
+
208
+ video_key: str = VIDEO
209
+ segmentations_key: str = SEGMENTATIONS
210
+ ragged_segmentations_key: str = RAGGED_SEGMENTATIONS
211
+ shape_key: str = SHAPE
212
+ padding_mask_key: str = PADDING_MASK
213
+ ragged_boxes_key: str = RAGGED_BOXES
214
+ boxes_key: str = BOXES
215
+ frames_key: str = FRAMES
216
+ instance_multi_labels_key: str = INSTANCE_MULTI_LABELS
217
+ flow_key: str = FLOW
218
+ depth_key: str = DEPTH
219
+
220
+ def __call__(self, features):
221
+
222
+ features_new = {}
223
+
224
+ if "rng" in features:
225
+ features_new[SEED_KEY] = features.pop("rng")
226
+
227
+ if "instances" in features:
228
+ features_new[self.ragged_boxes_key] = features["instances"]["bboxes"]
229
+ features_new[self.frames_key] = features["instances"]["bbox_frames"]
230
+ if "segmentations" in features["instances"]:
231
+ features_new[self.ragged_segmentations_key] = tf.cast(
232
+ features["instances"]["segmentations"][Ellipsis, 0], tf.int32)
233
+
234
+ # Special handling of CLEVR (https://arxiv.org/abs/1612.06890) objects.
235
+ if ("color" in features["instances"] and
236
+ "shape" in features["instances"] and
237
+ "material" in features["instances"]):
238
+ color = tf.cast(features["instances"]["color"], tf.int32)
239
+ shape = tf.cast(features["instances"]["shape"], tf.int32)
240
+ material = tf.cast(features["instances"]["material"], tf.int32)
241
+ features_new[self.instance_multi_labels_key] = tf.stack(
242
+ (color, shape, material), axis=-1)
243
+
244
+ if "segmentations" in features:
245
+ features_new[self.segmentations_key] = tf.cast(
246
+ features["segmentations"][Ellipsis, 0], tf.int32)
247
+
248
+ if "depth" in features:
249
+ # Undo float to uint16 scaling
250
+ if "metadata" in features and "depth_range" in features["metadata"]:
251
+ depth_range = features["metadata"]["depth_range"]
252
+ features_new[self.depth_key] = convert_uint16_to_float(
253
+ features["depth"], depth_range[0], depth_range[1])
254
+
255
+ if "flows" in features:
256
+ # Some datasets use "flows" instead of "flow" for optical flow.
257
+ features["flow"] = features["flows"]
258
+ if "backward_flow" in features:
259
+ # By default, use "backward_flow" if available.
260
+ features["flow"] = features["backward_flow"]
261
+ features["metadata"]["flow_range"] = features["metadata"][
262
+ "backward_flow_range"]
263
+ if "flow" in features:
264
+ # Undo float to uint16 scaling
265
+ flow_range = features["metadata"].get("flow_range", (-255, 255))
266
+ features_new[self.flow_key] = convert_uint16_to_float(
267
+ features["flow"], flow_range[0], flow_range[1])
268
+
269
+ # Convert video to float and normalize.
270
+ video = features["video"]
271
+ assert video.dtype == tf.uint8 # pytype: disable=attribute-error # allow-recursive-types
272
+ video = tf.image.convert_image_dtype(video, tf.float32)
273
+ features_new[self.video_key] = video
274
+
275
+ # Store original video shape (e.g. for correct evaluation metrics).
276
+ features_new[self.shape_key] = tf.shape(video)
277
+
278
+ # Store padding mask
279
+ features_new[self.padding_mask_key] = tf.cast(
280
+ tf.ones_like(video)[Ellipsis, 0], tf.uint8)
281
+
282
+ return features_new
283
+
284
+
285
+ @dataclasses.dataclass
286
+ class AddTemporalAxis:
287
+ """Lift images to videos by adding a temporal axis at the beginning.
288
+
289
+ We need to distinguish two cases because `image_ops.py` uses
290
+ ORIGINAL_SIZE = [H,W] and `video_ops.py` uses SHAPE = [T,H,W,C]:
291
+ a) The features are fed from image ops: ORIGINAL_SIZE is converted
292
+ to SHAPE ([H,W] -> [1,H,W,C]) and removed from the features.
293
+ Typical use case: Evaluation of GV image tasks in a video setting. This op
294
+ is added after the image preprocessing in order not to change the standard
295
+ image preprocessing.
296
+ b) The features are fed from video ops: The image SHAPE is lifted to a video
297
+ SHAPE ([H,W,C] -> [1,H,W,C]).
298
+ Typical use case: Training using images in a video setting. This op is added
299
+ before the video preprocessing in order not to change the standard video
300
+ preprocessing.
301
+ """
302
+
303
+ image_key: str = IMAGE
304
+ video_key: str = VIDEO
305
+ boxes_key: str = BOXES
306
+ padding_mask_key: str = PADDING_MASK
307
+ segmentations_key: str = SEGMENTATIONS
308
+ sparse_segmentations_key: str = SPARSE_SEGMENTATIONS
309
+ shape_key: str = SHAPE
310
+ original_size_key: str = ORIGINAL_SIZE
311
+
312
+ def __call__(self, features):
313
+ assert self.image_key in features
314
+
315
+ features_new = {}
316
+ for k, v in features.items():
317
+ if k == self.image_key:
318
+ features_new[self.video_key] = v[tf.newaxis]
319
+ elif k in (self.padding_mask_key, self.boxes_key, self.segmentations_key,
320
+ self.sparse_segmentations_key):
321
+ features_new[k] = v[tf.newaxis]
322
+ elif k == self.original_size_key:
323
+ pass # See comment in the docstring of the class.
324
+ else:
325
+ features_new[k] = v
326
+
327
+ if self.original_size_key in features:
328
+ # The features come from an image preprocessing pipeline.
329
+ shape = tf.concat([[1], features[self.original_size_key],
330
+ [features[self.image_key].shape[-1]]], # pytype: disable=attribute-error # allow-recursive-types
331
+ axis=0)
332
+ elif self.shape_key in features:
333
+ # The features come from a video preprocessing pipeline.
334
+ shape = tf.concat([[1], features[self.shape_key]], axis=0)
335
+ else:
336
+ shape = tf.shape(features_new[self.video_key])
337
+ features_new[self.shape_key] = shape
338
+
339
+ if self.padding_mask_key not in features_new:
340
+ features_new[self.padding_mask_key] = tf.cast(
341
+ tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
342
+
343
+ return features_new
344
+
345
+
346
+ @dataclasses.dataclass
347
+ class SparseToDenseAnnotation:
348
+ """Converts the sparse to a dense representation."""
349
+
350
+ max_instances: int = 10
351
+ segmentations_key: str = SEGMENTATIONS
352
+
353
+ def __call__(self, features):
354
+
355
+ features_new = {}
356
+
357
+ for k, v in features.items():
358
+
359
+ if k == self.segmentations_key:
360
+ # Dense segmentations are available for this dataset. It may be that
361
+ # max_instances < max(features_new[self.segmentations_key]).
362
+ # We prune out extra objects here.
363
+ segmentations = v
364
+ segmentations = tf.where(
365
+ tf.less_equal(segmentations, self.max_instances), segmentations, 0)
366
+ features_new[self.segmentations_key] = segmentations
367
+ else:
368
+ features_new[k] = v
369
+
370
+ return features_new
371
+
372
+
373
+ class VideoPreprocessOp(abc.ABC):
374
+ """Base class for all video preprocess ops."""
375
+
376
+ video_key: str = VIDEO
377
+ segmentations_key: str = SEGMENTATIONS
378
+ padding_mask_key: str = PADDING_MASK
379
+ boxes_key: str = BOXES
380
+ flow_key: str = FLOW
381
+ depth_key: str = DEPTH
382
+ sparse_segmentations_key: str = SPARSE_SEGMENTATIONS
383
+
384
+ def __call__(self, features):
385
+ # Get current video shape.
386
+ video_shape = tf.shape(features[self.video_key])
387
+ # Assemble all feature keys that the op should be applied on.
388
+ all_keys = [
389
+ self.video_key, self.segmentations_key, self.padding_mask_key,
390
+ self.flow_key, self.depth_key, self.sparse_segmentations_key,
391
+ self.boxes_key
392
+ ]
393
+ # Apply the op to all features.
394
+ for key in all_keys:
395
+ if key in features:
396
+ features[key] = self.apply(features[key], key, video_shape)
397
+ return features
398
+
399
+ @abc.abstractmethod
400
+ def apply(self, tensor, key,
401
+ video_shape):
402
+ """Returns the transformed tensor.
403
+
404
+ Args:
405
+ tensor: Any of a set of different video modalites, e.g video, flow,
406
+ bounding boxes, etc.
407
+ key: a string that indicates what feature the tensor represents so that
408
+ the apply function can take that into account.
409
+ video_shape: The shape of the video (which is necessary for some
410
+ transformations).
411
+ """
412
+
413
+
414
+ class RandomVideoPreprocessOp(VideoPreprocessOp):
415
+ """Base class for all random video preprocess ops."""
416
+
417
+ def __call__(self, features):
418
+ if features.get(SEED_KEY) is None:
419
+ logging.warning(
420
+ "Using random operation without seed. To avoid this "
421
+ "please provide a seed in feature %s.", SEED_KEY)
422
+ op_seed = tf.random.uniform(shape=(2,), maxval=2**32, dtype=tf.int64)
423
+ else:
424
+ features[SEED_KEY], op_seed = tf.unstack(
425
+ tf.random.experimental.stateless_split(features[SEED_KEY]))
426
+ # Get current video shape.
427
+ video_shape = tf.shape(features[self.video_key])
428
+ # Assemble all feature keys that the op should be applied on.
429
+ all_keys = [
430
+ self.video_key, self.segmentations_key, self.padding_mask_key,
431
+ self.flow_key, self.depth_key, self.sparse_segmentations_key,
432
+ self.boxes_key
433
+ ]
434
+ # Apply the op to all features.
435
+ for key in all_keys:
436
+ if key in features:
437
+ features[key] = self.apply(features[key], op_seed, key, video_shape)
438
+ return features
439
+
440
+ @abc.abstractmethod
441
+ def apply(self, tensor, seed, key,
442
+ video_shape):
443
+ """Returns the transformed tensor.
444
+
445
+ Args:
446
+ tensor: Any of a set of different video modalites, e.g video, flow,
447
+ bounding boxes, etc.
448
+ seed: A random seed.
449
+ key: a string that indicates what feature the tensor represents so that
450
+ the apply function can take that into account.
451
+ video_shape: The shape of the video (which is necessary for some
452
+ transformations).
453
+ """
454
+
455
+
456
+ @dataclasses.dataclass
457
+ class ResizeSmall(VideoPreprocessOp):
458
+ """Resizes the smaller (spatial) side to `size` keeping aspect ratio.
459
+
460
+ Attr:
461
+ size: An integer representing the new size of the smaller side of the input.
462
+ max_size: If set, an integer representing the maximum size in terms of the
463
+ largest side of the input.
464
+ """
465
+
466
+ size: int
467
+ max_size: Optional[int] = None
468
+
469
+ def apply(self, tensor, key=None, video_shape=None):
470
+ """See base class."""
471
+
472
+ # Boxes are defined in normalized image coordinates and are not affected.
473
+ if key == self.boxes_key:
474
+ return tensor
475
+
476
+ if key in (self.padding_mask_key, self.segmentations_key):
477
+ tensor = tensor[Ellipsis, tf.newaxis]
478
+ elif key == self.sparse_segmentations_key:
479
+ tensor = tf.reshape(tensor,
480
+ (-1, tf.shape(tensor)[2], tf.shape(tensor)[3], 1))
481
+
482
+ h, w = tf.shape(tensor)[1], tf.shape(tensor)[2]
483
+
484
+ # Determine resize method based on dtype (e.g. segmentations are int).
485
+ if tensor.dtype.is_integer:
486
+ resize_method = "nearest"
487
+ else:
488
+ resize_method = "bilinear"
489
+
490
+ # Clip size to max_size if needed.
491
+ small_size = self.size
492
+ if self.max_size is not None:
493
+ small_size = adjust_small_size(
494
+ original_size=(h, w), small_size=small_size, max_size=self.max_size)
495
+ new_h, new_w = get_resize_small_shape(
496
+ original_size=(h, w), small_size=small_size)
497
+ tensor = tf.image.resize(tensor, [new_h, new_w], method=resize_method)
498
+
499
+ # Flow needs to be rescaled according to the new size to stay valid.
500
+ if key == self.flow_key:
501
+ scale_h = tf.cast(new_h, tf.float32) / tf.cast(h, tf.float32)
502
+ scale_w = tf.cast(new_w, tf.float32) / tf.cast(w, tf.float32)
503
+ scale = tf.reshape(tf.stack([scale_h, scale_w], axis=0), (1, 2))
504
+ # Optionally repeat scale in case both forward and backward flow are
505
+ # stacked in the last dimension.
506
+ scale = tf.repeat(scale, tf.shape(tensor)[-1] // 2, axis=0)
507
+ scale = tf.reshape(scale, (1, 1, 1, tf.shape(tensor)[-1]))
508
+ tensor *= scale
509
+
510
+ if key in (self.padding_mask_key, self.segmentations_key):
511
+ tensor = tensor[Ellipsis, 0]
512
+ elif key == self.sparse_segmentations_key:
513
+ tensor = tf.reshape(tensor, (video_shape[0], -1, new_h, new_w))
514
+
515
+ return tensor
516
+
517
+
518
+ @dataclasses.dataclass
519
+ class CentralCrop(VideoPreprocessOp):
520
+ """Makes central (spatial) crop of a given size.
521
+
522
+ Attr:
523
+ height: An integer representing the height of the crop.
524
+ width: An (optional) integer representing the width of the crop. Make square
525
+ crop if width is not provided.
526
+ """
527
+
528
+ height: int
529
+ width: Optional[int] = None
530
+
531
+ def apply(self, tensor, key=None, video_shape=None):
532
+ """See base class."""
533
+ if key == self.boxes_key:
534
+ width = self.width or self.height
535
+ h_orig, w_orig = video_shape[1], video_shape[2]
536
+ top = (h_orig - self.height) // 2
537
+ left = (w_orig - width) // 2
538
+ tensor = crop_or_pad_boxes(tensor, top, left, self.height,
539
+ width, h_orig, w_orig)
540
+ return tensor
541
+ else:
542
+ if key in (self.padding_mask_key, self.segmentations_key):
543
+ tensor = tensor[Ellipsis, tf.newaxis]
544
+ seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
545
+ h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2]
546
+ width = self.width or self.height
547
+ crop_size = (seq_len, self.height, width, n_channels)
548
+ top = (h_orig - self.height) // 2
549
+ left = (w_orig - width) // 2
550
+ tensor = tf.image.crop_to_bounding_box(tensor, top, left, self.height,
551
+ width)
552
+ tensor = tf.ensure_shape(tensor, crop_size)
553
+ if key in (self.padding_mask_key, self.segmentations_key):
554
+ tensor = tensor[Ellipsis, 0]
555
+ return tensor
556
+
557
+
558
+ @dataclasses.dataclass
559
+ class CropOrPad(VideoPreprocessOp):
560
+ """Spatially crops or pads a video to a specified size.
561
+
562
+ Attr:
563
+ height: An integer representing the new height of the video.
564
+ width: An integer representing the new width of the video.
565
+ allow_crop: A boolean indicating if cropping is allowed.
566
+ """
567
+
568
+ height: int
569
+ width: int
570
+ allow_crop: bool = True
571
+
572
+ def apply(self, tensor, key=None, video_shape=None):
573
+ """See base class."""
574
+ if key == self.boxes_key:
575
+ # Pad and crop the spatial dimensions.
576
+ h_orig, w_orig = video_shape[1], video_shape[2]
577
+ if self.allow_crop:
578
+ # After cropping, the frame shape is always [self.height, self.width].
579
+ height, width = self.height, self.width
580
+ else:
581
+ # If only padding is performed, the frame size is at least
582
+ # [self.height, self.width].
583
+ height = tf.maximum(h_orig, self.height)
584
+ width = tf.maximum(w_orig, self.width)
585
+ tensor = crop_or_pad_boxes(
586
+ tensor,
587
+ top=0,
588
+ left=0,
589
+ height=height,
590
+ width=width,
591
+ h_orig=h_orig,
592
+ w_orig=w_orig)
593
+ return tensor
594
+ elif key == self.sparse_segmentations_key:
595
+ seq_len = tensor.get_shape()[0]
596
+ paddings = get_paddings(
597
+ tf.shape(tensor[Ellipsis, tf.newaxis]), (self.height, self.width),
598
+ allow_crop=self.allow_crop)[:-1]
599
+ tensor = tf.pad(tensor, paddings, constant_values=0)
600
+ if self.allow_crop:
601
+ tensor = tensor[Ellipsis, :self.height, :self.width]
602
+ tensor = tf.ensure_shape(
603
+ tensor, (seq_len, None, self.height, self.width))
604
+ return tensor
605
+ else:
606
+ if key in (self.padding_mask_key, self.segmentations_key):
607
+ tensor = tensor[Ellipsis, tf.newaxis]
608
+ seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
609
+ paddings = get_paddings(
610
+ tf.shape(tensor), (self.height, self.width),
611
+ allow_crop=self.allow_crop)
612
+ tensor = tf.pad(tensor, paddings, constant_values=0)
613
+ if self.allow_crop:
614
+ tensor = tensor[:, :self.height, :self.width, :]
615
+ tensor = tf.ensure_shape(tensor,
616
+ (seq_len, self.height, self.width, n_channels))
617
+ if key in (self.padding_mask_key, self.segmentations_key):
618
+ tensor = tensor[Ellipsis, 0]
619
+ return tensor
620
+
621
+
622
+ @dataclasses.dataclass
623
+ class RandomCrop(RandomVideoPreprocessOp):
624
+ """Gets a random (width, height) crop of input video.
625
+
626
+ Assumption: Height and width are the same for all video-like modalities.
627
+
628
+ Attr:
629
+ height: An integer representing the height of the crop.
630
+ width: An integer representing the width of the crop.
631
+ """
632
+
633
+ height: int
634
+ width: int
635
+
636
+ def apply(self, tensor, seed, key=None, video_shape=None):
637
+ """See base class."""
638
+ if key == self.boxes_key:
639
+ # We copy the random generation part from tf.image.stateless_random_crop
640
+ # to generate exactly the same offset as for the video.
641
+ crop_size = (video_shape[0], self.height, self.width, video_shape[-1])
642
+ size = tf.convert_to_tensor(crop_size, tf.int32)
643
+ limit = video_shape - size + 1
644
+ offset = tf.random.stateless_uniform(
645
+ tf.shape(video_shape), dtype=tf.int32, maxval=tf.int32.max,
646
+ seed=seed) % limit
647
+ tensor = crop_or_pad_boxes(tensor, offset[1], offset[2], self.height,
648
+ self.width, video_shape[1], video_shape[2])
649
+ return tensor
650
+ elif key == self.sparse_segmentations_key:
651
+ raise NotImplementedError("Sparse segmentations aren't supported yet")
652
+ else:
653
+ if key in (self.padding_mask_key, self.segmentations_key):
654
+ tensor = tensor[Ellipsis, tf.newaxis]
655
+ seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
656
+ crop_size = (seq_len, self.height, self.width, n_channels)
657
+ tensor = tf.image.stateless_random_crop(tensor, size=crop_size, seed=seed)
658
+ tensor = tf.ensure_shape(tensor, crop_size)
659
+ if key in (self.padding_mask_key, self.segmentations_key):
660
+ tensor = tensor[Ellipsis, 0]
661
+ return tensor
662
+
663
+
664
+ @dataclasses.dataclass
665
+ class DropFrames(VideoPreprocessOp):
666
+ """Subsamples a video by skipping frames.
667
+
668
+ Attr:
669
+ frame_skip: An integer representing the subsampling frequency of the video,
670
+ where 1 means no frames are skipped, 2 means every other frame is skipped,
671
+ and so forth.
672
+ """
673
+
674
+ frame_skip: int
675
+
676
+ def apply(self, tensor, key=None, video_shape=None):
677
+ """See base class."""
678
+ del key
679
+ del video_shape
680
+ tensor = tensor[::self.frame_skip]
681
+ new_length = tensor.get_shape()[0]
682
+ tensor = tf.ensure_shape(tensor, [new_length] + tensor.get_shape()[1:])
683
+ return tensor
684
+
685
+
686
+ @dataclasses.dataclass
687
+ class TemporalCropOrPad(VideoPreprocessOp):
688
+ """Crops or pads a video in time to a specified length.
689
+
690
+ Attr:
691
+ length: An integer representing the new length of the video.
692
+ allow_crop: A boolean, specifying whether temporal cropping is allowed. If
693
+ False, will throw an error if length of the video is more than "length"
694
+ """
695
+
696
+ length: int
697
+ allow_crop: bool = True
698
+
699
+ def _apply(self, tensor, constant_values):
700
+ frames_to_pad = self.length - tf.shape(tensor)[0]
701
+ if self.allow_crop:
702
+ frames_to_pad = tf.maximum(frames_to_pad, 0)
703
+ tensor = tf.pad(
704
+ tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1),
705
+ constant_values=constant_values)
706
+ tensor = tensor[:self.length]
707
+ tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:])
708
+ return tensor
709
+
710
+ def apply(self, tensor, key=None, video_shape=None):
711
+ """See base class."""
712
+ del video_shape
713
+ if key == self.boxes_key:
714
+ constant_values = NOTRACK_BOX[0]
715
+ else:
716
+ constant_values = 0
717
+ return self._apply(tensor, constant_values=constant_values)
718
+
719
+
720
+ @dataclasses.dataclass
721
+ class TemporalRandomWindow(RandomVideoPreprocessOp):
722
+ """Gets a random slice (window) along 0-th axis of input tensor.
723
+
724
+ Pads the video if the video length is shorter than the provided length.
725
+
726
+ Assumption: The number of frames is the same for all video-like modalities.
727
+
728
+ Attr:
729
+ length: An integer representing the new length of the video.
730
+ """
731
+
732
+ length: int
733
+
734
+ def _apply(self, tensor, seed, constant_values):
735
+ length = tf.minimum(self.length, tf.shape(tensor)[0])
736
+ frames_to_pad = tf.maximum(self.length - tf.shape(tensor)[0], 0)
737
+ window_size = tf.concat(([length], tf.shape(tensor)[1:]), axis=0)
738
+ tensor = tf.image.stateless_random_crop(tensor, size=window_size, seed=seed)
739
+ tensor = tf.pad(
740
+ tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1),
741
+ constant_values=constant_values)
742
+ tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:])
743
+ return tensor
744
+
745
+ def apply(self, tensor, seed, key=None, video_shape=None):
746
+ """See base class."""
747
+ del video_shape
748
+ if key == self.boxes_key:
749
+ constant_values = NOTRACK_BOX[0]
750
+ else:
751
+ constant_values = 0
752
+ return self._apply(tensor, seed, constant_values=constant_values)
753
+
754
+
755
+ @dataclasses.dataclass
756
+ class TemporalRandomStridedWindow(RandomVideoPreprocessOp):
757
+ """Gets a random strided slice (window) along 0-th axis of input tensor.
758
+
759
+ This op is like TemporalRandomWindow but it samples from one of a set of
760
+ strides of the video, whereas TemporalRandomWindow will densely sample from
761
+ all possible slices of `length` frames from the video.
762
+
763
+ For the following video and `length=3`: [1, 2, 3, 4, 5, 6, 7, 8, 9]
764
+
765
+ This op will return one of [1, 2, 3], [4, 5, 6], or [7, 8, 9]
766
+
767
+ This pads the video if the video length is shorter than the provided length.
768
+
769
+ Assumption: The number of frames is the same for all video-like modalities.
770
+
771
+ Attr:
772
+ length: An integer representing the new length of the video and the sampling
773
+ stride width.
774
+ """
775
+
776
+ length: int
777
+
778
+ def _apply(self, tensor, seed,
779
+ constant_values):
780
+ """Applies the strided crop operation to the video tensor."""
781
+ num_frames = tf.shape(tensor)[0]
782
+ num_crop_points = tf.cast(tf.math.ceil(num_frames / self.length), tf.int32)
783
+ crop_point = tf.random.stateless_uniform(
784
+ shape=(), minval=0, maxval=num_crop_points, dtype=tf.int32, seed=seed)
785
+ crop_point *= self.length
786
+ frames_sample = tensor[crop_point:crop_point + self.length]
787
+ frames_to_pad = tf.maximum(self.length - tf.shape(frames_sample)[0], 0)
788
+ frames_sample = tf.pad(
789
+ frames_sample,
790
+ ((0, frames_to_pad),) + ((0, 0),) * (len(frames_sample.shape) - 1),
791
+ constant_values=constant_values)
792
+ frames_sample = tf.ensure_shape(frames_sample, [self.length] +
793
+ frames_sample.get_shape()[1:])
794
+ return frames_sample
795
+
796
+ def apply(self, tensor, seed, key=None, video_shape=None):
797
+ """See base class."""
798
+ del video_shape
799
+ if key == self.boxes_key:
800
+ constant_values = NOTRACK_BOX[0]
801
+ else:
802
+ constant_values = 0
803
+ return self._apply(tensor, seed, constant_values=constant_values)
804
+
805
+
806
+ @dataclasses.dataclass
807
+ class FlowToRgb:
808
+ """Converts flow to an RGB image.
809
+
810
+ NOTE: This operation requires a statically known shape for the input flow,
811
+ i.e. it is best to place it as final operation into the preprocessing
812
+ pipeline after all shapes are statically known (e.g. after cropping /
813
+ padding).
814
+ """
815
+ flow_key: str = FLOW
816
+
817
+ def __call__(self, features):
818
+ if self.flow_key in features:
819
+ flow_rgb = flow_tensor_to_rgb_tensor(features[self.flow_key])
820
+ assert flow_rgb.dtype == tf.uint8
821
+ features[self.flow_key] = tf.image.convert_image_dtype(
822
+ flow_rgb, tf.float32)
823
+ return features
824
+
825
+
826
+ @dataclasses.dataclass
827
+ class TransformDepth:
828
+ """Applies one of several possible transformations to depth features."""
829
+ transform: str
830
+ depth_key: str = DEPTH
831
+
832
+ def __call__(self, features):
833
+ if self.depth_key in features:
834
+ if self.transform == "log":
835
+ depth_norm = tf.math.log(features[self.depth_key])
836
+ elif self.transform == "log_plus":
837
+ depth_norm = tf.math.log(1. + features[self.depth_key])
838
+ elif self.transform == "invert_plus":
839
+ depth_norm = 1. / (1. + features[self.depth_key])
840
+ else:
841
+ raise ValueError(f"Unknown depth transformation {self.transform}")
842
+
843
+ features[self.depth_key] = depth_norm
844
+ return features
845
+
846
+
847
+ @dataclasses.dataclass
848
+ class RandomResizedCrop(RandomVideoPreprocessOp):
849
+ """Random-resized crop for each of the two views.
850
+
851
+ Assumption: Height and width are the same for all video-like modalities.
852
+
853
+ We randomly crop the input and record the transformation this crop corresponds
854
+ to as a new feature. Croped images are resized to (height, width). Boxes are
855
+ corrected adjusted and boxes outside the crop are discarded. Flow is rescaled
856
+ so as to be pixel accurate after the operation. lidar_points_2d are
857
+ transformed using the computed transformation. These points may lie outside
858
+ the image after the operation.
859
+
860
+ Attr:
861
+ height: An integer representing the height to resize to.
862
+ width: An integer representing the width to resize to.
863
+ min_object_covered, aspect_ratio_range, area_range, max_attempts: See
864
+ docstring of `stateless_sample_distorted_bounding_box`. Aspect ratio range
865
+ has not been scaled by target aspect ratio. This differs from other
866
+ implementations of this data augmentation.
867
+ relative_box_area_threshold: If ratio of areas before and after cropping are
868
+ lower than this threshold, then the box is discarded (set to NOTRACK_BOX).
869
+ """
870
+ # Target size.
871
+ height: int
872
+ width: int
873
+
874
+ # Crop sampling attributes.
875
+ min_object_covered: float = 0.1
876
+ aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.)
877
+ area_range: Tuple[float, float] = (0.08, 1.0)
878
+ max_attempts: int = 100
879
+
880
+ # Box retention attributes
881
+ relative_box_area_threshold: float = 0.0
882
+
883
+ def apply(self, tensor, seed, key,
884
+ video_shape):
885
+ """Applies the crop operation on tensor."""
886
+ param = self.sample_augmentation_params(video_shape, seed)
887
+ si, sj = param[0], param[1]
888
+ crop_h, crop_w = param[2], param[3]
889
+
890
+ to_float32 = lambda x: tf.cast(x, tf.float32)
891
+
892
+ if key == self.boxes_key:
893
+ # First crop the boxes.
894
+ cropped_boxes = crop_or_pad_boxes(
895
+ tensor, si, sj,
896
+ crop_h, crop_w,
897
+ video_shape[1], video_shape[2])
898
+ # We do not need to scale the boxes because they are in normalized coords.
899
+ resized_boxes = cropped_boxes
900
+ # Lastly detects NOTRACK_BOX boxes and avoid manipulating those.
901
+ no_track_boxes = tf.convert_to_tensor(NOTRACK_BOX)
902
+ no_track_boxes = tf.reshape(no_track_boxes, [1, 4])
903
+ resized_boxes = tf.where(
904
+ tf.reduce_all(tensor == no_track_boxes, axis=-1, keepdims=True),
905
+ tensor, resized_boxes)
906
+
907
+ if self.relative_box_area_threshold > 0:
908
+ # Thresholds boxes that have been cropped too much, as in their area is
909
+ # lower, in relative terms, than `relative_box_area_threshold`.
910
+ area_before_crop = tf.reduce_prod(tensor[Ellipsis, 2:] - tensor[Ellipsis, :2],
911
+ axis=-1)
912
+ # Sets minimum area_before_crop to 1e-8 we avoid divisions by 0.
913
+ area_before_crop = tf.maximum(area_before_crop,
914
+ tf.zeros_like(area_before_crop) + 1e-8)
915
+ area_after_crop = tf.reduce_prod(
916
+ resized_boxes[Ellipsis, 2:] - resized_boxes[Ellipsis, :2], axis=-1)
917
+ # As the boxes have normalized coordinates, they need to be rescaled to
918
+ # be compared against the original uncropped boxes.
919
+ scale_x = to_float32(crop_w) / to_float32(self.width)
920
+ scale_y = to_float32(crop_h) / to_float32(self.height)
921
+ area_after_crop *= scale_x * scale_y
922
+
923
+ ratio = area_after_crop / area_before_crop
924
+ return tf.where(
925
+ tf.expand_dims(ratio > self.relative_box_area_threshold, -1),
926
+ resized_boxes, no_track_boxes)
927
+
928
+ else:
929
+ return resized_boxes
930
+
931
+ else:
932
+ if key in (self.padding_mask_key, self.segmentations_key):
933
+ tensor = tensor[Ellipsis, tf.newaxis]
934
+
935
+ # Crop.
936
+ seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
937
+ crop_size = (seq_len, crop_h, crop_w, n_channels)
938
+ tensor = tf.slice(tensor, tf.stack([0, si, sj, 0]), crop_size)
939
+
940
+ # Resize.
941
+ resize_method = tf.image.ResizeMethod.BILINEAR
942
+ if (tensor.dtype == tf.int32 or tensor.dtype == tf.int64 or
943
+ tensor.dtype == tf.uint8):
944
+ resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
945
+ tensor = tf.image.resize(tensor, [self.height, self.width],
946
+ method=resize_method)
947
+ out_size = (seq_len, self.height, self.width, n_channels)
948
+ tensor = tf.ensure_shape(tensor, out_size)
949
+
950
+ if key == self.flow_key:
951
+ # Rescales optical flow.
952
+ scale_x = to_float32(self.width) / to_float32(crop_w)
953
+ scale_y = to_float32(self.height) / to_float32(crop_h)
954
+ tensor = tf.stack(
955
+ [tensor[Ellipsis, 0] * scale_y, tensor[Ellipsis, 1] * scale_x], axis=-1)
956
+
957
+ if key in (self.padding_mask_key, self.segmentations_key):
958
+ tensor = tensor[Ellipsis, 0]
959
+ return tensor
960
+
961
+ def sample_augmentation_params(self, video_shape, rng):
962
+ """Sample a random bounding box for the crop."""
963
+ sample_bbox = tf.image.stateless_sample_distorted_bounding_box(
964
+ video_shape[1:],
965
+ bounding_boxes=tf.constant([0.0, 0.0, 1.0, 1.0],
966
+ dtype=tf.float32, shape=[1, 1, 4]),
967
+ seed=rng,
968
+ min_object_covered=self.min_object_covered,
969
+ aspect_ratio_range=self.aspect_ratio_range,
970
+ area_range=self.area_range,
971
+ max_attempts=self.max_attempts,
972
+ use_image_if_no_bounding_boxes=True)
973
+ bbox_begin, bbox_size, _ = sample_bbox
974
+
975
+ # The specified bounding box provides crop coordinates.
976
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
977
+ target_height, target_width, _ = tf.unstack(bbox_size)
978
+
979
+ return tf.stack([offset_y, offset_x, target_height, target_width])
980
+
981
+ def estimate_transformation(self, param, video_shape
982
+ ):
983
+ """Computes the affine transformation for crop params.
984
+
985
+ Args:
986
+ param: Crop parameters in the [y, x, h, w] format of shape [4,].
987
+ video_shape: Unused.
988
+
989
+ Returns:
990
+ Affine transformation of shape [3, 3] corresponding to cropping the image
991
+ at [y, x] of size [h, w] and resizing it into [self.height, self.width].
992
+ """
993
+ del video_shape
994
+ crop = tf.cast(param, tf.float32)
995
+ si, sj = crop[0], crop[1]
996
+ crop_h, crop_w = crop[2], crop[3]
997
+ ei, ej = si + crop_h - 1.0, sj + crop_w - 1.0
998
+ h, w = float(self.height), float(self.width)
999
+
1000
+ a1 = (ei - si + 1.)/h
1001
+ a2 = 0.
1002
+ a3 = si - 0.5 + a1 / 2.
1003
+ a4 = 0.
1004
+ a5 = (ej - sj + 1.)/w
1005
+ a6 = sj - 0.5 + a5 / 2.
1006
+ affine = tf.stack([a1, a2, a3, a4, a5, a6, 0., 0., 1.])
1007
+ return tf.reshape(affine, [3, 3])
1008
+
1009
+
1010
+ @dataclasses.dataclass
1011
+ class TfdsImageToTfdsVideo:
1012
+ """Lift TFDS image format to TFDS video format by adding a temporal axis.
1013
+
1014
+ This op is intended to be called directly before VideoFromTfds.
1015
+ """
1016
+
1017
+ TFDS_SEGMENTATIONS_KEY = "segmentations"
1018
+ TFDS_INSTANCES_KEY = "instances"
1019
+ TFDS_BOXES_KEY = "bboxes"
1020
+ TFDS_BOXES_FRAMES_KEY = "bbox_frames"
1021
+
1022
+ image_key: str = IMAGE
1023
+ video_key: str = VIDEO
1024
+ boxes_image_key: str = BOXES
1025
+ boxes_key: str = BOXES_VIDEO
1026
+ image_padding_mask_key: str = IMAGE_PADDING_MASK
1027
+ video_padding_mask_key: str = VIDEO_PADDING_MASK
1028
+ depth_key: str = DEPTH
1029
+ depth_mask_key: str = "depth_mask"
1030
+ force_overwrite: bool = False
1031
+
1032
+ def __call__(self, features):
1033
+ if self.video_key in features and not self.force_overwrite:
1034
+ return features
1035
+
1036
+ features_new = {}
1037
+ for k, v in features.items():
1038
+ if k == self.image_key:
1039
+ features_new[self.video_key] = v[tf.newaxis]
1040
+ elif k == self.image_padding_mask_key:
1041
+ features_new[self.video_padding_mask_key] = v[tf.newaxis]
1042
+ elif k == self.boxes_image_key:
1043
+ features_new[self.boxes_key] = v[tf.newaxis]
1044
+ elif k == self.TFDS_SEGMENTATIONS_KEY:
1045
+ features_new[self.TFDS_SEGMENTATIONS_KEY] = v[tf.newaxis]
1046
+ elif k == self.TFDS_INSTANCES_KEY and self.TFDS_BOXES_KEY in v:
1047
+ # Add sequence dimension to boxes and create boxes frames for indexing.
1048
+ features_new[k] = v
1049
+
1050
+ # Create dummy ragged tensor (1, None) and broadcast
1051
+ dummy = tf.ragged.constant([[0]], dtype=tf.int32)
1052
+ boxes_frames_value = tf.zeros_like(
1053
+ v[self.TFDS_BOXES_KEY][Ellipsis, 0], dtype=tf.int32)[Ellipsis, tf.newaxis]
1054
+ features_new[k][self.TFDS_BOXES_FRAMES_KEY] = boxes_frames_value + dummy
1055
+ # Create dummy ragged tensor (1, None, 1) and broadcast
1056
+ dummy = tf.ragged.constant([[0]], dtype=tf.float32)[Ellipsis, tf.newaxis]
1057
+ boxes_value = v[self.TFDS_BOXES_KEY][Ellipsis, tf.newaxis, :]
1058
+ features_new[k][self.TFDS_BOXES_KEY] = boxes_value + dummy
1059
+ elif k == self.depth_key:
1060
+ features_new[self.depth_key] = v[tf.newaxis]
1061
+ elif k == self.depth_mask_key:
1062
+ features_new[self.depth_mask_key] = v[tf.newaxis]
1063
+ else:
1064
+ features_new[k] = v
1065
+
1066
+ if self.video_padding_mask_key not in features_new:
1067
+ logging.warning("Adding default video_padding_mask")
1068
+ features_new[self.video_padding_mask_key] = tf.cast(
1069
+ tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
1070
+
1071
+ return features_new
1072
+
1073
+
1074
+ @dataclasses.dataclass
1075
+ class TopLeftCrop(VideoPreprocessOp):
1076
+ """Makes an arbitrary crop in all video frames.
1077
+
1078
+ Attr:
1079
+ top: An integer representing the horizontal coordinate of the crop start.
1080
+ left: An integer representing the vertical coordinate of the crop start.
1081
+ height: An integer representing the height of the crop.
1082
+ width: An (optional) integer representing the width of the crop. Make square
1083
+ crop if width is not provided.
1084
+ """
1085
+
1086
+ top: int
1087
+ left: int
1088
+ height: int
1089
+ width: Optional[int] = None
1090
+
1091
+ def apply(self, tensor, key=None, video_shape=None):
1092
+ """See base class."""
1093
+ if key in (self.boxes_key,):
1094
+ width = self.width or self.height
1095
+ h_orig, w_orig = video_shape[1], video_shape[2]
1096
+ tensor = transforms.crop_or_pad_boxes(
1097
+ tensor, self.top, self.left, self.height, width, h_orig, w_orig)
1098
+ return tensor
1099
+ else:
1100
+ if key in (self.padding_mask_key, self.segmentations_key):
1101
+ tensor = tensor[Ellipsis, tf.newaxis]
1102
+ seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
1103
+ h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2]
1104
+ width = self.width or self.height
1105
+ crop_size = (seq_len, self.height, width, n_channels)
1106
+ tensor = tf.image.crop_to_bounding_box(
1107
+ tensor, self.top, self.left, self.height, width)
1108
+ tensor = tf.ensure_shape(tensor, crop_size)
1109
+ if key in (self.padding_mask_key, self.segmentations_key):
1110
+ tensor = tensor[Ellipsis, 0]
1111
+ return tensor
1112
+
1113
+
1114
+ @dataclasses.dataclass
1115
+ class DeleteSmallMasks:
1116
+ """Delete masks smaller than a selected fraction of pixels."""
1117
+ threshold: float = 0.05
1118
+ max_instances: int = 50
1119
+ max_instances_after: int = 11
1120
+
1121
+ def __call__(self, features):
1122
+
1123
+ features_new = {}
1124
+
1125
+ for key in features.keys():
1126
+
1127
+ if key == SEGMENTATIONS:
1128
+ seg = features[key]
1129
+ size = tf.shape(seg)
1130
+
1131
+ assert_op = tf.Assert(
1132
+ tf.equal(size[0], 1), ["Implemented only for a single frame."])
1133
+
1134
+ with tf.control_dependencies([assert_op]):
1135
+ # Delete time dimension.
1136
+ seg = seg[0]
1137
+
1138
+ # Get the minimum number of pixels a masks needs to have.
1139
+ max_pixels = size[1] * size[2]
1140
+ threshold_pixels = tf.cast(
1141
+ tf.cast(max_pixels, tf.float32) * self.threshold, tf.int32)
1142
+
1143
+ # Decompose the segmentation map as a single image for each instance.
1144
+ dec_seg = tf.stack(
1145
+ tf.map_fn(functools.partial(self._decompose, seg=seg),
1146
+ tf.range(self.max_instances)), axis=0)
1147
+
1148
+ # Count the pixels and find segmentation masks that are big enough.
1149
+ sums = tf.reduce_sum(dec_seg, axis=(1, 2))
1150
+ # We want the background to always be slot zero.
1151
+ # We can accomplish that be pretending it has the maximum
1152
+ # number of pixels.
1153
+ sums = tf.concat(
1154
+ [tf.ones_like(sums[0: 1]) * max_pixels, sums[1:]],
1155
+ axis=0)
1156
+
1157
+ sort = tf.argsort(sums, axis=0, direction="DESCENDING")
1158
+ sums_s = tf.gather(sums, sort, axis=0)
1159
+ mask_s = tf.cast(tf.greater_equal(sums_s, threshold_pixels), tf.int32)
1160
+
1161
+ dec_seg_plus = tf.stack(
1162
+ tf.map_fn(functools.partial(
1163
+ self._compose_sort, seg=seg, sort=sort, mask_s=mask_s),
1164
+ tf.range(self.max_instances_after)), axis=0)
1165
+ new_seg = tf.reduce_sum(dec_seg_plus, axis=0)
1166
+
1167
+ features_new[key] = tf.cast(new_seg[None], tf.int32)
1168
+
1169
+ else:
1170
+ # keep all other features
1171
+ features_new[key] = features[key]
1172
+
1173
+ return features_new
1174
+
1175
+ @classmethod
1176
+ def _decompose(cls, i, seg):
1177
+ return tf.cast(tf.equal(seg, i), tf.int32)
1178
+
1179
+ @classmethod
1180
+ def _compose_sort(cls, i, seg, sort, mask_s):
1181
+ return tf.cast(tf.equal(seg, sort[i]), tf.int32) * i * mask_s[i]
1182
+
1183
+
1184
+ @dataclasses.dataclass
1185
+ class SundsToTfdsVideo:
1186
+ """Lift Sunds format to TFDS video format.
1187
+
1188
+ Renames fields and adds a temporal axis.
1189
+ This op is intended to be called directly before VideoFromTfds.
1190
+ """
1191
+
1192
+ SUNDS_IMAGE_KEY = "color_image"
1193
+ SUNDS_SEGMENTATIONS_KEY = "instance_image"
1194
+ SUNDS_DEPTH_KEY = "depth_image"
1195
+
1196
+ image_key: str = SUNDS_IMAGE_KEY
1197
+ image_segmentations_key = SUNDS_SEGMENTATIONS_KEY
1198
+ video_key: str = VIDEO
1199
+ video_segmentations_key = SEGMENTATIONS
1200
+ image_depths_key: str = SUNDS_DEPTH_KEY
1201
+ depths_key = DEPTH
1202
+ video_padding_mask_key: str = VIDEO_PADDING_MASK
1203
+ force_overwrite: bool = False
1204
+
1205
+ def __call__(self, features):
1206
+ if self.video_key in features and not self.force_overwrite:
1207
+ return features
1208
+
1209
+ features_new = {}
1210
+ for k, v in features.items():
1211
+ if k == self.image_key:
1212
+ features_new[self.video_key] = v[tf.newaxis]
1213
+ elif k == self.image_segmentations_key:
1214
+ features_new[self.video_segmentations_key] = v[tf.newaxis]
1215
+ elif k == self.image_depths_key:
1216
+ features_new[self.depths_key] = v[tf.newaxis]
1217
+ else:
1218
+ features_new[k] = v
1219
+
1220
+ if self.video_padding_mask_key not in features_new:
1221
+ logging.warning("Adding default video_padding_mask")
1222
+ features_new[self.video_padding_mask_key] = tf.cast(
1223
+ tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
1224
+
1225
+ return features_new
1226
+
1227
+
1228
+ @dataclasses.dataclass
1229
+ class SubtractOneFromSegmentations:
1230
+ """Subtract one from segmentation masks. Used for MultiShapeNet-Easy."""
1231
+
1232
+ segmentations_key: str = SEGMENTATIONS
1233
+
1234
+ def __call__(self, features):
1235
+ features[self.segmentations_key] = features[self.segmentations_key] - 1
1236
+ return features
invariant_slot_attention/lib/trainer.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The main model training loop."""
17
+
18
+ import functools
19
+ import os
20
+ import time
21
+ from typing import Dict, Iterable, Mapping, Optional, Tuple, Type, Union
22
+
23
+ from absl import logging
24
+ from clu import checkpoint
25
+ from clu import metric_writers
26
+ from clu import metrics
27
+ from clu import parameter_overview
28
+ from clu import periodic_actions
29
+ import flax
30
+ from flax import linen as nn
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import ml_collections
35
+ import numpy as np
36
+ import optax
37
+
38
+ from scenic.train_lib import lr_schedules
39
+ from scenic.train_lib import optimizers
40
+
41
+ import tensorflow as tf
42
+
43
+ from invariant_slot_attention.lib import evaluator
44
+ from invariant_slot_attention.lib import input_pipeline
45
+ from invariant_slot_attention.lib import losses
46
+ from invariant_slot_attention.lib import utils
47
+
48
+ Array = jnp.ndarray
49
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
50
+ PRNGKey = Array
51
+
52
+
53
+ def train_step(
54
+ model,
55
+ tx,
56
+ rng,
57
+ step,
58
+ state_vars,
59
+ opt_state,
60
+ params,
61
+ batch,
62
+ loss_fn,
63
+ train_metrics_cls,
64
+ predicted_max_num_instances,
65
+ ground_truth_max_num_instances,
66
+ conditioning_key = None,
67
+ ):
68
+ """Perform a single training step.
69
+
70
+ Args:
71
+ model: Model used in training step.
72
+ tx: The optimizer to use to minimize loss_fn.
73
+ rng: Random number key
74
+ step: Which training step we are on.
75
+ state_vars: Accessory variables.
76
+ opt_state: The state of the optimizer.
77
+ params: The current parameters to be updated.
78
+ batch: Training inputs for this step.
79
+ loss_fn: Loss function that takes model predictions and a batch of data.
80
+ train_metrics_cls: The metrics collection for computing training metrics.
81
+ predicted_max_num_instances: Maximum number of instances in prediction.
82
+ ground_truth_max_num_instances: Maximum number of instances in ground truth,
83
+ including background (which counts as a separate instance).
84
+ conditioning_key: Optional string. If provided, defines the batch key to be
85
+ used as conditioning signal for the model. Otherwise this is inferred from
86
+ the available keys in the batch.
87
+
88
+ Returns:
89
+ Tuple of the updated opt, state_vars, new random number key,
90
+ metrics update, and step + 1. Note that some of this info is stored in
91
+ TrainState, but here it is unpacked.
92
+ """
93
+
94
+ # Split PRNGKey and bind to host / device.
95
+ new_rng, rng = jax.random.split(rng)
96
+ rng = jax.random.fold_in(rng, jax.host_id())
97
+ rng = jax.random.fold_in(rng, jax.lax.axis_index("batch"))
98
+ init_rng, dropout_rng = jax.random.split(rng, 2)
99
+
100
+ mutable_var_keys = list(state_vars.keys()) + ["intermediates"]
101
+
102
+ conditioning = batch[conditioning_key] if conditioning_key else None
103
+
104
+ def train_loss_fn(params, state_vars):
105
+ preds, mutable_vars = model.apply(
106
+ {"params": params, **state_vars}, video=batch["video"],
107
+ conditioning=conditioning, mutable=mutable_var_keys,
108
+ rngs={"state_init": init_rng, "dropout": dropout_rng}, train=True,
109
+ padding_mask=batch.get("padding_mask"))
110
+ # Filter intermediates, as we do not want to store them in the TrainState.
111
+ state_vars = utils.filter_key_from_frozen_dict(
112
+ mutable_vars, key="intermediates")
113
+ loss, loss_aux = loss_fn(preds, batch)
114
+ return loss, (state_vars, preds, loss_aux)
115
+
116
+ grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
117
+ (loss, (state_vars, preds, loss_aux)), grad = grad_fn(params, state_vars)
118
+
119
+ # Compute average gradient across multiple workers.
120
+ grad = jax.lax.pmean(grad, axis_name="batch")
121
+
122
+ updates, new_opt_state = tx.update(grad, opt_state, params)
123
+ new_params = optax.apply_updates(params, updates)
124
+
125
+ # Compute metrics.
126
+ metrics_update = train_metrics_cls.gather_from_model_output(
127
+ loss=loss,
128
+ **loss_aux,
129
+ predicted_segmentations=utils.remove_singleton_dim(
130
+ preds["outputs"].get("segmentations")), # pytype: disable=attribute-error
131
+ ground_truth_segmentations=batch.get("segmentations"),
132
+ predicted_max_num_instances=predicted_max_num_instances,
133
+ ground_truth_max_num_instances=ground_truth_max_num_instances,
134
+ padding_mask=batch.get("padding_mask"),
135
+ mask=batch.get("mask"))
136
+ return (
137
+ new_opt_state, new_params, state_vars, new_rng, metrics_update, step + 1)
138
+
139
+
140
+ def train_and_evaluate(config,
141
+ workdir):
142
+ """Runs a training and evaluation loop.
143
+
144
+ Args:
145
+ config: Configuration to use.
146
+ workdir: Working directory for checkpoints and TF summaries. If this
147
+ contains checkpoint training will be resumed from the latest checkpoint.
148
+ """
149
+ rng = jax.random.PRNGKey(config.seed)
150
+
151
+ tf.io.gfile.makedirs(workdir)
152
+
153
+ # Input pipeline.
154
+ rng, data_rng = jax.random.split(rng)
155
+ # Make sure each host uses a different RNG for the training data.
156
+ if config.get("seed_data", True): # Default to seeding data if not specified.
157
+ data_rng = jax.random.fold_in(data_rng, jax.host_id())
158
+ else:
159
+ data_rng = None
160
+ train_ds, eval_ds = input_pipeline.create_datasets(config, data_rng)
161
+ train_iter = iter(train_ds) # pytype: disable=wrong-arg-types
162
+
163
+ # Initialize model
164
+ model = utils.build_model_from_config(config.model)
165
+
166
+ # Construct TrainMetrics and EvalMetrics, metrics collections.
167
+ train_metrics_cls = utils.make_metrics_collection("TrainMetrics",
168
+ config.train_metrics_spec)
169
+ eval_metrics_cls = utils.make_metrics_collection("EvalMetrics",
170
+ config.eval_metrics_spec)
171
+
172
+ def init_model(rng):
173
+ rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
174
+
175
+ init_conditioning = None
176
+ if config.get("conditioning_key"):
177
+ init_conditioning = jnp.ones(
178
+ [1] + list(train_ds.element_spec[config.conditioning_key].shape)[2:],
179
+ jnp.int32)
180
+ init_inputs = jnp.ones(
181
+ [1] + list(train_ds.element_spec["video"].shape)[2:],
182
+ jnp.float32)
183
+ initial_vars = model.init(
184
+ {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
185
+ video=init_inputs, conditioning=init_conditioning,
186
+ padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
187
+
188
+ # Split into state variables (e.g. for batchnorm stats) and model params.
189
+ # Note that `pop()` on a FrozenDict performs a deep copy.
190
+ state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
191
+
192
+ # Filter out intermediates (we don't want to store these in the TrainState).
193
+ state_vars = utils.filter_key_from_frozen_dict(
194
+ state_vars, key="intermediates")
195
+ return state_vars, initial_params
196
+
197
+ state_vars, initial_params = init_model(rng)
198
+ parameter_overview.log_parameter_overview(initial_params) # pytype: disable=wrong-arg-types
199
+
200
+ learning_rate_fn = lr_schedules.get_learning_rate_fn(config)
201
+ tx = optimizers.get_optimizer(
202
+ config.optimizer_configs, learning_rate_fn, params=initial_params)
203
+
204
+ opt_state = tx.init(initial_params)
205
+
206
+ state = utils.TrainState(
207
+ step=1, opt_state=opt_state, params=initial_params, rng=rng,
208
+ variables=state_vars)
209
+
210
+ loss_fn = functools.partial(
211
+ losses.compute_full_loss, loss_config=config.losses)
212
+
213
+ checkpoint_dir = os.path.join(workdir, "checkpoints")
214
+ ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
215
+ state = ckpt.restore_or_initialize(state)
216
+ initial_step = int(state.step)
217
+
218
+ # Replicate our parameters.
219
+ state = flax.jax_utils.replicate(state, devices=jax.local_devices())
220
+ del rng # rng is stored in the state.
221
+
222
+ # Only write metrics on host 0, write to logs on all other hosts.
223
+ writer = metric_writers.create_default_writer(
224
+ workdir, just_logging=jax.host_id() > 0)
225
+ writer.write_hparams(utils.prepare_dict_for_logging(config.to_dict()))
226
+
227
+ logging.info("Starting training loop at step %d.", initial_step)
228
+ report_progress = periodic_actions.ReportProgress(
229
+ num_train_steps=config.num_train_steps, writer=writer)
230
+ if jax.process_index() == 0:
231
+ profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
232
+ p_train_step = jax.pmap(
233
+ train_step,
234
+ axis_name="batch",
235
+ donate_argnums=(2, 3, 4, 5, 6, 7),
236
+ static_broadcasted_argnums=(0, 1, 8, 9, 10, 11, 12))
237
+
238
+ train_metrics = None
239
+ with metric_writers.ensure_flushes(writer):
240
+ if config.num_train_steps == 0:
241
+ with report_progress.timed("eval"):
242
+ evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls, config,
243
+ writer, step=0)
244
+ with report_progress.timed("checkpoint"):
245
+ ckpt.save(flax.jax_utils.unreplicate(state))
246
+ return
247
+
248
+ for step in range(initial_step, config.num_train_steps + 1):
249
+ # `step` is a Python integer. `state.step` is JAX integer on GPU/TPU.
250
+ is_last_step = step == config.num_train_steps
251
+
252
+ with jax.profiler.StepTraceAnnotation("train", step_num=step):
253
+ batch = jax.tree_map(np.asarray, next(train_iter))
254
+ (opt_state, params, state_vars, rng, metrics_update, p_step
255
+ ) = p_train_step(
256
+ model, tx, state.rng, state.step, state.variables,
257
+ state.opt_state, state.params, batch, loss_fn,
258
+ train_metrics_cls,
259
+ config.num_slots,
260
+ config.max_instances + 1, # Incl. background.
261
+ config.get("conditioning_key"))
262
+
263
+ state = state.replace( # pytype: disable=attribute-error
264
+ opt_state=opt_state,
265
+ params=params,
266
+ step=p_step,
267
+ variables=state_vars,
268
+ rng=rng,
269
+ )
270
+
271
+ metric_update = flax.jax_utils.unreplicate(metrics_update)
272
+ train_metrics = (
273
+ metric_update
274
+ if train_metrics is None else train_metrics.merge(metric_update))
275
+
276
+ # Quick indication that training is happening.
277
+ logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
278
+ report_progress(step, time.time())
279
+
280
+ if jax.process_index() == 0:
281
+ profiler(step)
282
+
283
+ if step % config.log_loss_every_steps == 0 or is_last_step:
284
+ metrics_res = train_metrics.compute()
285
+ writer.write_scalars(step, jax.tree_map(np.array, metrics_res))
286
+ train_metrics = None
287
+
288
+ if step % config.eval_every_steps == 0 or is_last_step:
289
+ with report_progress.timed("eval"):
290
+ evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls,
291
+ config, writer, step=step)
292
+
293
+ if step % config.checkpoint_every_steps == 0 or is_last_step:
294
+ with report_progress.timed("checkpoint"):
295
+ ckpt.save(flax.jax_utils.unreplicate(state))
296
+
297
+
298
+ def evaluate(model, state, eval_ds, loss_fn_eval, eval_metrics_cls, config,
299
+ writer, step):
300
+ """Evaluate the model."""
301
+ eval_metrics, eval_batch, eval_preds = evaluator.evaluate(
302
+ model,
303
+ state,
304
+ eval_ds,
305
+ loss_fn_eval,
306
+ eval_metrics_cls,
307
+ predicted_max_num_instances=config.num_slots,
308
+ ground_truth_max_num_instances=config.max_instances + 1, # Incl. bg.
309
+ slice_size=config.get("eval_slice_size"),
310
+ slice_keys=config.get("eval_slice_keys"),
311
+ conditioning_key=config.get("conditioning_key"),
312
+ remove_from_predictions=config.get("remove_from_predictions"),
313
+ metrics_on_cpu=config.get("metrics_on_cpu", False))
314
+
315
+ metrics_res = eval_metrics.compute()
316
+ writer.write_scalars(
317
+ step, jax.tree_map(np.array, utils.flatten_named_dicttree(metrics_res)))
318
+ writer.write_images(
319
+ step,
320
+ jax.tree_map(
321
+ np.array,
322
+ utils.prepare_images_for_logging(
323
+ config,
324
+ eval_batch,
325
+ eval_preds,
326
+ n_samples=config.get("n_samples", 5),
327
+ n_frames=config.get("n_frames", 1),
328
+ min_n_colors=config.get("logging_min_n_colors", 1))))
invariant_slot_attention/lib/transforms.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Transform functions for preprocessing."""
17
+ from typing import Any, Optional, Tuple
18
+
19
+ import tensorflow as tf
20
+
21
+
22
+ SizeTuple = Tuple[tf.Tensor, tf.Tensor] # (height, width).
23
+ Self = Any
24
+
25
+ PADDING_VALUE = -1
26
+ PADDING_VALUE_STR = b""
27
+
28
+ NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding.
29
+ NOTRACK_RBOX = (0., 0., 0., 0., 0.) # No-track bounding rbox for padding.
30
+
31
+
32
+ def crop_or_pad_boxes(boxes, top, left, height,
33
+ width, h_orig, w_orig,
34
+ min_cropped_area = None):
35
+ """Transforms the relative box coordinates according to the frame crop.
36
+
37
+ Note that, if height/width are larger than h_orig/w_orig, this function
38
+ implements the equivalent of padding.
39
+
40
+ Args:
41
+ boxes: Tensor of bounding boxes with shape (..., 4).
42
+ top: Top of crop box in absolute pixel coordinates.
43
+ left: Left of crop box in absolute pixel coordinates.
44
+ height: Height of crop box in absolute pixel coordinates.
45
+ width: Width of crop box in absolute pixel coordinates.
46
+ h_orig: Original image height in absolute pixel coordinates.
47
+ w_orig: Original image width in absolute pixel coordinates.
48
+ min_cropped_area: If set, remove cropped boxes whose area relative to the
49
+ original box is less than min_cropped_area or that covers the entire
50
+ image.
51
+
52
+ Returns:
53
+ Boxes tensor with same shape as input boxes but updated values.
54
+ """
55
+ # Video track bound boxes: [num_instances, num_tracks, 4]
56
+ # Image bounding boxes: [num_instances, 4]
57
+ assert boxes.shape[-1] == 4
58
+ seq_len = tf.shape(boxes)[0]
59
+ not_padding = tf.reduce_any(tf.not_equal(boxes, PADDING_VALUE), axis=-1)
60
+ has_tracks = len(boxes.shape) == 3
61
+ if has_tracks:
62
+ num_tracks = tf.shape(boxes)[1]
63
+ else:
64
+ assert len(boxes.shape) == 2
65
+ num_tracks = 1
66
+
67
+ # Transform the box coordinates.
68
+ a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32)
69
+ b = tf.cast(tf.stack([top, left]), tf.float32)
70
+ c = tf.cast(tf.stack([height, width]), tf.float32)
71
+ boxes = tf.reshape(
72
+ (tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c,
73
+ (seq_len, num_tracks, len(NOTRACK_BOX)),
74
+ )
75
+
76
+ # Filter the valid boxes.
77
+ areas_uncropped = tf.reduce_prod(
78
+ tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1
79
+ )
80
+ boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0)
81
+ if has_tracks:
82
+ cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1)
83
+ boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX)
84
+ if min_cropped_area is not None:
85
+ areas_cropped = tf.reduce_prod(
86
+ tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1
87
+ )
88
+ boxes = tf.where(
89
+ tf.logical_and(
90
+ tf.reduce_max(areas_cropped, axis=0, keepdims=True)
91
+ > min_cropped_area * areas_uncropped,
92
+ tf.reduce_min(areas_cropped, axis=0, keepdims=True) < 1,
93
+ )[Ellipsis, tf.newaxis],
94
+ boxes,
95
+ tf.constant(NOTRACK_BOX)[tf.newaxis, tf.newaxis],
96
+ )
97
+ else:
98
+ boxes = tf.reshape(boxes, (seq_len, 4))
99
+ # Image ops use `-1``, whereas video ops above use `NOTRACK_BOX`.
100
+ boxes = tf.where(not_padding[Ellipsis, tf.newaxis], boxes, PADDING_VALUE)
101
+
102
+ return boxes
103
+
104
+
105
+ def cxcywha_to_corners(cxcywha):
106
+ """Convert [cx, cy, w, h, a] to four corners of [x, y].
107
+
108
+ TF version of cxcywha_to_corners in
109
+ third_party/py/scenic/model_lib/base_models/box_utils.py.
110
+
111
+ Args:
112
+ cxcywha: [..., 5]-tf.Tensor of [center-x, center-y, width, height, angle]
113
+ representation of rotated boxes. Angle is in radians and center of rotation
114
+ is defined by [center-x, center-y] point.
115
+
116
+ Returns:
117
+ [..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y] points.
118
+ """
119
+ assert cxcywha.shape[-1] == 5, "Expected [..., [cx, cy, w, h, a] input."
120
+ bs = cxcywha.shape[:-1]
121
+ cx, cy, w, h, a = tf.split(cxcywha, num_or_size_splits=5, axis=-1)
122
+ xs = tf.constant([.5, .5, -.5, -.5]) * w
123
+ ys = tf.constant([-.5, .5, .5, -.5]) * h
124
+ pts = tf.stack([xs, ys], axis=-1)
125
+ sin = tf.sin(a)
126
+ cos = tf.cos(a)
127
+ rot = tf.reshape(tf.concat([cos, -sin, sin, cos], axis=-1), (*bs, 2, 2))
128
+ offset = tf.reshape(tf.concat([cx, cy], -1), (*bs, 1, 2))
129
+ corners = pts @ rot + offset
130
+ return corners
131
+
132
+
133
+ def corners_to_cxcywha(corners):
134
+ """Convert four corners of [x, y] to [cx, cy, w, h, a].
135
+
136
+ Args:
137
+ corners: [..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y]
138
+ points.
139
+
140
+ Returns:
141
+ [..., 5]-tf.Tensor of [center-x, center-y, width, height, angle]
142
+ representation of rotated boxes. Angle is in radians and center of rotation
143
+ is defined by [center-x, center-y] point.
144
+ """
145
+ assert corners.shape[-2] == 4 and corners.shape[-1] == 2, (
146
+ "Expected [..., [cx, cy, w, h, a] input.")
147
+
148
+ cornersx, cornersy = tf.unstack(corners, axis=-1)
149
+ cx = tf.reduce_mean(cornersx, axis=-1)
150
+ cy = tf.reduce_mean(cornersy, axis=-1)
151
+ wcornersx = (
152
+ cornersx[Ellipsis, 0] + cornersx[Ellipsis, 1] - cornersx[Ellipsis, 2] - cornersx[Ellipsis, 3])
153
+ wcornersy = (
154
+ cornersy[Ellipsis, 0] + cornersy[Ellipsis, 1] - cornersy[Ellipsis, 2] - cornersy[Ellipsis, 3])
155
+ hcornersy = (-cornersy[Ellipsis, 0,] + cornersy[Ellipsis, 1] + cornersy[Ellipsis, 2] -
156
+ cornersy[Ellipsis, 3])
157
+ a = -tf.atan2(wcornersy, wcornersx)
158
+ cos = tf.cos(a)
159
+ w = wcornersx / (2 * cos)
160
+ h = hcornersy / (2 * cos)
161
+ cxcywha = tf.stack([cx, cy, w, h, a], axis=-1)
162
+
163
+ return cxcywha
invariant_slot_attention/lib/utils.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Common utils."""
17
+
18
+ import functools
19
+ import importlib
20
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union
21
+
22
+ from absl import logging
23
+ from clu import metrics as base_metrics
24
+
25
+ import flax
26
+ from flax import linen as nn
27
+ from flax import traverse_util
28
+
29
+ import jax
30
+ import jax.numpy as jnp
31
+ import jax.ops
32
+
33
+ import matplotlib
34
+ import matplotlib.pyplot as plt
35
+ import ml_collections
36
+ import numpy as np
37
+ import optax
38
+
39
+ import skimage.transform
40
+ import tensorflow as tf
41
+
42
+ from invariant_slot_attention.lib import metrics
43
+
44
+
45
+ Array = Any # Union[np.ndarray, jnp.ndarray]
46
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
47
+ DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet
48
+ PRNGKey = Array
49
+ ConfigAttr = Any
50
+ MetricSpec = Dict[str, str]
51
+
52
+
53
+ @flax.struct.dataclass
54
+ class TrainState:
55
+ """Data structure for checkpointing the model."""
56
+ step: int
57
+ opt_state: optax.OptState
58
+ params: ArrayTree
59
+ variables: flax.core.FrozenDict
60
+ rng: PRNGKey
61
+
62
+
63
+ METRIC_TYPE_TO_CLS = {
64
+ "loss": base_metrics.Average.from_output(name="loss"),
65
+ "ari": metrics.Ari,
66
+ "ari_nobg": metrics.AriNoBg,
67
+ }
68
+
69
+
70
+ def make_metrics_collection(
71
+ class_name,
72
+ metrics_spec):
73
+ """Create class inhering from metrics.Collection based on spec."""
74
+ metrics_dict = {}
75
+ if metrics_spec:
76
+ for m_name, m_type in metrics_spec.items():
77
+ metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type]
78
+
79
+ return flax.struct.dataclass(
80
+ type(class_name,
81
+ (base_metrics.Collection,),
82
+ {"__annotations__": metrics_dict}))
83
+
84
+
85
+ def flatten_named_dicttree(metrics_res, sep = "/"):
86
+ """Flatten dictionary."""
87
+ metrics_res_flat = {}
88
+ for k, v in traverse_util.flatten_dict(metrics_res).items():
89
+ metrics_res_flat[(sep.join(k)).strip(sep)] = v
90
+ return metrics_res_flat
91
+
92
+
93
+ def spatial_broadcast(x, resolution):
94
+ """Broadcast flat inputs to a 2D grid of a given resolution."""
95
+ # x.shape = (batch_size, features).
96
+ x = x[:, jnp.newaxis, jnp.newaxis, :]
97
+ return jnp.tile(x, [1, resolution[0], resolution[1], 1])
98
+
99
+
100
+ def time_distributed(cls, in_axes=1, axis=1):
101
+ """Wrapper for time-distributed (vmapped) application of a module."""
102
+ return nn.vmap(
103
+ cls, in_axes=in_axes, out_axes=axis, axis_name="time",
104
+ # Stack debug vars along sequence dim and broadcast params.
105
+ variable_axes={
106
+ "params": None, "intermediates": axis, "batch_stats": None},
107
+ split_rngs={"params": False, "dropout": True, "state_init": True})
108
+
109
+
110
+ def broadcast_across_batch(inputs, batch_size):
111
+ """Broadcasts inputs across a batch of examples (creates new axis)."""
112
+ return jnp.broadcast_to(
113
+ array=jnp.expand_dims(inputs, axis=0),
114
+ shape=(batch_size,) + inputs.shape)
115
+
116
+
117
+ def create_gradient_grid(
118
+ samples_per_dim, value_range = (-1.0, 1.0)
119
+ ):
120
+ """Creates a tensor with equidistant entries from -1 to +1 in each dim.
121
+
122
+ Args:
123
+ samples_per_dim: Number of points to have along each dimension.
124
+ value_range: In each dimension, points will go from range[0] to range[1]
125
+
126
+ Returns:
127
+ A tensor of shape [samples_per_dim] + [len(samples_per_dim)].
128
+ """
129
+ s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim]
130
+ pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1)
131
+ return jnp.array(pe)
132
+
133
+
134
+ def convert_to_fourier_features(inputs, basis_degree):
135
+ """Convert inputs to Fourier features, e.g. for positional encoding."""
136
+
137
+ # inputs.shape = (..., n_dims).
138
+ # inputs should be in range [-pi, pi] or [0, 2pi].
139
+ n_dims = inputs.shape[-1]
140
+
141
+ # Generate frequency basis.
142
+ freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree)
143
+ [2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1)
144
+
145
+ # x.shape = (..., n_dims * basis_degree)
146
+ x = inputs @ freq_basis # Project inputs onto frequency basis.
147
+
148
+ # Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)].
149
+ return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1))
150
+
151
+
152
+ def prepare_images_for_logging(
153
+ config,
154
+ batch = None,
155
+ preds = None,
156
+ n_samples = 5,
157
+ n_frames = 5,
158
+ min_n_colors = 1,
159
+ epsilon = 1e-6,
160
+ first_replica_only = False):
161
+ """Prepare images from batch and/or model predictions for logging."""
162
+
163
+ images = dict()
164
+ # Converts all tensors to numpy arrays to run everything on CPU as JAX
165
+ # eager mode is inefficient and because memory usage from these ops may
166
+ # lead to OOM errors.
167
+ batch = jax.tree_map(np.array, batch)
168
+ preds = jax.tree_map(np.array, preds)
169
+
170
+ if n_samples <= 0:
171
+ return images
172
+
173
+ if not first_replica_only:
174
+ # Move the two leading batch dimensions into a single dimension. We do this
175
+ # to plot the same number of examples regardless of the data parallelism.
176
+ batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch)
177
+ preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds)
178
+ else:
179
+ batch = jax.tree_map(lambda x: x[0], batch)
180
+ preds = jax.tree_map(lambda x: x[0], preds)
181
+
182
+ # Limit the tensors to n_samples and n_frames.
183
+ batch = jax.tree_map(
184
+ lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
185
+ batch)
186
+ preds = jax.tree_map(
187
+ lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
188
+ preds)
189
+
190
+ # Log input data.
191
+ if batch is not None:
192
+ images["video"] = video_to_image_grid(batch["video"])
193
+ if "segmentations" in batch:
194
+ images["mask"] = video_to_image_grid(convert_categories_to_color(
195
+ batch["segmentations"], min_n_colors=min_n_colors))
196
+ if "flow" in batch:
197
+ images["flow"] = video_to_image_grid(batch["flow"])
198
+ if "boxes" in batch:
199
+ images["boxes"] = draw_bounding_boxes(
200
+ batch["video"],
201
+ batch["boxes"],
202
+ min_n_colors=min_n_colors)
203
+
204
+ # Log model predictions.
205
+ if preds is not None and preds.get("outputs") is not None:
206
+ if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error
207
+ images["segmentations"] = video_to_image_grid(
208
+ convert_categories_to_color(
209
+ preds["outputs"]["segmentations"], min_n_colors=min_n_colors))
210
+
211
+ def shape_fn(x):
212
+ if isinstance(x, (np.ndarray, jnp.ndarray)):
213
+ return x.shape
214
+
215
+ # Log intermediate variables.
216
+ if preds is not None and "intermediates" in preds:
217
+
218
+ logging.info("intermediates: %s",
219
+ jax.tree_map(shape_fn, preds["intermediates"]))
220
+
221
+ for key, path in config.debug_var_video_paths.items():
222
+ log_vars = retrieve_from_collection(preds["intermediates"], path)
223
+ if log_vars is not None:
224
+ if not isinstance(log_vars, Sequence):
225
+ log_vars = [log_vars]
226
+ for i, log_var in enumerate(log_vars):
227
+ log_var = np.array(log_var) # Moves log_var to CPU.
228
+ images[key + "_" + str(i)] = video_to_image_grid(log_var)
229
+ else:
230
+ logging.warning("%s not found in intermediates", path)
231
+
232
+ # Log attention weights.
233
+ for key, path in config.debug_var_attn_paths.items():
234
+ log_vars = retrieve_from_collection(preds["intermediates"], path)
235
+ if log_vars is not None:
236
+ if not isinstance(log_vars, Sequence):
237
+ log_vars = [log_vars]
238
+ for i, log_var in enumerate(log_vars):
239
+ log_var = np.array(log_var) # Moves log_var to CPU.
240
+ images.update(
241
+ prepare_attention_maps_for_logging(
242
+ attn_maps=log_var,
243
+ key=key + "_" + str(i),
244
+ map_width=config.debug_var_attn_widths.get(key),
245
+ video=batch["video"],
246
+ epsilon=epsilon,
247
+ n_samples=n_samples,
248
+ n_frames=n_frames))
249
+ else:
250
+ logging.warning("%s not found in intermediates", path)
251
+
252
+ # Crop each image to a maximum of 3 channels for RGB visualization.
253
+ for key, image in images.items():
254
+ if image.shape[-1] > 3:
255
+ logging.warning("Truncating channels of %s for visualization.", key)
256
+ images[key] = image[Ellipsis, :3]
257
+
258
+ return images
259
+
260
+
261
+ def prepare_attention_maps_for_logging(attn_maps, key,
262
+ map_width, epsilon,
263
+ n_samples, n_frames,
264
+ video):
265
+ """Visualize (overlayed) attention maps as an image grid."""
266
+ images = {} # Results dictionary.
267
+ attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width)
268
+
269
+ num_heads = attn_maps.shape[2]
270
+ for head_idx in range(num_heads):
271
+ attn = attn_maps[:n_samples, :n_frames, head_idx]
272
+ attn /= attn.max() + epsilon # Standardizes scale for visualization.
273
+ # attn.shape: [bs, seq_len, 11, h', w', 1]
274
+
275
+ bs, seq_len, _, h_attn, w_attn, _ = attn.shape
276
+ images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn)
277
+
278
+ # Attention maps are interpretable when they align with object boundaries.
279
+ # However, if they are overly smooth then the following visualization which
280
+ # overlays attention maps on video is helpful.
281
+ video = video[:n_samples, :n_frames]
282
+ # video.shape: [bs, seq_len, h, w, 3]
283
+ video_resized = []
284
+ for i in range(n_samples):
285
+ for j in range(n_frames):
286
+ video_resized.append(
287
+ skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1))
288
+ video_resized = np.array(video_resized).reshape(
289
+ (bs, seq_len, h_attn, w_attn, 3))
290
+ attn_overlayed = attn * np.expand_dims(video_resized, 2)
291
+ images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid(
292
+ attn_overlayed)
293
+
294
+ return images
295
+
296
+
297
+ def convert_categories_to_color(
298
+ inputs, min_n_colors = 1, include_black = True):
299
+ """Converts int-valued categories to color in last axis of input tensor.
300
+
301
+ Args:
302
+ inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the
303
+ categories.
304
+ min_n_colors: Minimum number of colors (excl. black) to encode categories.
305
+ include_black: Include black as 0-th entry in the color palette. Increases
306
+ `min_n_colors` by 1 if True.
307
+
308
+ Returns:
309
+ `np.ndarray` with RGB colors in last axis.
310
+ """
311
+ if inputs.shape[-1] == 1: # Strip category axis.
312
+ inputs = np.squeeze(inputs, axis=-1)
313
+ inputs = np.array(inputs, dtype=np.int32) # Convert to int.
314
+
315
+ # Infer number of colors from inputs.
316
+ n_colors = int(inputs.max()) + 1 # One color per category incl. 0.
317
+ if include_black:
318
+ n_colors -= 1 # If we include black, we need one color less.
319
+
320
+ if min_n_colors > n_colors: # Use more colors in color palette if requested.
321
+ n_colors = min_n_colors
322
+
323
+ rgb_colors = get_uniform_colors(n_colors)
324
+
325
+ if include_black: # Add black as color for zero-th index.
326
+ rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0)
327
+ return rgb_colors[inputs]
328
+
329
+
330
+ def get_uniform_colors(n_colors):
331
+ """Get n_colors with uniformly spaced hues."""
332
+ hues = np.linspace(0, 1, n_colors, endpoint=False)
333
+ hsv_colors = np.concatenate(
334
+ (np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1)
335
+ rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors)
336
+ return rgb_colors # rgb_colors.shape = (n_colors, 3)
337
+
338
+
339
+ def unflatten_image(image, width = None):
340
+ """Unflatten image array of shape [batch_dims..., height*width, channels]."""
341
+ n_channels = image.shape[-1]
342
+ # If width is not provided, we assume that the image is square.
343
+ if width is None:
344
+ width = int(np.floor(np.sqrt(image.shape[-2])))
345
+ height = width
346
+ assert width * height == image.shape[-2], "Image is not square."
347
+ else:
348
+ height = image.shape[-2] // width
349
+ return image.reshape(image.shape[:-2] + (height, width, n_channels))
350
+
351
+
352
+ def video_to_image_grid(video):
353
+ """Transform video to image grid by folding sequence dim along width."""
354
+ if len(video.shape) == 5:
355
+ n_samples, n_frames, height, width, n_channels = video.shape
356
+ video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height.
357
+ image_grid = np.reshape(
358
+ video, (n_samples, height, n_frames * width, n_channels))
359
+ elif len(video.shape) == 6:
360
+ n_samples, n_frames, n_slots, height, width, n_channels = video.shape
361
+ # Put n_frames next to width.
362
+ video = np.transpose(video, (0, 2, 3, 1, 4, 5))
363
+ image_grid = np.reshape(
364
+ video, (n_samples, n_slots * height, n_frames * width, n_channels))
365
+ else:
366
+ raise ValueError("Unsupported video shape for visualization.")
367
+ return image_grid
368
+
369
+
370
+ def draw_bounding_boxes(video,
371
+ boxes,
372
+ min_n_colors = 1,
373
+ include_black = True):
374
+ """Draw bounding boxes in videos."""
375
+ colors = get_uniform_colors(min_n_colors - include_black)
376
+
377
+ b, t, h, w, c = video.shape
378
+ n = boxes.shape[2]
379
+ image_grid = tf.image.draw_bounding_boxes(
380
+ np.reshape(video, (b * t, h, w, c)),
381
+ np.reshape(boxes, (b * t, n, 4)),
382
+ colors).numpy()
383
+ image_grid = np.reshape(
384
+ np.transpose(np.reshape(image_grid, (b, t, h, w, c)),
385
+ (0, 2, 1, 3, 4)),
386
+ (b, h, t * w, c))
387
+ return image_grid
388
+
389
+
390
+ def plot_image(ax, image):
391
+ """Add an image visualization to a provided `plt.Axes` instance."""
392
+ num_channels = image.shape[-1]
393
+ if num_channels == 1:
394
+ image = image.reshape(image.shape[:2])
395
+ ax.imshow(image, cmap="viridis")
396
+ ax.grid(False)
397
+ plt.axis("off")
398
+
399
+
400
+ def visualize_image_dict(images, plot_scale = 10):
401
+ """Visualize a dictionary of images in colab using maptlotlib."""
402
+
403
+ for key in images.keys():
404
+ logging.info("Visualizing key: %s", key)
405
+ n_images = len(images[key])
406
+ fig = plt.figure(figsize=(n_images * plot_scale, plot_scale))
407
+ for idx, image in enumerate(images[key]):
408
+ ax = fig.add_subplot(1, n_images, idx+1)
409
+ plot_image(ax, image)
410
+ plt.show()
411
+
412
+
413
+ def filter_key_from_frozen_dict(
414
+ frozen_dict, key):
415
+ """Filters (removes) an item by key from a flax.core.FrozenDict."""
416
+ if key in frozen_dict:
417
+ frozen_dict, _ = frozen_dict.pop(key)
418
+ return frozen_dict
419
+
420
+
421
+ def prepare_dict_for_logging(nested_dict, parent_key = "",
422
+ sep = "_"):
423
+ """Prepare a nested dictionary for logging with `clu.metric_writers`.
424
+
425
+ Args:
426
+ nested_dict: A nested dictionary, e.g. obtained from a
427
+ `ml_collections.ConfigDict` via `.to_dict()`.
428
+ parent_key: String used in recursion.
429
+ sep: String used to separate parent and child keys.
430
+
431
+ Returns:
432
+ Flattened dict.
433
+ """
434
+ items = []
435
+ for k, v in nested_dict.items():
436
+ # Flatten keys of nested elements.
437
+ new_key = parent_key + sep + k if parent_key else k
438
+
439
+ # Convert None values, lists and tuples to strings.
440
+ if v is None:
441
+ v = "None"
442
+ if isinstance(v, list) or isinstance(v, tuple):
443
+ v = str(v)
444
+
445
+ # Recursively flatten the dict.
446
+ if isinstance(v, dict):
447
+ items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items())
448
+ else:
449
+ items.append((new_key, v))
450
+ return dict(items)
451
+
452
+
453
+ def retrieve_from_collection(
454
+ variable_collection, path):
455
+ """Finds variables by their path by recursively searching the collection.
456
+
457
+ Args:
458
+ variable_collection: Nested dict containing the variables (or tuples/lists
459
+ of variables).
460
+ path: Path to variable in module tree, similar to Unix file names (e.g.
461
+ '/module/dense/0/bias').
462
+
463
+ Returns:
464
+ The requested variable, variable collection or None (in case the variable
465
+ could not be found).
466
+ """
467
+ key, _, rpath = path.strip("/").partition("/")
468
+
469
+ # In case the variable is not found, we return None.
470
+ if (key.isdigit() and not isinstance(variable_collection, Sequence)) or (
471
+ key.isdigit() and int(key) >= len(variable_collection)) or (
472
+ not key.isdigit() and key not in variable_collection):
473
+ return None
474
+
475
+ if key.isdigit():
476
+ key = int(key)
477
+
478
+ if not rpath:
479
+ return variable_collection[key]
480
+ else:
481
+ return retrieve_from_collection(variable_collection[key], rpath)
482
+
483
+
484
+ def build_model_from_config(config):
485
+ """Build a Flax model from a (nested) ConfigDict."""
486
+ model_constructor = _parse_config(config)
487
+ if callable(model_constructor):
488
+ return model_constructor()
489
+ else:
490
+ raise ValueError("Provided config does not contain module constructors.")
491
+
492
+
493
+ def _parse_config(config
494
+ ):
495
+ """Recursively parses a nested ConfigDict and resolves module constructors."""
496
+
497
+ if isinstance(config, list):
498
+ return [_parse_config(c) for c in config]
499
+ elif isinstance(config, tuple):
500
+ return tuple([_parse_config(c) for c in config])
501
+ elif not isinstance(config, ml_collections.ConfigDict):
502
+ return config
503
+ elif "module" in config:
504
+ module_constructor = _resolve_module_constructor(config.module)
505
+ kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"}
506
+ return functools.partial(module_constructor, **kwargs)
507
+ else:
508
+ return {k: _parse_config(v) for k, v in config.items()}
509
+
510
+
511
+ def _resolve_module_constructor(
512
+ constructor_str):
513
+ import_str, _, module_name = constructor_str.rpartition(".")
514
+ py_module = importlib.import_module(import_str)
515
+ return getattr(py_module, module_name)
516
+
517
+
518
+ def get_slices_along_axis(
519
+ inputs,
520
+ slice_keys,
521
+ start_idx = 0,
522
+ end_idx = -1,
523
+ axis = 2,
524
+ pad_value = 0):
525
+ """Extracts slices from a dictionary of tensors along the specified axis.
526
+
527
+ The slice operation is only applied to `slice_keys` dictionary keys. If
528
+ `end_idx` is larger than the actual size of the specified axis, padding is
529
+ added (with values provided in `pad_value`).
530
+
531
+ Args:
532
+ inputs: Dictionary of tensors.
533
+ slice_keys: Iterable of strings, the keys for the inputs dictionary for
534
+ which to apply the slice operation.
535
+ start_idx: Integer, defining the first index to be part of the slice.
536
+ end_idx: Integer, defining the end of the slice interval (exclusive). If set
537
+ to `-1`, the end index is set to the size of the axis. If a value is
538
+ provided that is larger than the size of the axis, zero-padding is added
539
+ for the remaining elements.
540
+ axis: Integer, the axis along which to slice.
541
+ pad_value: Integer, value to be used in padding.
542
+
543
+ Returns:
544
+ Dictionary of tensors where elements described in `slice_keys` are sliced,
545
+ and all other elements are returned as original.
546
+ """
547
+
548
+ max_size = None
549
+ pad_size = 0
550
+
551
+ # Check shapes and get maximum size of requested axis.
552
+ for key in slice_keys:
553
+ curr_size = inputs[key].shape[axis]
554
+ if max_size is None:
555
+ max_size = curr_size
556
+ elif max_size != curr_size:
557
+ raise ValueError(
558
+ "For specified tensors the requested axis needs to be of equal size.")
559
+
560
+ # Infer end index if not provided.
561
+ if end_idx == -1:
562
+ end_idx = max_size
563
+
564
+ # Set padding size if end index is larger than maximum size of requested axis.
565
+ elif end_idx > max_size:
566
+ pad_size = end_idx - max_size
567
+ end_idx = max_size
568
+
569
+ outputs = {}
570
+ for key in slice_keys:
571
+ outputs[key] = np.take(
572
+ inputs[key], indices=np.arange(start_idx, end_idx), axis=axis)
573
+
574
+ # Add padding if necessary.
575
+ if pad_size > 0:
576
+ pad_shape = np.array(outputs[key].shape)
577
+ np.put(pad_shape, axis, pad_size) # In-place op.
578
+ padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype)
579
+ outputs[key] = np.concatenate((outputs[key], padding), axis=axis)
580
+
581
+ return outputs
582
+
583
+
584
+ def get_element_by_str(
585
+ dictionary, multilevel_key, separator = "/"
586
+ ):
587
+ """Gets element in a dictionary with multilevel key (e.g., "key1/key2")."""
588
+ keys = multilevel_key.split(separator)
589
+ if len(keys) == 1:
590
+ return dictionary[keys[0]]
591
+ return get_element_by_str(
592
+ dictionary[keys[0]], separator.join(keys[1:]), separator=separator)
593
+
594
+
595
+ def set_element_by_str(
596
+ dictionary, multilevel_key, new_value,
597
+ separator = "/"):
598
+ """Sets element in a dictionary with multilevel key (e.g., "key1/key2")."""
599
+ keys = multilevel_key.split(separator)
600
+ if len(keys) == 1:
601
+ if keys[0] not in dictionary:
602
+ key_error = (
603
+ "Pretrained {key} was not found in trained model. "
604
+ "Make sure you are loading the correct pretrained model "
605
+ "or consider adding {key} to exceptions.")
606
+ raise KeyError(key_error.format(type="parameter", key=keys[0]))
607
+ dictionary[keys[0]] = new_value
608
+ else:
609
+ set_element_by_str(
610
+ dictionary[keys[0]],
611
+ separator.join(keys[1:]),
612
+ new_value,
613
+ separator=separator)
614
+
615
+
616
+ def remove_singleton_dim(inputs):
617
+ """Removes the final dimension if it is singleton (i.e. of size 1)."""
618
+ if inputs is None:
619
+ return None
620
+ if inputs.shape[-1] != 1:
621
+ logging.warning("Expected final dimension of inputs to be 1, "
622
+ "received inputs of shape %s: ", str(inputs.shape))
623
+ return inputs
624
+ return inputs[Ellipsis, 0]
625
+
invariant_slot_attention/modules/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Module library."""
17
+ # pylint: disable=g-multiple-import
18
+ # pylint: disable=g-bad-import-order
19
+ # Re-export commonly used modules and functions
20
+
21
+ from .attention import (GeneralizedDotProductAttention,
22
+ InvertedDotProductAttention, SlotAttention,
23
+ TransformerBlock, Transformer)
24
+ from .convolution import (SimpleCNN, CNN)
25
+ from .decoders import (SpatialBroadcastDecoder, SiameseSpatialBroadcastDecoder)
26
+ from .initializers import (GaussianStateInit, ParamStateInit,
27
+ SegmentationEncoderStateInit,
28
+ CoordinateEncoderStateInit)
29
+ from .misc import (Dense, GRU, Identity, MLP, PositionEmbedding, Readout,
30
+ RelativePositionEmbedding)
31
+ from .video import (CorrectorPredictorTuple, FrameEncoder, Processor, SAVi)
32
+ from .resnet import (ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
33
+ ResNet200)
34
+ from .invariant_attention import (InvertedDotProductAttentionKeyPerQuery,
35
+ SlotAttentionExplicitStats,
36
+ SlotAttentionPosKeysValues,
37
+ SlotAttentionTranslEquiv,
38
+ SlotAttentionTranslScaleEquiv,
39
+ SlotAttentionTranslRotScaleEquiv)
40
+ from .invariant_initializers import (
41
+ ParamStateInitRandomPositions,
42
+ ParamStateInitRandomPositionsScales,
43
+ ParamStateInitRandomPositionsRotationsScales,
44
+ ParamStateInitLearnablePositions,
45
+ ParamStateInitLearnablePositionsScales,
46
+ ParamStateInitLearnablePositionsRotationsScales)
47
+
48
+
49
+ # pylint: enable=g-multiple-import
invariant_slot_attention/modules/attention.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Attention module library."""
17
+
18
+ import functools
19
+ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
20
+
21
+ from flax import linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from invariant_slot_attention.modules import misc
25
+
26
+ Shape = Tuple[int]
27
+
28
+ DType = Any
29
+ Array = Any # jnp.ndarray
30
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
31
+ ProcessorState = ArrayTree
32
+ PRNGKey = Array
33
+ NestedDict = Dict[str, Any]
34
+
35
+
36
+ class SlotAttention(nn.Module):
37
+ """Slot Attention module.
38
+
39
+ Note: This module uses pre-normalization by default.
40
+ """
41
+
42
+ num_iterations: int = 1
43
+ qkv_size: Optional[int] = None
44
+ mlp_size: Optional[int] = None
45
+ epsilon: float = 1e-8
46
+ num_heads: int = 1
47
+
48
+ @nn.compact
49
+ def __call__(self, slots, inputs,
50
+ padding_mask = None,
51
+ train = False):
52
+ """Slot Attention module forward pass."""
53
+ del padding_mask, train # Unused.
54
+
55
+ qkv_size = self.qkv_size or slots.shape[-1]
56
+ head_dim = qkv_size // self.num_heads
57
+ dense = functools.partial(nn.DenseGeneral,
58
+ axis=-1, features=(self.num_heads, head_dim),
59
+ use_bias=False)
60
+
61
+ # Shared modules.
62
+ dense_q = dense(name="general_dense_q_0")
63
+ layernorm_q = nn.LayerNorm()
64
+ inverted_attention = InvertedDotProductAttention(
65
+ norm_type="mean", multi_head=self.num_heads > 1)
66
+ gru = misc.GRU()
67
+
68
+ if self.mlp_size is not None:
69
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
70
+
71
+ # inputs.shape = (..., n_inputs, inputs_size).
72
+ inputs = nn.LayerNorm()(inputs)
73
+ # k.shape = (..., n_inputs, slot_size).
74
+ k = dense(name="general_dense_k_0")(inputs)
75
+ # v.shape = (..., n_inputs, slot_size).
76
+ v = dense(name="general_dense_v_0")(inputs)
77
+
78
+ # Multiple rounds of attention.
79
+ for _ in range(self.num_iterations):
80
+
81
+ # Inverted dot-product attention.
82
+ slots_n = layernorm_q(slots)
83
+ q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
84
+ updates = inverted_attention(query=q, key=k, value=v)
85
+
86
+ # Recurrent update.
87
+ slots = gru(slots, updates)
88
+
89
+ # Feedforward block with pre-normalization.
90
+ if self.mlp_size is not None:
91
+ slots = mlp(slots)
92
+
93
+ return slots
94
+
95
+
96
+ class InvertedDotProductAttention(nn.Module):
97
+ """Inverted version of dot-product attention (softmax over query axis)."""
98
+
99
+ norm_type: Optional[str] = "mean" # mean, layernorm, or None
100
+ multi_head: bool = False
101
+ epsilon: float = 1e-8
102
+ dtype: DType = jnp.float32
103
+ precision: Optional[jax.lax.Precision] = None
104
+ return_attn_weights: bool = False
105
+
106
+ @nn.compact
107
+ def __call__(self, query, key, value,
108
+ train = False):
109
+ """Computes inverted dot-product attention.
110
+
111
+ Args:
112
+ query: Queries with shape of `[batch..., q_num, qk_features]`.
113
+ key: Keys with shape of `[batch..., kv_num, qk_features]`.
114
+ value: Values with shape of `[batch..., kv_num, v_features]`.
115
+ train: Indicating whether we're training or evaluating.
116
+
117
+ Returns:
118
+ Output of shape `[batch_size..., n_queries, v_features]`
119
+ """
120
+ del train # Unused.
121
+
122
+ attn = GeneralizedDotProductAttention(
123
+ inverted_attn=True,
124
+ renormalize_keys=True if self.norm_type == "mean" else False,
125
+ epsilon=self.epsilon,
126
+ dtype=self.dtype,
127
+ precision=self.precision,
128
+ return_attn_weights=True)
129
+
130
+ # Apply attention mechanism.
131
+ output, attn = attn(query=query, key=key, value=value)
132
+
133
+ if self.multi_head:
134
+ # Multi-head aggregation. Equivalent to concat + dense layer.
135
+ output = nn.DenseGeneral(features=output.shape[-1], axis=(-2, -1))(output)
136
+ else:
137
+ # Remove head dimension.
138
+ output = jnp.squeeze(output, axis=-2)
139
+ attn = jnp.squeeze(attn, axis=-3)
140
+
141
+ if self.norm_type == "layernorm":
142
+ output = nn.LayerNorm()(output)
143
+
144
+ if self.return_attn_weights:
145
+ return output, attn
146
+
147
+ return output
148
+
149
+
150
+ class GeneralizedDotProductAttention(nn.Module):
151
+ """Multi-head dot-product attention with customizable normalization axis.
152
+
153
+ This module supports logging of attention weights in a variable collection.
154
+ """
155
+
156
+ dtype: DType = jnp.float32
157
+ precision: Optional[jax.lax.Precision] = None
158
+ epsilon: float = 1e-8
159
+ inverted_attn: bool = False
160
+ renormalize_keys: bool = False
161
+ attn_weights_only: bool = False
162
+ return_attn_weights: bool = False
163
+
164
+ @nn.compact
165
+ def __call__(self, query, key, value,
166
+ train = False, **kwargs
167
+ ):
168
+ """Computes multi-head dot-product attention given query, key, and value.
169
+
170
+ Args:
171
+ query: Queries with shape of `[batch..., q_num, num_heads, qk_features]`.
172
+ key: Keys with shape of `[batch..., kv_num, num_heads, qk_features]`.
173
+ value: Values with shape of `[batch..., kv_num, num_heads, v_features]`.
174
+ train: Indicating whether we're training or evaluating.
175
+ **kwargs: Additional keyword arguments are required when used as attention
176
+ function in nn.MultiHeadDotProductAttention, but they will be ignored
177
+ here.
178
+
179
+ Returns:
180
+ Output of shape `[batch..., q_num, num_heads, v_features]`.
181
+ """
182
+
183
+ assert query.ndim == key.ndim == value.ndim, (
184
+ "Queries, keys, and values must have the same rank.")
185
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
186
+ "Query, key, and value batch dimensions must match.")
187
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
188
+ "Query, key, and value num_heads dimensions must match.")
189
+ assert key.shape[-3] == value.shape[-3], (
190
+ "Key and value cardinality dimensions must match.")
191
+ assert query.shape[-1] == key.shape[-1], (
192
+ "Query and key feature dimensions must match.")
193
+
194
+ if kwargs.get("bias") is not None:
195
+ raise NotImplementedError(
196
+ "Support for masked attention is not yet implemented.")
197
+
198
+ if "dropout_rate" in kwargs:
199
+ if kwargs["dropout_rate"] > 0.:
200
+ raise NotImplementedError("Support for dropout is not yet implemented.")
201
+
202
+ # Temperature normalization.
203
+ qk_features = query.shape[-1]
204
+ query = query / jnp.sqrt(qk_features).astype(self.dtype)
205
+
206
+ # attn.shape = (batch..., num_heads, q_num, kv_num)
207
+ attn = jnp.einsum("...qhd,...khd->...hqk", query, key,
208
+ precision=self.precision)
209
+
210
+ if self.inverted_attn:
211
+ attention_axis = -2 # Query axis.
212
+ else:
213
+ attention_axis = -1 # Key axis.
214
+
215
+ # Softmax normalization (by default over key axis).
216
+ attn = jax.nn.softmax(attn, axis=attention_axis).astype(self.dtype)
217
+
218
+ # Defines intermediate for logging.
219
+ if not train:
220
+ self.sow("intermediates", "attn", attn)
221
+
222
+ if self.renormalize_keys:
223
+ # Corresponds to value aggregation via weighted mean (as opposed to sum).
224
+ normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon
225
+ attn = attn / normalizer
226
+
227
+ if self.attn_weights_only:
228
+ return attn
229
+
230
+ # Aggregate values using a weighted sum with weights provided by `attn`.
231
+ output = jnp.einsum(
232
+ "...hqk,...khd->...qhd", attn, value, precision=self.precision)
233
+
234
+ if self.return_attn_weights:
235
+ return output, attn
236
+
237
+ return output
238
+
239
+
240
+ class Transformer(nn.Module):
241
+ """Transformer with multiple blocks."""
242
+
243
+ num_heads: int
244
+ qkv_size: int
245
+ mlp_size: int
246
+ num_layers: int
247
+ pre_norm: bool = False
248
+
249
+ @nn.compact
250
+ def __call__(self, queries, inputs = None,
251
+ padding_mask = None,
252
+ train = False):
253
+ x = queries
254
+ for lyr in range(self.num_layers):
255
+ x = TransformerBlock(
256
+ num_heads=self.num_heads, qkv_size=self.qkv_size,
257
+ mlp_size=self.mlp_size, pre_norm=self.pre_norm,
258
+ name=f"TransformerBlock{lyr}")( # pytype: disable=wrong-arg-types
259
+ x, inputs, padding_mask, train)
260
+ return x
261
+
262
+
263
+ class TransformerBlock(nn.Module):
264
+ """Transformer decoder block."""
265
+
266
+ num_heads: int
267
+ qkv_size: int
268
+ mlp_size: int
269
+ pre_norm: bool = False
270
+
271
+ @nn.compact
272
+ def __call__(self, queries, inputs = None,
273
+ padding_mask = None,
274
+ train = False):
275
+ del padding_mask # Unused.
276
+ assert queries.ndim == 3
277
+
278
+ attention_fn = GeneralizedDotProductAttention()
279
+
280
+ attn = functools.partial(
281
+ nn.MultiHeadDotProductAttention,
282
+ num_heads=self.num_heads,
283
+ qkv_features=self.qkv_size,
284
+ attention_fn=attention_fn)
285
+
286
+ mlp = misc.MLP(hidden_size=self.mlp_size) # type: ignore
287
+
288
+ if self.pre_norm:
289
+ # Self-attention on queries.
290
+ x = nn.LayerNorm()(queries)
291
+ x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
292
+ x = x + queries
293
+
294
+ # Cross-attention on inputs.
295
+ if inputs is not None:
296
+ assert inputs.ndim == 3
297
+ y = nn.LayerNorm()(x)
298
+ y = attn()(inputs_q=y, inputs_kv=inputs, deterministic=not train)
299
+ y = y + x
300
+ else:
301
+ y = x
302
+
303
+ # MLP
304
+ z = nn.LayerNorm()(y)
305
+ z = mlp(z, train)
306
+ z = z + y
307
+ else:
308
+ # Self-attention on queries.
309
+ x = queries
310
+ x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
311
+ x = x + queries
312
+ x = nn.LayerNorm()(x)
313
+
314
+ # Cross-attention on inputs.
315
+ if inputs is not None:
316
+ assert inputs.ndim == 3
317
+ y = attn()(inputs_q=x, inputs_kv=inputs, deterministic=not train)
318
+ y = y + x
319
+ y = nn.LayerNorm()(y)
320
+ else:
321
+ y = x
322
+
323
+ # MLP.
324
+ z = mlp(y, train)
325
+ z = z + y
326
+ z = nn.LayerNorm()(z)
327
+ return z
invariant_slot_attention/modules/convolution.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convolutional module library."""
17
+
18
+ import functools
19
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
20
+
21
+ from flax import linen as nn
22
+ import jax
23
+
24
+ Shape = Tuple[int]
25
+
26
+ DType = Any
27
+ Array = Any # jnp.ndarray
28
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
29
+ ProcessorState = ArrayTree
30
+ PRNGKey = Array
31
+ NestedDict = Dict[str, Any]
32
+
33
+
34
+ class SimpleCNN(nn.Module):
35
+ """Simple CNN encoder with multiple Conv+ReLU layers."""
36
+
37
+ features: Sequence[int]
38
+ kernel_size: Sequence[Tuple[int, int]]
39
+ strides: Sequence[Tuple[int, int]]
40
+ transpose: bool = False
41
+ use_batch_norm: bool = False
42
+ axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
43
+ padding: Union[str, Iterable[Tuple[int, int]]] = "SAME"
44
+ resize_output: Optional[Iterable[int]] = None
45
+
46
+ @nn.compact
47
+ def __call__(self, inputs, train = False):
48
+ num_layers = len(self.features)
49
+ assert len(self.kernel_size) == num_layers, (
50
+ "len(kernel_size) and len(features) must match.")
51
+ assert len(self.strides) == num_layers, (
52
+ "len(strides) and len(features) must match.")
53
+ assert num_layers >= 1, "Need to have at least one layer."
54
+
55
+ if self.transpose:
56
+ conv_module = nn.ConvTranspose
57
+ else:
58
+ conv_module = nn.Conv
59
+
60
+ x = conv_module(
61
+ name="conv_simple_0",
62
+ features=self.features[0],
63
+ kernel_size=self.kernel_size[0],
64
+ strides=self.strides[0],
65
+ use_bias=False if self.use_batch_norm else True,
66
+ padding=self.padding)(inputs)
67
+
68
+ for i in range(1, num_layers):
69
+ if self.use_batch_norm:
70
+ x = nn.BatchNorm(
71
+ momentum=0.9, use_running_average=not train,
72
+ axis_name=self.axis_name, name=f"bn_simple_{i-1}")(x)
73
+
74
+ x = nn.relu(x)
75
+ x = conv_module(
76
+ name=f"conv_simple_{i}",
77
+ features=self.features[i],
78
+ kernel_size=self.kernel_size[i],
79
+ strides=self.strides[i],
80
+ use_bias=False if (
81
+ self.use_batch_norm and i < (num_layers-1)) else True,
82
+ padding=self.padding)(x)
83
+
84
+ if self.resize_output:
85
+ x = jax.image.resize(
86
+ x, list(x.shape[:-3]) + list(self.resize_output) + [x.shape[-1]],
87
+ method=jax.image.ResizeMethod.LINEAR)
88
+ return x
89
+
90
+
91
+ class CNN(nn.Module):
92
+ """Flexible CNN model with Conv/Normalization/Pooling layers."""
93
+
94
+ features: Sequence[int]
95
+ kernel_size: Sequence[Tuple[int, int]]
96
+ strides: Sequence[Tuple[int, int]]
97
+ max_pool_strides: Sequence[Tuple[int, int]]
98
+ layer_transpose: Sequence[bool]
99
+ activation_fn: Callable[[Array], Array] = nn.relu
100
+ norm_type: Optional[str] = None
101
+ axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
102
+ output_size: Optional[int] = None
103
+
104
+ @nn.compact
105
+ def __call__(self, inputs, train = False):
106
+ num_layers = len(self.features)
107
+
108
+ assert num_layers >= 1, "Need to have at least one layer."
109
+ assert len(self.kernel_size) == num_layers, (
110
+ "len(kernel_size) and len(features) must match.")
111
+ assert len(self.strides) == num_layers, (
112
+ "len(strides) and len(features) must match.")
113
+ assert len(self.max_pool_strides) == num_layers, (
114
+ "len(max_pool_strides) and len(features) must match.")
115
+ assert len(self.layer_transpose) == num_layers, (
116
+ "len(layer_transpose) and len(features) must match.")
117
+
118
+ if self.norm_type:
119
+ assert self.norm_type in {"batch", "group", "instance", "layer"}, (
120
+ f"{self.norm_type} is unrecognizaed normalization")
121
+
122
+ # Whether transpose conv or regular conv.
123
+ conv_module = {False: nn.Conv, True: nn.ConvTranspose}
124
+
125
+ if self.norm_type == "batch":
126
+ norm_module = functools.partial(
127
+ nn.BatchNorm, momentum=0.9, use_running_average=not train,
128
+ axis_name=self.axis_name)
129
+ elif self.norm_type == "group":
130
+ norm_module = functools.partial(
131
+ nn.GroupNorm, num_groups=32)
132
+ elif self.norm_type == "layer":
133
+ norm_module = nn.LayerNorm
134
+
135
+ x = inputs
136
+ for i in range(num_layers):
137
+ x = conv_module[self.layer_transpose[i]](
138
+ name=f"conv_{i}",
139
+ features=self.features[i],
140
+ kernel_size=self.kernel_size[i],
141
+ strides=self.strides[i],
142
+ use_bias=False if self.norm_type else True)(x)
143
+
144
+ # Normalization layer.
145
+ if self.norm_type:
146
+ if self.norm_type == "instance":
147
+ x = nn.GroupNorm(
148
+ num_groups=self.features[i],
149
+ name=f"{self.norm_type}_norm_{i}")(x)
150
+ else:
151
+ norm_module(name=f"{self.norm_type}_norm_{i}")(x)
152
+
153
+ # Activation layer.
154
+ x = self.activation_fn(x)
155
+
156
+ # Max pooling layer.
157
+ x = x if self.max_pool_strides[i] == (1, 1) else nn.max_pool(
158
+ x, self.max_pool_strides[i], strides=self.max_pool_strides[i],
159
+ padding="SAME")
160
+
161
+ # Final dense layer.
162
+ if self.output_size:
163
+ x = nn.Dense(self.output_size, name="output_layer", use_bias=True)(x)
164
+ return x
invariant_slot_attention/modules/decoders.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Decoder module library."""
17
+ import functools
18
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
19
+
20
+ from flax import linen as nn
21
+
22
+ import jax.numpy as jnp
23
+
24
+ from invariant_slot_attention.lib import utils
25
+ from invariant_slot_attention.modules import misc
26
+
27
+ Shape = Tuple[int]
28
+
29
+ DType = Any
30
+ Array = Any # jnp.ndarray
31
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
32
+ ProcessorState = ArrayTree
33
+ PRNGKey = Array
34
+ NestedDict = Dict[str, Any]
35
+
36
+
37
+ class SpatialBroadcastDecoder(nn.Module):
38
+ """Spatial broadcast decoder for a set of slots (per frame)."""
39
+
40
+ resolution: Sequence[int]
41
+ backbone: Callable[[], nn.Module]
42
+ pos_emb: Callable[[], nn.Module]
43
+ early_fusion: bool = False # Fuse slot features before constructing targets.
44
+ target_readout: Optional[Callable[[], nn.Module]] = None
45
+
46
+ # Vmapped application of module, consumes time axis (axis=1).
47
+ @functools.partial(utils.time_distributed, in_axes=(1, None))
48
+ @nn.compact
49
+ def __call__(self, slots, train = False):
50
+
51
+ batch_size, n_slots, n_features = slots.shape
52
+
53
+ # Fold slot dim into batch dim.
54
+ x = jnp.reshape(slots, (batch_size * n_slots, n_features))
55
+
56
+ # Spatial broadcast with position embedding.
57
+ x = utils.spatial_broadcast(x, self.resolution)
58
+ x = self.pos_emb()(x)
59
+
60
+ # bb_features.shape = (batch_size * n_slots, h, w, c)
61
+ bb_features = self.backbone()(x, train=train)
62
+ spatial_dims = bb_features.shape[-3:-1]
63
+
64
+ alpha_logits = nn.Dense(
65
+ features=1, use_bias=True, name="alpha_logits")(bb_features)
66
+ alpha_logits = jnp.reshape(
67
+ alpha_logits, (batch_size, n_slots) + spatial_dims + (-1,))
68
+
69
+ alphas = nn.softmax(alpha_logits, axis=1)
70
+ if not train:
71
+ # Define intermediates for logging / visualization.
72
+ self.sow("intermediates", "alphas", alphas)
73
+
74
+ if self.early_fusion:
75
+ # To save memory, fuse the slot features before predicting targets.
76
+ # The final target output should be equivalent to the late fusion when
77
+ # using linear prediction.
78
+ bb_features = jnp.reshape(
79
+ bb_features, (batch_size, n_slots) + spatial_dims + (-1,))
80
+ # Combine backbone features by alpha masks.
81
+ bb_features = jnp.sum(bb_features * alphas, axis=1)
82
+
83
+ targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable
84
+
85
+ preds_dict = dict()
86
+ for target_key, channels in targets_dict.items():
87
+ if self.early_fusion:
88
+ # decoded_target.shape = (batch_size, h, w, c) after next line.
89
+ decoded_target = channels
90
+ else:
91
+ # channels.shape = (batch_size, n_slots, h, w, c)
92
+ channels = jnp.reshape(
93
+ channels, (batch_size, n_slots) + (spatial_dims) + (-1,))
94
+
95
+ # masked_channels.shape = (batch_size, n_slots, h, w, c)
96
+ masked_channels = channels * alphas
97
+
98
+ # decoded_target.shape = (batch_size, h, w, c)
99
+ decoded_target = jnp.sum(masked_channels, axis=1) # Combine target.
100
+ preds_dict[target_key] = decoded_target
101
+
102
+ if not train:
103
+ # Define intermediates for logging / visualization.
104
+ self.sow("intermediates", f"{target_key}_slots", channels)
105
+ if not self.early_fusion:
106
+ self.sow("intermediates", f"{target_key}_masked", masked_channels)
107
+ self.sow("intermediates", f"{target_key}_combined", decoded_target)
108
+
109
+ preds_dict["segmentations"] = jnp.argmax(alpha_logits, axis=1)
110
+
111
+ return preds_dict
112
+
113
+
114
+ class SiameseSpatialBroadcastDecoder(nn.Module):
115
+ """Siamese spatial broadcast decoder for a set of slots (per frame).
116
+
117
+ Similar to the decoders used in IODINE: https://arxiv.org/abs/1903.00450
118
+ and in Slot Attention: https://arxiv.org/abs/2006.15055.
119
+ """
120
+
121
+ resolution: Sequence[int]
122
+ backbone: Callable[[], nn.Module]
123
+ pos_emb: Callable[[], nn.Module]
124
+ pass_intermediates: bool = False
125
+ alpha_only: bool = False # Predict only alpha masks.
126
+ concat_attn: bool = False
127
+ # Readout after backbone.
128
+ target_readout_from_slots: bool = False
129
+ target_readout: Optional[Callable[[], nn.Module]] = None
130
+ early_fusion: bool = False # Fuse slot features before constructing targets.
131
+ # Readout on slots.
132
+ attribute_readout: Optional[Callable[[], nn.Module]] = None
133
+ remove_background_attribute: bool = False
134
+ attn_key: Optional[str] = None
135
+ attn_width: Optional[int] = None
136
+ # If True, expects slot embeddings to contain slot positions.
137
+ relative_positions: bool = False
138
+ # Slot positions and scales.
139
+ relative_positions_and_scales: bool = False
140
+ relative_positions_rotations_and_scales: bool = False
141
+
142
+ # Vmapped application of module, consumes time axis (axis=1).
143
+ @functools.partial(utils.time_distributed, in_axes=(1, None))
144
+ @nn.compact
145
+ def __call__(self,
146
+ slots,
147
+ train = False):
148
+
149
+ if self.remove_background_attribute and self.attribute_readout is None:
150
+ raise NotImplementedError(
151
+ "Background removal is only supported for attribute readout.")
152
+
153
+ if self.relative_positions:
154
+ # Assume slot positions were concatenated to slot embeddings.
155
+ # E.g. an output of SlotAttentionTranslEquiv.
156
+ slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:]
157
+ # Reshape positions to [B * num_slots, 2]
158
+ positions = positions.reshape(
159
+ (positions.shape[0] * positions.shape[1], positions.shape[2]))
160
+ elif self.relative_positions_and_scales:
161
+ # Assume slot positions and scales were concatenated to slot embeddings.
162
+ # E.g. an output of SlotAttentionTranslScaleEquiv.
163
+ slots, positions, scales = (slots[Ellipsis, :-4],
164
+ slots[Ellipsis, -4: -2],
165
+ slots[Ellipsis, -2:])
166
+ positions = positions.reshape(
167
+ (positions.shape[0] * positions.shape[1], positions.shape[2]))
168
+ scales = scales.reshape(
169
+ (scales.shape[0] * scales.shape[1], scales.shape[2]))
170
+ elif self.relative_positions_rotations_and_scales:
171
+ slots, positions, scales, rotm = (slots[Ellipsis, :-8],
172
+ slots[Ellipsis, -8: -6],
173
+ slots[Ellipsis, -6: -4],
174
+ slots[Ellipsis, -4:])
175
+ positions = positions.reshape(
176
+ (positions.shape[0] * positions.shape[1], positions.shape[2]))
177
+ scales = scales.reshape(
178
+ (scales.shape[0] * scales.shape[1], scales.shape[2]))
179
+ rotm = rotm.reshape(
180
+ rotm.shape[0] * rotm.shape[1], 2, 2)
181
+
182
+ batch_size, n_slots, n_features = slots.shape
183
+
184
+ preds_dict = {}
185
+ # Fold slot dim into batch dim.
186
+ x = jnp.reshape(slots, (batch_size * n_slots, n_features))
187
+
188
+ # Attribute readout.
189
+ if self.attribute_readout is not None:
190
+ if self.remove_background_attribute:
191
+ slots = slots[:, 1:]
192
+ attributes_dict = self.attribute_readout()(slots, train) # pylint: disable=not-callable
193
+ preds_dict.update(attributes_dict)
194
+
195
+ # Spatial broadcast with position embedding.
196
+ # See https://arxiv.org/abs/1901.07017.
197
+ x = utils.spatial_broadcast(x, self.resolution)
198
+
199
+ if self.relative_positions:
200
+ x = self.pos_emb()(inputs=x, slot_positions=positions)
201
+ elif self.relative_positions_and_scales:
202
+ x = self.pos_emb()(inputs=x, slot_positions=positions, slot_scales=scales)
203
+ elif self.relative_positions_rotations_and_scales:
204
+ x = self.pos_emb()(
205
+ inputs=x, slot_positions=positions, slot_scales=scales,
206
+ slot_rotm=rotm)
207
+ else:
208
+ x = self.pos_emb()(x)
209
+
210
+ # bb_features.shape = (batch_size*n_slots, h, w, c)
211
+ bb_features = self.backbone()(x, train=train)
212
+ spatial_dims = bb_features.shape[-3:-1]
213
+ alphas = nn.Dense(features=1, use_bias=True, name="alphas")(bb_features)
214
+ alphas = jnp.reshape(
215
+ alphas, (batch_size, n_slots) + spatial_dims + (-1,))
216
+ alphas_softmaxed = nn.softmax(alphas, axis=1)
217
+ preds_dict["segmentation_logits"] = alphas
218
+ preds_dict["segmentations"] = jnp.argmax(alphas, axis=1)
219
+ # Define intermediates for logging.
220
+ _ = misc.Identity(name="alphas_softmaxed")(alphas_softmaxed)
221
+ if self.alpha_only or self.target_readout is None:
222
+ assert alphas.shape[-1] == 1, "Alpha masks need to be one-dimensional."
223
+ return preds_dict, {"segmentation_logits": alphas}
224
+
225
+ if self.early_fusion:
226
+ # To save memory, fuse the slot features before predicting targets.
227
+ # The final target output should be equivalent to the late fusion when
228
+ # using linear prediction.
229
+ bb_features = jnp.reshape(
230
+ bb_features, (batch_size, n_slots) + spatial_dims + (-1,))
231
+ # Combine backbone features by alpha masks.
232
+ bb_features = jnp.sum(bb_features * alphas_softmaxed, axis=1)
233
+
234
+ if self.target_readout_from_slots:
235
+ targets_dict = self.target_readout()(slots, train) # pylint: disable=not-callable
236
+ else:
237
+ targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable
238
+
239
+ targets_dict_new = dict()
240
+ targets_dict_new["targets_masks"] = alphas_softmaxed
241
+ targets_dict_new["targets_logits_masks"] = alphas
242
+
243
+ for target_key, channels in targets_dict.items():
244
+ if self.early_fusion:
245
+ # decoded_target.shape = (batch_size, h, w, c) after next line.
246
+ decoded_target = channels
247
+ else:
248
+ # channels.shape = (batch_size, n_slots, h, w, c) after next line.
249
+ channels = jnp.reshape(
250
+ channels, (batch_size, n_slots) +
251
+ (spatial_dims if not self.target_readout_from_slots else
252
+ (1, 1)) + (-1,))
253
+ # masked_channels.shape = (batch_size, n_slots, h, w, c) at next line.
254
+ masked_channels = channels * alphas_softmaxed
255
+ # decoded_target.shape = (batch_size, h, w, c) after next line.
256
+ decoded_target = jnp.sum(masked_channels, axis=1) # Combine target.
257
+ targets_dict_new[target_key + "_channels"] = channels
258
+ # Define intermediates for logging.
259
+ _ = misc.Identity(name=f"{target_key}_channels")(channels)
260
+ _ = misc.Identity(name=f"{target_key}_masked_channels")(masked_channels)
261
+
262
+ targets_dict_new[target_key] = decoded_target
263
+ # Define intermediates for logging.
264
+ _ = misc.Identity(name=f"decoded_{target_key}")(decoded_target)
265
+
266
+ preds_dict.update(targets_dict_new)
267
+ return preds_dict
invariant_slot_attention/modules/initializers.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Initializers module library."""
17
+
18
+ import functools
19
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
20
+
21
+ from flax import linen as nn
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ from invariant_slot_attention.lib import utils
27
+ from invariant_slot_attention.modules import misc
28
+ from invariant_slot_attention.modules import video
29
+
30
+ Shape = Tuple[int]
31
+
32
+ DType = Any
33
+ Array = Any # jnp.ndarray
34
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
35
+ ProcessorState = ArrayTree
36
+ PRNGKey = Array
37
+ NestedDict = Dict[str, Any]
38
+
39
+
40
+ class ParamStateInit(nn.Module):
41
+ """Fixed, learnable state initalization.
42
+
43
+ Note: This module ignores any conditional input (by design).
44
+ """
45
+
46
+ shape: Sequence[int]
47
+ init_fn: str = "normal" # Default init with unit variance.
48
+
49
+ @nn.compact
50
+ def __call__(self, inputs, batch_size,
51
+ train = False):
52
+ del inputs, train # Unused.
53
+
54
+ if self.init_fn == "normal":
55
+ init_fn = functools.partial(nn.initializers.normal, stddev=1.)
56
+ elif self.init_fn == "zeros":
57
+ init_fn = lambda: nn.initializers.zeros
58
+ else:
59
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
60
+
61
+ param = self.param("state_init", init_fn(), self.shape)
62
+ return utils.broadcast_across_batch(param, batch_size=batch_size)
63
+
64
+
65
+ class GaussianStateInit(nn.Module):
66
+ """Random state initialization with zero-mean, unit-variance Gaussian.
67
+
68
+ Note: This module does not contain any trainable parameters and requires
69
+ providing a jax.PRNGKey both at training and at test time. Note: This module
70
+ also ignores any conditional input (by design).
71
+ """
72
+
73
+ shape: Sequence[int]
74
+
75
+ @nn.compact
76
+ def __call__(self, inputs, batch_size,
77
+ train = False):
78
+ del inputs, train # Unused.
79
+ rng = self.make_rng("state_init")
80
+ return jax.random.normal(rng, shape=[batch_size] + list(self.shape))
81
+
82
+
83
+ class SegmentationEncoderStateInit(nn.Module):
84
+ """State init that encodes segmentation masks as conditional input."""
85
+
86
+ max_num_slots: int
87
+ backbone: Callable[[], nn.Module]
88
+ pos_emb: Callable[[], nn.Module] = misc.Identity
89
+ reduction: Optional[str] = "all_flatten" # Reduce spatial dim by default.
90
+ output_transform: Callable[[], nn.Module] = misc.Identity
91
+ zero_background: bool = False
92
+
93
+ @nn.compact
94
+ def __call__(self, inputs, batch_size,
95
+ train = False):
96
+ del batch_size # Unused.
97
+
98
+ # inputs.shape = (batch_size, seq_len, height, width)
99
+ inputs = inputs[:, 0] # Only condition on first time step.
100
+
101
+ # Convert mask index to one-hot.
102
+ inputs_oh = jax.nn.one_hot(inputs, self.max_num_slots)
103
+ # inputs_oh.shape = (batch_size, height, width, n_slots)
104
+ # NOTE: 0th entry inputs_oh[..., 0] will typically correspond to background.
105
+
106
+ # Set background slot to all-zeros.
107
+ if self.zero_background:
108
+ inputs_oh = inputs_oh.at[:, :, :, 0].set(0)
109
+
110
+ # Switch one-hot axis into 1st position (i.e. sequence axis).
111
+ inputs_oh = jnp.transpose(inputs_oh, (0, 3, 1, 2))
112
+ # inputs_oh.shape = (batch_size, max_num_slots, height, width)
113
+
114
+ # Append dummy feature axis.
115
+ inputs_oh = jnp.expand_dims(inputs_oh, axis=-1)
116
+
117
+ # Vmapped encoder over seq. axis (i.e. we process each slot independently).
118
+ encoder = video.FrameEncoder(
119
+ backbone=self.backbone,
120
+ pos_emb=self.pos_emb,
121
+ reduction=self.reduction,
122
+ output_transform=self.output_transform) # type: ignore
123
+
124
+ # encoder(inputs_oh).shape = (batch_size, n_slots, n_features)
125
+ slots = encoder(inputs_oh, None, train)
126
+
127
+ return slots
128
+
129
+
130
+ class CoordinateEncoderStateInit(nn.Module):
131
+ """State init that encodes bounding box coordinates as conditional input.
132
+
133
+ Attributes:
134
+ embedding_transform: A nn.Module that is applied on inputs (bounding boxes).
135
+ prepend_background: Boolean flag; whether to prepend a special, zero-valued
136
+ background bounding box to the input. Default: false.
137
+ center_of_mass: Boolean flag; whether to convert bounding boxes to center
138
+ of mass coordinates. Default: false.
139
+ background_value: Default value to fill in the background.
140
+ """
141
+
142
+ embedding_transform: Callable[[], nn.Module]
143
+ prepend_background: bool = False
144
+ center_of_mass: bool = False
145
+ background_value: float = 0.
146
+
147
+ @nn.compact
148
+ def __call__(self, inputs, batch_size,
149
+ train = False):
150
+ del batch_size # Unused.
151
+
152
+ # inputs.shape = (batch_size, seq_len, bboxes, 4)
153
+ inputs = inputs[:, 0] # Only condition on first time step.
154
+ # inputs.shape = (batch_size, bboxes, 4)
155
+
156
+ if self.prepend_background:
157
+ # Adds a fake background box [0, 0, 0, 0] at the beginning.
158
+ batch_size = inputs.shape[0]
159
+
160
+ # Encode the background as specified by background_value.
161
+ background = jnp.full(
162
+ (batch_size, 1, 4), self.background_value, dtype=inputs.dtype)
163
+
164
+ inputs = jnp.concatenate((background, inputs), axis=1)
165
+
166
+ if self.center_of_mass:
167
+ y_pos = (inputs[:, :, 0] + inputs[:, :, 2]) / 2
168
+ x_pos = (inputs[:, :, 1] + inputs[:, :, 3]) / 2
169
+ inputs = jnp.stack((y_pos, x_pos), axis=-1)
170
+
171
+ slots = self.embedding_transform()(inputs, train=train) # pytype: disable=not-callable
172
+
173
+ return slots
invariant_slot_attention/modules/invariant_attention.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Equivariant attention module library."""
17
+ import functools
18
+ from typing import Any, Optional, Tuple
19
+
20
+ from flax import linen as nn
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from invariant_slot_attention.modules import attention
24
+ from invariant_slot_attention.modules import misc
25
+
26
+ Shape = Tuple[int]
27
+
28
+ DType = Any
29
+ Array = Any # jnp.ndarray
30
+ PRNGKey = Array
31
+
32
+
33
+ class InvertedDotProductAttentionKeyPerQuery(nn.Module):
34
+ """Inverted dot-product attention with a different set of keys per query.
35
+
36
+ Used in SlotAttentionTranslEquiv, where each slot has a position.
37
+ The positions are used to create relative coordinate grids,
38
+ which result in a different set of inputs (keys) for each slot.
39
+ """
40
+
41
+ dtype: DType = jnp.float32
42
+ precision: Optional[jax.lax.Precision] = None
43
+ epsilon: float = 1e-8
44
+ renormalize_keys: bool = False
45
+ attn_weights_only: bool = False
46
+ softmax_temperature: float = 1.0
47
+ value_per_query: bool = False
48
+
49
+ @nn.compact
50
+ def __call__(self, query, key, value, train):
51
+ """Computes inverted dot-product attention with key per query.
52
+
53
+ Args:
54
+ query: Queries with shape of `[batch..., q_num, qk_features]`.
55
+ key: Keys with shape of `[batch..., q_num, kv_num, qk_features]`.
56
+ value: Values with shape of `[batch..., kv_num, v_features]`.
57
+ train: Indicating whether we're training or evaluating.
58
+
59
+ Returns:
60
+ Tuple of two elements: (1) output of shape
61
+ `[batch_size..., q_num, v_features]` and (2) attention mask of shape
62
+ `[batch_size..., q_num, kv_num]`.
63
+ """
64
+ qk_features = query.shape[-1]
65
+ query = query / jnp.sqrt(qk_features).astype(self.dtype)
66
+
67
+ # Each query is multiplied with its own set of keys.
68
+ attn = jnp.einsum(
69
+ "...qd,...qkd->...qk", query, key, precision=self.precision
70
+ )
71
+
72
+ # axis=-2 for a softmax over query axis (inverted attention).
73
+ attn = jax.nn.softmax(
74
+ attn / self.softmax_temperature, axis=-2
75
+ ).astype(self.dtype)
76
+
77
+ # We expand dims because the logger expect a #heads dimension.
78
+ self.sow("intermediates", "attn", jnp.expand_dims(attn, -3))
79
+
80
+ if self.renormalize_keys:
81
+ normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon
82
+ attn = attn / normalizer
83
+
84
+ if self.attn_weights_only:
85
+ return attn
86
+
87
+ output = jnp.einsum(
88
+ "...qk,...qkd->...qd" if self.value_per_query else "...qk,...kd->...qd",
89
+ attn,
90
+ value,
91
+ precision=self.precision
92
+ )
93
+
94
+ return output, attn
95
+
96
+
97
+ class SlotAttentionExplicitStats(nn.Module):
98
+ """Slot Attention module with explicit slot statistics.
99
+
100
+ Slot statistics, such as position and scale, are appended to the
101
+ output slot representations.
102
+
103
+ Note: This module expects a 2D coordinate grid to be appended
104
+ at the end of inputs.
105
+
106
+ Note: This module uses pre-normalization by default.
107
+ """
108
+ grid_encoder: nn.Module
109
+ num_iterations: int = 1
110
+ qkv_size: Optional[int] = None
111
+ mlp_size: Optional[int] = None
112
+ epsilon: float = 1e-8
113
+ softmax_temperature: float = 1.0
114
+ gumbel_softmax: bool = False
115
+ gumbel_softmax_straight_through: bool = False
116
+ num_heads: int = 1
117
+ min_scale: float = 0.01
118
+ max_scale: float = 5.
119
+ return_slot_positions: bool = True
120
+ return_slot_scales: bool = True
121
+
122
+ @nn.compact
123
+ def __call__(self, slots, inputs,
124
+ padding_mask = None,
125
+ train = False):
126
+ """Slot Attention with explicit slot statistics module forward pass."""
127
+ del padding_mask # Unused.
128
+ # Slot scales require slot positions.
129
+ assert self.return_slot_positions or not self.return_slot_scales
130
+
131
+ # Separate a concatenated linear coordinate grid from the inputs.
132
+ inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
133
+
134
+ # Hack so that the input and output slot dimensions are the same.
135
+ to_remove = 0
136
+ if self.return_slot_positions:
137
+ to_remove += 2
138
+ if self.return_slot_scales:
139
+ to_remove += 2
140
+ if to_remove > 0:
141
+ slots = slots[Ellipsis, :-to_remove]
142
+
143
+ # Add position encodings to inputs
144
+ n_features = inputs.shape[-1]
145
+ grid_projector = nn.Dense(n_features, name="dense_pe_0")
146
+ inputs = self.grid_encoder()(inputs + grid_projector(grid))
147
+
148
+ qkv_size = self.qkv_size or slots.shape[-1]
149
+ head_dim = qkv_size // self.num_heads
150
+ dense = functools.partial(nn.DenseGeneral,
151
+ axis=-1, features=(self.num_heads, head_dim),
152
+ use_bias=False)
153
+
154
+ # Shared modules.
155
+ dense_q = dense(name="general_dense_q_0")
156
+ layernorm_q = nn.LayerNorm()
157
+ inverted_attention = attention.InvertedDotProductAttention(
158
+ norm_type="mean",
159
+ multi_head=self.num_heads > 1,
160
+ return_attn_weights=True)
161
+ gru = misc.GRU()
162
+
163
+ if self.mlp_size is not None:
164
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
165
+
166
+ # inputs.shape = (..., n_inputs, inputs_size).
167
+ inputs = nn.LayerNorm()(inputs)
168
+ # k.shape = (..., n_inputs, slot_size).
169
+ k = dense(name="general_dense_k_0")(inputs)
170
+ # v.shape = (..., n_inputs, slot_size).
171
+ v = dense(name="general_dense_v_0")(inputs)
172
+
173
+ # Multiple rounds of attention.
174
+ for _ in range(self.num_iterations):
175
+
176
+ # Inverted dot-product attention.
177
+ slots_n = layernorm_q(slots)
178
+ q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
179
+ updates, attn = inverted_attention(query=q, key=k, value=v, train=train)
180
+
181
+ # Recurrent update.
182
+ slots = gru(slots, updates)
183
+
184
+ # Feedforward block with pre-normalization.
185
+ if self.mlp_size is not None:
186
+ slots = mlp(slots)
187
+
188
+ if self.return_slot_positions:
189
+ # Compute the center of mass of each slot attention mask.
190
+ positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
191
+ slots = jnp.concatenate([slots, positions], axis=-1)
192
+
193
+ if self.return_slot_scales:
194
+ # Compute slot scales. Take the square root to make the operation
195
+ # analogous to normalizing data drawn from a Gaussian.
196
+ spread = jnp.square(
197
+ jnp.expand_dims(grid, axis=-3) - jnp.expand_dims(positions, axis=-2))
198
+ scales = jnp.sqrt(
199
+ jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
200
+ scales = jnp.clip(scales, self.min_scale, self.max_scale)
201
+ slots = jnp.concatenate([slots, scales], axis=-1)
202
+
203
+ return slots
204
+
205
+
206
+ class SlotAttentionPosKeysValues(nn.Module):
207
+ """Slot Attention module with positional encodings in keys and values.
208
+
209
+ Feature position encodings are added to keys and values instead
210
+ of the inputs.
211
+
212
+ Note: This module expects a 2D coordinate grid to be appended
213
+ at the end of inputs.
214
+
215
+ Note: This module uses pre-normalization by default.
216
+ """
217
+ grid_encoder: nn.Module
218
+ num_iterations: int = 1
219
+ qkv_size: Optional[int] = None
220
+ mlp_size: Optional[int] = None
221
+ epsilon: float = 1e-8
222
+ softmax_temperature: float = 1.0
223
+ gumbel_softmax: bool = False
224
+ gumbel_softmax_straight_through: bool = False
225
+ num_heads: int = 1
226
+
227
+ @nn.compact
228
+ def __call__(self, slots, inputs,
229
+ padding_mask = None,
230
+ train = False):
231
+ """Slot Attention with explicit slot statistics module forward pass."""
232
+ del padding_mask # Unused.
233
+
234
+ # Separate a concatenated linear coordinate grid from the inputs.
235
+ inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
236
+
237
+ qkv_size = self.qkv_size or slots.shape[-1]
238
+ head_dim = qkv_size // self.num_heads
239
+ dense = functools.partial(nn.DenseGeneral,
240
+ axis=-1, features=(self.num_heads, head_dim),
241
+ use_bias=False)
242
+
243
+ # Shared modules.
244
+ dense_q = dense(name="general_dense_q_0")
245
+ layernorm_q = nn.LayerNorm()
246
+ inverted_attention = attention.InvertedDotProductAttention(
247
+ norm_type="mean",
248
+ multi_head=self.num_heads > 1)
249
+ gru = misc.GRU()
250
+
251
+ if self.mlp_size is not None:
252
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
253
+
254
+ # inputs.shape = (..., n_inputs, inputs_size).
255
+ inputs = nn.LayerNorm()(inputs)
256
+ # k.shape = (..., n_inputs, slot_size).
257
+ k = dense(name="general_dense_k_0")(inputs)
258
+ # v.shape = (..., n_inputs, slot_size).
259
+ v = dense(name="general_dense_v_0")(inputs)
260
+
261
+ # Add position encodings to keys and values.
262
+ grid_projector = dense(name="general_dense_p_0")
263
+ grid_encoder = self.grid_encoder()
264
+ k = grid_encoder(k + grid_projector(grid))
265
+ v = grid_encoder(v + grid_projector(grid))
266
+
267
+ # Multiple rounds of attention.
268
+ for _ in range(self.num_iterations):
269
+
270
+ # Inverted dot-product attention.
271
+ slots_n = layernorm_q(slots)
272
+ q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
273
+ updates = inverted_attention(query=q, key=k, value=v, train=train)
274
+
275
+ # Recurrent update.
276
+ slots = gru(slots, updates)
277
+
278
+ # Feedforward block with pre-normalization.
279
+ if self.mlp_size is not None:
280
+ slots = mlp(slots)
281
+
282
+ return slots
283
+
284
+
285
+ class SlotAttentionTranslEquiv(nn.Module):
286
+ """Slot Attention module with slot positions.
287
+
288
+ A position is computed for each slot. Slot positions are used to create
289
+ relative coordinate grids, which are used as position embeddings reapplied
290
+ in each iteration of slot attention. The last two channels in inputs
291
+ must contain the flattened position grid.
292
+
293
+ Note: This module uses pre-normalization by default.
294
+ """
295
+
296
+ grid_encoder: nn.Module
297
+ num_iterations: int = 1
298
+ qkv_size: Optional[int] = None
299
+ mlp_size: Optional[int] = None
300
+ epsilon: float = 1e-8
301
+ softmax_temperature: float = 1.0
302
+ gumbel_softmax: bool = False
303
+ gumbel_softmax_straight_through: bool = False
304
+ num_heads: int = 1
305
+ zero_position_init: bool = True
306
+ ablate_non_equivariant: bool = False
307
+ stop_grad_positions: bool = False
308
+ mix_slots: bool = False
309
+ add_rel_pos_to_values: bool = False
310
+ append_statistics: bool = False
311
+
312
+ @nn.compact
313
+ def __call__(self, slots, inputs,
314
+ padding_mask = None,
315
+ train = False):
316
+ """Slot Attention translation equiv. module forward pass."""
317
+ del padding_mask # Unused.
318
+
319
+ if self.num_heads > 1:
320
+ raise NotImplementedError("This prototype only uses one attn. head.")
321
+
322
+ # Separate a concatenated linear coordinate grid from the inputs.
323
+ inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
324
+
325
+ # Separate position (x,y) from slot embeddings.
326
+ slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:]
327
+ qkv_size = self.qkv_size or slots.shape[-1]
328
+ num_slots = slots.shape[-2]
329
+
330
+ # Prepare initial slot positions.
331
+ if self.zero_position_init:
332
+ # All slots start in the middle of the image.
333
+ positions *= 0.
334
+
335
+ # Learnable initial positions might deviate from the allowed range.
336
+ positions = jnp.clip(positions, -1., 1.)
337
+
338
+ # Pre-normalization.
339
+ inputs = nn.LayerNorm()(inputs)
340
+
341
+ grid_per_slot = jnp.repeat(
342
+ jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
343
+
344
+ # Shared modules.
345
+ dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
346
+ dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
347
+ dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
348
+ grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
349
+ grid_enc = self.grid_encoder()
350
+ layernorm_q = nn.LayerNorm()
351
+ inverted_attention = InvertedDotProductAttentionKeyPerQuery(
352
+ epsilon=self.epsilon,
353
+ renormalize_keys=True,
354
+ softmax_temperature=self.softmax_temperature,
355
+ value_per_query=self.add_rel_pos_to_values
356
+ )
357
+ gru = misc.GRU()
358
+
359
+ if self.mlp_size is not None:
360
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
361
+
362
+ if self.append_statistics:
363
+ embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
364
+
365
+ # k.shape and v.shape = (..., n_inputs, slot_size).
366
+ v = dense_v(inputs)
367
+ k = dense_k(inputs)
368
+ k_expand = jnp.expand_dims(k, axis=-3)
369
+ v_expand = jnp.expand_dims(v, axis=-3)
370
+
371
+ # Multiple rounds of attention. Last iteration updates positions only.
372
+ for attn_round in range(self.num_iterations + 1):
373
+
374
+ if self.ablate_non_equivariant:
375
+ # Add an encoded coordinate grid with absolute positions.
376
+ grid_emb_per_slot = grid_proj(grid_per_slot)
377
+ k_rel_pos = grid_enc(k_expand + grid_emb_per_slot)
378
+ if self.add_rel_pos_to_values:
379
+ v_rel_pos = grid_enc(v_expand + grid_emb_per_slot)
380
+ else:
381
+ # Relativize positions, encode them and add them to the keys
382
+ # and optionally to values.
383
+ relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
384
+ grid_emb_per_slot = grid_proj(relative_grid)
385
+ k_rel_pos = grid_enc(k_expand + grid_emb_per_slot)
386
+ if self.add_rel_pos_to_values:
387
+ v_rel_pos = grid_enc(v_expand + grid_emb_per_slot)
388
+
389
+ # Inverted dot-product attention.
390
+ slots_n = layernorm_q(slots)
391
+ q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
392
+ updates, attn = inverted_attention(
393
+ query=q,
394
+ key=k_rel_pos,
395
+ value=v_rel_pos if self.add_rel_pos_to_values else v,
396
+ train=train)
397
+
398
+ # Compute the center of mass of each slot attention mask.
399
+ # Guaranteed to be in [-1, 1].
400
+ positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
401
+
402
+ if self.stop_grad_positions:
403
+ # Do not backprop through positions and scales.
404
+ positions = jax.lax.stop_gradient(positions)
405
+
406
+ if attn_round < self.num_iterations:
407
+ if self.append_statistics:
408
+ # Projects and add 2D slot positions into slot latents.
409
+ tmp = jnp.concatenate([slots, positions], axis=-1)
410
+ slots = embed_statistics(tmp)
411
+
412
+ # Recurrent update.
413
+ slots = gru(slots, updates)
414
+
415
+ # Feedforward block with pre-normalization.
416
+ if self.mlp_size is not None:
417
+ slots = mlp(slots)
418
+
419
+ # Concatenate position information to slots.
420
+ output = jnp.concatenate([slots, positions], axis=-1)
421
+
422
+ if self.mix_slots:
423
+ output = misc.MLP(hidden_size=128, layernorm="pre")(output)
424
+
425
+ return output
426
+
427
+
428
+ class SlotAttentionTranslScaleEquiv(nn.Module):
429
+ """Slot Attention module with slot positions and scales.
430
+
431
+ A position and scale is computed for each slot. Slot positions and scales
432
+ are used to create relative coordinate grids, which are used as position
433
+ embeddings reapplied in each iteration of slot attention. The last two
434
+ channels in input must contain the flattened position grid.
435
+
436
+ Note: This module uses pre-normalization by default.
437
+ """
438
+
439
+ grid_encoder: nn.Module
440
+ num_iterations: int = 1
441
+ qkv_size: Optional[int] = None
442
+ mlp_size: Optional[int] = None
443
+ epsilon: float = 1e-8
444
+ softmax_temperature: float = 1.0
445
+ gumbel_softmax: bool = False
446
+ gumbel_softmax_straight_through: bool = False
447
+ num_heads: int = 1
448
+ zero_position_init: bool = True
449
+ # Scale of 0.1 corresponds to fairly small objects.
450
+ init_with_fixed_scale: Optional[float] = 0.1
451
+ ablate_non_equivariant: bool = False
452
+ stop_grad_positions_and_scales: bool = False
453
+ mix_slots: bool = False
454
+ add_rel_pos_to_values: bool = False
455
+ scales_factor: float = 1.
456
+ # Slot scales cannot be negative and should not be too close to zero
457
+ # or too large.
458
+ min_scale: float = 0.001
459
+ max_scale: float = 2.
460
+ append_statistics: bool = False
461
+
462
+ @nn.compact
463
+ def __call__(self, slots, inputs,
464
+ padding_mask = None,
465
+ train = False):
466
+ """Slot Attention translation and scale equiv. module forward pass."""
467
+ del padding_mask # Unused.
468
+
469
+ if self.num_heads > 1:
470
+ raise NotImplementedError("This prototype only uses one attn. head.")
471
+
472
+ # Separate a concatenated linear coordinate grid from the inputs.
473
+ inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
474
+
475
+ # Separate position (x,y) and scale from slot embeddings.
476
+ slots, positions, scales = (slots[Ellipsis, :-4],
477
+ slots[Ellipsis, -4: -2],
478
+ slots[Ellipsis, -2:])
479
+ qkv_size = self.qkv_size or slots.shape[-1]
480
+ num_slots = slots.shape[-2]
481
+
482
+ # Prepare initial slot positions.
483
+ if self.zero_position_init:
484
+ # All slots start in the middle of the image.
485
+ positions *= 0.
486
+
487
+ if self.init_with_fixed_scale is not None:
488
+ scales = scales * 0. + self.init_with_fixed_scale
489
+
490
+ # Learnable initial positions and scales could have arbitrary values.
491
+ positions = jnp.clip(positions, -1., 1.)
492
+ scales = jnp.clip(scales, self.min_scale, self.max_scale)
493
+
494
+ # Pre-normalization.
495
+ inputs = nn.LayerNorm()(inputs)
496
+
497
+ grid_per_slot = jnp.repeat(
498
+ jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
499
+
500
+ # Shared modules.
501
+ dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
502
+ dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
503
+ dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
504
+ grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
505
+ grid_enc = self.grid_encoder()
506
+ layernorm_q = nn.LayerNorm()
507
+ inverted_attention = InvertedDotProductAttentionKeyPerQuery(
508
+ epsilon=self.epsilon,
509
+ renormalize_keys=True,
510
+ softmax_temperature=self.softmax_temperature,
511
+ value_per_query=self.add_rel_pos_to_values
512
+ )
513
+ gru = misc.GRU()
514
+
515
+ if self.mlp_size is not None:
516
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
517
+
518
+ if self.append_statistics:
519
+ embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
520
+
521
+ # k.shape and v.shape = (..., n_inputs, slot_size).
522
+ v = dense_v(inputs)
523
+ k = dense_k(inputs)
524
+ k_expand = jnp.expand_dims(k, axis=-3)
525
+ v_expand = jnp.expand_dims(v, axis=-3)
526
+
527
+ # Multiple rounds of attention.
528
+ # Last iteration updates positions and scales only.
529
+ for attn_round in range(self.num_iterations + 1):
530
+
531
+ if self.ablate_non_equivariant:
532
+ # Add an encoded coordinate grid with absolute positions.
533
+ tmp_grid = grid_proj(grid_per_slot)
534
+ k_rel_pos = grid_enc(k_expand + tmp_grid)
535
+ if self.add_rel_pos_to_values:
536
+ v_rel_pos = grid_enc(v_expand + tmp_grid)
537
+ else:
538
+ # Relativize and scale positions, encode them and add them to inputs.
539
+ relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
540
+ # Scales are usually small so the grid might get too large.
541
+ relative_grid = relative_grid / self.scales_factor
542
+ relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2)
543
+ tmp_grid = grid_proj(relative_grid)
544
+ k_rel_pos = grid_enc(k_expand + tmp_grid)
545
+ if self.add_rel_pos_to_values:
546
+ v_rel_pos = grid_enc(v_expand + tmp_grid)
547
+
548
+ # Inverted dot-product attention.
549
+ slots_n = layernorm_q(slots)
550
+ q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
551
+ updates, attn = inverted_attention(
552
+ query=q,
553
+ key=k_rel_pos,
554
+ value=v_rel_pos if self.add_rel_pos_to_values else v,
555
+ train=train)
556
+
557
+ # Compute the center of mass of each slot attention mask.
558
+ positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
559
+
560
+ # Compute slot scales. Take the square root to make the operation
561
+ # analogous to normalizing data drawn from a Gaussian.
562
+ spread = jnp.square(grid_per_slot - jnp.expand_dims(positions, axis=-2))
563
+ scales = jnp.sqrt(
564
+ jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
565
+
566
+ # Computed positions are guaranteed to be in [-1, 1].
567
+ # Scales are unbounded.
568
+ scales = jnp.clip(scales, self.min_scale, self.max_scale)
569
+
570
+ if self.stop_grad_positions_and_scales:
571
+ # Do not backprop through positions and scales.
572
+ positions = jax.lax.stop_gradient(positions)
573
+ scales = jax.lax.stop_gradient(scales)
574
+
575
+ if attn_round < self.num_iterations:
576
+ if self.append_statistics:
577
+ # Project and add 2D slot positions and scales into slot latents.
578
+ tmp = jnp.concatenate([slots, positions, scales], axis=-1)
579
+ slots = embed_statistics(tmp)
580
+
581
+ # Recurrent update.
582
+ slots = gru(slots, updates)
583
+
584
+ # Feedforward block with pre-normalization.
585
+ if self.mlp_size is not None:
586
+ slots = mlp(slots)
587
+
588
+ # Concatenate position and scale information to slots.
589
+ output = jnp.concatenate([slots, positions, scales], axis=-1)
590
+
591
+ if self.mix_slots:
592
+ output = misc.MLP(hidden_size=128, layernorm="pre")(output)
593
+
594
+ return output
595
+
596
+
597
+ class SlotAttentionTranslRotScaleEquiv(nn.Module):
598
+ """Slot Attention module with slot positions, rotations and scales.
599
+
600
+ A position, rotation and scale is computed for each slot.
601
+ Slot positions, rotations and scales are used to create relative
602
+ coordinate grids, which are used as position embeddings reapplied in each
603
+ iteration of slot attention. The last two channels in input must contain
604
+ the flattened position grid.
605
+
606
+ Note: This module uses pre-normalization by default.
607
+ """
608
+
609
+ grid_encoder: nn.Module
610
+ num_iterations: int = 1
611
+ qkv_size: Optional[int] = None
612
+ mlp_size: Optional[int] = None
613
+ epsilon: float = 1e-8
614
+ softmax_temperature: float = 1.0
615
+ gumbel_softmax: bool = False
616
+ gumbel_softmax_straight_through: bool = False
617
+ num_heads: int = 1
618
+ zero_position_init: bool = True
619
+ # Scale of 0.1 corresponds to fairly small objects.
620
+ init_with_fixed_scale: Optional[float] = 0.1
621
+ ablate_non_equivariant: bool = False
622
+ stop_grad_positions: bool = False
623
+ stop_grad_scales: bool = False
624
+ stop_grad_rotations: bool = False
625
+ mix_slots: bool = False
626
+ add_rel_pos_to_values: bool = False
627
+ scales_factor: float = 1.
628
+ # Slot scales cannot be negative and should not be too close to zero
629
+ # or too large.
630
+ min_scale: float = 0.001
631
+ max_scale: float = 2.
632
+ limit_rot_to_45_deg: bool = True
633
+ append_statistics: bool = False
634
+
635
+ @nn.compact
636
+ def __call__(self, slots, inputs,
637
+ padding_mask = None,
638
+ train = False):
639
+ """Slot Attention translation and scale equiv. module forward pass."""
640
+ del padding_mask # Unused.
641
+
642
+ if self.num_heads > 1:
643
+ raise NotImplementedError("This prototype only uses one attn. head.")
644
+
645
+ # Separate a concatenated linear coordinate grid from the inputs.
646
+ inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
647
+
648
+ # Separate position (x,y) and scale from slot embeddings.
649
+ slots, positions, scales, rotm = (slots[Ellipsis, :-8],
650
+ slots[Ellipsis, -8: -6],
651
+ slots[Ellipsis, -6: -4],
652
+ slots[Ellipsis, -4:])
653
+ rotm = jnp.reshape(rotm, (*rotm.shape[:-1], 2, 2))
654
+ qkv_size = self.qkv_size or slots.shape[-1]
655
+ num_slots = slots.shape[-2]
656
+
657
+ # Prepare initial slot positions.
658
+ if self.zero_position_init:
659
+ # All slots start in the middle of the image.
660
+ positions *= 0.
661
+
662
+ if self.init_with_fixed_scale is not None:
663
+ scales = scales * 0. + self.init_with_fixed_scale
664
+
665
+ # Learnable initial positions and scales could have arbitrary values.
666
+ positions = jnp.clip(positions, -1., 1.)
667
+ scales = jnp.clip(scales, self.min_scale, self.max_scale)
668
+
669
+ # Pre-normalization.
670
+ inputs = nn.LayerNorm()(inputs)
671
+
672
+ grid_per_slot = jnp.repeat(
673
+ jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
674
+
675
+ # Shared modules.
676
+ dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
677
+ dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
678
+ dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
679
+ grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
680
+ grid_enc = self.grid_encoder()
681
+ layernorm_q = nn.LayerNorm()
682
+ inverted_attention = InvertedDotProductAttentionKeyPerQuery(
683
+ epsilon=self.epsilon,
684
+ renormalize_keys=True,
685
+ softmax_temperature=self.softmax_temperature,
686
+ value_per_query=self.add_rel_pos_to_values
687
+ )
688
+ gru = misc.GRU()
689
+
690
+ if self.mlp_size is not None:
691
+ mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
692
+
693
+ if self.append_statistics:
694
+ embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
695
+
696
+ # k.shape and v.shape = (..., n_inputs, slot_size).
697
+ v = dense_v(inputs)
698
+ k = dense_k(inputs)
699
+ k_expand = jnp.expand_dims(k, axis=-3)
700
+ v_expand = jnp.expand_dims(v, axis=-3)
701
+
702
+ # Multiple rounds of attention.
703
+ # Last iteration updates positions and scales only.
704
+ for attn_round in range(self.num_iterations + 1):
705
+
706
+ if self.ablate_non_equivariant:
707
+ # Add an encoded coordinate grid with absolute positions.
708
+ tmp_grid = grid_proj(grid_per_slot)
709
+ k_rel_pos = grid_enc(k_expand + tmp_grid)
710
+ if self.add_rel_pos_to_values:
711
+ v_rel_pos = grid_enc(v_expand + tmp_grid)
712
+ else:
713
+ # Relativize and scale positions, encode them and add them to inputs.
714
+ relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
715
+
716
+ # Rotation.
717
+ relative_grid = self.transform(rotm, relative_grid)
718
+
719
+ # Scales are usually small so the grid might get too large.
720
+ relative_grid = relative_grid / self.scales_factor
721
+ relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2)
722
+ tmp_grid = grid_proj(relative_grid)
723
+ k_rel_pos = grid_enc(k_expand + tmp_grid)
724
+ if self.add_rel_pos_to_values:
725
+ v_rel_pos = grid_enc(v_expand + tmp_grid)
726
+
727
+ # Inverted dot-product attention.
728
+ slots_n = layernorm_q(slots)
729
+ q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
730
+ updates, attn = inverted_attention(
731
+ query=q,
732
+ key=k_rel_pos,
733
+ value=v_rel_pos if self.add_rel_pos_to_values else v,
734
+ train=train)
735
+
736
+ # Compute the center of mass of each slot attention mask.
737
+ positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
738
+
739
+ # Find the axis with the highest spread.
740
+ relp = grid_per_slot - jnp.expand_dims(positions, axis=-2)
741
+ if self.limit_rot_to_45_deg:
742
+ rotm = self.compute_rotation_matrix_45_deg(relp, attn)
743
+ else:
744
+ rotm = self.compute_rotation_matrix_90_deg(relp, attn)
745
+
746
+ # Compute slot scales. Take the square root to make the operation
747
+ # analogous to normalizing data drawn from a Gaussian.
748
+ relp = self.transform(rotm, relp)
749
+
750
+ spread = jnp.square(relp)
751
+ scales = jnp.sqrt(
752
+ jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
753
+
754
+ # Computed positions are guaranteed to be in [-1, 1].
755
+ # Scales are unbounded.
756
+ scales = jnp.clip(scales, self.min_scale, self.max_scale)
757
+
758
+ if self.stop_grad_positions:
759
+ positions = jax.lax.stop_gradient(positions)
760
+ if self.stop_grad_scales:
761
+ scales = jax.lax.stop_gradient(scales)
762
+ if self.stop_grad_rotations:
763
+ rotm = jax.lax.stop_gradient(rotm)
764
+
765
+ if attn_round < self.num_iterations:
766
+ if self.append_statistics:
767
+ # For the slot rotations, we append both the 2D rotation matrix
768
+ # and the angle by which we rotate.
769
+ # We can compute the angle using atan2(R[0, 0], R[1, 0]).
770
+ tmp = jnp.concatenate(
771
+ [slots, positions, scales,
772
+ rotm.reshape(*rotm.shape[:-2], 4),
773
+ jnp.arctan2(rotm[Ellipsis, 0, 0], rotm[Ellipsis, 1, 0])[Ellipsis, None]],
774
+ axis=-1)
775
+ slots = embed_statistics(tmp)
776
+
777
+ # Recurrent update.
778
+ slots = gru(slots, updates)
779
+
780
+ # Feedforward block with pre-normalization.
781
+ if self.mlp_size is not None:
782
+ slots = mlp(slots)
783
+
784
+ # Concatenate position and scale information to slots.
785
+ output = jnp.concatenate(
786
+ [slots, positions, scales, rotm.reshape(*rotm.shape[:-2], 4)], axis=-1)
787
+
788
+ if self.mix_slots:
789
+ output = misc.MLP(hidden_size=128, layernorm="pre")(output)
790
+
791
+ return output
792
+
793
+ @classmethod
794
+ def compute_weighted_covariance(cls, x, w):
795
+ # The coordinate grid is (y, x), we want (x, y).
796
+ x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1)
797
+
798
+ # Pixel coordinates weighted by attention mask.
799
+ cov = x * w[Ellipsis, None]
800
+ cov = jnp.einsum(
801
+ "...ji,...jk->...ik", cov, x, precision=jax.lax.Precision.HIGHEST)
802
+
803
+ return cov
804
+
805
+ @classmethod
806
+ def compute_reference_frame_45_deg(cls, x, w):
807
+ cov = cls.compute_weighted_covariance(x, w)
808
+
809
+ # Compute eigenvalues.
810
+ pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) +
811
+ jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16)
812
+
813
+ eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2.
814
+ eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2.
815
+
816
+ # Compute eigenvectors, note that both have a positive y-axis.
817
+ # This means we have eliminated half of the possible rotations.
818
+ div = cov[Ellipsis, 0, 1] + 1e-16
819
+
820
+ v1 = (eig1 - cov[Ellipsis, 1, 1]) / div
821
+ v2 = (eig2 - cov[Ellipsis, 1, 1]) / div
822
+
823
+ v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1)
824
+ v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1)
825
+
826
+ # RULE 1:
827
+ # We catch two failure modes here.
828
+ # 1. If all attention weights are zero the covariance is also zero.
829
+ # Then the above computation is meaningless.
830
+ # 2. If the attention pattern is exactly aligned with the axes
831
+ # (e.g. a horizontal/vertical bar), the off-diagonal covariance
832
+ # values are going to be very low. If we use float32, we get
833
+ # basis vectors that are not orthogonal.
834
+ # Solution: use the default reference frame if the off-diagonal
835
+ # covariance value is too low.
836
+ default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1)
837
+ default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1)
838
+
839
+ mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None]
840
+ v1 = (1. - mask) * v1 + mask * default_1
841
+ v2 = (1. - mask) * v2 + mask * default_2
842
+
843
+ # Turn eigenvectors into unit vectors, so that we can construct
844
+ # a basis of a new reference frame.
845
+ norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True))
846
+ norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True))
847
+
848
+ v1 = v1 / norm1
849
+ v2 = v2 / norm2
850
+
851
+ # RULE 2:
852
+ # If the first basis vector is "pointing up" we assume the object
853
+ # is vertical (e.g. we say a door is vertical, whereas a car is horizontal).
854
+ # In the case of vertical objects, we swap the two basis vectors.
855
+ # This limits the possible rotations to +- 45deg instead of +- 90deg.
856
+ # We define "pointing up" as the first coordinate of the first basis vector
857
+ # being between +- sin(pi/4). The second coordinate is always positive.
858
+ mask = (jnp.logical_and(v1[Ellipsis, 0] < 0.707, v1[Ellipsis, 0] > -0.707)
859
+ ).astype(jnp.float32)[Ellipsis, None]
860
+ v1_ = (1. - mask) * v1 + mask * v2
861
+ v2_ = (1. - mask) * v2 + mask * v1
862
+ v1 = v1_
863
+ v2 = v2_
864
+
865
+ # RULE 3:
866
+ # Mirror the first basis vector if the first coordinate is negative.
867
+ # Here, we ensure that our coordinate system is always left-handed.
868
+ # Otherwise, we would sometimes unintentionally mirror the grid.
869
+ mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None]
870
+ v1 = (1. - mask) * v1 - mask * v1
871
+
872
+ return v1, v2
873
+
874
+ @classmethod
875
+ def compute_reference_frame_90_deg(cls, x, w):
876
+ cov = cls.compute_weighted_covariance(x, w)
877
+
878
+ # Compute eigenvalues.
879
+ pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) +
880
+ jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16)
881
+
882
+ eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2.
883
+ eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2.
884
+
885
+ # Compute eigenvectors, note that both have a positive y-axis.
886
+ # This means we have eliminated half of the possible rotations.
887
+ div = cov[Ellipsis, 0, 1] + 1e-16
888
+
889
+ v1 = (eig1 - cov[Ellipsis, 1, 1]) / div
890
+ v2 = (eig2 - cov[Ellipsis, 1, 1]) / div
891
+
892
+ v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1)
893
+ v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1)
894
+
895
+ # RULE 1:
896
+ # We catch two failure modes here.
897
+ # 1. If all attention weights are zero the covariance is also zero.
898
+ # Then the above computation is meaningless.
899
+ # 2. If the attention pattern is exactly aligned with the axes
900
+ # (e.g. a horizontal/vertical bar), the off-diagonal covariance
901
+ # values are going to be very low. If we use float32, we get
902
+ # basis vectors that are not orthogonal.
903
+ # Solution: use the default reference frame if the off-diagonal
904
+ # covariance value is too low.
905
+ default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1)
906
+ default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1)
907
+
908
+ # RULE 1.5:
909
+ # RULE 1 is activated if we see a vertical or a horizontal bar.
910
+ # We make sure that the coordinate grid for a horizontal bar is not rotated,
911
+ # whereas the coordinate grid for a vertical bar is rotated by 90deg.
912
+ # If cov[0, 0] > cov[1, 1], the bar is vertical.
913
+ mask = (cov[Ellipsis, 0, 0] <= cov[Ellipsis, 1, 1]).astype(jnp.float32)[Ellipsis, None]
914
+ # Furthermore, we have to mirror one of the basis vectors (if mask==1)
915
+ # so that we always have a left-handed coordinate grid.
916
+ default_v1 = (1. - mask) * default_1 - mask * default_2
917
+ default_v2 = (1. - mask) * default_2 + mask * default_1
918
+
919
+ # Continuation of RULE 1.
920
+ mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None]
921
+ v1 = mask * default_v1 + (1. - mask) * v1
922
+ v2 = mask * default_v2 + (1. - mask) * v2
923
+
924
+ # Turn eigenvectors into unit vectors, so that we can construct
925
+ # a basis of a new reference frame.
926
+ norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True))
927
+ norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True))
928
+
929
+ v1 = v1 / norm1
930
+ v2 = v2 / norm2
931
+
932
+ # RULE 2:
933
+ # Mirror the first basis vector if the first coordinate is negative.
934
+ # Here, we ensure that the our coordinate system is always left-handed.
935
+ # Otherwise, we would sometimes unintentionally mirror the grid.
936
+ mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None]
937
+ v1 = (1. - mask) * v1 - mask * v1
938
+
939
+ return v1, v2
940
+
941
+ @classmethod
942
+ def compute_rotation_matrix_45_deg(cls, x, w):
943
+ v1, v2 = cls.compute_reference_frame_45_deg(x, w)
944
+ return jnp.stack([v1, v2], axis=-1)
945
+
946
+ @classmethod
947
+ def compute_rotation_matrix_90_deg(cls, x, w):
948
+ v1, v2 = cls.compute_reference_frame_90_deg(x, w)
949
+ return jnp.stack([v1, v2], axis=-1)
950
+
951
+ @classmethod
952
+ def transform(cls, rotm, x):
953
+ # The coordinate grid x is in the (y, x) format, so we need to swap
954
+ # the coordinates on the input and output.
955
+ x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1)
956
+ # Equivalent to inv(R) * x^T = R^T * x^T = (x * R)^T.
957
+ # We are multiplying by the inverse of the rotation matrix because
958
+ # we are rotating the coordinate grid *against* the rotation of the object.
959
+ # y = jnp.matmul(x, R)
960
+ y = jnp.einsum("...ij,...jk->...ik", x, rotm)
961
+ # Swap coordinates again.
962
+ y = jnp.stack([y[Ellipsis, 1], y[Ellipsis, 0]], axis=-1)
963
+ return y
invariant_slot_attention/modules/invariant_initializers.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Initializers module library for equivariant slot attention."""
17
+
18
+ import functools
19
+ from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
20
+
21
+ from flax import linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from invariant_slot_attention.lib import utils
25
+
26
+ Shape = Tuple[int]
27
+
28
+ DType = Any
29
+ Array = Any # jnp.ndarray
30
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
31
+ ProcessorState = ArrayTree
32
+ PRNGKey = Array
33
+ NestedDict = Dict[str, Any]
34
+
35
+
36
+ def get_uniform_initializer(vmin, vmax):
37
+ """Get an uniform initializer with an arbitrary range."""
38
+ init = nn.initializers.uniform(scale=vmax - vmin)
39
+
40
+ def fn(*args, **kwargs):
41
+ return init(*args, **kwargs) + vmin
42
+
43
+ return fn
44
+
45
+
46
+ def get_normal_initializer(mean, sd):
47
+ """Get a normal initializer with an arbitrary mean."""
48
+ init = nn.initializers.normal(stddev=sd)
49
+
50
+ def fn(*args, **kwargs):
51
+ return init(*args, **kwargs) + mean
52
+
53
+ return fn
54
+
55
+
56
+ class ParamStateInitRandomPositions(nn.Module):
57
+ """Fixed, learnable state initalization with random positions.
58
+
59
+ Random slot positions sampled from U[-1, 1] are concatenated
60
+ as the last two dimensions.
61
+ Note: This module ignores any conditional input (by design).
62
+ """
63
+
64
+ shape: Sequence[int]
65
+ init_fn: str = "normal" # Default init with unit variance.
66
+ conditioning_key: Optional[str] = None
67
+ slot_positions_min: float = -1.
68
+ slot_positions_max: float = 1.
69
+
70
+ @nn.compact
71
+ def __call__(self, inputs, batch_size,
72
+ train = False):
73
+ del inputs, train # Unused.
74
+
75
+ if self.init_fn == "normal":
76
+ init_fn = functools.partial(nn.initializers.normal, stddev=1.)
77
+ elif self.init_fn == "zeros":
78
+ init_fn = lambda: nn.initializers.zeros
79
+ else:
80
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
81
+
82
+ param = self.param("state_init", init_fn(), self.shape)
83
+
84
+ out = utils.broadcast_across_batch(param, batch_size=batch_size)
85
+ shape = out.shape[:-1]
86
+ rng = self.make_rng("state_init")
87
+ slot_positions = jax.random.uniform(
88
+ rng, shape=[*shape, 2], minval=self.slot_positions_min,
89
+ maxval=self.slot_positions_max)
90
+ out = jnp.concatenate((out, slot_positions), axis=-1)
91
+ return out
92
+
93
+
94
+ class ParamStateInitLearnablePositions(nn.Module):
95
+ """Fixed, learnable state initalization with learnable positions.
96
+
97
+ Learnable initial positions are concatenated at the end of slots.
98
+ Note: This module ignores any conditional input (by design).
99
+ """
100
+
101
+ shape: Sequence[int]
102
+ init_fn: str = "normal" # Default init with unit variance.
103
+ conditioning_key: Optional[str] = None
104
+ slot_positions_min: float = -1.
105
+ slot_positions_max: float = 1.
106
+
107
+ @nn.compact
108
+ def __call__(self, inputs, batch_size,
109
+ train = False):
110
+ del inputs, train # Unused.
111
+
112
+ if self.init_fn == "normal":
113
+ init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
114
+ elif self.init_fn == "zeros":
115
+ init_fn_state = lambda: nn.initializers.zeros
116
+ else:
117
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
118
+
119
+ init_fn_state = init_fn_state()
120
+ init_fn_pos = get_uniform_initializer(
121
+ self.slot_positions_min, self.slot_positions_max)
122
+
123
+ param_state = self.param("state_init", init_fn_state, self.shape)
124
+ param_pos = self.param(
125
+ "state_init_position", init_fn_pos, (*self.shape[:-1], 2))
126
+
127
+ param = jnp.concatenate((param_state, param_pos), axis=-1)
128
+
129
+ return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
130
+
131
+
132
+ class ParamStateInitRandomPositionsScales(nn.Module):
133
+ """Fixed, learnable state initalization with random positions and scales.
134
+
135
+ Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
136
+ are concatenated as the last four dimensions.
137
+ Note: This module ignores any conditional input (by design).
138
+ """
139
+
140
+ shape: Sequence[int]
141
+ init_fn: str = "normal" # Default init with unit variance.
142
+ conditioning_key: Optional[str] = None
143
+ slot_positions_min: float = -1.
144
+ slot_positions_max: float = 1.
145
+ slot_scales_mean: float = 0.1
146
+ slot_scales_sd: float = 0.1
147
+
148
+ @nn.compact
149
+ def __call__(self, inputs, batch_size,
150
+ train = False):
151
+ del inputs, train # Unused.
152
+
153
+ if self.init_fn == "normal":
154
+ init_fn = functools.partial(nn.initializers.normal, stddev=1.)
155
+ elif self.init_fn == "zeros":
156
+ init_fn = lambda: nn.initializers.zeros
157
+ else:
158
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
159
+
160
+ param = self.param("state_init", init_fn(), self.shape)
161
+
162
+ out = utils.broadcast_across_batch(param, batch_size=batch_size)
163
+ shape = out.shape[:-1]
164
+ rng = self.make_rng("state_init")
165
+ slot_positions = jax.random.uniform(
166
+ rng, shape=[*shape, 2], minval=self.slot_positions_min,
167
+ maxval=self.slot_positions_max)
168
+ slot_scales = jax.random.normal(rng, shape=[*shape, 2])
169
+ slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
170
+ out = jnp.concatenate((out, slot_positions, slot_scales), axis=-1)
171
+ return out
172
+
173
+
174
+ class ParamStateInitLearnablePositionsScales(nn.Module):
175
+ """Fixed, learnable state initalization with random positions and scales.
176
+
177
+ Lernable initial positions and scales are concatenated at the end of slots.
178
+ Note: This module ignores any conditional input (by design).
179
+ """
180
+
181
+ shape: Sequence[int]
182
+ init_fn: str = "normal" # Default init with unit variance.
183
+ conditioning_key: Optional[str] = None
184
+ slot_positions_min: float = -1.
185
+ slot_positions_max: float = 1.
186
+ slot_scales_mean: float = 0.1
187
+ slot_scales_sd: float = 0.01
188
+
189
+ @nn.compact
190
+ def __call__(self, inputs, batch_size,
191
+ train = False):
192
+ del inputs, train # Unused.
193
+
194
+ if self.init_fn == "normal":
195
+ init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
196
+ elif self.init_fn == "zeros":
197
+ init_fn_state = lambda: nn.initializers.zeros
198
+ else:
199
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
200
+
201
+ init_fn_state = init_fn_state()
202
+ init_fn_pos = get_uniform_initializer(
203
+ self.slot_positions_min, self.slot_positions_max)
204
+ init_fn_scales = get_normal_initializer(
205
+ self.slot_scales_mean, self.slot_scales_sd)
206
+
207
+ param_state = self.param("state_init", init_fn_state, self.shape)
208
+ param_pos = self.param(
209
+ "state_init_position", init_fn_pos, (*self.shape[:-1], 2))
210
+ param_scales = self.param(
211
+ "state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
212
+
213
+ param = jnp.concatenate((param_state, param_pos, param_scales), axis=-1)
214
+
215
+ return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
216
+
217
+
218
+ class ParamStateInitLearnablePositionsRotationsScales(nn.Module):
219
+ """Fixed, learnable state initalization.
220
+
221
+ Learnable initial positions, rotations and scales are concatenated
222
+ at the end of slots. The rotation matrix is flattened.
223
+ Note: This module ignores any conditional input (by design).
224
+ """
225
+
226
+ shape: Sequence[int]
227
+ init_fn: str = "normal" # Default init with unit variance.
228
+ conditioning_key: Optional[str] = None
229
+ slot_positions_min: float = -1.
230
+ slot_positions_max: float = 1.
231
+ slot_scales_mean: float = 0.1
232
+ slot_scales_sd: float = 0.01
233
+ slot_angles_mean: float = 0.
234
+ slot_angles_sd: float = 0.1
235
+
236
+ @nn.compact
237
+ def __call__(self, inputs, batch_size,
238
+ train = False):
239
+ del inputs, train # Unused.
240
+
241
+ if self.init_fn == "normal":
242
+ init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
243
+ elif self.init_fn == "zeros":
244
+ init_fn_state = lambda: nn.initializers.zeros
245
+ else:
246
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
247
+
248
+ init_fn_state = init_fn_state()
249
+ init_fn_pos = get_uniform_initializer(
250
+ self.slot_positions_min, self.slot_positions_max)
251
+ init_fn_scales = get_normal_initializer(
252
+ self.slot_scales_mean, self.slot_scales_sd)
253
+ init_fn_angles = get_normal_initializer(
254
+ self.slot_angles_mean, self.slot_angles_sd)
255
+
256
+ param_state = self.param("state_init", init_fn_state, self.shape)
257
+ param_pos = self.param(
258
+ "state_init_position", init_fn_pos, (*self.shape[:-1], 2))
259
+ param_scales = self.param(
260
+ "state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
261
+ param_angles = self.param(
262
+ "state_init_angles", init_fn_angles, (*self.shape[:-1], 1))
263
+
264
+ # Initial angles in the range of (-pi / 4, pi / 4) <=> (-45, 45) degrees.
265
+ angles = jnp.tanh(param_angles) * (jnp.pi / 4)
266
+ rotm = jnp.concatenate(
267
+ [jnp.cos(angles), jnp.sin(angles),
268
+ -jnp.sin(angles), jnp.cos(angles)], axis=-1)
269
+
270
+ param = jnp.concatenate(
271
+ (param_state, param_pos, param_scales, rotm), axis=-1)
272
+
273
+ return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
274
+
275
+
276
+ class ParamStateInitRandomPositionsRotationsScales(nn.Module):
277
+ """Fixed, learnable state initialization with random pos., rot. and scales.
278
+
279
+ Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
280
+ are concatenated as the last four dimensions. Rotations are sampled
281
+ from +- 45 degrees.
282
+ Note: This module ignores any conditional input (by design).
283
+ """
284
+
285
+ shape: Sequence[int]
286
+ init_fn: str = "normal" # Default init with unit variance.
287
+ conditioning_key: Optional[str] = None
288
+ slot_positions_min: float = -1.
289
+ slot_positions_max: float = 1.
290
+ slot_scales_mean: float = 0.1
291
+ slot_scales_sd: float = 0.1
292
+ slot_angles_min: float = -jnp.pi / 4.
293
+ slot_angles_max: float = jnp.pi / 4.
294
+
295
+ @nn.compact
296
+ def __call__(self, inputs, batch_size,
297
+ train = False):
298
+ del inputs, train # Unused.
299
+
300
+ if self.init_fn == "normal":
301
+ init_fn = functools.partial(nn.initializers.normal, stddev=1.)
302
+ elif self.init_fn == "zeros":
303
+ init_fn = lambda: nn.initializers.zeros
304
+ else:
305
+ raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
306
+
307
+ param = self.param("state_init", init_fn(), self.shape)
308
+
309
+ out = utils.broadcast_across_batch(param, batch_size=batch_size)
310
+ shape = out.shape[:-1]
311
+ rng = self.make_rng("state_init")
312
+ slot_positions = jax.random.uniform(
313
+ rng, shape=[*shape, 2], minval=self.slot_positions_min,
314
+ maxval=self.slot_positions_max)
315
+ rng = self.make_rng("state_init")
316
+ slot_scales = jax.random.normal(rng, shape=[*shape, 2])
317
+ slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
318
+ rng = self.make_rng("state_init")
319
+ slot_angles = jax.random.uniform(rng, shape=[*shape, 1])
320
+ slot_angles = (slot_angles * (self.slot_angles_max - self.slot_angles_min)
321
+ ) + self.slot_angles_min
322
+ slot_rotm = jnp.concatenate(
323
+ [jnp.cos(slot_angles), jnp.sin(slot_angles),
324
+ -jnp.sin(slot_angles), jnp.cos(slot_angles)], axis=-1)
325
+ out = jnp.concatenate(
326
+ (out, slot_positions, slot_scales, slot_rotm), axis=-1)
327
+ return out
invariant_slot_attention/modules/misc.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Miscellaneous modules."""
17
+
18
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
19
+
20
+ from flax import linen as nn
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ from invariant_slot_attention.lib import utils
25
+
26
+ Shape = Tuple[int]
27
+
28
+ DType = Any
29
+ Array = Any # jnp.ndarray
30
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
31
+ ProcessorState = ArrayTree
32
+ PRNGKey = Array
33
+ NestedDict = Dict[str, Any]
34
+
35
+
36
+ class Identity(nn.Module):
37
+ """Module that applies the identity function, ignoring any additional args."""
38
+
39
+ @nn.compact
40
+ def __call__(self, inputs, **args):
41
+ return inputs
42
+
43
+
44
+ class Readout(nn.Module):
45
+ """Module for reading out multiple targets from an embedding."""
46
+
47
+ keys: Sequence[str]
48
+ readout_modules: Sequence[Callable[[], nn.Module]]
49
+ stop_gradient: Optional[Sequence[bool]] = None
50
+
51
+ @nn.compact
52
+ def __call__(self, inputs, train = False):
53
+ num_targets = len(self.keys)
54
+ assert num_targets >= 1, "Need to have at least one target."
55
+ assert len(self.readout_modules) == num_targets, (
56
+ "len(modules) and len(keys) must match.")
57
+ if self.stop_gradient is not None:
58
+ assert len(self.stop_gradient) == num_targets, (
59
+ "len(stop_gradient) and len(keys) must match.")
60
+ outputs = {}
61
+ for i in range(num_targets):
62
+ if self.stop_gradient is not None and self.stop_gradient[i]:
63
+ x = jax.lax.stop_gradient(inputs)
64
+ else:
65
+ x = inputs
66
+ outputs[self.keys[i]] = self.readout_modules[i]()(x, train) # pytype: disable=not-callable
67
+ return outputs
68
+
69
+
70
+ class MLP(nn.Module):
71
+ """Simple MLP with one hidden layer and optional pre-/post-layernorm."""
72
+
73
+ hidden_size: int
74
+ output_size: Optional[int] = None
75
+ num_hidden_layers: int = 1
76
+ activation_fn: Callable[[Array], Array] = nn.relu
77
+ layernorm: Optional[str] = None
78
+ activate_output: bool = False
79
+ residual: bool = False
80
+
81
+ @nn.compact
82
+ def __call__(self, inputs, train = False):
83
+ del train # Unused.
84
+
85
+ output_size = self.output_size or inputs.shape[-1]
86
+
87
+ x = inputs
88
+
89
+ if self.layernorm == "pre":
90
+ x = nn.LayerNorm()(x)
91
+
92
+ for i in range(self.num_hidden_layers):
93
+ x = nn.Dense(self.hidden_size, name=f"dense_mlp_{i}")(x)
94
+ x = self.activation_fn(x)
95
+ x = nn.Dense(output_size, name=f"dense_mlp_{self.num_hidden_layers}")(x)
96
+
97
+ if self.activate_output:
98
+ x = self.activation_fn(x)
99
+
100
+ if self.residual:
101
+ x = x + inputs
102
+
103
+ if self.layernorm == "post":
104
+ x = nn.LayerNorm()(x)
105
+
106
+ return x
107
+
108
+
109
+ class GRU(nn.Module):
110
+ """GRU cell as nn.Module."""
111
+
112
+ @nn.compact
113
+ def __call__(self, carry, inputs,
114
+ train = False):
115
+ del train # Unused.
116
+ carry, _ = nn.GRUCell()(carry, inputs)
117
+ return carry
118
+
119
+
120
+ class Dense(nn.Module):
121
+ """Dense layer as nn.Module accepting "train" flag."""
122
+
123
+ features: int
124
+ use_bias: bool = True
125
+
126
+ @nn.compact
127
+ def __call__(self, inputs, train = False):
128
+ del train # Unused.
129
+ return nn.Dense(features=self.features, use_bias=self.use_bias)(inputs)
130
+
131
+
132
+ class PositionEmbedding(nn.Module):
133
+ """A module for applying N-dimensional position embedding.
134
+
135
+ Attr:
136
+ embedding_type: A string defining the type of position embedding to use. One
137
+ of ["linear", "discrete_1d", "fourier", "gaussian_fourier"].
138
+ update_type: A string defining how the input is updated with the position
139
+ embedding. One of ["proj_add", "concat"].
140
+ num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
141
+ "fourier", the embedding dimensionality is 2 x number of position
142
+ dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
143
+ the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
144
+ == "linear", this parameter is ignored.
145
+ gaussian_sigma: Standard deviation of sampled Gaussians.
146
+ pos_transform: Optional transform for the embedding.
147
+ output_transform: Optional transform for the combined input and embedding.
148
+ trainable_pos_embedding: Boolean flag for allowing gradients to flow into
149
+ the position embedding, so that the optimizer can update it.
150
+ """
151
+
152
+ embedding_type: str
153
+ update_type: str
154
+ num_fourier_bases: int = 0
155
+ gaussian_sigma: float = 1.0
156
+ pos_transform: Callable[[], nn.Module] = Identity
157
+ output_transform: Callable[[], nn.Module] = Identity
158
+ trainable_pos_embedding: bool = False
159
+
160
+ def _make_pos_embedding_tensor(self, rng, input_shape):
161
+ if self.embedding_type == "discrete_1d":
162
+ # An integer tensor in [0, input_shape[-2]-1] reflecting
163
+ # 1D discrete position encoding (encode the second-to-last axis).
164
+ pos_embedding = jnp.broadcast_to(
165
+ jnp.arange(input_shape[-2]), input_shape[1:-1])
166
+ else:
167
+ # A tensor grid in [-1, +1] for each input dimension.
168
+ pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])
169
+
170
+ if self.embedding_type == "linear":
171
+ pass
172
+ elif self.embedding_type == "discrete_1d":
173
+ pos_embedding = jax.nn.one_hot(pos_embedding, input_shape[-2])
174
+ elif self.embedding_type == "fourier":
175
+ # NeRF-style Fourier/sinusoidal position encoding.
176
+ pos_embedding = utils.convert_to_fourier_features(
177
+ pos_embedding * jnp.pi, basis_degree=self.num_fourier_bases)
178
+ elif self.embedding_type == "gaussian_fourier":
179
+ # Gaussian Fourier features. Reference: https://arxiv.org/abs/2006.10739
180
+ num_dims = pos_embedding.shape[-1]
181
+ projection = jax.random.normal(
182
+ rng, [num_dims, self.num_fourier_bases]) * self.gaussian_sigma
183
+ pos_embedding = jnp.pi * pos_embedding.dot(projection)
184
+ # A slightly faster implementation of sin and cos.
185
+ pos_embedding = jnp.sin(
186
+ jnp.concatenate([pos_embedding, pos_embedding + 0.5 * jnp.pi],
187
+ axis=-1))
188
+ else:
189
+ raise ValueError("Invalid embedding type provided.")
190
+
191
+ # Add batch dimension.
192
+ pos_embedding = jnp.expand_dims(pos_embedding, axis=0)
193
+
194
+ return pos_embedding
195
+
196
+ @nn.compact
197
+ def __call__(self, inputs):
198
+
199
+ # Compute the position embedding only in the initial call use the same rng
200
+ # as is used for initializing learnable parameters.
201
+ pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
202
+ inputs.shape)
203
+
204
+ if not self.trainable_pos_embedding:
205
+ pos_embedding = jax.lax.stop_gradient(pos_embedding)
206
+
207
+ # Apply optional transformation on the position embedding.
208
+ pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable
209
+
210
+ # Apply position encoding to inputs.
211
+ if self.update_type == "project_add":
212
+ # Here, we project the position encodings to the same dimensionality as
213
+ # the inputs and add them to the inputs (broadcast along batch dimension).
214
+ # This is roughly equivalent to concatenation of position encodings to the
215
+ # inputs (if followed by a Dense layer), but is slightly more efficient.
216
+ n_features = inputs.shape[-1]
217
+ x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
218
+ elif self.update_type == "concat":
219
+ # Repeat the position embedding along the first (batch) dimension.
220
+ pos_embedding = jnp.broadcast_to(
221
+ pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
222
+ # concatenate along the channel dimension.
223
+ x = jnp.concatenate((inputs, pos_embedding), axis=-1)
224
+ else:
225
+ raise ValueError("Invalid update type provided.")
226
+
227
+ # Apply optional output transformation.
228
+ x = self.output_transform()(x) # pytype: disable=not-callable
229
+ return x
230
+
231
+
232
+ class RelativePositionEmbedding(nn.Module):
233
+ """A module for applying embedding of input position relative to slots.
234
+
235
+ Attr
236
+ update_type: A string defining how the input is updated with the position
237
+ embedding. One of ["proj_add", "concat"].
238
+ embedding_type: A string defining the type of position embedding to use.
239
+ Currently only "linear" is supported.
240
+ num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
241
+ "fourier", the embedding dimensionality is 2 x number of position
242
+ dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
243
+ the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
244
+ == "linear", this parameter is ignored.
245
+ gaussian_sigma: Standard deviation of sampled Gaussians.
246
+ pos_transform: Optional transform for the embedding.
247
+ output_transform: Optional transform for the combined input and embedding.
248
+ trainable_pos_embedding: Boolean flag for allowing gradients to flow into
249
+ the position embedding, so that the optimizer can update it.
250
+ """
251
+
252
+ update_type: str
253
+ embedding_type: str = "linear"
254
+ num_fourier_bases: int = 0
255
+ gaussian_sigma: float = 1.0
256
+ pos_transform: Callable[[], nn.Module] = Identity
257
+ output_transform: Callable[[], nn.Module] = Identity
258
+ trainable_pos_embedding: bool = False
259
+ scales_factor: float = 1.0
260
+
261
+ def _make_pos_embedding_tensor(self, rng, input_shape):
262
+
263
+ # A tensor grid in [-1, +1] for each input dimension.
264
+ pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])
265
+
266
+ # Add batch dimension.
267
+ pos_embedding = jnp.expand_dims(pos_embedding, axis=0)
268
+
269
+ return pos_embedding
270
+
271
+ @nn.compact
272
+ def __call__(self, inputs, slot_positions,
273
+ slot_scales = None,
274
+ slot_rotm = None):
275
+
276
+ # Compute the position embedding only in the initial call use the same rng
277
+ # as is used for initializing learnable parameters.
278
+ pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
279
+ inputs.shape)
280
+
281
+ if not self.trainable_pos_embedding:
282
+ pos_embedding = jax.lax.stop_gradient(pos_embedding)
283
+
284
+ # Relativize pos_embedding with respect to slot positions
285
+ # and optionally slot scales.
286
+ slot_positions = jnp.expand_dims(
287
+ jnp.expand_dims(slot_positions, axis=-2), axis=-2)
288
+ if slot_scales is not None:
289
+ slot_scales = jnp.expand_dims(
290
+ jnp.expand_dims(slot_scales, axis=-2), axis=-2)
291
+
292
+ if self.embedding_type == "linear":
293
+ pos_embedding = pos_embedding - slot_positions
294
+ if slot_rotm is not None:
295
+ pos_embedding = self.transform(slot_rotm, pos_embedding)
296
+ if slot_scales is not None:
297
+ # Scales are usually small so the grid might get too large.
298
+ pos_embedding = pos_embedding / self.scales_factor
299
+ pos_embedding = pos_embedding / slot_scales
300
+ else:
301
+ raise ValueError("Invalid embedding type provided.")
302
+
303
+ # Apply optional transformation on the position embedding.
304
+ pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable
305
+
306
+ # Define intermediate for logging.
307
+ pos_embedding = Identity(name="pos_emb")(pos_embedding)
308
+
309
+ # Apply position encoding to inputs.
310
+ if self.update_type == "project_add":
311
+ # Here, we project the position encodings to the same dimensionality as
312
+ # the inputs and add them to the inputs (broadcast along batch dimension).
313
+ # This is roughly equivalent to concatenation of position encodings to the
314
+ # inputs (if followed by a Dense layer), but is slightly more efficient.
315
+ n_features = inputs.shape[-1]
316
+ x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
317
+ elif self.update_type == "concat":
318
+ # Repeat the position embedding along the first (batch) dimension.
319
+ pos_embedding = jnp.broadcast_to(
320
+ pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
321
+ # concatenate along the channel dimension.
322
+ x = jnp.concatenate((inputs, pos_embedding), axis=-1)
323
+ else:
324
+ raise ValueError("Invalid update type provided.")
325
+
326
+ # Apply optional output transformation.
327
+ x = self.output_transform()(x) # pytype: disable=not-callable
328
+ return x
329
+
330
+ @classmethod
331
+ def transform(cls, rot, coords):
332
+ # The coordinate grid coords is in the (y, x) format, so we need to swap
333
+ # the coordinates on the input and output.
334
+ coords = jnp.stack([coords[Ellipsis, 1], coords[Ellipsis, 0]], axis=-1)
335
+ # Equivalent to inv(R) * coords^T = R^T * coords^T = (coords * R)^T.
336
+ # We are multiplying by the inverse of the rotation matrix because
337
+ # we are rotating the coordinate grid *against* the rotation of the object.
338
+ new_coords = jnp.einsum("...hij,...jk->...hik", coords, rot)
339
+ # Swap coordinates again.
340
+ return jnp.stack([new_coords[Ellipsis, 1], new_coords[Ellipsis, 0]], axis=-1)
invariant_slot_attention/modules/resnet.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Implementation of ResNet V1 in Flax.
17
+
18
+ "Deep Residual Learning for Image Recognition"
19
+ He et al., 2015, [https://arxiv.org/abs/1512.03385]
20
+ """
21
+
22
+ import functools
23
+
24
+ from typing import Any, Tuple, Type, List, Optional, Callable, Sequence
25
+ import flax.linen as nn
26
+ import jax.numpy as jnp
27
+
28
+
29
+ Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False)
30
+ Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False)
31
+
32
+
33
+ class ResNetBlock(nn.Module):
34
+ """ResNet block without bottleneck used in ResNet-18 and ResNet-34."""
35
+
36
+ filters: int
37
+ norm: Any
38
+ kernel_dilation: Tuple[int, int] = (1, 1)
39
+ strides: Tuple[int, int] = (1, 1)
40
+
41
+ @nn.compact
42
+ def __call__(self, x):
43
+ residual = x
44
+
45
+ x = Conv3x3(
46
+ self.filters,
47
+ strides=self.strides,
48
+ kernel_dilation=self.kernel_dilation,
49
+ name="conv1")(x)
50
+ x = self.norm(name="bn1")(x)
51
+ x = nn.relu(x)
52
+ x = Conv3x3(self.filters, name="conv2")(x)
53
+ # Initializing the scale to 0 has been common practice since "Fixup
54
+ # Initialization: Residual Learning Without Normalization" Tengyu et al,
55
+ # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
56
+ x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x)
57
+
58
+ if residual.shape != x.shape:
59
+ residual = Conv1x1(
60
+ self.filters, strides=self.strides, name="proj_conv")(
61
+ residual)
62
+ residual = self.norm(name="proj_bn")(residual)
63
+
64
+ x = nn.relu(residual + x)
65
+ return x
66
+
67
+
68
+ class BottleneckResNetBlock(ResNetBlock):
69
+ """Bottleneck ResNet block used in ResNet-50 and larger."""
70
+
71
+ @nn.compact
72
+ def __call__(self, x):
73
+ residual = x
74
+
75
+ x = Conv1x1(self.filters, name="conv1")(x)
76
+ x = self.norm(name="bn1")(x)
77
+ x = nn.relu(x)
78
+ x = Conv3x3(
79
+ self.filters,
80
+ strides=self.strides,
81
+ kernel_dilation=self.kernel_dilation,
82
+ name="conv2")(x)
83
+ x = self.norm(name="bn2")(x)
84
+ x = nn.relu(x)
85
+ x = Conv1x1(4 * self.filters, name="conv3")(x)
86
+ # Initializing the scale to 0 has been common practice since "Fixup
87
+ # Initialization: Residual Learning Without Normalization" Tengyu et al,
88
+ # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
89
+ x = self.norm(name="bn3")(x)
90
+
91
+ if residual.shape != x.shape:
92
+ residual = Conv1x1(
93
+ 4 * self.filters, strides=self.strides, name="proj_conv")(
94
+ residual)
95
+ residual = self.norm(name="proj_bn")(residual)
96
+
97
+ x = nn.relu(residual + x)
98
+ return x
99
+
100
+
101
+ class ResNetStage(nn.Module):
102
+ """ResNet stage consistent of multiple ResNet blocks."""
103
+
104
+ stage_size: int
105
+ filters: int
106
+ block_cls: Type[ResNetBlock]
107
+ norm: Any
108
+ first_block_strides: Tuple[int, int]
109
+
110
+ @nn.compact
111
+ def __call__(self, x):
112
+ for i in range(self.stage_size):
113
+ x = self.block_cls(
114
+ filters=self.filters,
115
+ norm=self.norm,
116
+ strides=self.first_block_strides if i == 0 else (1, 1),
117
+ name=f"block{i + 1}")(
118
+ x)
119
+ return x
120
+
121
+
122
+ class ResNet(nn.Module):
123
+ """Construct ResNet V1 with `num_classes` outputs.
124
+
125
+ Attributes:
126
+ num_classes: Number of nodes in the final layer.
127
+ block_cls: Class for the blocks. ResNet-50 and larger use
128
+ `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
129
+ ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
130
+ stage_sizes: List with the number of ResNet blocks in each stage. Number of
131
+ stages can be varied.
132
+ norm_type: Which type of normalization layer to apply. Options are:
133
+ "batch": BatchNorm, "group": GroupNorm, "layer": LayerNorm. Defaults to
134
+ BatchNorm.
135
+ width_factor: Factor applied to the number of filters. The 64 * width_factor
136
+ is the number of filters in the first stage, every consecutive stage
137
+ doubles the number of filters.
138
+ small_inputs: Bool, if True, ignore strides and skip max pooling in the root
139
+ block and use smaller filter size.
140
+ stage_strides: Stride per stage. This overrides all other arguments.
141
+ include_top: Whether to include the fully-connected layer at the top
142
+ of the network.
143
+ axis_name: Axis name over which to aggregate batchnorm statistics.
144
+ """
145
+ num_classes: int
146
+ block_cls: Type[ResNetBlock]
147
+ stage_sizes: List[int]
148
+ norm_type: str = "batch"
149
+ width_factor: int = 1
150
+ small_inputs: bool = False
151
+ stage_strides: Optional[List[Tuple[int, int]]] = None
152
+ include_top: bool = False
153
+ axis_name: Optional[str] = None
154
+ output_initializer: Callable[[Any, Sequence[int], Any], Any] = (
155
+ nn.initializers.zeros)
156
+
157
+ @nn.compact
158
+ def __call__(self, x, *, train):
159
+ """Apply the ResNet to the inputs `x`.
160
+
161
+ Args:
162
+ x: Inputs.
163
+ train: Whether to use BatchNorm in training or inference mode.
164
+
165
+ Returns:
166
+ The output head with `num_classes` entries.
167
+ """
168
+ width = 64 * self.width_factor
169
+
170
+ if self.norm_type == "batch":
171
+ norm = functools.partial(
172
+ nn.BatchNorm, use_running_average=not train, momentum=0.9,
173
+ axis_name=self.axis_name)
174
+ elif self.norm_type == "layer":
175
+ norm = nn.LayerNorm
176
+ elif self.norm_type == "group":
177
+ norm = nn.GroupNorm
178
+ else:
179
+ raise ValueError(f"Invalid norm_type: {self.norm_type}")
180
+
181
+ # Root block.
182
+ x = nn.Conv(
183
+ features=width,
184
+ kernel_size=(7, 7) if not self.small_inputs else (3, 3),
185
+ strides=(2, 2) if not self.small_inputs else (1, 1),
186
+ use_bias=False,
187
+ name="init_conv")(
188
+ x)
189
+ x = norm(name="init_bn")(x)
190
+
191
+ if not self.small_inputs:
192
+ x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
193
+
194
+ # Stages.
195
+ for i, stage_size in enumerate(self.stage_sizes):
196
+ if i == 0:
197
+ first_block_strides = (
198
+ 1, 1) if self.stage_strides is None else self.stage_strides[i]
199
+ else:
200
+ first_block_strides = (
201
+ 2, 2) if self.stage_strides is None else self.stage_strides[i]
202
+
203
+ x = ResNetStage(
204
+ stage_size,
205
+ filters=width * 2**i,
206
+ block_cls=self.block_cls,
207
+ norm=norm,
208
+ first_block_strides=first_block_strides,
209
+ name=f"stage{i + 1}")(x)
210
+
211
+ # Head.
212
+ if self.include_top:
213
+ x = jnp.mean(x, axis=(1, 2))
214
+ x = nn.Dense(
215
+ self.num_classes, kernel_init=self.output_initializer, name="head")(x)
216
+ return x
217
+
218
+
219
+ ResNetWithBasicBlk = functools.partial(ResNet, block_cls=ResNetBlock)
220
+ ResNetWithBottleneckBlk = functools.partial(ResNet,
221
+ block_cls=BottleneckResNetBlock)
222
+
223
+ ResNet18 = functools.partial(ResNetWithBasicBlk, stage_sizes=[2, 2, 2, 2])
224
+ ResNet34 = functools.partial(ResNetWithBasicBlk, stage_sizes=[3, 4, 6, 3])
225
+ ResNet50 = functools.partial(ResNetWithBottleneckBlk, stage_sizes=[3, 4, 6, 3])
226
+ ResNet101 = functools.partial(ResNetWithBottleneckBlk,
227
+ stage_sizes=[3, 4, 23, 3])
228
+ ResNet152 = functools.partial(ResNetWithBottleneckBlk,
229
+ stage_sizes=[3, 8, 36, 3])
230
+ ResNet200 = functools.partial(ResNetWithBottleneckBlk,
231
+ stage_sizes=[3, 24, 36, 3])
invariant_slot_attention/modules/video.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Video module library."""
17
+
18
+ import functools
19
+ from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Optional, Tuple, Union
20
+
21
+ from flax import linen as nn
22
+ import jax.numpy as jnp
23
+ from invariant_slot_attention.lib import utils
24
+ from invariant_slot_attention.modules import misc
25
+
26
+ Shape = Tuple[int]
27
+
28
+ DType = Any
29
+ Array = Any # jnp.ndarray
30
+ ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
31
+ ProcessorState = ArrayTree
32
+ PRNGKey = Array
33
+ NestedDict = Dict[str, Any]
34
+
35
+
36
+ class CorrectorPredictorTuple(NamedTuple):
37
+ corrected: ProcessorState
38
+ predicted: ProcessorState
39
+
40
+
41
+ class Processor(nn.Module):
42
+ """Recurrent processor module.
43
+
44
+ This module is scanned (applied recurrently) over the sequence dimension of
45
+ the input and applies a corrector and a predictor module. The corrector is
46
+ only applied if new inputs (such as a new image/frame) are received and uses
47
+ the new input to correct its internal state.
48
+
49
+ The predictor is equivalent to a latent transition model and produces a
50
+ prediction for the state at the next time step, given the current (corrected)
51
+ state.
52
+ """
53
+ corrector: Callable[[ProcessorState, Array], ProcessorState]
54
+ predictor: Callable[[ProcessorState], ProcessorState]
55
+
56
+ @functools.partial(
57
+ nn.scan, # Scan (recurrently apply) over time axis.
58
+ in_axes=(1, 1, nn.broadcast), # (inputs, padding_mask, train).
59
+ out_axes=1,
60
+ variable_axes={"intermediates": 1}, # Stack intermediates along seq. dim.
61
+ variable_broadcast="params",
62
+ split_rngs={"params": False, "dropout": True})
63
+ @nn.compact
64
+ def __call__(self, state, inputs,
65
+ padding_mask,
66
+ train):
67
+
68
+ # Only apply corrector if we receive new inputs.
69
+ if inputs is not None:
70
+ corrected_state = self.corrector(state, inputs, padding_mask, train=train)
71
+ # Otherwise simply use previous state as input for predictor.
72
+ else:
73
+ corrected_state = state
74
+
75
+ # Always apply predictor (i.e. transition model).
76
+ predicted_state = self.predictor(corrected_state, train=train)
77
+
78
+ # Prepare outputs in a format compatible with nn.scan.
79
+ new_state = predicted_state
80
+ outputs = CorrectorPredictorTuple(
81
+ corrected=corrected_state, predicted=predicted_state)
82
+ return new_state, outputs
83
+
84
+
85
+ class SAVi(nn.Module):
86
+ """Video model consisting of encoder, recurrent processor, and decoder."""
87
+
88
+ encoder: Callable[[], nn.Module]
89
+ decoder: Callable[[], nn.Module]
90
+ corrector: Callable[[], nn.Module]
91
+ predictor: Callable[[], nn.Module]
92
+ initializer: Callable[[], nn.Module]
93
+ decode_corrected: bool = True
94
+ decode_predicted: bool = True
95
+
96
+ @nn.compact
97
+ def __call__(self, video, conditioning = None,
98
+ continue_from_previous_state = False,
99
+ padding_mask = None,
100
+ train = False):
101
+ """Performs a forward pass on a video.
102
+
103
+ Args:
104
+ video: Video of shape `[batch_size, n_frames, height, width, n_channels]`.
105
+ conditioning: Optional jnp.ndarray used for conditioning the initial state
106
+ of the recurrent processor.
107
+ continue_from_previous_state: Boolean, whether to continue from a previous
108
+ state or not. If True, the conditioning variable is used directly as
109
+ initial state.
110
+ padding_mask: Binary mask for padding video inputs (e.g. for videos of
111
+ different sizes/lengths). Zero corresponds to padding.
112
+ train: Indicating whether we're training or evaluating.
113
+
114
+ Returns:
115
+ A dictionary of model predictions.
116
+ """
117
+ processor = Processor(
118
+ corrector=self.corrector(), predictor=self.predictor()) # pytype: disable=wrong-arg-types
119
+
120
+ if padding_mask is None:
121
+ padding_mask = jnp.ones(video.shape[:-1], jnp.int32)
122
+
123
+ # video.shape = (batch_size, n_frames, height, width, n_channels)
124
+ # Vmapped over sequence dim.
125
+ encoded_inputs = self.encoder()(video, padding_mask, train) # pytype: disable=not-callable
126
+ if continue_from_previous_state:
127
+ assert conditioning is not None, (
128
+ "When continuing from a previous state, the state has to be passed "
129
+ "via the `conditioning` variable, which cannot be `None`.")
130
+ init_state = conditioning[:, -1] # We currently only use last state.
131
+ else:
132
+ # Same as above but without encoded inputs.
133
+ init_state = self.initializer()(
134
+ conditioning, batch_size=video.shape[0], train=train) # pytype: disable=not-callable
135
+
136
+ # Scan recurrent processor over encoded inputs along sequence dimension.
137
+ _, states = processor(init_state, encoded_inputs, padding_mask, train)
138
+ # type(states) = CorrectorPredictorTuple.
139
+ # states.corrected.shape = (batch_size, n_frames, ..., n_features).
140
+ # states.predicted.shape = (batch_size, n_frames, ..., n_features).
141
+
142
+ # Decode latent states.
143
+ decoder = self.decoder() # Vmapped over sequence dim.
144
+ outputs = decoder(states.corrected,
145
+ train) if self.decode_corrected else None # pytype: disable=not-callable
146
+ outputs_pred = decoder(states.predicted,
147
+ train) if self.decode_predicted else None # pytype: disable=not-callable
148
+
149
+ return {
150
+ "states": states.corrected,
151
+ "states_pred": states.predicted,
152
+ "outputs": outputs,
153
+ "outputs_pred": outputs_pred,
154
+ }
155
+
156
+
157
+ class FrameEncoder(nn.Module):
158
+ """Encoder for single video frame, vmapped over time axis."""
159
+
160
+ backbone: Callable[[], nn.Module]
161
+ pos_emb: Callable[[], nn.Module] = misc.Identity
162
+ reduction: Optional[str] = None
163
+ output_transform: Callable[[], nn.Module] = misc.Identity
164
+
165
+ # Vmapped application of module, consumes time axis (axis=1).
166
+ @functools.partial(utils.time_distributed, in_axes=(1, 1, None))
167
+ @nn.compact
168
+ def __call__(self, inputs, padding_mask = None,
169
+ train = False):
170
+ del padding_mask # Unused.
171
+
172
+ # inputs.shape = (batch_size, height, width, n_channels)
173
+ x = self.backbone()(inputs, train=train)
174
+
175
+ x = self.pos_emb()(x)
176
+
177
+ if self.reduction == "spatial_flatten":
178
+ batch_size, height, width, n_features = x.shape
179
+ x = jnp.reshape(x, (batch_size, height * width, n_features))
180
+ elif self.reduction == "spatial_average":
181
+ x = jnp.mean(x, axis=(1, 2))
182
+ elif self.reduction == "all_flatten":
183
+ batch_size, height, width, n_features = x.shape
184
+ x = jnp.reshape(x, (batch_size, height * width * n_features))
185
+ elif self.reduction is not None:
186
+ raise ValueError("Unknown reduction type: {}.".format(self.reduction))
187
+
188
+ output_block = self.output_transform()
189
+
190
+ if hasattr(output_block, "qkv_size"):
191
+ # Project to qkv_size if used transformer.
192
+ x = nn.relu(nn.Dense(output_block.qkv_size)(x))
193
+
194
+ x = output_block(x, train=train)
195
+ return x
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Main file for running the model trainer."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from absl import logging
21
+
22
+ from clu import platform
23
+ import jax
24
+ from ml_collections import config_flags
25
+
26
+ import tensorflow as tf
27
+
28
+
29
+ from invariant_slot_attention.lib import trainer
30
+
31
+ FLAGS = flags.FLAGS
32
+
33
+ config_flags.DEFINE_config_file(
34
+ "config", None, "Config file.")
35
+ flags.DEFINE_string("workdir", None, "Work unit directory.")
36
+ flags.DEFINE_string("jax_backend_target", None, "JAX backend target to use.")
37
+ flags.mark_flags_as_required(["config", "workdir"])
38
+
39
+
40
+ def main(argv):
41
+ del argv
42
+
43
+ # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
44
+ # it unavailable to JAX.
45
+ tf.config.experimental.set_visible_devices([], "GPU")
46
+
47
+ if FLAGS.jax_backend_target:
48
+ logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
49
+ jax.config.update("jax_xla_backend", "tpu_driver")
50
+ jax.config.update("jax_backend_target", FLAGS.jax_backend_target)
51
+
52
+ logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
53
+ logging.info("JAX devices: %r", jax.devices())
54
+
55
+ # Add a note so that we can tell which task is which JAX host.
56
+ platform.work_unit().set_task_status(
57
+ f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
58
+ platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
59
+ FLAGS.workdir, "workdir")
60
+
61
+ trainer.train_and_evaluate(FLAGS.config, FLAGS.workdir)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ app.run(main)