Spaces:
Runtime error
Runtime error
ondrejbiza
commited on
Commit
•
a560c26
1
Parent(s):
db5cc89
Working on isa demo.
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- __init__.py +15 -0
- app.py +66 -0
- invariant_slot_attention/configs/__init__.py +15 -0
- invariant_slot_attention/configs/clevr_with_masks/baseline.py +194 -0
- invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py +202 -0
- invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py +203 -0
- invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py +203 -0
- invariant_slot_attention/configs/clevrtex/resnet/baseline.py +198 -0
- invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py +206 -0
- invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py +213 -0
- invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py +213 -0
- invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py +197 -0
- invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py +205 -0
- invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py +207 -0
- invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py +207 -0
- invariant_slot_attention/configs/multishapenet_easy/baseline.py +195 -0
- invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py +203 -0
- invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py +205 -0
- invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py +205 -0
- invariant_slot_attention/configs/objects_room/baseline.py +192 -0
- invariant_slot_attention/configs/objects_room/equiv_transl.py +200 -0
- invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py +202 -0
- invariant_slot_attention/configs/objects_room/equiv_transl_scale.py +202 -0
- invariant_slot_attention/configs/tetrominoes/baseline.py +191 -0
- invariant_slot_attention/configs/tetrominoes/equiv_transl.py +199 -0
- invariant_slot_attention/configs/waymo_open/baseline.py +191 -0
- invariant_slot_attention/configs/waymo_open/equiv_transl.py +199 -0
- invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py +206 -0
- invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py +206 -0
- invariant_slot_attention/lib/__init__.py +15 -0
- invariant_slot_attention/lib/evaluator.py +326 -0
- invariant_slot_attention/lib/input_pipeline.py +390 -0
- invariant_slot_attention/lib/losses.py +295 -0
- invariant_slot_attention/lib/metrics.py +263 -0
- invariant_slot_attention/lib/preprocessing.py +1236 -0
- invariant_slot_attention/lib/trainer.py +328 -0
- invariant_slot_attention/lib/transforms.py +163 -0
- invariant_slot_attention/lib/utils.py +625 -0
- invariant_slot_attention/modules/__init__.py +49 -0
- invariant_slot_attention/modules/attention.py +327 -0
- invariant_slot_attention/modules/convolution.py +164 -0
- invariant_slot_attention/modules/decoders.py +267 -0
- invariant_slot_attention/modules/initializers.py +173 -0
- invariant_slot_attention/modules/invariant_attention.py +963 -0
- invariant_slot_attention/modules/invariant_initializers.py +327 -0
- invariant_slot_attention/modules/misc.py +340 -0
- invariant_slot_attention/modules/resnet.py +231 -0
- invariant_slot_attention/modules/video.py +195 -0
- main.py +65 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/venv
|
2 |
+
/flagged
|
3 |
+
/clevr_isa_ts
|
4 |
+
*.pyc
|
__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import os
|
3 |
+
|
4 |
+
from absl import flags
|
5 |
+
import gradio as gr
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
|
9 |
+
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
|
10 |
+
from invariant_slot_attention.lib import input_pipeline
|
11 |
+
from invariant_slot_attention.lib import utils
|
12 |
+
|
13 |
+
|
14 |
+
def load_model(config):
|
15 |
+
rng, data_rng = jax.random.split(rng)
|
16 |
+
|
17 |
+
# Initialize model
|
18 |
+
model = utils.build_model_from_config(config.model)
|
19 |
+
|
20 |
+
def init_model(rng):
|
21 |
+
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
22 |
+
|
23 |
+
init_conditioning = None
|
24 |
+
init_inputs = jnp.ones(
|
25 |
+
[1] + list(train_ds.element_spec["video"].shape)[2:],
|
26 |
+
jnp.float32)
|
27 |
+
initial_vars = model.init(
|
28 |
+
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
29 |
+
video=init_inputs, conditioning=init_conditioning,
|
30 |
+
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
|
31 |
+
|
32 |
+
# Split into state variables (e.g. for batchnorm stats) and model params.
|
33 |
+
# Note that `pop()` on a FrozenDict performs a deep copy.
|
34 |
+
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
|
35 |
+
|
36 |
+
# Filter out intermediates (we don't want to store these in the TrainState).
|
37 |
+
state_vars = utils.filter_key_from_frozen_dict(
|
38 |
+
state_vars, key="intermediates")
|
39 |
+
return state_vars, initial_params
|
40 |
+
|
41 |
+
state_vars, initial_params = init_model(rng)
|
42 |
+
|
43 |
+
learning_rate_fn = lr_schedules.get_learning_rate_fn(config)
|
44 |
+
tx = optimizers.get_optimizer(
|
45 |
+
config.optimizer_configs, learning_rate_fn, params=initial_params)
|
46 |
+
|
47 |
+
opt_state = tx.init(initial_params)
|
48 |
+
|
49 |
+
state = utils.TrainState(
|
50 |
+
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
51 |
+
variables=state_vars)
|
52 |
+
|
53 |
+
loss_fn = functools.partial(
|
54 |
+
losses.compute_full_loss, loss_config=config.losses)
|
55 |
+
|
56 |
+
checkpoint_dir = os.path.join(workdir, "checkpoints")
|
57 |
+
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
58 |
+
state = ckpt.restore_or_initialize(state)
|
59 |
+
|
60 |
+
|
61 |
+
def greet(name):
|
62 |
+
return "Hello " + name + "!"
|
63 |
+
|
64 |
+
|
65 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
66 |
+
demo.launch()
|
invariant_slot_attention/configs/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
invariant_slot_attention/configs/clevr_with_masks/baseline.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVR."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "clevr_with_masks",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"top_left_crop(top=29, left=64, height=192)",
|
79 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"top_left_crop(top=29, left=64, height=192)",
|
86 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
87 |
+
]
|
88 |
+
|
89 |
+
config.eval_slice_size = 1
|
90 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
91 |
+
|
92 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
93 |
+
targets = {"video": 3}
|
94 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
95 |
+
config.losses = ml_collections.ConfigDict({
|
96 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
97 |
+
for target in targets})
|
98 |
+
|
99 |
+
config.model = ml_collections.ConfigDict({
|
100 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
101 |
+
|
102 |
+
# Encoder.
|
103 |
+
"encoder": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
105 |
+
"reduction": "spatial_flatten",
|
106 |
+
"backbone": ml_collections.ConfigDict({
|
107 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
108 |
+
"features": [64, 64, 64, 64],
|
109 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
110 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
111 |
+
}),
|
112 |
+
"pos_emb": ml_collections.ConfigDict({
|
113 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
114 |
+
"embedding_type": "linear",
|
115 |
+
"update_type": "project_add",
|
116 |
+
"output_transform": ml_collections.ConfigDict({
|
117 |
+
"module": "invariant_slot_attention.modules.MLP",
|
118 |
+
"hidden_size": 128,
|
119 |
+
"layernorm": "pre"
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
}),
|
123 |
+
|
124 |
+
# Corrector.
|
125 |
+
"corrector": ml_collections.ConfigDict({
|
126 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
127 |
+
"num_iterations": 3,
|
128 |
+
"qkv_size": 64,
|
129 |
+
"mlp_size": 128,
|
130 |
+
}),
|
131 |
+
|
132 |
+
# Predictor.
|
133 |
+
# Removed since we are running a single frame.
|
134 |
+
"predictor": ml_collections.ConfigDict({
|
135 |
+
"module": "invariant_slot_attention.modules.Identity"
|
136 |
+
}),
|
137 |
+
|
138 |
+
# Initializer.
|
139 |
+
"initializer": ml_collections.ConfigDict({
|
140 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
141 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
142 |
+
}),
|
143 |
+
|
144 |
+
# Decoder.
|
145 |
+
"decoder": ml_collections.ConfigDict({
|
146 |
+
"module":
|
147 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
148 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
149 |
+
"backbone": ml_collections.ConfigDict({
|
150 |
+
"module": "invariant_slot_attention.modules.CNN",
|
151 |
+
"features": [64, 64, 64, 64, 64],
|
152 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
153 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
154 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
155 |
+
"layer_transpose": [True, True, True, False, False]
|
156 |
+
}),
|
157 |
+
"target_readout": ml_collections.ConfigDict({
|
158 |
+
"module": "invariant_slot_attention.modules.Readout",
|
159 |
+
"keys": list(targets),
|
160 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
161 |
+
"module": "invariant_slot_attention.modules.MLP",
|
162 |
+
"num_hidden_layers": 0,
|
163 |
+
"hidden_size": 0,
|
164 |
+
"output_size": targets[k]}) for k in targets
|
165 |
+
],
|
166 |
+
}),
|
167 |
+
"pos_emb": ml_collections.ConfigDict({
|
168 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
169 |
+
"embedding_type": "linear",
|
170 |
+
"update_type": "project_add"
|
171 |
+
}),
|
172 |
+
}),
|
173 |
+
"decode_corrected": True,
|
174 |
+
"decode_predicted": False,
|
175 |
+
})
|
176 |
+
|
177 |
+
# Which video-shaped variables to visualize.
|
178 |
+
config.debug_var_video_paths = {
|
179 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
180 |
+
}
|
181 |
+
|
182 |
+
# Define which attention matrices to log/visualize.
|
183 |
+
config.debug_var_attn_paths = {
|
184 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
185 |
+
}
|
186 |
+
|
187 |
+
# Widths of attention matrices (for reshaping to image grid).
|
188 |
+
config.debug_var_attn_widths = {
|
189 |
+
"corrector_attn": 16,
|
190 |
+
}
|
191 |
+
|
192 |
+
return config
|
193 |
+
|
194 |
+
|
invariant_slot_attention/configs/clevr_with_masks/equiv_transl.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVR."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "clevr_with_masks",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"top_left_crop(top=29, left=64, height=192)",
|
79 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"top_left_crop(top=29, left=64, height=192)",
|
86 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
87 |
+
]
|
88 |
+
|
89 |
+
config.eval_slice_size = 1
|
90 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
91 |
+
|
92 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
93 |
+
targets = {"video": 3}
|
94 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
95 |
+
config.losses = ml_collections.ConfigDict({
|
96 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
97 |
+
for target in targets})
|
98 |
+
|
99 |
+
config.model = ml_collections.ConfigDict({
|
100 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
101 |
+
|
102 |
+
# Encoder.
|
103 |
+
"encoder": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
105 |
+
"reduction": "spatial_flatten",
|
106 |
+
"backbone": ml_collections.ConfigDict({
|
107 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
108 |
+
"features": [64, 64, 64, 64],
|
109 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
110 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
111 |
+
}),
|
112 |
+
"pos_emb": ml_collections.ConfigDict({
|
113 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
114 |
+
"embedding_type": "linear",
|
115 |
+
"update_type": "concat"
|
116 |
+
}),
|
117 |
+
}),
|
118 |
+
|
119 |
+
# Corrector.
|
120 |
+
"corrector": ml_collections.ConfigDict({
|
121 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
122 |
+
"num_iterations": 3,
|
123 |
+
"qkv_size": 64,
|
124 |
+
"mlp_size": 128,
|
125 |
+
"grid_encoder": ml_collections.ConfigDict({
|
126 |
+
"module": "invariant_slot_attention.modules.MLP",
|
127 |
+
"hidden_size": 128,
|
128 |
+
"layernorm": "pre"
|
129 |
+
}),
|
130 |
+
"add_rel_pos_to_values": True, # V3
|
131 |
+
"zero_position_init": False, # Random positions.
|
132 |
+
}),
|
133 |
+
|
134 |
+
# Predictor.
|
135 |
+
# Removed since we are running a single frame.
|
136 |
+
"predictor": ml_collections.ConfigDict({
|
137 |
+
"module": "invariant_slot_attention.modules.Identity"
|
138 |
+
}),
|
139 |
+
|
140 |
+
# Initializer.
|
141 |
+
"initializer": ml_collections.ConfigDict({
|
142 |
+
"module":
|
143 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
144 |
+
"shape":
|
145 |
+
(11, 64), # (num_slots, slot_size)
|
146 |
+
}),
|
147 |
+
|
148 |
+
# Decoder.
|
149 |
+
"decoder": ml_collections.ConfigDict({
|
150 |
+
"module":
|
151 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
152 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
153 |
+
"backbone": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.CNN",
|
155 |
+
"features": [64, 64, 64, 64, 64],
|
156 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
157 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
158 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
159 |
+
"layer_transpose": [True, True, True, False, False]
|
160 |
+
}),
|
161 |
+
"target_readout": ml_collections.ConfigDict({
|
162 |
+
"module": "invariant_slot_attention.modules.Readout",
|
163 |
+
"keys": list(targets),
|
164 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
165 |
+
"module": "invariant_slot_attention.modules.MLP",
|
166 |
+
"num_hidden_layers": 0,
|
167 |
+
"hidden_size": 0,
|
168 |
+
"output_size": targets[k]}) for k in targets
|
169 |
+
],
|
170 |
+
}),
|
171 |
+
"relative_positions": True,
|
172 |
+
"pos_emb": ml_collections.ConfigDict({
|
173 |
+
"module":
|
174 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
175 |
+
"embedding_type":
|
176 |
+
"linear",
|
177 |
+
"update_type":
|
178 |
+
"project_add",
|
179 |
+
}),
|
180 |
+
}),
|
181 |
+
"decode_corrected": True,
|
182 |
+
"decode_predicted": False,
|
183 |
+
})
|
184 |
+
|
185 |
+
# Which video-shaped variables to visualize.
|
186 |
+
config.debug_var_video_paths = {
|
187 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
188 |
+
}
|
189 |
+
|
190 |
+
# Define which attention matrices to log/visualize.
|
191 |
+
config.debug_var_attn_paths = {
|
192 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
193 |
+
}
|
194 |
+
|
195 |
+
# Widths of attention matrices (for reshaping to image grid).
|
196 |
+
config.debug_var_attn_widths = {
|
197 |
+
"corrector_attn": 16,
|
198 |
+
}
|
199 |
+
|
200 |
+
return config
|
201 |
+
|
202 |
+
|
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVR."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "clevr_with_masks",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"top_left_crop(top=29, left=64, height=192)",
|
79 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"top_left_crop(top=29, left=64, height=192)",
|
86 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
87 |
+
]
|
88 |
+
|
89 |
+
config.eval_slice_size = 1
|
90 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
91 |
+
|
92 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
93 |
+
targets = {"video": 3}
|
94 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
95 |
+
config.losses = ml_collections.ConfigDict({
|
96 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
97 |
+
for target in targets})
|
98 |
+
|
99 |
+
config.model = ml_collections.ConfigDict({
|
100 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
101 |
+
|
102 |
+
# Encoder.
|
103 |
+
"encoder": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
105 |
+
"reduction": "spatial_flatten",
|
106 |
+
"backbone": ml_collections.ConfigDict({
|
107 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
108 |
+
"features": [64, 64, 64, 64],
|
109 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
110 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
111 |
+
}),
|
112 |
+
"pos_emb": ml_collections.ConfigDict({
|
113 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
114 |
+
"embedding_type": "linear",
|
115 |
+
"update_type": "concat"
|
116 |
+
}),
|
117 |
+
}),
|
118 |
+
|
119 |
+
# Corrector.
|
120 |
+
"corrector": ml_collections.ConfigDict({
|
121 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
122 |
+
"num_iterations": 3,
|
123 |
+
"qkv_size": 64,
|
124 |
+
"mlp_size": 128,
|
125 |
+
"grid_encoder": ml_collections.ConfigDict({
|
126 |
+
"module": "invariant_slot_attention.modules.MLP",
|
127 |
+
"hidden_size": 128,
|
128 |
+
"layernorm": "pre"
|
129 |
+
}),
|
130 |
+
"add_rel_pos_to_values": True, # V3
|
131 |
+
"zero_position_init": False, # Random positions.
|
132 |
+
"init_with_fixed_scale": None, # Random scales.
|
133 |
+
"scales_factor": 5.0,
|
134 |
+
}),
|
135 |
+
|
136 |
+
# Predictor.
|
137 |
+
# Removed since we are running a single frame.
|
138 |
+
"predictor": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.Identity"
|
140 |
+
}),
|
141 |
+
|
142 |
+
# Initializer.
|
143 |
+
"initializer": ml_collections.ConfigDict({
|
144 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
145 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
146 |
+
}),
|
147 |
+
|
148 |
+
# Decoder.
|
149 |
+
"decoder": ml_collections.ConfigDict({
|
150 |
+
"module":
|
151 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
152 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
153 |
+
"backbone": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.CNN",
|
155 |
+
"features": [64, 64, 64, 64, 64],
|
156 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
157 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
158 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
159 |
+
"layer_transpose": [True, True, True, False, False]
|
160 |
+
}),
|
161 |
+
"target_readout": ml_collections.ConfigDict({
|
162 |
+
"module": "invariant_slot_attention.modules.Readout",
|
163 |
+
"keys": list(targets),
|
164 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
165 |
+
"module": "invariant_slot_attention.modules.MLP",
|
166 |
+
"num_hidden_layers": 0,
|
167 |
+
"hidden_size": 0,
|
168 |
+
"output_size": targets[k]}) for k in targets],
|
169 |
+
}),
|
170 |
+
"relative_positions_rotations_and_scales": True,
|
171 |
+
"pos_emb": ml_collections.ConfigDict({
|
172 |
+
"module":
|
173 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
174 |
+
"embedding_type":
|
175 |
+
"linear",
|
176 |
+
"update_type":
|
177 |
+
"project_add",
|
178 |
+
"scales_factor":
|
179 |
+
5.0,
|
180 |
+
}),
|
181 |
+
}),
|
182 |
+
"decode_corrected": True,
|
183 |
+
"decode_predicted": False,
|
184 |
+
})
|
185 |
+
|
186 |
+
# Which video-shaped variables to visualize.
|
187 |
+
config.debug_var_video_paths = {
|
188 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
189 |
+
}
|
190 |
+
|
191 |
+
# Define which attention matrices to log/visualize.
|
192 |
+
config.debug_var_attn_paths = {
|
193 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
194 |
+
}
|
195 |
+
|
196 |
+
# Widths of attention matrices (for reshaping to image grid).
|
197 |
+
config.debug_var_attn_widths = {
|
198 |
+
"corrector_attn": 16,
|
199 |
+
}
|
200 |
+
|
201 |
+
return config
|
202 |
+
|
203 |
+
|
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVR."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "clevr_with_masks",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"top_left_crop(top=29, left=64, height=192)",
|
79 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"top_left_crop(top=29, left=64, height=192)",
|
86 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
87 |
+
]
|
88 |
+
|
89 |
+
config.eval_slice_size = 1
|
90 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
91 |
+
|
92 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
93 |
+
targets = {"video": 3}
|
94 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
95 |
+
config.losses = ml_collections.ConfigDict({
|
96 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
97 |
+
for target in targets})
|
98 |
+
|
99 |
+
config.model = ml_collections.ConfigDict({
|
100 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
101 |
+
|
102 |
+
# Encoder.
|
103 |
+
"encoder": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
105 |
+
"reduction": "spatial_flatten",
|
106 |
+
"backbone": ml_collections.ConfigDict({
|
107 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
108 |
+
"features": [64, 64, 64, 64],
|
109 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
110 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
111 |
+
}),
|
112 |
+
"pos_emb": ml_collections.ConfigDict({
|
113 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
114 |
+
"embedding_type": "linear",
|
115 |
+
"update_type": "concat"
|
116 |
+
}),
|
117 |
+
}),
|
118 |
+
|
119 |
+
# Corrector.
|
120 |
+
"corrector": ml_collections.ConfigDict({
|
121 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
122 |
+
"num_iterations": 3,
|
123 |
+
"qkv_size": 64,
|
124 |
+
"mlp_size": 128,
|
125 |
+
"grid_encoder": ml_collections.ConfigDict({
|
126 |
+
"module": "invariant_slot_attention.modules.MLP",
|
127 |
+
"hidden_size": 128,
|
128 |
+
"layernorm": "pre"
|
129 |
+
}),
|
130 |
+
"add_rel_pos_to_values": True, # V3
|
131 |
+
"zero_position_init": False, # Random positions.
|
132 |
+
"init_with_fixed_scale": None, # Random scales.
|
133 |
+
"scales_factor": 5.0,
|
134 |
+
}),
|
135 |
+
|
136 |
+
# Predictor.
|
137 |
+
# Removed since we are running a single frame.
|
138 |
+
"predictor": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.Identity"
|
140 |
+
}),
|
141 |
+
|
142 |
+
# Initializer.
|
143 |
+
"initializer": ml_collections.ConfigDict({
|
144 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
145 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
146 |
+
}),
|
147 |
+
|
148 |
+
# Decoder.
|
149 |
+
"decoder": ml_collections.ConfigDict({
|
150 |
+
"module":
|
151 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
152 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
153 |
+
"backbone": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.CNN",
|
155 |
+
"features": [64, 64, 64, 64, 64],
|
156 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
157 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
158 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
159 |
+
"layer_transpose": [True, True, True, False, False]
|
160 |
+
}),
|
161 |
+
"target_readout": ml_collections.ConfigDict({
|
162 |
+
"module": "invariant_slot_attention.modules.Readout",
|
163 |
+
"keys": list(targets),
|
164 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
165 |
+
"module": "invariant_slot_attention.modules.MLP",
|
166 |
+
"num_hidden_layers": 0,
|
167 |
+
"hidden_size": 0,
|
168 |
+
"output_size": targets[k]}) for k in targets],
|
169 |
+
}),
|
170 |
+
"relative_positions_and_scales": True,
|
171 |
+
"pos_emb": ml_collections.ConfigDict({
|
172 |
+
"module":
|
173 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
174 |
+
"embedding_type":
|
175 |
+
"linear",
|
176 |
+
"update_type":
|
177 |
+
"project_add",
|
178 |
+
"scales_factor":
|
179 |
+
5.0,
|
180 |
+
}),
|
181 |
+
}),
|
182 |
+
"decode_corrected": True,
|
183 |
+
"decode_predicted": False,
|
184 |
+
})
|
185 |
+
|
186 |
+
# Which video-shaped variables to visualize.
|
187 |
+
config.debug_var_video_paths = {
|
188 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
189 |
+
}
|
190 |
+
|
191 |
+
# Define which attention matrices to log/visualize.
|
192 |
+
config.debug_var_attn_paths = {
|
193 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
194 |
+
}
|
195 |
+
|
196 |
+
# Widths of attention matrices (for reshaping to image grid).
|
197 |
+
config.debug_var_attn_widths = {
|
198 |
+
"corrector_attn": 16,
|
199 |
+
}
|
200 |
+
|
201 |
+
return config
|
202 |
+
|
203 |
+
|
invariant_slot_attention/configs/clevrtex/resnet/baseline.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
112 |
+
"num_classes": None,
|
113 |
+
"axis_name": "time",
|
114 |
+
"norm_type": "group",
|
115 |
+
"small_inputs": True
|
116 |
+
}),
|
117 |
+
"pos_emb": ml_collections.ConfigDict({
|
118 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
119 |
+
"embedding_type": "linear",
|
120 |
+
"update_type": "project_add",
|
121 |
+
"output_transform": ml_collections.ConfigDict({
|
122 |
+
"module": "invariant_slot_attention.modules.MLP",
|
123 |
+
"hidden_size": 128,
|
124 |
+
"layernorm": "pre"
|
125 |
+
}),
|
126 |
+
}),
|
127 |
+
}),
|
128 |
+
|
129 |
+
# Corrector.
|
130 |
+
"corrector": ml_collections.ConfigDict({
|
131 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
132 |
+
"num_iterations": 3,
|
133 |
+
"qkv_size": 64,
|
134 |
+
"mlp_size": 128,
|
135 |
+
}),
|
136 |
+
|
137 |
+
# Predictor.
|
138 |
+
# Removed since we are running a single frame.
|
139 |
+
"predictor": ml_collections.ConfigDict({
|
140 |
+
"module": "invariant_slot_attention.modules.Identity"
|
141 |
+
}),
|
142 |
+
|
143 |
+
# Initializer.
|
144 |
+
"initializer": ml_collections.ConfigDict({
|
145 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
146 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
147 |
+
}),
|
148 |
+
|
149 |
+
# Decoder.
|
150 |
+
"decoder": ml_collections.ConfigDict({
|
151 |
+
"module":
|
152 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
153 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
154 |
+
"backbone": ml_collections.ConfigDict({
|
155 |
+
"module": "invariant_slot_attention.modules.CNN",
|
156 |
+
"features": [64, 64, 64, 64, 64],
|
157 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
158 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
159 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
160 |
+
"layer_transpose": [True, True, True, False, False]
|
161 |
+
}),
|
162 |
+
"target_readout": ml_collections.ConfigDict({
|
163 |
+
"module": "invariant_slot_attention.modules.Readout",
|
164 |
+
"keys": list(targets),
|
165 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
166 |
+
"module": "invariant_slot_attention.modules.MLP",
|
167 |
+
"num_hidden_layers": 0,
|
168 |
+
"hidden_size": 0,
|
169 |
+
"output_size": targets[k]}) for k in targets],
|
170 |
+
}),
|
171 |
+
"pos_emb": ml_collections.ConfigDict({
|
172 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
173 |
+
"embedding_type": "linear",
|
174 |
+
"update_type": "project_add"
|
175 |
+
}),
|
176 |
+
}),
|
177 |
+
"decode_corrected": True,
|
178 |
+
"decode_predicted": False,
|
179 |
+
})
|
180 |
+
|
181 |
+
# Which video-shaped variables to visualize.
|
182 |
+
config.debug_var_video_paths = {
|
183 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
184 |
+
}
|
185 |
+
|
186 |
+
# Define which attention matrices to log/visualize.
|
187 |
+
config.debug_var_attn_paths = {
|
188 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
189 |
+
}
|
190 |
+
|
191 |
+
# Widths of attention matrices (for reshaping to image grid).
|
192 |
+
config.debug_var_attn_widths = {
|
193 |
+
"corrector_attn": 16,
|
194 |
+
}
|
195 |
+
|
196 |
+
return config
|
197 |
+
|
198 |
+
|
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
112 |
+
"num_classes": None,
|
113 |
+
"axis_name": "time",
|
114 |
+
"norm_type": "group",
|
115 |
+
"small_inputs": True
|
116 |
+
}),
|
117 |
+
"pos_emb": ml_collections.ConfigDict({
|
118 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
119 |
+
"embedding_type": "linear",
|
120 |
+
"update_type": "concat"
|
121 |
+
}),
|
122 |
+
}),
|
123 |
+
|
124 |
+
# Corrector.
|
125 |
+
"corrector": ml_collections.ConfigDict({
|
126 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
127 |
+
"num_iterations": 3,
|
128 |
+
"qkv_size": 64,
|
129 |
+
"mlp_size": 128,
|
130 |
+
"grid_encoder": ml_collections.ConfigDict({
|
131 |
+
"module": "invariant_slot_attention.modules.MLP",
|
132 |
+
"hidden_size": 128,
|
133 |
+
"layernorm": "pre"
|
134 |
+
}),
|
135 |
+
"add_rel_pos_to_values": True, # V3
|
136 |
+
"zero_position_init": False, # Random positions.
|
137 |
+
}),
|
138 |
+
|
139 |
+
# Predictor.
|
140 |
+
# Removed since we are running a single frame.
|
141 |
+
"predictor": ml_collections.ConfigDict({
|
142 |
+
"module": "invariant_slot_attention.modules.Identity"
|
143 |
+
}),
|
144 |
+
|
145 |
+
# Initializer.
|
146 |
+
"initializer": ml_collections.ConfigDict({
|
147 |
+
"module":
|
148 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
149 |
+
"shape":
|
150 |
+
(11, 64), # (num_slots, slot_size)
|
151 |
+
}),
|
152 |
+
|
153 |
+
# Decoder.
|
154 |
+
"decoder": ml_collections.ConfigDict({
|
155 |
+
"module":
|
156 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
157 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
158 |
+
"backbone": ml_collections.ConfigDict({
|
159 |
+
"module": "invariant_slot_attention.modules.CNN",
|
160 |
+
"features": [64, 64, 64, 64, 64],
|
161 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
162 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
163 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
164 |
+
"layer_transpose": [True, True, True, False, False]
|
165 |
+
}),
|
166 |
+
"target_readout": ml_collections.ConfigDict({
|
167 |
+
"module": "invariant_slot_attention.modules.Readout",
|
168 |
+
"keys": list(targets),
|
169 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
170 |
+
"module": "invariant_slot_attention.modules.MLP",
|
171 |
+
"num_hidden_layers": 0,
|
172 |
+
"hidden_size": 0,
|
173 |
+
"output_size": targets[k]}) for k in targets],
|
174 |
+
}),
|
175 |
+
"relative_positions": True,
|
176 |
+
"pos_emb": ml_collections.ConfigDict({
|
177 |
+
"module":
|
178 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
179 |
+
"embedding_type":
|
180 |
+
"linear",
|
181 |
+
"update_type":
|
182 |
+
"project_add",
|
183 |
+
}),
|
184 |
+
}),
|
185 |
+
"decode_corrected": True,
|
186 |
+
"decode_predicted": False,
|
187 |
+
})
|
188 |
+
|
189 |
+
# Which video-shaped variables to visualize.
|
190 |
+
config.debug_var_video_paths = {
|
191 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
192 |
+
}
|
193 |
+
|
194 |
+
# Define which attention matrices to log/visualize.
|
195 |
+
config.debug_var_attn_paths = {
|
196 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
197 |
+
}
|
198 |
+
|
199 |
+
# Widths of attention matrices (for reshaping to image grid).
|
200 |
+
config.debug_var_attn_widths = {
|
201 |
+
"corrector_attn": 16,
|
202 |
+
}
|
203 |
+
|
204 |
+
return config
|
205 |
+
|
206 |
+
|
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
112 |
+
"num_classes": None,
|
113 |
+
"axis_name": "time",
|
114 |
+
"norm_type": "group",
|
115 |
+
"small_inputs": True
|
116 |
+
}),
|
117 |
+
"pos_emb": ml_collections.ConfigDict({
|
118 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
119 |
+
"embedding_type": "linear",
|
120 |
+
"update_type": "project_add",
|
121 |
+
"output_transform": ml_collections.ConfigDict({
|
122 |
+
"module": "invariant_slot_attention.modules.MLP",
|
123 |
+
"hidden_size": 128,
|
124 |
+
"layernorm": "pre"
|
125 |
+
}),
|
126 |
+
}),
|
127 |
+
}),
|
128 |
+
|
129 |
+
# Corrector.
|
130 |
+
"corrector": ml_collections.ConfigDict({
|
131 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
132 |
+
"num_iterations": 3,
|
133 |
+
"qkv_size": 64,
|
134 |
+
"mlp_size": 128,
|
135 |
+
"grid_encoder": ml_collections.ConfigDict({
|
136 |
+
"module": "invariant_slot_attention.modules.MLP",
|
137 |
+
"hidden_size": 128,
|
138 |
+
"layernorm": "pre"
|
139 |
+
}),
|
140 |
+
"add_rel_pos_to_values": True, # V3
|
141 |
+
"zero_position_init": False, # Random positions.
|
142 |
+
"init_with_fixed_scale": None, # Random scales.
|
143 |
+
"scales_factor": 5.0,
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Predictor.
|
147 |
+
# Removed since we are running a single frame.
|
148 |
+
"predictor": ml_collections.ConfigDict({
|
149 |
+
"module": "invariant_slot_attention.modules.Identity"
|
150 |
+
}),
|
151 |
+
|
152 |
+
# Initializer.
|
153 |
+
"initializer": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
155 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
156 |
+
}),
|
157 |
+
|
158 |
+
# Decoder.
|
159 |
+
"decoder": ml_collections.ConfigDict({
|
160 |
+
"module":
|
161 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
162 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
163 |
+
"backbone": ml_collections.ConfigDict({
|
164 |
+
"module": "invariant_slot_attention.modules.CNN",
|
165 |
+
"features": [64, 64, 64, 64, 64],
|
166 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
167 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
168 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
169 |
+
"layer_transpose": [True, True, True, False, False]
|
170 |
+
}),
|
171 |
+
"target_readout": ml_collections.ConfigDict({
|
172 |
+
"module": "invariant_slot_attention.modules.Readout",
|
173 |
+
"keys": list(targets),
|
174 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
175 |
+
"module": "invariant_slot_attention.modules.MLP",
|
176 |
+
"num_hidden_layers": 0,
|
177 |
+
"hidden_size": 0,
|
178 |
+
"output_size": targets[k]}) for k in targets],
|
179 |
+
}),
|
180 |
+
"relative_positions_rotations_and_scales": True,
|
181 |
+
"pos_emb": ml_collections.ConfigDict({
|
182 |
+
"module":
|
183 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
184 |
+
"embedding_type":
|
185 |
+
"linear",
|
186 |
+
"update_type":
|
187 |
+
"project_add",
|
188 |
+
"scales_factor":
|
189 |
+
5.0,
|
190 |
+
}),
|
191 |
+
}),
|
192 |
+
"decode_corrected": True,
|
193 |
+
"decode_predicted": False,
|
194 |
+
})
|
195 |
+
|
196 |
+
# Which video-shaped variables to visualize.
|
197 |
+
config.debug_var_video_paths = {
|
198 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
199 |
+
}
|
200 |
+
|
201 |
+
# Define which attention matrices to log/visualize.
|
202 |
+
config.debug_var_attn_paths = {
|
203 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
204 |
+
}
|
205 |
+
|
206 |
+
# Widths of attention matrices (for reshaping to image grid).
|
207 |
+
config.debug_var_attn_widths = {
|
208 |
+
"corrector_attn": 16,
|
209 |
+
}
|
210 |
+
|
211 |
+
return config
|
212 |
+
|
213 |
+
|
invariant_slot_attention/configs/clevrtex/resnet/equiv_transl_scale.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
112 |
+
"num_classes": None,
|
113 |
+
"axis_name": "time",
|
114 |
+
"norm_type": "group",
|
115 |
+
"small_inputs": True
|
116 |
+
}),
|
117 |
+
"pos_emb": ml_collections.ConfigDict({
|
118 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
119 |
+
"embedding_type": "linear",
|
120 |
+
"update_type": "project_add",
|
121 |
+
"output_transform": ml_collections.ConfigDict({
|
122 |
+
"module": "invariant_slot_attention.modules.MLP",
|
123 |
+
"hidden_size": 128,
|
124 |
+
"layernorm": "pre"
|
125 |
+
}),
|
126 |
+
}),
|
127 |
+
}),
|
128 |
+
|
129 |
+
# Corrector.
|
130 |
+
"corrector": ml_collections.ConfigDict({
|
131 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
132 |
+
"num_iterations": 3,
|
133 |
+
"qkv_size": 64,
|
134 |
+
"mlp_size": 128,
|
135 |
+
"grid_encoder": ml_collections.ConfigDict({
|
136 |
+
"module": "invariant_slot_attention.modules.MLP",
|
137 |
+
"hidden_size": 128,
|
138 |
+
"layernorm": "pre"
|
139 |
+
}),
|
140 |
+
"add_rel_pos_to_values": True, # V3
|
141 |
+
"zero_position_init": False, # Random positions.
|
142 |
+
"init_with_fixed_scale": None, # Random scales.
|
143 |
+
"scales_factor": 5.0,
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Predictor.
|
147 |
+
# Removed since we are running a single frame.
|
148 |
+
"predictor": ml_collections.ConfigDict({
|
149 |
+
"module": "invariant_slot_attention.modules.Identity"
|
150 |
+
}),
|
151 |
+
|
152 |
+
# Initializer.
|
153 |
+
"initializer": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
155 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
156 |
+
}),
|
157 |
+
|
158 |
+
# Decoder.
|
159 |
+
"decoder": ml_collections.ConfigDict({
|
160 |
+
"module":
|
161 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
162 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
163 |
+
"backbone": ml_collections.ConfigDict({
|
164 |
+
"module": "invariant_slot_attention.modules.CNN",
|
165 |
+
"features": [64, 64, 64, 64, 64],
|
166 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
167 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
168 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
169 |
+
"layer_transpose": [True, True, True, False, False]
|
170 |
+
}),
|
171 |
+
"target_readout": ml_collections.ConfigDict({
|
172 |
+
"module": "invariant_slot_attention.modules.Readout",
|
173 |
+
"keys": list(targets),
|
174 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
175 |
+
"module": "invariant_slot_attention.modules.MLP",
|
176 |
+
"num_hidden_layers": 0,
|
177 |
+
"hidden_size": 0,
|
178 |
+
"output_size": targets[k]}) for k in targets],
|
179 |
+
}),
|
180 |
+
"relative_positions_and_scales": True,
|
181 |
+
"pos_emb": ml_collections.ConfigDict({
|
182 |
+
"module":
|
183 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
184 |
+
"embedding_type":
|
185 |
+
"linear",
|
186 |
+
"update_type":
|
187 |
+
"project_add",
|
188 |
+
"scales_factor":
|
189 |
+
5.0,
|
190 |
+
}),
|
191 |
+
}),
|
192 |
+
"decode_corrected": True,
|
193 |
+
"decode_predicted": False,
|
194 |
+
})
|
195 |
+
|
196 |
+
# Which video-shaped variables to visualize.
|
197 |
+
config.debug_var_video_paths = {
|
198 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
199 |
+
}
|
200 |
+
|
201 |
+
# Define which attention matrices to log/visualize.
|
202 |
+
config.debug_var_attn_paths = {
|
203 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
204 |
+
}
|
205 |
+
|
206 |
+
# Widths of attention matrices (for reshaping to image grid).
|
207 |
+
config.debug_var_attn_widths = {
|
208 |
+
"corrector_attn": 16,
|
209 |
+
}
|
210 |
+
|
211 |
+
return config
|
212 |
+
|
213 |
+
|
invariant_slot_attention/configs/clevrtex/simplecnn/baseline.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
112 |
+
"features": [64, 64, 64, 64],
|
113 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
114 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
115 |
+
}),
|
116 |
+
"pos_emb": ml_collections.ConfigDict({
|
117 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
118 |
+
"embedding_type": "linear",
|
119 |
+
"update_type": "project_add",
|
120 |
+
"output_transform": ml_collections.ConfigDict({
|
121 |
+
"module": "invariant_slot_attention.modules.MLP",
|
122 |
+
"hidden_size": 128,
|
123 |
+
"layernorm": "pre"
|
124 |
+
}),
|
125 |
+
}),
|
126 |
+
}),
|
127 |
+
|
128 |
+
# Corrector.
|
129 |
+
"corrector": ml_collections.ConfigDict({
|
130 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
131 |
+
"num_iterations": 3,
|
132 |
+
"qkv_size": 64,
|
133 |
+
"mlp_size": 128,
|
134 |
+
}),
|
135 |
+
|
136 |
+
# Predictor.
|
137 |
+
# Removed since we are running a single frame.
|
138 |
+
"predictor": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.Identity"
|
140 |
+
}),
|
141 |
+
|
142 |
+
# Initializer.
|
143 |
+
"initializer": ml_collections.ConfigDict({
|
144 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
145 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
146 |
+
}),
|
147 |
+
|
148 |
+
# Decoder.
|
149 |
+
"decoder": ml_collections.ConfigDict({
|
150 |
+
"module":
|
151 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
152 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
153 |
+
"backbone": ml_collections.ConfigDict({
|
154 |
+
"module": "invariant_slot_attention.modules.CNN",
|
155 |
+
"features": [64, 64, 64, 64, 64],
|
156 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
157 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
158 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
159 |
+
"layer_transpose": [True, True, True, False, False]
|
160 |
+
}),
|
161 |
+
"target_readout": ml_collections.ConfigDict({
|
162 |
+
"module": "invariant_slot_attention.modules.Readout",
|
163 |
+
"keys": list(targets),
|
164 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
165 |
+
"module": "invariant_slot_attention.modules.MLP",
|
166 |
+
"num_hidden_layers": 0,
|
167 |
+
"hidden_size": 0,
|
168 |
+
"output_size": targets[k]}) for k in targets],
|
169 |
+
}),
|
170 |
+
"pos_emb": ml_collections.ConfigDict({
|
171 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
172 |
+
"embedding_type": "linear",
|
173 |
+
"update_type": "project_add"
|
174 |
+
}),
|
175 |
+
}),
|
176 |
+
"decode_corrected": True,
|
177 |
+
"decode_predicted": False,
|
178 |
+
})
|
179 |
+
|
180 |
+
# Which video-shaped variables to visualize.
|
181 |
+
config.debug_var_video_paths = {
|
182 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
183 |
+
}
|
184 |
+
|
185 |
+
# Define which attention matrices to log/visualize.
|
186 |
+
config.debug_var_attn_paths = {
|
187 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
188 |
+
}
|
189 |
+
|
190 |
+
# Widths of attention matrices (for reshaping to image grid).
|
191 |
+
config.debug_var_attn_widths = {
|
192 |
+
"corrector_attn": 16,
|
193 |
+
}
|
194 |
+
|
195 |
+
return config
|
196 |
+
|
197 |
+
|
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
112 |
+
"features": [64, 64, 64, 64],
|
113 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
114 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
115 |
+
}),
|
116 |
+
"pos_emb": ml_collections.ConfigDict({
|
117 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
118 |
+
"embedding_type": "linear",
|
119 |
+
"update_type": "concat"
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
|
123 |
+
# Corrector.
|
124 |
+
"corrector": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
126 |
+
"num_iterations": 3,
|
127 |
+
"qkv_size": 64,
|
128 |
+
"mlp_size": 128,
|
129 |
+
"grid_encoder": ml_collections.ConfigDict({
|
130 |
+
"module": "invariant_slot_attention.modules.MLP",
|
131 |
+
"hidden_size": 128,
|
132 |
+
"layernorm": "pre"
|
133 |
+
}),
|
134 |
+
"add_rel_pos_to_values": True, # V3
|
135 |
+
"zero_position_init": False, # Random positions.
|
136 |
+
}),
|
137 |
+
|
138 |
+
# Predictor.
|
139 |
+
# Removed since we are running a single frame.
|
140 |
+
"predictor": ml_collections.ConfigDict({
|
141 |
+
"module": "invariant_slot_attention.modules.Identity"
|
142 |
+
}),
|
143 |
+
|
144 |
+
# Initializer.
|
145 |
+
"initializer": ml_collections.ConfigDict({
|
146 |
+
"module":
|
147 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
148 |
+
"shape":
|
149 |
+
(11, 64), # (num_slots, slot_size)
|
150 |
+
}),
|
151 |
+
|
152 |
+
# Decoder.
|
153 |
+
"decoder": ml_collections.ConfigDict({
|
154 |
+
"module":
|
155 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
156 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
157 |
+
"backbone": ml_collections.ConfigDict({
|
158 |
+
"module": "invariant_slot_attention.modules.CNN",
|
159 |
+
"features": [64, 64, 64, 64, 64],
|
160 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
161 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
162 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
163 |
+
"layer_transpose": [True, True, True, False, False]
|
164 |
+
}),
|
165 |
+
"target_readout": ml_collections.ConfigDict({
|
166 |
+
"module": "invariant_slot_attention.modules.Readout",
|
167 |
+
"keys": list(targets),
|
168 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
169 |
+
"module": "invariant_slot_attention.modules.MLP",
|
170 |
+
"num_hidden_layers": 0,
|
171 |
+
"hidden_size": 0,
|
172 |
+
"output_size": targets[k]}) for k in targets],
|
173 |
+
}),
|
174 |
+
"relative_positions": True,
|
175 |
+
"pos_emb": ml_collections.ConfigDict({
|
176 |
+
"module":
|
177 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
178 |
+
"embedding_type":
|
179 |
+
"linear",
|
180 |
+
"update_type":
|
181 |
+
"project_add",
|
182 |
+
}),
|
183 |
+
}),
|
184 |
+
"decode_corrected": True,
|
185 |
+
"decode_predicted": False,
|
186 |
+
})
|
187 |
+
|
188 |
+
# Which video-shaped variables to visualize.
|
189 |
+
config.debug_var_video_paths = {
|
190 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
191 |
+
}
|
192 |
+
|
193 |
+
# Define which attention matrices to log/visualize.
|
194 |
+
config.debug_var_attn_paths = {
|
195 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
196 |
+
}
|
197 |
+
|
198 |
+
# Widths of attention matrices (for reshaping to image grid).
|
199 |
+
config.debug_var_attn_widths = {
|
200 |
+
"corrector_attn": 16,
|
201 |
+
}
|
202 |
+
|
203 |
+
return config
|
204 |
+
|
205 |
+
|
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
112 |
+
"features": [64, 64, 64, 64],
|
113 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
114 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
115 |
+
}),
|
116 |
+
"pos_emb": ml_collections.ConfigDict({
|
117 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
118 |
+
"embedding_type": "linear",
|
119 |
+
"update_type": "concat"
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
|
123 |
+
# Corrector.
|
124 |
+
"corrector": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
126 |
+
"num_iterations": 3,
|
127 |
+
"qkv_size": 64,
|
128 |
+
"mlp_size": 128,
|
129 |
+
"grid_encoder": ml_collections.ConfigDict({
|
130 |
+
"module": "invariant_slot_attention.modules.MLP",
|
131 |
+
"hidden_size": 128,
|
132 |
+
"layernorm": "pre"
|
133 |
+
}),
|
134 |
+
"add_rel_pos_to_values": True, # V3
|
135 |
+
"zero_position_init": False, # Random positions.
|
136 |
+
"init_with_fixed_scale": None, # Random scales.
|
137 |
+
"scales_factor": 5.0,
|
138 |
+
}),
|
139 |
+
|
140 |
+
# Predictor.
|
141 |
+
# Removed since we are running a single frame.
|
142 |
+
"predictor": ml_collections.ConfigDict({
|
143 |
+
"module": "invariant_slot_attention.modules.Identity"
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Initializer.
|
147 |
+
"initializer": ml_collections.ConfigDict({
|
148 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
149 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
150 |
+
}),
|
151 |
+
|
152 |
+
# Decoder.
|
153 |
+
"decoder": ml_collections.ConfigDict({
|
154 |
+
"module":
|
155 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
156 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
157 |
+
"backbone": ml_collections.ConfigDict({
|
158 |
+
"module": "invariant_slot_attention.modules.CNN",
|
159 |
+
"features": [64, 64, 64, 64, 64],
|
160 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
161 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
162 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
163 |
+
"layer_transpose": [True, True, True, False, False]
|
164 |
+
}),
|
165 |
+
"target_readout": ml_collections.ConfigDict({
|
166 |
+
"module": "invariant_slot_attention.modules.Readout",
|
167 |
+
"keys": list(targets),
|
168 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
169 |
+
"module": "invariant_slot_attention.modules.MLP",
|
170 |
+
"num_hidden_layers": 0,
|
171 |
+
"hidden_size": 0,
|
172 |
+
"output_size": targets[k]}) for k in targets],
|
173 |
+
}),
|
174 |
+
"relative_positions_rotations_and_scales": True,
|
175 |
+
"pos_emb": ml_collections.ConfigDict({
|
176 |
+
"module":
|
177 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
178 |
+
"embedding_type":
|
179 |
+
"linear",
|
180 |
+
"update_type":
|
181 |
+
"project_add",
|
182 |
+
"scales_factor":
|
183 |
+
5.0,
|
184 |
+
}),
|
185 |
+
}),
|
186 |
+
"decode_corrected": True,
|
187 |
+
"decode_predicted": False,
|
188 |
+
})
|
189 |
+
|
190 |
+
# Which video-shaped variables to visualize.
|
191 |
+
config.debug_var_video_paths = {
|
192 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
193 |
+
}
|
194 |
+
|
195 |
+
# Define which attention matrices to log/visualize.
|
196 |
+
config.debug_var_attn_paths = {
|
197 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
198 |
+
}
|
199 |
+
|
200 |
+
# Widths of attention matrices (for reshaping to image grid).
|
201 |
+
config.debug_var_attn_widths = {
|
202 |
+
"corrector_attn": 16,
|
203 |
+
}
|
204 |
+
|
205 |
+
return config
|
206 |
+
|
207 |
+
|
invariant_slot_attention/configs/clevrtex/simplecnn/equiv_transl_scale.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on CLEVRTex."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "tfds",
|
67 |
+
# The TFDS dataset will be created in the directory below
|
68 |
+
# if you follow the README in datasets/clevrtex.
|
69 |
+
"data_dir": "~/tensorflow_datasets",
|
70 |
+
"tfds_name": "clevr_tex",
|
71 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
72 |
+
"resolution": (128, 128)
|
73 |
+
})
|
74 |
+
|
75 |
+
config.max_instances = 11
|
76 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
77 |
+
config.logging_min_n_colors = config.max_instances
|
78 |
+
|
79 |
+
config.preproc_train = [
|
80 |
+
"tfds_image_to_tfds_video",
|
81 |
+
"video_from_tfds",
|
82 |
+
"central_crop(height=192,width=192)",
|
83 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
84 |
+
]
|
85 |
+
|
86 |
+
config.preproc_eval = [
|
87 |
+
"tfds_image_to_tfds_video",
|
88 |
+
"video_from_tfds",
|
89 |
+
"central_crop(height=192,width=192)",
|
90 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
91 |
+
]
|
92 |
+
|
93 |
+
config.eval_slice_size = 1
|
94 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
95 |
+
|
96 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
97 |
+
targets = {"video": 3}
|
98 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
99 |
+
config.losses = ml_collections.ConfigDict({
|
100 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
101 |
+
for target in targets})
|
102 |
+
|
103 |
+
config.model = ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
105 |
+
|
106 |
+
# Encoder.
|
107 |
+
"encoder": ml_collections.ConfigDict({
|
108 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
109 |
+
"reduction": "spatial_flatten",
|
110 |
+
"backbone": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
112 |
+
"features": [64, 64, 64, 64],
|
113 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
114 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
115 |
+
}),
|
116 |
+
"pos_emb": ml_collections.ConfigDict({
|
117 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
118 |
+
"embedding_type": "linear",
|
119 |
+
"update_type": "concat"
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
|
123 |
+
# Corrector.
|
124 |
+
"corrector": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
126 |
+
"num_iterations": 3,
|
127 |
+
"qkv_size": 64,
|
128 |
+
"mlp_size": 128,
|
129 |
+
"grid_encoder": ml_collections.ConfigDict({
|
130 |
+
"module": "invariant_slot_attention.modules.MLP",
|
131 |
+
"hidden_size": 128,
|
132 |
+
"layernorm": "pre"
|
133 |
+
}),
|
134 |
+
"add_rel_pos_to_values": True, # V3
|
135 |
+
"zero_position_init": False, # Random positions.
|
136 |
+
"init_with_fixed_scale": None, # Random scales.
|
137 |
+
"scales_factor": 5.0,
|
138 |
+
}),
|
139 |
+
|
140 |
+
# Predictor.
|
141 |
+
# Removed since we are running a single frame.
|
142 |
+
"predictor": ml_collections.ConfigDict({
|
143 |
+
"module": "invariant_slot_attention.modules.Identity"
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Initializer.
|
147 |
+
"initializer": ml_collections.ConfigDict({
|
148 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
149 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
150 |
+
}),
|
151 |
+
|
152 |
+
# Decoder.
|
153 |
+
"decoder": ml_collections.ConfigDict({
|
154 |
+
"module":
|
155 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
156 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
157 |
+
"backbone": ml_collections.ConfigDict({
|
158 |
+
"module": "invariant_slot_attention.modules.CNN",
|
159 |
+
"features": [64, 64, 64, 64, 64],
|
160 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
161 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
162 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
163 |
+
"layer_transpose": [True, True, True, False, False]
|
164 |
+
}),
|
165 |
+
"target_readout": ml_collections.ConfigDict({
|
166 |
+
"module": "invariant_slot_attention.modules.Readout",
|
167 |
+
"keys": list(targets),
|
168 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
169 |
+
"module": "invariant_slot_attention.modules.MLP",
|
170 |
+
"num_hidden_layers": 0,
|
171 |
+
"hidden_size": 0,
|
172 |
+
"output_size": targets[k]}) for k in targets],
|
173 |
+
}),
|
174 |
+
"relative_positions_and_scales": True,
|
175 |
+
"pos_emb": ml_collections.ConfigDict({
|
176 |
+
"module":
|
177 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
178 |
+
"embedding_type":
|
179 |
+
"linear",
|
180 |
+
"update_type":
|
181 |
+
"project_add",
|
182 |
+
"scales_factor":
|
183 |
+
5.0,
|
184 |
+
}),
|
185 |
+
}),
|
186 |
+
"decode_corrected": True,
|
187 |
+
"decode_predicted": False,
|
188 |
+
})
|
189 |
+
|
190 |
+
# Which video-shaped variables to visualize.
|
191 |
+
config.debug_var_video_paths = {
|
192 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
193 |
+
}
|
194 |
+
|
195 |
+
# Define which attention matrices to log/visualize.
|
196 |
+
config.debug_var_attn_paths = {
|
197 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
198 |
+
}
|
199 |
+
|
200 |
+
# Widths of attention matrices (for reshaping to image grid).
|
201 |
+
config.debug_var_attn_widths = {
|
202 |
+
"corrector_attn": 16,
|
203 |
+
}
|
204 |
+
|
205 |
+
return config
|
206 |
+
|
207 |
+
|
invariant_slot_attention/configs/multishapenet_easy/baseline.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on MultiShapeNet-Easy."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "multishapenet_easy",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"sunds_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"subtract_one_from_segmentations",
|
79 |
+
"central_crop(height=240, width=240)",
|
80 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
81 |
+
]
|
82 |
+
|
83 |
+
config.preproc_eval = [
|
84 |
+
"sunds_to_tfds_video",
|
85 |
+
"video_from_tfds",
|
86 |
+
"subtract_one_from_segmentations",
|
87 |
+
"central_crop(height=240, width=240)",
|
88 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
89 |
+
]
|
90 |
+
|
91 |
+
config.eval_slice_size = 1
|
92 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
93 |
+
|
94 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
95 |
+
targets = {"video": 3}
|
96 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
97 |
+
config.losses = ml_collections.ConfigDict({
|
98 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
99 |
+
for target in targets})
|
100 |
+
|
101 |
+
config.model = ml_collections.ConfigDict({
|
102 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
103 |
+
|
104 |
+
# Encoder.
|
105 |
+
"encoder": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
107 |
+
"reduction": "spatial_flatten",
|
108 |
+
"backbone": ml_collections.ConfigDict({
|
109 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
110 |
+
"features": [64, 64, 64, 64],
|
111 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
112 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
113 |
+
}),
|
114 |
+
"pos_emb": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
116 |
+
"embedding_type": "linear",
|
117 |
+
"update_type": "project_add",
|
118 |
+
"output_transform": ml_collections.ConfigDict({
|
119 |
+
"module": "invariant_slot_attention.modules.MLP",
|
120 |
+
"hidden_size": 128,
|
121 |
+
"layernorm": "pre"
|
122 |
+
}),
|
123 |
+
}),
|
124 |
+
}),
|
125 |
+
|
126 |
+
# Corrector.
|
127 |
+
"corrector": ml_collections.ConfigDict({
|
128 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
129 |
+
"num_iterations": 3,
|
130 |
+
"qkv_size": 64,
|
131 |
+
"mlp_size": 128,
|
132 |
+
}),
|
133 |
+
|
134 |
+
# Predictor.
|
135 |
+
# Removed since we are running a single frame.
|
136 |
+
"predictor": ml_collections.ConfigDict({
|
137 |
+
"module": "invariant_slot_attention.modules.Identity"
|
138 |
+
}),
|
139 |
+
|
140 |
+
# Initializer.
|
141 |
+
"initializer": ml_collections.ConfigDict({
|
142 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
143 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Decoder.
|
147 |
+
"decoder": ml_collections.ConfigDict({
|
148 |
+
"module":
|
149 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
150 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
151 |
+
"backbone": ml_collections.ConfigDict({
|
152 |
+
"module": "invariant_slot_attention.modules.CNN",
|
153 |
+
"features": [64, 64, 64, 64, 64],
|
154 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
155 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
156 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
157 |
+
"layer_transpose": [True, True, True, False, False]
|
158 |
+
}),
|
159 |
+
"target_readout": ml_collections.ConfigDict({
|
160 |
+
"module": "invariant_slot_attention.modules.Readout",
|
161 |
+
"keys": list(targets),
|
162 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
163 |
+
"module": "invariant_slot_attention.modules.MLP",
|
164 |
+
"num_hidden_layers": 0,
|
165 |
+
"hidden_size": 0,
|
166 |
+
"output_size": targets[k]}) for k in targets],
|
167 |
+
}),
|
168 |
+
"pos_emb": ml_collections.ConfigDict({
|
169 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
170 |
+
"embedding_type": "linear",
|
171 |
+
"update_type": "project_add"
|
172 |
+
}),
|
173 |
+
}),
|
174 |
+
"decode_corrected": True,
|
175 |
+
"decode_predicted": False,
|
176 |
+
})
|
177 |
+
|
178 |
+
# Which video-shaped variables to visualize.
|
179 |
+
config.debug_var_video_paths = {
|
180 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
181 |
+
}
|
182 |
+
|
183 |
+
# Define which attention matrices to log/visualize.
|
184 |
+
config.debug_var_attn_paths = {
|
185 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
186 |
+
}
|
187 |
+
|
188 |
+
# Widths of attention matrices (for reshaping to image grid).
|
189 |
+
config.debug_var_attn_widths = {
|
190 |
+
"corrector_attn": 16,
|
191 |
+
}
|
192 |
+
|
193 |
+
return config
|
194 |
+
|
195 |
+
|
invariant_slot_attention/configs/multishapenet_easy/equiv_transl.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on MultiShapeNet-Easy."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "multishapenet_easy",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"sunds_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"subtract_one_from_segmentations",
|
79 |
+
"central_crop(height=240, width=240)",
|
80 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
81 |
+
]
|
82 |
+
|
83 |
+
config.preproc_eval = [
|
84 |
+
"sunds_to_tfds_video",
|
85 |
+
"video_from_tfds",
|
86 |
+
"subtract_one_from_segmentations",
|
87 |
+
"central_crop(height=240, width=240)",
|
88 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
89 |
+
]
|
90 |
+
|
91 |
+
config.eval_slice_size = 1
|
92 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
93 |
+
|
94 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
95 |
+
targets = {"video": 3}
|
96 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
97 |
+
config.losses = ml_collections.ConfigDict({
|
98 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
99 |
+
for target in targets})
|
100 |
+
|
101 |
+
config.model = ml_collections.ConfigDict({
|
102 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
103 |
+
|
104 |
+
# Encoder.
|
105 |
+
"encoder": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
107 |
+
"reduction": "spatial_flatten",
|
108 |
+
"backbone": ml_collections.ConfigDict({
|
109 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
110 |
+
"features": [64, 64, 64, 64],
|
111 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
112 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
113 |
+
}),
|
114 |
+
"pos_emb": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
116 |
+
"embedding_type": "linear",
|
117 |
+
"update_type": "concat"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
|
121 |
+
# Corrector.
|
122 |
+
"corrector": ml_collections.ConfigDict({
|
123 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
124 |
+
"num_iterations": 3,
|
125 |
+
"qkv_size": 64,
|
126 |
+
"mlp_size": 128,
|
127 |
+
"grid_encoder": ml_collections.ConfigDict({
|
128 |
+
"module": "invariant_slot_attention.modules.MLP",
|
129 |
+
"hidden_size": 128,
|
130 |
+
"layernorm": "pre"
|
131 |
+
}),
|
132 |
+
"add_rel_pos_to_values": True, # V3
|
133 |
+
"zero_position_init": False, # Random positions.
|
134 |
+
}),
|
135 |
+
|
136 |
+
# Predictor.
|
137 |
+
# Removed since we are running a single frame.
|
138 |
+
"predictor": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.Identity"
|
140 |
+
}),
|
141 |
+
|
142 |
+
# Initializer.
|
143 |
+
"initializer": ml_collections.ConfigDict({
|
144 |
+
"module":
|
145 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
146 |
+
"shape":
|
147 |
+
(11, 64), # (num_slots, slot_size)
|
148 |
+
}),
|
149 |
+
|
150 |
+
# Decoder.
|
151 |
+
"decoder": ml_collections.ConfigDict({
|
152 |
+
"module":
|
153 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
154 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
155 |
+
"backbone": ml_collections.ConfigDict({
|
156 |
+
"module": "invariant_slot_attention.modules.CNN",
|
157 |
+
"features": [64, 64, 64, 64, 64],
|
158 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
159 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
160 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
161 |
+
"layer_transpose": [True, True, True, False, False]
|
162 |
+
}),
|
163 |
+
"target_readout": ml_collections.ConfigDict({
|
164 |
+
"module": "invariant_slot_attention.modules.Readout",
|
165 |
+
"keys": list(targets),
|
166 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
167 |
+
"module": "invariant_slot_attention.modules.MLP",
|
168 |
+
"num_hidden_layers": 0,
|
169 |
+
"hidden_size": 0,
|
170 |
+
"output_size": targets[k]}) for k in targets],
|
171 |
+
}),
|
172 |
+
"relative_positions": True,
|
173 |
+
"pos_emb": ml_collections.ConfigDict({
|
174 |
+
"module":
|
175 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
176 |
+
"embedding_type":
|
177 |
+
"linear",
|
178 |
+
"update_type":
|
179 |
+
"project_add",
|
180 |
+
}),
|
181 |
+
}),
|
182 |
+
"decode_corrected": True,
|
183 |
+
"decode_predicted": False,
|
184 |
+
})
|
185 |
+
|
186 |
+
# Which video-shaped variables to visualize.
|
187 |
+
config.debug_var_video_paths = {
|
188 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
189 |
+
}
|
190 |
+
|
191 |
+
# Define which attention matrices to log/visualize.
|
192 |
+
config.debug_var_attn_paths = {
|
193 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
194 |
+
}
|
195 |
+
|
196 |
+
# Widths of attention matrices (for reshaping to image grid).
|
197 |
+
config.debug_var_attn_widths = {
|
198 |
+
"corrector_attn": 16,
|
199 |
+
}
|
200 |
+
|
201 |
+
return config
|
202 |
+
|
203 |
+
|
invariant_slot_attention/configs/multishapenet_easy/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on MultiShapeNet-Easy."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "multishapenet_easy",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"sunds_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"subtract_one_from_segmentations",
|
79 |
+
"central_crop(height=240, width=240)",
|
80 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
81 |
+
]
|
82 |
+
|
83 |
+
config.preproc_eval = [
|
84 |
+
"sunds_to_tfds_video",
|
85 |
+
"video_from_tfds",
|
86 |
+
"subtract_one_from_segmentations",
|
87 |
+
"central_crop(height=240, width=240)",
|
88 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
89 |
+
]
|
90 |
+
|
91 |
+
config.eval_slice_size = 1
|
92 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
93 |
+
|
94 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
95 |
+
targets = {"video": 3}
|
96 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
97 |
+
config.losses = ml_collections.ConfigDict({
|
98 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
99 |
+
for target in targets})
|
100 |
+
|
101 |
+
config.model = ml_collections.ConfigDict({
|
102 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
103 |
+
|
104 |
+
# Encoder.
|
105 |
+
"encoder": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
107 |
+
"reduction": "spatial_flatten",
|
108 |
+
"backbone": ml_collections.ConfigDict({
|
109 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
110 |
+
"features": [64, 64, 64, 64],
|
111 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
112 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
113 |
+
}),
|
114 |
+
"pos_emb": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
116 |
+
"embedding_type": "linear",
|
117 |
+
"update_type": "concat"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
|
121 |
+
# Corrector.
|
122 |
+
"corrector": ml_collections.ConfigDict({
|
123 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
124 |
+
"num_iterations": 3,
|
125 |
+
"qkv_size": 64,
|
126 |
+
"mlp_size": 128,
|
127 |
+
"grid_encoder": ml_collections.ConfigDict({
|
128 |
+
"module": "invariant_slot_attention.modules.MLP",
|
129 |
+
"hidden_size": 128,
|
130 |
+
"layernorm": "pre"
|
131 |
+
}),
|
132 |
+
"add_rel_pos_to_values": True, # V3
|
133 |
+
"zero_position_init": False, # Random positions.
|
134 |
+
"init_with_fixed_scale": None, # Random scales.
|
135 |
+
"scales_factor": 5.0,
|
136 |
+
}),
|
137 |
+
|
138 |
+
# Predictor.
|
139 |
+
# Removed since we are running a single frame.
|
140 |
+
"predictor": ml_collections.ConfigDict({
|
141 |
+
"module": "invariant_slot_attention.modules.Identity"
|
142 |
+
}),
|
143 |
+
|
144 |
+
# Initializer.
|
145 |
+
"initializer": ml_collections.ConfigDict({
|
146 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
147 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
148 |
+
}),
|
149 |
+
|
150 |
+
# Decoder.
|
151 |
+
"decoder": ml_collections.ConfigDict({
|
152 |
+
"module":
|
153 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
154 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
155 |
+
"backbone": ml_collections.ConfigDict({
|
156 |
+
"module": "invariant_slot_attention.modules.CNN",
|
157 |
+
"features": [64, 64, 64, 64, 64],
|
158 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
159 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
160 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
161 |
+
"layer_transpose": [True, True, True, False, False]
|
162 |
+
}),
|
163 |
+
"target_readout": ml_collections.ConfigDict({
|
164 |
+
"module": "invariant_slot_attention.modules.Readout",
|
165 |
+
"keys": list(targets),
|
166 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
167 |
+
"module": "invariant_slot_attention.modules.MLP",
|
168 |
+
"num_hidden_layers": 0,
|
169 |
+
"hidden_size": 0,
|
170 |
+
"output_size": targets[k]}) for k in targets],
|
171 |
+
}),
|
172 |
+
"relative_positions_rotations_and_scales": True,
|
173 |
+
"pos_emb": ml_collections.ConfigDict({
|
174 |
+
"module":
|
175 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
176 |
+
"embedding_type":
|
177 |
+
"linear",
|
178 |
+
"update_type":
|
179 |
+
"project_add",
|
180 |
+
"scales_factor":
|
181 |
+
5.0,
|
182 |
+
}),
|
183 |
+
}),
|
184 |
+
"decode_corrected": True,
|
185 |
+
"decode_predicted": False,
|
186 |
+
})
|
187 |
+
|
188 |
+
# Which video-shaped variables to visualize.
|
189 |
+
config.debug_var_video_paths = {
|
190 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
191 |
+
}
|
192 |
+
|
193 |
+
# Define which attention matrices to log/visualize.
|
194 |
+
config.debug_var_attn_paths = {
|
195 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
196 |
+
}
|
197 |
+
|
198 |
+
# Widths of attention matrices (for reshaping to image grid).
|
199 |
+
config.debug_var_attn_widths = {
|
200 |
+
"corrector_attn": 16,
|
201 |
+
}
|
202 |
+
|
203 |
+
return config
|
204 |
+
|
205 |
+
|
invariant_slot_attention/configs/multishapenet_easy/equiv_transl_scale.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on MultiShapeNet-Easy."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "multishapenet_easy",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 128)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"sunds_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
"subtract_one_from_segmentations",
|
79 |
+
"central_crop(height=240, width=240)",
|
80 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
81 |
+
]
|
82 |
+
|
83 |
+
config.preproc_eval = [
|
84 |
+
"sunds_to_tfds_video",
|
85 |
+
"video_from_tfds",
|
86 |
+
"subtract_one_from_segmentations",
|
87 |
+
"central_crop(height=240, width=240)",
|
88 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
89 |
+
]
|
90 |
+
|
91 |
+
config.eval_slice_size = 1
|
92 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
93 |
+
|
94 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
95 |
+
targets = {"video": 3}
|
96 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
97 |
+
config.losses = ml_collections.ConfigDict({
|
98 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
99 |
+
for target in targets})
|
100 |
+
|
101 |
+
config.model = ml_collections.ConfigDict({
|
102 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
103 |
+
|
104 |
+
# Encoder.
|
105 |
+
"encoder": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
107 |
+
"reduction": "spatial_flatten",
|
108 |
+
"backbone": ml_collections.ConfigDict({
|
109 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
110 |
+
"features": [64, 64, 64, 64],
|
111 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
112 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
113 |
+
}),
|
114 |
+
"pos_emb": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
116 |
+
"embedding_type": "linear",
|
117 |
+
"update_type": "concat"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
|
121 |
+
# Corrector.
|
122 |
+
"corrector": ml_collections.ConfigDict({
|
123 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
124 |
+
"num_iterations": 3,
|
125 |
+
"qkv_size": 64,
|
126 |
+
"mlp_size": 128,
|
127 |
+
"grid_encoder": ml_collections.ConfigDict({
|
128 |
+
"module": "invariant_slot_attention.modules.MLP",
|
129 |
+
"hidden_size": 128,
|
130 |
+
"layernorm": "pre"
|
131 |
+
}),
|
132 |
+
"add_rel_pos_to_values": True, # V3
|
133 |
+
"zero_position_init": False, # Random positions.
|
134 |
+
"init_with_fixed_scale": None, # Random scales.
|
135 |
+
"scales_factor": 5.0,
|
136 |
+
}),
|
137 |
+
|
138 |
+
# Predictor.
|
139 |
+
# Removed since we are running a single frame.
|
140 |
+
"predictor": ml_collections.ConfigDict({
|
141 |
+
"module": "invariant_slot_attention.modules.Identity"
|
142 |
+
}),
|
143 |
+
|
144 |
+
# Initializer.
|
145 |
+
"initializer": ml_collections.ConfigDict({
|
146 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
147 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
148 |
+
}),
|
149 |
+
|
150 |
+
# Decoder.
|
151 |
+
"decoder": ml_collections.ConfigDict({
|
152 |
+
"module":
|
153 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
154 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
155 |
+
"backbone": ml_collections.ConfigDict({
|
156 |
+
"module": "invariant_slot_attention.modules.CNN",
|
157 |
+
"features": [64, 64, 64, 64, 64],
|
158 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
159 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
160 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
161 |
+
"layer_transpose": [True, True, True, False, False]
|
162 |
+
}),
|
163 |
+
"target_readout": ml_collections.ConfigDict({
|
164 |
+
"module": "invariant_slot_attention.modules.Readout",
|
165 |
+
"keys": list(targets),
|
166 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
167 |
+
"module": "invariant_slot_attention.modules.MLP",
|
168 |
+
"num_hidden_layers": 0,
|
169 |
+
"hidden_size": 0,
|
170 |
+
"output_size": targets[k]}) for k in targets],
|
171 |
+
}),
|
172 |
+
"relative_positions_and_scales": True,
|
173 |
+
"pos_emb": ml_collections.ConfigDict({
|
174 |
+
"module":
|
175 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
176 |
+
"embedding_type":
|
177 |
+
"linear",
|
178 |
+
"update_type":
|
179 |
+
"project_add",
|
180 |
+
"scales_factor":
|
181 |
+
5.0,
|
182 |
+
}),
|
183 |
+
}),
|
184 |
+
"decode_corrected": True,
|
185 |
+
"decode_predicted": False,
|
186 |
+
})
|
187 |
+
|
188 |
+
# Which video-shaped variables to visualize.
|
189 |
+
config.debug_var_video_paths = {
|
190 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
191 |
+
}
|
192 |
+
|
193 |
+
# Define which attention matrices to log/visualize.
|
194 |
+
config.debug_var_attn_paths = {
|
195 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
196 |
+
}
|
197 |
+
|
198 |
+
# Widths of attention matrices (for reshaping to image grid).
|
199 |
+
config.debug_var_attn_widths = {
|
200 |
+
"corrector_attn": 16,
|
201 |
+
}
|
202 |
+
|
203 |
+
return config
|
204 |
+
|
205 |
+
|
invariant_slot_attention/configs/objects_room/baseline.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on objects_room."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 5000
|
53 |
+
config.checkpoint_every_steps = 5000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "objects_room",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (64, 64)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 11
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "project_add",
|
115 |
+
"output_transform": ml_collections.ConfigDict({
|
116 |
+
"module": "invariant_slot_attention.modules.MLP",
|
117 |
+
"hidden_size": 128,
|
118 |
+
"layernorm": "pre"
|
119 |
+
}),
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
|
123 |
+
# Corrector.
|
124 |
+
"corrector": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
126 |
+
"num_iterations": 3,
|
127 |
+
"qkv_size": 64,
|
128 |
+
"mlp_size": 128,
|
129 |
+
}),
|
130 |
+
|
131 |
+
# Predictor.
|
132 |
+
# Removed since we are running a single frame.
|
133 |
+
"predictor": ml_collections.ConfigDict({
|
134 |
+
"module": "invariant_slot_attention.modules.Identity"
|
135 |
+
}),
|
136 |
+
|
137 |
+
# Initializer.
|
138 |
+
"initializer": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
140 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
141 |
+
}),
|
142 |
+
|
143 |
+
# Decoder.
|
144 |
+
"decoder": ml_collections.ConfigDict({
|
145 |
+
"module":
|
146 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
147 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
148 |
+
"backbone": ml_collections.ConfigDict({
|
149 |
+
"module": "invariant_slot_attention.modules.CNN",
|
150 |
+
"features": [64, 64, 64, 64, 64],
|
151 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
152 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
|
153 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
154 |
+
"layer_transpose": [True, True, False, False, False]
|
155 |
+
}),
|
156 |
+
"target_readout": ml_collections.ConfigDict({
|
157 |
+
"module": "invariant_slot_attention.modules.Readout",
|
158 |
+
"keys": list(targets),
|
159 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
160 |
+
"module": "invariant_slot_attention.modules.MLP",
|
161 |
+
"num_hidden_layers": 0,
|
162 |
+
"hidden_size": 0,
|
163 |
+
"output_size": targets[k]}) for k in targets],
|
164 |
+
}),
|
165 |
+
"pos_emb": ml_collections.ConfigDict({
|
166 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
167 |
+
"embedding_type": "linear",
|
168 |
+
"update_type": "project_add"
|
169 |
+
}),
|
170 |
+
}),
|
171 |
+
"decode_corrected": True,
|
172 |
+
"decode_predicted": False,
|
173 |
+
})
|
174 |
+
|
175 |
+
# Which video-shaped variables to visualize.
|
176 |
+
config.debug_var_video_paths = {
|
177 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
178 |
+
}
|
179 |
+
|
180 |
+
# Define which attention matrices to log/visualize.
|
181 |
+
config.debug_var_attn_paths = {
|
182 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
183 |
+
}
|
184 |
+
|
185 |
+
# Widths of attention matrices (for reshaping to image grid).
|
186 |
+
config.debug_var_attn_widths = {
|
187 |
+
"corrector_attn": 16,
|
188 |
+
}
|
189 |
+
|
190 |
+
return config
|
191 |
+
|
192 |
+
|
invariant_slot_attention/configs/objects_room/equiv_transl.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on objects_room."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 5000
|
53 |
+
config.checkpoint_every_steps = 5000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "objects_room",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (64, 64)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 11
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "concat"
|
115 |
+
}),
|
116 |
+
}),
|
117 |
+
|
118 |
+
# Corrector.
|
119 |
+
"corrector": ml_collections.ConfigDict({
|
120 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
121 |
+
"num_iterations": 3,
|
122 |
+
"qkv_size": 64,
|
123 |
+
"mlp_size": 128,
|
124 |
+
"grid_encoder": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.MLP",
|
126 |
+
"hidden_size": 128,
|
127 |
+
"layernorm": "pre"
|
128 |
+
}),
|
129 |
+
"add_rel_pos_to_values": True, # V3
|
130 |
+
"zero_position_init": False, # Random positions.
|
131 |
+
}),
|
132 |
+
|
133 |
+
# Predictor.
|
134 |
+
# Removed since we are running a single frame.
|
135 |
+
"predictor": ml_collections.ConfigDict({
|
136 |
+
"module": "invariant_slot_attention.modules.Identity"
|
137 |
+
}),
|
138 |
+
|
139 |
+
# Initializer.
|
140 |
+
"initializer": ml_collections.ConfigDict({
|
141 |
+
"module":
|
142 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
143 |
+
"shape":
|
144 |
+
(11, 64), # (num_slots, slot_size)
|
145 |
+
}),
|
146 |
+
|
147 |
+
# Decoder.
|
148 |
+
"decoder": ml_collections.ConfigDict({
|
149 |
+
"module":
|
150 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
151 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
152 |
+
"backbone": ml_collections.ConfigDict({
|
153 |
+
"module": "invariant_slot_attention.modules.CNN",
|
154 |
+
"features": [64, 64, 64, 64, 64],
|
155 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
156 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
|
157 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
158 |
+
"layer_transpose": [True, True, False, False, False]
|
159 |
+
}),
|
160 |
+
"target_readout": ml_collections.ConfigDict({
|
161 |
+
"module": "invariant_slot_attention.modules.Readout",
|
162 |
+
"keys": list(targets),
|
163 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
164 |
+
"module": "invariant_slot_attention.modules.MLP",
|
165 |
+
"num_hidden_layers": 0,
|
166 |
+
"hidden_size": 0,
|
167 |
+
"output_size": targets[k]}) for k in targets],
|
168 |
+
}),
|
169 |
+
"relative_positions": True,
|
170 |
+
"pos_emb": ml_collections.ConfigDict({
|
171 |
+
"module":
|
172 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
173 |
+
"embedding_type":
|
174 |
+
"linear",
|
175 |
+
"update_type":
|
176 |
+
"project_add",
|
177 |
+
}),
|
178 |
+
}),
|
179 |
+
"decode_corrected": True,
|
180 |
+
"decode_predicted": False,
|
181 |
+
})
|
182 |
+
|
183 |
+
# Which video-shaped variables to visualize.
|
184 |
+
config.debug_var_video_paths = {
|
185 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
186 |
+
}
|
187 |
+
|
188 |
+
# Define which attention matrices to log/visualize.
|
189 |
+
config.debug_var_attn_paths = {
|
190 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
191 |
+
}
|
192 |
+
|
193 |
+
# Widths of attention matrices (for reshaping to image grid).
|
194 |
+
config.debug_var_attn_widths = {
|
195 |
+
"corrector_attn": 16,
|
196 |
+
}
|
197 |
+
|
198 |
+
return config
|
199 |
+
|
200 |
+
|
invariant_slot_attention/configs/objects_room/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on objects_room."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 5000
|
53 |
+
config.checkpoint_every_steps = 5000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "objects_room",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (64, 64)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 11
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "concat"
|
115 |
+
}),
|
116 |
+
}),
|
117 |
+
|
118 |
+
# Corrector.
|
119 |
+
"corrector": ml_collections.ConfigDict({
|
120 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
121 |
+
"num_iterations": 3,
|
122 |
+
"qkv_size": 64,
|
123 |
+
"mlp_size": 128,
|
124 |
+
"grid_encoder": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.MLP",
|
126 |
+
"hidden_size": 128,
|
127 |
+
"layernorm": "pre"
|
128 |
+
}),
|
129 |
+
"add_rel_pos_to_values": True, # V3
|
130 |
+
"zero_position_init": False, # Random positions.
|
131 |
+
"init_with_fixed_scale": None, # Random scales.
|
132 |
+
"scales_factor": 5.0,
|
133 |
+
}),
|
134 |
+
|
135 |
+
# Predictor.
|
136 |
+
# Removed since we are running a single frame.
|
137 |
+
"predictor": ml_collections.ConfigDict({
|
138 |
+
"module": "invariant_slot_attention.modules.Identity"
|
139 |
+
}),
|
140 |
+
|
141 |
+
# Initializer.
|
142 |
+
"initializer": ml_collections.ConfigDict({
|
143 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
144 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
145 |
+
}),
|
146 |
+
|
147 |
+
# Decoder.
|
148 |
+
"decoder": ml_collections.ConfigDict({
|
149 |
+
"module":
|
150 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
151 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
152 |
+
"backbone": ml_collections.ConfigDict({
|
153 |
+
"module": "invariant_slot_attention.modules.CNN",
|
154 |
+
"features": [64, 64, 64, 64, 64],
|
155 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
156 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
|
157 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
158 |
+
"layer_transpose": [True, True, False, False, False]
|
159 |
+
}),
|
160 |
+
"target_readout": ml_collections.ConfigDict({
|
161 |
+
"module": "invariant_slot_attention.modules.Readout",
|
162 |
+
"keys": list(targets),
|
163 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
164 |
+
"module": "invariant_slot_attention.modules.MLP",
|
165 |
+
"num_hidden_layers": 0,
|
166 |
+
"hidden_size": 0,
|
167 |
+
"output_size": targets[k]}) for k in targets],
|
168 |
+
}),
|
169 |
+
"relative_positions_rotations_and_scales": True,
|
170 |
+
"pos_emb": ml_collections.ConfigDict({
|
171 |
+
"module":
|
172 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
173 |
+
"embedding_type":
|
174 |
+
"linear",
|
175 |
+
"update_type":
|
176 |
+
"project_add",
|
177 |
+
"scales_factor":
|
178 |
+
5.0,
|
179 |
+
}),
|
180 |
+
}),
|
181 |
+
"decode_corrected": True,
|
182 |
+
"decode_predicted": False,
|
183 |
+
})
|
184 |
+
|
185 |
+
# Which video-shaped variables to visualize.
|
186 |
+
config.debug_var_video_paths = {
|
187 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
188 |
+
}
|
189 |
+
|
190 |
+
# Define which attention matrices to log/visualize.
|
191 |
+
config.debug_var_attn_paths = {
|
192 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
193 |
+
}
|
194 |
+
|
195 |
+
# Widths of attention matrices (for reshaping to image grid).
|
196 |
+
config.debug_var_attn_widths = {
|
197 |
+
"corrector_attn": 16,
|
198 |
+
}
|
199 |
+
|
200 |
+
return config
|
201 |
+
|
202 |
+
|
invariant_slot_attention/configs/objects_room/equiv_transl_scale.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on objects_room."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 5000
|
53 |
+
config.checkpoint_every_steps = 5000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "objects_room",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (64, 64)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 11
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=10)",
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "concat"
|
115 |
+
}),
|
116 |
+
}),
|
117 |
+
|
118 |
+
# Corrector.
|
119 |
+
"corrector": ml_collections.ConfigDict({
|
120 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
121 |
+
"num_iterations": 3,
|
122 |
+
"qkv_size": 64,
|
123 |
+
"mlp_size": 128,
|
124 |
+
"grid_encoder": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.MLP",
|
126 |
+
"hidden_size": 128,
|
127 |
+
"layernorm": "pre"
|
128 |
+
}),
|
129 |
+
"add_rel_pos_to_values": True, # V3
|
130 |
+
"zero_position_init": False, # Random positions.
|
131 |
+
"init_with_fixed_scale": None, # Random scales.
|
132 |
+
"scales_factor": 5.0,
|
133 |
+
}),
|
134 |
+
|
135 |
+
# Predictor.
|
136 |
+
# Removed since we are running a single frame.
|
137 |
+
"predictor": ml_collections.ConfigDict({
|
138 |
+
"module": "invariant_slot_attention.modules.Identity"
|
139 |
+
}),
|
140 |
+
|
141 |
+
# Initializer.
|
142 |
+
"initializer": ml_collections.ConfigDict({
|
143 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
144 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
145 |
+
}),
|
146 |
+
|
147 |
+
# Decoder.
|
148 |
+
"decoder": ml_collections.ConfigDict({
|
149 |
+
"module":
|
150 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
151 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
152 |
+
"backbone": ml_collections.ConfigDict({
|
153 |
+
"module": "invariant_slot_attention.modules.CNN",
|
154 |
+
"features": [64, 64, 64, 64, 64],
|
155 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
156 |
+
"strides": [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1)],
|
157 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
158 |
+
"layer_transpose": [True, True, False, False, False]
|
159 |
+
}),
|
160 |
+
"target_readout": ml_collections.ConfigDict({
|
161 |
+
"module": "invariant_slot_attention.modules.Readout",
|
162 |
+
"keys": list(targets),
|
163 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
164 |
+
"module": "invariant_slot_attention.modules.MLP",
|
165 |
+
"num_hidden_layers": 0,
|
166 |
+
"hidden_size": 0,
|
167 |
+
"output_size": targets[k]}) for k in targets],
|
168 |
+
}),
|
169 |
+
"relative_positions_and_scales": True,
|
170 |
+
"pos_emb": ml_collections.ConfigDict({
|
171 |
+
"module":
|
172 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
173 |
+
"embedding_type":
|
174 |
+
"linear",
|
175 |
+
"update_type":
|
176 |
+
"project_add",
|
177 |
+
"scales_factor":
|
178 |
+
5.0,
|
179 |
+
}),
|
180 |
+
}),
|
181 |
+
"decode_corrected": True,
|
182 |
+
"decode_predicted": False,
|
183 |
+
})
|
184 |
+
|
185 |
+
# Which video-shaped variables to visualize.
|
186 |
+
config.debug_var_video_paths = {
|
187 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
188 |
+
}
|
189 |
+
|
190 |
+
# Define which attention matrices to log/visualize.
|
191 |
+
config.debug_var_attn_paths = {
|
192 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
193 |
+
}
|
194 |
+
|
195 |
+
# Widths of attention matrices (for reshaping to image grid).
|
196 |
+
config.debug_var_attn_widths = {
|
197 |
+
"corrector_attn": 16,
|
198 |
+
}
|
199 |
+
|
200 |
+
return config
|
201 |
+
|
202 |
+
|
invariant_slot_attention/configs/tetrominoes/baseline.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Tetrominoes with 512 train samples."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 20000
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 5000
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 1000
|
53 |
+
config.checkpoint_every_steps = 1000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "tetrominoes",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (35, 35)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 4
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=3)"
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=3)"
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(1, 1), (1, 1), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "project_add",
|
115 |
+
"output_transform": ml_collections.ConfigDict({
|
116 |
+
"module": "invariant_slot_attention.modules.MLP",
|
117 |
+
"hidden_size": 128,
|
118 |
+
"layernorm": "pre"
|
119 |
+
}),
|
120 |
+
}),
|
121 |
+
}),
|
122 |
+
|
123 |
+
# Corrector.
|
124 |
+
"corrector": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
126 |
+
"num_iterations": 3,
|
127 |
+
"qkv_size": 64,
|
128 |
+
"mlp_size": 128,
|
129 |
+
}),
|
130 |
+
|
131 |
+
# Predictor.
|
132 |
+
# Removed since we are running a single frame.
|
133 |
+
"predictor": ml_collections.ConfigDict({
|
134 |
+
"module": "invariant_slot_attention.modules.Identity"
|
135 |
+
}),
|
136 |
+
|
137 |
+
# Initializer.
|
138 |
+
"initializer": ml_collections.ConfigDict({
|
139 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
140 |
+
"shape": (4, 64), # (num_slots, slot_size)
|
141 |
+
}),
|
142 |
+
|
143 |
+
# Decoder.
|
144 |
+
"decoder": ml_collections.ConfigDict({
|
145 |
+
"module":
|
146 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
147 |
+
"resolution": (35, 35), # Update if data resolution or strides change
|
148 |
+
"backbone": ml_collections.ConfigDict({
|
149 |
+
"module": "invariant_slot_attention.modules.MLP",
|
150 |
+
"hidden_size": 256,
|
151 |
+
"output_size": 256,
|
152 |
+
"num_hidden_layers": 5,
|
153 |
+
"activate_output": True
|
154 |
+
}),
|
155 |
+
"target_readout": ml_collections.ConfigDict({
|
156 |
+
"module": "invariant_slot_attention.modules.Readout",
|
157 |
+
"keys": list(targets),
|
158 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
159 |
+
"module": "invariant_slot_attention.modules.MLP",
|
160 |
+
"num_hidden_layers": 0,
|
161 |
+
"hidden_size": 0,
|
162 |
+
"output_size": targets[k]}) for k in targets],
|
163 |
+
}),
|
164 |
+
"pos_emb": ml_collections.ConfigDict({
|
165 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
166 |
+
"embedding_type": "linear",
|
167 |
+
"update_type": "project_add"
|
168 |
+
}),
|
169 |
+
}),
|
170 |
+
"decode_corrected": True,
|
171 |
+
"decode_predicted": False,
|
172 |
+
})
|
173 |
+
|
174 |
+
# Which video-shaped variables to visualize.
|
175 |
+
config.debug_var_video_paths = {
|
176 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
177 |
+
}
|
178 |
+
|
179 |
+
# Define which attention matrices to log/visualize.
|
180 |
+
config.debug_var_attn_paths = {
|
181 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
182 |
+
}
|
183 |
+
|
184 |
+
# Widths of attention matrices (for reshaping to image grid).
|
185 |
+
config.debug_var_attn_widths = {
|
186 |
+
"corrector_attn": 35,
|
187 |
+
}
|
188 |
+
|
189 |
+
return config
|
190 |
+
|
191 |
+
|
invariant_slot_attention/configs/tetrominoes/equiv_transl.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Tetrominoes with 512 train samples."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 20000
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 5000
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
# TODO(obvis): Implement masked evaluation.
|
50 |
+
config.eval_pad_last_batch = False # True
|
51 |
+
config.log_loss_every_steps = 50
|
52 |
+
config.eval_every_steps = 1000
|
53 |
+
config.checkpoint_every_steps = 1000
|
54 |
+
|
55 |
+
config.train_metrics_spec = {
|
56 |
+
"loss": "loss",
|
57 |
+
"ari": "ari",
|
58 |
+
"ari_nobg": "ari_nobg",
|
59 |
+
}
|
60 |
+
config.eval_metrics_spec = {
|
61 |
+
"eval_loss": "loss",
|
62 |
+
"eval_ari": "ari",
|
63 |
+
"eval_ari_nobg": "ari_nobg",
|
64 |
+
}
|
65 |
+
|
66 |
+
config.data = ml_collections.ConfigDict({
|
67 |
+
"dataset_name": "tetrominoes",
|
68 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
69 |
+
"resolution": (35, 35)
|
70 |
+
})
|
71 |
+
|
72 |
+
config.max_instances = 4
|
73 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
74 |
+
config.logging_min_n_colors = config.max_instances
|
75 |
+
|
76 |
+
config.preproc_train = [
|
77 |
+
"tfds_image_to_tfds_video",
|
78 |
+
"video_from_tfds",
|
79 |
+
"sparse_to_dense_annotation(max_instances=3)"
|
80 |
+
]
|
81 |
+
|
82 |
+
config.preproc_eval = [
|
83 |
+
"tfds_image_to_tfds_video",
|
84 |
+
"video_from_tfds",
|
85 |
+
"sparse_to_dense_annotation(max_instances=3)"
|
86 |
+
]
|
87 |
+
|
88 |
+
config.eval_slice_size = 1
|
89 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
90 |
+
|
91 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
92 |
+
targets = {"video": 3}
|
93 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
94 |
+
config.losses = ml_collections.ConfigDict({
|
95 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
96 |
+
for target in targets})
|
97 |
+
|
98 |
+
config.model = ml_collections.ConfigDict({
|
99 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
100 |
+
|
101 |
+
# Encoder.
|
102 |
+
"encoder": ml_collections.ConfigDict({
|
103 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
104 |
+
"reduction": "spatial_flatten",
|
105 |
+
"backbone": ml_collections.ConfigDict({
|
106 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
107 |
+
"features": [64, 64, 64, 64],
|
108 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
109 |
+
"strides": [(1, 1), (1, 1), (1, 1), (1, 1)]
|
110 |
+
}),
|
111 |
+
"pos_emb": ml_collections.ConfigDict({
|
112 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
113 |
+
"embedding_type": "linear",
|
114 |
+
"update_type": "concat"
|
115 |
+
}),
|
116 |
+
}),
|
117 |
+
|
118 |
+
# Corrector.
|
119 |
+
"corrector": ml_collections.ConfigDict({
|
120 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
121 |
+
"num_iterations": 3,
|
122 |
+
"qkv_size": 64,
|
123 |
+
"mlp_size": 128,
|
124 |
+
"grid_encoder": ml_collections.ConfigDict({
|
125 |
+
"module": "invariant_slot_attention.modules.MLP",
|
126 |
+
"hidden_size": 128,
|
127 |
+
"layernorm": "pre"
|
128 |
+
}),
|
129 |
+
"add_rel_pos_to_values": True, # V3
|
130 |
+
"zero_position_init": False, # Random positions.
|
131 |
+
}),
|
132 |
+
|
133 |
+
# Predictor.
|
134 |
+
# Removed since we are running a single frame.
|
135 |
+
"predictor": ml_collections.ConfigDict({
|
136 |
+
"module": "invariant_slot_attention.modules.Identity"
|
137 |
+
}),
|
138 |
+
|
139 |
+
# Initializer.
|
140 |
+
"initializer": ml_collections.ConfigDict({
|
141 |
+
"module":
|
142 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
143 |
+
"shape":
|
144 |
+
(4, 64), # (num_slots, slot_size)
|
145 |
+
}),
|
146 |
+
|
147 |
+
# Decoder.
|
148 |
+
"decoder": ml_collections.ConfigDict({
|
149 |
+
"module":
|
150 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
151 |
+
"resolution": (35, 35), # Update if data resolution or strides change
|
152 |
+
"backbone": ml_collections.ConfigDict({
|
153 |
+
"module": "invariant_slot_attention.modules.MLP",
|
154 |
+
"hidden_size": 256,
|
155 |
+
"output_size": 256,
|
156 |
+
"num_hidden_layers": 5,
|
157 |
+
"activate_output": True
|
158 |
+
}),
|
159 |
+
"target_readout": ml_collections.ConfigDict({
|
160 |
+
"module": "invariant_slot_attention.modules.Readout",
|
161 |
+
"keys": list(targets),
|
162 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
163 |
+
"module": "invariant_slot_attention.modules.MLP",
|
164 |
+
"num_hidden_layers": 0,
|
165 |
+
"hidden_size": 0,
|
166 |
+
"output_size": targets[k]}) for k in targets],
|
167 |
+
}),
|
168 |
+
"relative_positions": True,
|
169 |
+
"pos_emb": ml_collections.ConfigDict({
|
170 |
+
"module":
|
171 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
172 |
+
"embedding_type":
|
173 |
+
"linear",
|
174 |
+
"update_type":
|
175 |
+
"project_add",
|
176 |
+
}),
|
177 |
+
}),
|
178 |
+
"decode_corrected": True,
|
179 |
+
"decode_predicted": False,
|
180 |
+
})
|
181 |
+
|
182 |
+
# Which video-shaped variables to visualize.
|
183 |
+
config.debug_var_video_paths = {
|
184 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
185 |
+
}
|
186 |
+
|
187 |
+
# Define which attention matrices to log/visualize.
|
188 |
+
config.debug_var_attn_paths = {
|
189 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
190 |
+
}
|
191 |
+
|
192 |
+
# Widths of attention matrices (for reshaping to image grid).
|
193 |
+
config.debug_var_attn_widths = {
|
194 |
+
"corrector_attn": 35,
|
195 |
+
}
|
196 |
+
|
197 |
+
return config
|
198 |
+
|
199 |
+
|
invariant_slot_attention/configs/waymo_open/baseline.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Waymo Open."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "waymo_open",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 192)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
]
|
79 |
+
|
80 |
+
config.preproc_eval = [
|
81 |
+
"tfds_image_to_tfds_video",
|
82 |
+
"video_from_tfds",
|
83 |
+
"delete_small_masks(threshold=0.01, max_instances_after=11)",
|
84 |
+
]
|
85 |
+
|
86 |
+
config.eval_slice_size = 1
|
87 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
88 |
+
|
89 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
90 |
+
targets = {"video": 3}
|
91 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
92 |
+
config.losses = ml_collections.ConfigDict({
|
93 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
94 |
+
for target in targets})
|
95 |
+
|
96 |
+
config.model = ml_collections.ConfigDict({
|
97 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
98 |
+
|
99 |
+
# Encoder.
|
100 |
+
"encoder": ml_collections.ConfigDict({
|
101 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
102 |
+
"reduction": "spatial_flatten",
|
103 |
+
"backbone": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
105 |
+
"num_classes": None,
|
106 |
+
"axis_name": "time",
|
107 |
+
"norm_type": "group",
|
108 |
+
"small_inputs": True
|
109 |
+
}),
|
110 |
+
"pos_emb": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
112 |
+
"embedding_type": "linear",
|
113 |
+
"update_type": "project_add",
|
114 |
+
"output_transform": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.MLP",
|
116 |
+
"hidden_size": 128,
|
117 |
+
"layernorm": "pre"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
}),
|
121 |
+
|
122 |
+
# Corrector.
|
123 |
+
"corrector": ml_collections.ConfigDict({
|
124 |
+
"module": "invariant_slot_attention.modules.SlotAttention",
|
125 |
+
"num_iterations": 3,
|
126 |
+
"qkv_size": 64,
|
127 |
+
"mlp_size": 128,
|
128 |
+
}),
|
129 |
+
|
130 |
+
# Predictor.
|
131 |
+
# Removed since we are running a single frame.
|
132 |
+
"predictor": ml_collections.ConfigDict({
|
133 |
+
"module": "invariant_slot_attention.modules.Identity"
|
134 |
+
}),
|
135 |
+
|
136 |
+
# Initializer.
|
137 |
+
"initializer": ml_collections.ConfigDict({
|
138 |
+
"module": "invariant_slot_attention.modules.ParamStateInit",
|
139 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
140 |
+
}),
|
141 |
+
|
142 |
+
# Decoder.
|
143 |
+
"decoder": ml_collections.ConfigDict({
|
144 |
+
"module":
|
145 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
146 |
+
"resolution": (16, 24), # Update if data resolution or strides change
|
147 |
+
"backbone": ml_collections.ConfigDict({
|
148 |
+
"module": "invariant_slot_attention.modules.CNN",
|
149 |
+
"features": [64, 64, 64, 64, 64],
|
150 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
151 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
152 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
153 |
+
"layer_transpose": [True, True, True, False, False]
|
154 |
+
}),
|
155 |
+
"target_readout": ml_collections.ConfigDict({
|
156 |
+
"module": "invariant_slot_attention.modules.Readout",
|
157 |
+
"keys": list(targets),
|
158 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
159 |
+
"module": "invariant_slot_attention.modules.MLP",
|
160 |
+
"num_hidden_layers": 0,
|
161 |
+
"hidden_size": 0,
|
162 |
+
"output_size": targets[k]}) for k in targets],
|
163 |
+
}),
|
164 |
+
"pos_emb": ml_collections.ConfigDict({
|
165 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
166 |
+
"embedding_type": "linear",
|
167 |
+
"update_type": "project_add"
|
168 |
+
}),
|
169 |
+
}),
|
170 |
+
"decode_corrected": True,
|
171 |
+
"decode_predicted": False,
|
172 |
+
})
|
173 |
+
|
174 |
+
# Which video-shaped variables to visualize.
|
175 |
+
config.debug_var_video_paths = {
|
176 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
177 |
+
}
|
178 |
+
|
179 |
+
# Define which attention matrices to log/visualize.
|
180 |
+
config.debug_var_attn_paths = {
|
181 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
182 |
+
}
|
183 |
+
|
184 |
+
# Widths of attention matrices (for reshaping to image grid).
|
185 |
+
config.debug_var_attn_widths = {
|
186 |
+
"corrector_attn": 16,
|
187 |
+
}
|
188 |
+
|
189 |
+
return config
|
190 |
+
|
191 |
+
|
invariant_slot_attention/configs/waymo_open/equiv_transl.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Waymo Open."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "waymo_open",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 192)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
]
|
79 |
+
|
80 |
+
config.preproc_eval = [
|
81 |
+
"tfds_image_to_tfds_video",
|
82 |
+
"video_from_tfds",
|
83 |
+
"delete_small_masks(threshold=0.01, max_instances_after=11)",
|
84 |
+
]
|
85 |
+
|
86 |
+
config.eval_slice_size = 1
|
87 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
88 |
+
|
89 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
90 |
+
targets = {"video": 3}
|
91 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
92 |
+
config.losses = ml_collections.ConfigDict({
|
93 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
94 |
+
for target in targets})
|
95 |
+
|
96 |
+
config.model = ml_collections.ConfigDict({
|
97 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
98 |
+
|
99 |
+
# Encoder.
|
100 |
+
"encoder": ml_collections.ConfigDict({
|
101 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
102 |
+
"reduction": "spatial_flatten",
|
103 |
+
"backbone": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
105 |
+
"num_classes": None,
|
106 |
+
"axis_name": "time",
|
107 |
+
"norm_type": "group",
|
108 |
+
"small_inputs": True
|
109 |
+
}),
|
110 |
+
"pos_emb": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
112 |
+
"embedding_type": "linear",
|
113 |
+
"update_type": "concat"
|
114 |
+
}),
|
115 |
+
}),
|
116 |
+
|
117 |
+
# Corrector.
|
118 |
+
"corrector": ml_collections.ConfigDict({
|
119 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslEquiv",
|
120 |
+
"num_iterations": 3,
|
121 |
+
"qkv_size": 64,
|
122 |
+
"mlp_size": 128,
|
123 |
+
"grid_encoder": ml_collections.ConfigDict({
|
124 |
+
"module": "invariant_slot_attention.modules.MLP",
|
125 |
+
"hidden_size": 128,
|
126 |
+
"layernorm": "pre"
|
127 |
+
}),
|
128 |
+
"add_rel_pos_to_values": True, # V3
|
129 |
+
"zero_position_init": False, # Random positions.
|
130 |
+
}),
|
131 |
+
|
132 |
+
# Predictor.
|
133 |
+
# Removed since we are running a single frame.
|
134 |
+
"predictor": ml_collections.ConfigDict({
|
135 |
+
"module": "invariant_slot_attention.modules.Identity"
|
136 |
+
}),
|
137 |
+
|
138 |
+
# Initializer.
|
139 |
+
"initializer": ml_collections.ConfigDict({
|
140 |
+
"module":
|
141 |
+
"invariant_slot_attention.modules.ParamStateInitRandomPositions",
|
142 |
+
"shape":
|
143 |
+
(11, 64), # (num_slots, slot_size)
|
144 |
+
}),
|
145 |
+
|
146 |
+
# Decoder.
|
147 |
+
"decoder": ml_collections.ConfigDict({
|
148 |
+
"module":
|
149 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
150 |
+
"resolution": (16, 24), # Update if data resolution or strides change
|
151 |
+
"backbone": ml_collections.ConfigDict({
|
152 |
+
"module": "invariant_slot_attention.modules.CNN",
|
153 |
+
"features": [64, 64, 64, 64, 64],
|
154 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
155 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
156 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
157 |
+
"layer_transpose": [True, True, True, False, False]
|
158 |
+
}),
|
159 |
+
"target_readout": ml_collections.ConfigDict({
|
160 |
+
"module": "invariant_slot_attention.modules.Readout",
|
161 |
+
"keys": list(targets),
|
162 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
163 |
+
"module": "invariant_slot_attention.modules.MLP",
|
164 |
+
"num_hidden_layers": 0,
|
165 |
+
"hidden_size": 0,
|
166 |
+
"output_size": targets[k]}) for k in targets],
|
167 |
+
}),
|
168 |
+
"relative_positions": True,
|
169 |
+
"pos_emb": ml_collections.ConfigDict({
|
170 |
+
"module":
|
171 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
172 |
+
"embedding_type":
|
173 |
+
"linear",
|
174 |
+
"update_type":
|
175 |
+
"project_add",
|
176 |
+
}),
|
177 |
+
}),
|
178 |
+
"decode_corrected": True,
|
179 |
+
"decode_predicted": False,
|
180 |
+
})
|
181 |
+
|
182 |
+
# Which video-shaped variables to visualize.
|
183 |
+
config.debug_var_video_paths = {
|
184 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
185 |
+
}
|
186 |
+
|
187 |
+
# Define which attention matrices to log/visualize.
|
188 |
+
config.debug_var_attn_paths = {
|
189 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
190 |
+
}
|
191 |
+
|
192 |
+
# Widths of attention matrices (for reshaping to image grid).
|
193 |
+
config.debug_var_attn_widths = {
|
194 |
+
"corrector_attn": 16,
|
195 |
+
}
|
196 |
+
|
197 |
+
return config
|
198 |
+
|
199 |
+
|
invariant_slot_attention/configs/waymo_open/equiv_transl_rot_scale.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Waymo Open."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "waymo_open",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 192)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
]
|
79 |
+
|
80 |
+
config.preproc_eval = [
|
81 |
+
"tfds_image_to_tfds_video",
|
82 |
+
"video_from_tfds",
|
83 |
+
"delete_small_masks(threshold=0.01, max_instances_after=11)",
|
84 |
+
]
|
85 |
+
|
86 |
+
config.eval_slice_size = 1
|
87 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
88 |
+
|
89 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
90 |
+
targets = {"video": 3}
|
91 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
92 |
+
config.losses = ml_collections.ConfigDict({
|
93 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
94 |
+
for target in targets})
|
95 |
+
|
96 |
+
config.model = ml_collections.ConfigDict({
|
97 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
98 |
+
|
99 |
+
# Encoder.
|
100 |
+
"encoder": ml_collections.ConfigDict({
|
101 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
102 |
+
"reduction": "spatial_flatten",
|
103 |
+
"backbone": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
105 |
+
"num_classes": None,
|
106 |
+
"axis_name": "time",
|
107 |
+
"norm_type": "group",
|
108 |
+
"small_inputs": True
|
109 |
+
}),
|
110 |
+
"pos_emb": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
112 |
+
"embedding_type": "linear",
|
113 |
+
"update_type": "project_add",
|
114 |
+
"output_transform": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.MLP",
|
116 |
+
"hidden_size": 128,
|
117 |
+
"layernorm": "pre"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
}),
|
121 |
+
|
122 |
+
# Corrector.
|
123 |
+
"corrector": ml_collections.ConfigDict({
|
124 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslRotScaleEquiv", # pylint: disable=line-too-long
|
125 |
+
"num_iterations": 3,
|
126 |
+
"qkv_size": 64,
|
127 |
+
"mlp_size": 128,
|
128 |
+
"grid_encoder": ml_collections.ConfigDict({
|
129 |
+
"module": "invariant_slot_attention.modules.MLP",
|
130 |
+
"hidden_size": 128,
|
131 |
+
"layernorm": "pre"
|
132 |
+
}),
|
133 |
+
"add_rel_pos_to_values": True, # V3
|
134 |
+
"zero_position_init": False, # Random positions.
|
135 |
+
"init_with_fixed_scale": None, # Random scales.
|
136 |
+
"scales_factor": 5.0,
|
137 |
+
}),
|
138 |
+
|
139 |
+
# Predictor.
|
140 |
+
# Removed since we are running a single frame.
|
141 |
+
"predictor": ml_collections.ConfigDict({
|
142 |
+
"module": "invariant_slot_attention.modules.Identity"
|
143 |
+
}),
|
144 |
+
|
145 |
+
# Initializer.
|
146 |
+
"initializer": ml_collections.ConfigDict({
|
147 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsRotationsScales", # pylint: disable=line-too-long
|
148 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
149 |
+
}),
|
150 |
+
|
151 |
+
# Decoder.
|
152 |
+
"decoder": ml_collections.ConfigDict({
|
153 |
+
"module":
|
154 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
155 |
+
"resolution": (16, 24), # Update if data resolution or strides change
|
156 |
+
"backbone": ml_collections.ConfigDict({
|
157 |
+
"module": "invariant_slot_attention.modules.CNN",
|
158 |
+
"features": [64, 64, 64, 64, 64],
|
159 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
160 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
161 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
162 |
+
"layer_transpose": [True, True, True, False, False]
|
163 |
+
}),
|
164 |
+
"target_readout": ml_collections.ConfigDict({
|
165 |
+
"module": "invariant_slot_attention.modules.Readout",
|
166 |
+
"keys": list(targets),
|
167 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
168 |
+
"module": "invariant_slot_attention.modules.MLP",
|
169 |
+
"num_hidden_layers": 0,
|
170 |
+
"hidden_size": 0,
|
171 |
+
"output_size": targets[k]}) for k in targets],
|
172 |
+
}),
|
173 |
+
"relative_positions_rotations_and_scales": True,
|
174 |
+
"pos_emb": ml_collections.ConfigDict({
|
175 |
+
"module":
|
176 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
177 |
+
"embedding_type":
|
178 |
+
"linear",
|
179 |
+
"update_type":
|
180 |
+
"project_add",
|
181 |
+
"scales_factor":
|
182 |
+
5.0,
|
183 |
+
}),
|
184 |
+
}),
|
185 |
+
"decode_corrected": True,
|
186 |
+
"decode_predicted": False,
|
187 |
+
})
|
188 |
+
|
189 |
+
# Which video-shaped variables to visualize.
|
190 |
+
config.debug_var_video_paths = {
|
191 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
192 |
+
}
|
193 |
+
|
194 |
+
# Define which attention matrices to log/visualize.
|
195 |
+
config.debug_var_attn_paths = {
|
196 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
197 |
+
}
|
198 |
+
|
199 |
+
# Widths of attention matrices (for reshaping to image grid).
|
200 |
+
config.debug_var_attn_widths = {
|
201 |
+
"corrector_attn": 16,
|
202 |
+
}
|
203 |
+
|
204 |
+
return config
|
205 |
+
|
206 |
+
|
invariant_slot_attention/configs/waymo_open/equiv_transl_scale.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
r"""Config for unsupervised training on Waymo Open."""
|
17 |
+
|
18 |
+
import ml_collections
|
19 |
+
|
20 |
+
|
21 |
+
def get_config():
|
22 |
+
"""Get the default hyperparameter configuration."""
|
23 |
+
config = ml_collections.ConfigDict()
|
24 |
+
|
25 |
+
config.seed = 42
|
26 |
+
config.seed_data = True
|
27 |
+
|
28 |
+
config.batch_size = 64
|
29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
32 |
+
config.init_checkpoint.wid = 1
|
33 |
+
|
34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
35 |
+
config.optimizer_configs.optimizer = "adam"
|
36 |
+
|
37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
40 |
+
|
41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
46 |
+
# from the original Slot Attention
|
47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
48 |
+
|
49 |
+
config.eval_pad_last_batch = False # True
|
50 |
+
config.log_loss_every_steps = 50
|
51 |
+
config.eval_every_steps = 5000
|
52 |
+
config.checkpoint_every_steps = 5000
|
53 |
+
|
54 |
+
config.train_metrics_spec = {
|
55 |
+
"loss": "loss",
|
56 |
+
"ari": "ari",
|
57 |
+
"ari_nobg": "ari_nobg",
|
58 |
+
}
|
59 |
+
config.eval_metrics_spec = {
|
60 |
+
"eval_loss": "loss",
|
61 |
+
"eval_ari": "ari",
|
62 |
+
"eval_ari_nobg": "ari_nobg",
|
63 |
+
}
|
64 |
+
|
65 |
+
config.data = ml_collections.ConfigDict({
|
66 |
+
"dataset_name": "waymo_open",
|
67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
68 |
+
"resolution": (128, 192)
|
69 |
+
})
|
70 |
+
|
71 |
+
config.max_instances = 11
|
72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
73 |
+
config.logging_min_n_colors = config.max_instances
|
74 |
+
|
75 |
+
config.preproc_train = [
|
76 |
+
"tfds_image_to_tfds_video",
|
77 |
+
"video_from_tfds",
|
78 |
+
]
|
79 |
+
|
80 |
+
config.preproc_eval = [
|
81 |
+
"tfds_image_to_tfds_video",
|
82 |
+
"video_from_tfds",
|
83 |
+
"delete_small_masks(threshold=0.01, max_instances_after=11)",
|
84 |
+
]
|
85 |
+
|
86 |
+
config.eval_slice_size = 1
|
87 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
88 |
+
|
89 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
90 |
+
targets = {"video": 3}
|
91 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
92 |
+
config.losses = ml_collections.ConfigDict({
|
93 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
94 |
+
for target in targets})
|
95 |
+
|
96 |
+
config.model = ml_collections.ConfigDict({
|
97 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
98 |
+
|
99 |
+
# Encoder.
|
100 |
+
"encoder": ml_collections.ConfigDict({
|
101 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
102 |
+
"reduction": "spatial_flatten",
|
103 |
+
"backbone": ml_collections.ConfigDict({
|
104 |
+
"module": "invariant_slot_attention.modules.ResNet34",
|
105 |
+
"num_classes": None,
|
106 |
+
"axis_name": "time",
|
107 |
+
"norm_type": "group",
|
108 |
+
"small_inputs": True
|
109 |
+
}),
|
110 |
+
"pos_emb": ml_collections.ConfigDict({
|
111 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
112 |
+
"embedding_type": "linear",
|
113 |
+
"update_type": "project_add",
|
114 |
+
"output_transform": ml_collections.ConfigDict({
|
115 |
+
"module": "invariant_slot_attention.modules.MLP",
|
116 |
+
"hidden_size": 128,
|
117 |
+
"layernorm": "pre"
|
118 |
+
}),
|
119 |
+
}),
|
120 |
+
}),
|
121 |
+
|
122 |
+
# Corrector.
|
123 |
+
"corrector": ml_collections.ConfigDict({
|
124 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
125 |
+
"num_iterations": 3,
|
126 |
+
"qkv_size": 64,
|
127 |
+
"mlp_size": 128,
|
128 |
+
"grid_encoder": ml_collections.ConfigDict({
|
129 |
+
"module": "invariant_slot_attention.modules.MLP",
|
130 |
+
"hidden_size": 128,
|
131 |
+
"layernorm": "pre"
|
132 |
+
}),
|
133 |
+
"add_rel_pos_to_values": True, # V3
|
134 |
+
"zero_position_init": False, # Random positions.
|
135 |
+
"init_with_fixed_scale": None, # Random scales.
|
136 |
+
"scales_factor": 5.0,
|
137 |
+
}),
|
138 |
+
|
139 |
+
# Predictor.
|
140 |
+
# Removed since we are running a single frame.
|
141 |
+
"predictor": ml_collections.ConfigDict({
|
142 |
+
"module": "invariant_slot_attention.modules.Identity"
|
143 |
+
}),
|
144 |
+
|
145 |
+
# Initializer.
|
146 |
+
"initializer": ml_collections.ConfigDict({
|
147 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
148 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
149 |
+
}),
|
150 |
+
|
151 |
+
# Decoder.
|
152 |
+
"decoder": ml_collections.ConfigDict({
|
153 |
+
"module":
|
154 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
155 |
+
"resolution": (16, 24), # Update if data resolution or strides change
|
156 |
+
"backbone": ml_collections.ConfigDict({
|
157 |
+
"module": "invariant_slot_attention.modules.CNN",
|
158 |
+
"features": [64, 64, 64, 64, 64],
|
159 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
160 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
161 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
162 |
+
"layer_transpose": [True, True, True, False, False]
|
163 |
+
}),
|
164 |
+
"target_readout": ml_collections.ConfigDict({
|
165 |
+
"module": "invariant_slot_attention.modules.Readout",
|
166 |
+
"keys": list(targets),
|
167 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
168 |
+
"module": "invariant_slot_attention.modules.MLP",
|
169 |
+
"num_hidden_layers": 0,
|
170 |
+
"hidden_size": 0,
|
171 |
+
"output_size": targets[k]}) for k in targets],
|
172 |
+
}),
|
173 |
+
"relative_positions_and_scales": True,
|
174 |
+
"pos_emb": ml_collections.ConfigDict({
|
175 |
+
"module":
|
176 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
177 |
+
"embedding_type":
|
178 |
+
"linear",
|
179 |
+
"update_type":
|
180 |
+
"project_add",
|
181 |
+
"scales_factor":
|
182 |
+
5.0,
|
183 |
+
}),
|
184 |
+
}),
|
185 |
+
"decode_corrected": True,
|
186 |
+
"decode_predicted": False,
|
187 |
+
})
|
188 |
+
|
189 |
+
# Which video-shaped variables to visualize.
|
190 |
+
config.debug_var_video_paths = {
|
191 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
192 |
+
}
|
193 |
+
|
194 |
+
# Define which attention matrices to log/visualize.
|
195 |
+
config.debug_var_attn_paths = {
|
196 |
+
"corrector_attn": "corrector/InvertedDotProductAttentionKeyPerQuery_0/attn" # pylint: disable=line-too-long
|
197 |
+
}
|
198 |
+
|
199 |
+
# Widths of attention matrices (for reshaping to image grid).
|
200 |
+
config.debug_var_attn_widths = {
|
201 |
+
"corrector_attn": 16,
|
202 |
+
}
|
203 |
+
|
204 |
+
return config
|
205 |
+
|
206 |
+
|
invariant_slot_attention/lib/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
invariant_slot_attention/lib/evaluator.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Model evaluation."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
|
20 |
+
|
21 |
+
from absl import logging
|
22 |
+
from clu import metrics
|
23 |
+
import flax
|
24 |
+
from flax import linen as nn
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
import numpy as np
|
28 |
+
import tensorflow as tf
|
29 |
+
|
30 |
+
from invariant_slot_attention.lib import losses
|
31 |
+
from invariant_slot_attention.lib import utils
|
32 |
+
|
33 |
+
|
34 |
+
Array = jnp.ndarray
|
35 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
36 |
+
PRNGKey = Array
|
37 |
+
|
38 |
+
|
39 |
+
def get_eval_metrics(
|
40 |
+
preds,
|
41 |
+
batch,
|
42 |
+
loss_fn,
|
43 |
+
eval_metrics_cls,
|
44 |
+
predicted_max_num_instances,
|
45 |
+
ground_truth_max_num_instances,
|
46 |
+
):
|
47 |
+
"""Compute the metrics for the model predictions in inference mode.
|
48 |
+
|
49 |
+
The metrics are averaged across *all* devices (of all hosts).
|
50 |
+
|
51 |
+
Args:
|
52 |
+
preds: Model predictions.
|
53 |
+
batch: Inputs that should be evaluated.
|
54 |
+
loss_fn: Loss function that takes model predictions and a batch of data.
|
55 |
+
eval_metrics_cls: Evaluation metrics collection.
|
56 |
+
predicted_max_num_instances: Maximum number of instances in prediction.
|
57 |
+
ground_truth_max_num_instances: Maximum number of instances in ground truth,
|
58 |
+
including background (which counts as a separate instance).
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
The evaluation metrics.
|
62 |
+
"""
|
63 |
+
loss, loss_aux = loss_fn(preds, batch)
|
64 |
+
metrics_update = eval_metrics_cls.gather_from_model_output(
|
65 |
+
loss=loss,
|
66 |
+
**loss_aux,
|
67 |
+
predicted_segmentations=utils.remove_singleton_dim(
|
68 |
+
preds["outputs"].get("segmentations")), # pytype: disable=attribute-error
|
69 |
+
ground_truth_segmentations=batch.get("segmentations"),
|
70 |
+
predicted_max_num_instances=predicted_max_num_instances,
|
71 |
+
ground_truth_max_num_instances=ground_truth_max_num_instances,
|
72 |
+
padding_mask=batch.get("padding_mask"),
|
73 |
+
mask=batch.get("mask"))
|
74 |
+
return metrics_update
|
75 |
+
|
76 |
+
|
77 |
+
def eval_first_step(
|
78 |
+
model,
|
79 |
+
state_variables,
|
80 |
+
params,
|
81 |
+
batch,
|
82 |
+
rng,
|
83 |
+
conditioning_key = None
|
84 |
+
):
|
85 |
+
"""Get the model predictions with a freshly initialized recurrent state.
|
86 |
+
|
87 |
+
The model is applied to the inputs using all devices on the host.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
model: Model used in eval step.
|
91 |
+
state_variables: State variables for the model.
|
92 |
+
params: Params for the model.
|
93 |
+
batch: Inputs that should be evaluated.
|
94 |
+
rng: PRNGKey for model forward pass.
|
95 |
+
conditioning_key: Optional string. If provided, defines the batch key to be
|
96 |
+
used as conditioning signal for the model. Otherwise this is inferred from
|
97 |
+
the available keys in the batch.
|
98 |
+
Returns:
|
99 |
+
The model's predictions.
|
100 |
+
"""
|
101 |
+
logging.info("eval_first_step(batch=%s)", batch)
|
102 |
+
|
103 |
+
conditioning = None
|
104 |
+
if conditioning_key:
|
105 |
+
conditioning = batch[conditioning_key]
|
106 |
+
preds, mutable_vars = model.apply(
|
107 |
+
{"params": params, **state_variables}, video=batch["video"],
|
108 |
+
conditioning=conditioning, mutable="intermediates",
|
109 |
+
rngs={"state_init": rng}, train=False,
|
110 |
+
padding_mask=batch.get("padding_mask"))
|
111 |
+
|
112 |
+
if "intermediates" in mutable_vars:
|
113 |
+
preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"])
|
114 |
+
|
115 |
+
return preds
|
116 |
+
|
117 |
+
|
118 |
+
def eval_continued_step(
|
119 |
+
model,
|
120 |
+
state_variables,
|
121 |
+
params,
|
122 |
+
batch,
|
123 |
+
rng,
|
124 |
+
recurrent_states
|
125 |
+
):
|
126 |
+
"""Get the model predictions, continuing from a provided recurrent state.
|
127 |
+
|
128 |
+
The model is applied to the inputs using all devices on the host.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
model: Model used in eval step.
|
132 |
+
state_variables: State variables for the model.
|
133 |
+
params: The model parameters.
|
134 |
+
batch: Inputs that should be evaluated.
|
135 |
+
rng: PRNGKey for model forward pass.
|
136 |
+
recurrent_states: Recurrent internal model state from which to continue.
|
137 |
+
Returns:
|
138 |
+
The model's predictions.
|
139 |
+
"""
|
140 |
+
logging.info("eval_continued_step(batch=%s, recurrent_states=%s)", batch,
|
141 |
+
recurrent_states)
|
142 |
+
|
143 |
+
preds, mutable_vars = model.apply(
|
144 |
+
{"params": params, **state_variables}, video=batch["video"],
|
145 |
+
conditioning=recurrent_states, continue_from_previous_state=True,
|
146 |
+
mutable="intermediates", rngs={"state_init": rng}, train=False,
|
147 |
+
padding_mask=batch.get("padding_mask"))
|
148 |
+
|
149 |
+
if "intermediates" in mutable_vars:
|
150 |
+
preds["intermediates"] = flax.core.unfreeze(mutable_vars["intermediates"])
|
151 |
+
|
152 |
+
return preds
|
153 |
+
|
154 |
+
|
155 |
+
def eval_step(
|
156 |
+
model,
|
157 |
+
state,
|
158 |
+
batch,
|
159 |
+
rng,
|
160 |
+
p_eval_first_step,
|
161 |
+
p_eval_continued_step,
|
162 |
+
slice_size = None,
|
163 |
+
slice_keys = None,
|
164 |
+
conditioning_key = None,
|
165 |
+
remove_from_predictions = None
|
166 |
+
):
|
167 |
+
"""Compute the metrics for the given model in inference mode.
|
168 |
+
|
169 |
+
The model is applied to the inputs using all devices on the host. Afterwards
|
170 |
+
metrics are averaged across *all* devices (of all hosts).
|
171 |
+
|
172 |
+
Args:
|
173 |
+
model: Model used in eval step.
|
174 |
+
state: Replicated model state.
|
175 |
+
batch: Inputs that should be evaluated.
|
176 |
+
rng: PRNGKey for model forward pass.
|
177 |
+
p_eval_first_step: A parallel version of the function eval_first_step.
|
178 |
+
p_eval_continued_step: A parallel version of the function
|
179 |
+
eval_continued_step.
|
180 |
+
slice_size: Optional integer, if provided, evaluate the model on temporal
|
181 |
+
slices of this size instead of on the full sequence length at once.
|
182 |
+
slice_keys: Optional list of strings, the keys of the tensors which will be
|
183 |
+
sliced if slice_size is provided.
|
184 |
+
conditioning_key: Optional string. If provided, defines the batch key to be
|
185 |
+
used as conditioning signal for the model. Otherwise this is inferred from
|
186 |
+
the available keys in the batch.
|
187 |
+
remove_from_predictions: Remove the provided keys. The default None removes
|
188 |
+
"states" and "states_pred" from model output to save memory. Disable this
|
189 |
+
if either of these are required in the loss function or for visualization.
|
190 |
+
Returns:
|
191 |
+
Model predictions.
|
192 |
+
"""
|
193 |
+
if remove_from_predictions is None:
|
194 |
+
remove_from_predictions = ["states", "states_pred"]
|
195 |
+
|
196 |
+
seq_len = batch["video"].shape[2]
|
197 |
+
# Sliced evaluation (i.e. on smaller temporal slices of the video).
|
198 |
+
if slice_size is not None and slice_size < seq_len:
|
199 |
+
num_slices = int(np.ceil(seq_len / slice_size))
|
200 |
+
|
201 |
+
assert slice_keys is not None, (
|
202 |
+
"Slice keys need to be provided for sliced evaluation.")
|
203 |
+
|
204 |
+
preds_per_slice = []
|
205 |
+
# Get predictions for first slice (with fresh recurrent state).
|
206 |
+
batch_slice = utils.get_slices_along_axis(
|
207 |
+
batch, slice_keys=slice_keys, start_idx=0, end_idx=slice_size)
|
208 |
+
preds_slice = p_eval_first_step(model, state.variables,
|
209 |
+
state.params, batch_slice, rng,
|
210 |
+
conditioning_key)
|
211 |
+
preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU.
|
212 |
+
preds_per_slice.append(preds_slice)
|
213 |
+
|
214 |
+
# Iterate over remaining slices (re-using the previous recurrent state).
|
215 |
+
for slice_idx in range(1, num_slices):
|
216 |
+
recurrent_states = preds_per_slice[-1]["states_pred"]
|
217 |
+
batch_slice = utils.get_slices_along_axis(
|
218 |
+
batch, slice_keys=slice_keys, start_idx=slice_idx * slice_size,
|
219 |
+
end_idx=(slice_idx + 1) * slice_size)
|
220 |
+
preds_slice = p_eval_continued_step(
|
221 |
+
model, state.variables, state.params,
|
222 |
+
batch_slice, rng, recurrent_states)
|
223 |
+
preds_slice = jax.tree_map(np.asarray, preds_slice) # Copy to CPU.
|
224 |
+
preds_per_slice.append(preds_slice)
|
225 |
+
|
226 |
+
# Remove states from predictions before concat to save memory.
|
227 |
+
for k in remove_from_predictions:
|
228 |
+
for i in range(num_slices):
|
229 |
+
_ = preds_per_slice[i].pop(k, None)
|
230 |
+
|
231 |
+
# Join predictions along sequence dimension.
|
232 |
+
concat_fn = lambda _, *x: functools.partial(np.concatenate, axis=2)([*x])
|
233 |
+
preds = jax.tree_map(concat_fn, preds_per_slice[0], *preds_per_slice)
|
234 |
+
|
235 |
+
# Truncate to original sequence length.
|
236 |
+
# NOTE: This op assumes that all predictions have a (complete) time axis.
|
237 |
+
preds = jax.tree_map(lambda x: x[:, :, :seq_len], preds)
|
238 |
+
|
239 |
+
# Evaluate on full sequence if no (or too large) slice size is provided.
|
240 |
+
else:
|
241 |
+
preds = p_eval_first_step(model, state.variables,
|
242 |
+
state.params, batch, rng,
|
243 |
+
conditioning_key)
|
244 |
+
for k in remove_from_predictions:
|
245 |
+
_ = preds.pop(k, None)
|
246 |
+
|
247 |
+
return preds
|
248 |
+
|
249 |
+
|
250 |
+
def evaluate(
|
251 |
+
model,
|
252 |
+
state,
|
253 |
+
eval_ds,
|
254 |
+
loss_fn,
|
255 |
+
eval_metrics_cls,
|
256 |
+
predicted_max_num_instances,
|
257 |
+
ground_truth_max_num_instances,
|
258 |
+
slice_size = None,
|
259 |
+
slice_keys = None,
|
260 |
+
conditioning_key = None,
|
261 |
+
remove_from_predictions = None,
|
262 |
+
metrics_on_cpu = False,
|
263 |
+
):
|
264 |
+
"""Evaluate the model on the given dataset."""
|
265 |
+
eval_metrics = None
|
266 |
+
batch = None
|
267 |
+
preds = None
|
268 |
+
rng = state.rng[0] # Get training state PRNGKey from first replica.
|
269 |
+
|
270 |
+
if metrics_on_cpu and jax.process_count() > 1:
|
271 |
+
raise NotImplementedError(
|
272 |
+
"metrics_on_cpu feature cannot be used in a multi-host setup."
|
273 |
+
" This experiment is using {} hosts.".format(jax.process_count()))
|
274 |
+
metric_devices = jax.devices("cpu") if metrics_on_cpu else jax.devices()
|
275 |
+
|
276 |
+
p_eval_first_step = jax.pmap(
|
277 |
+
eval_first_step,
|
278 |
+
axis_name="batch",
|
279 |
+
static_broadcasted_argnums=(0, 5),
|
280 |
+
devices=jax.devices())
|
281 |
+
p_eval_continued_step = jax.pmap(
|
282 |
+
eval_continued_step,
|
283 |
+
axis_name="batch",
|
284 |
+
static_broadcasted_argnums=(0),
|
285 |
+
devices=jax.devices())
|
286 |
+
p_get_eval_metrics = jax.pmap(
|
287 |
+
get_eval_metrics,
|
288 |
+
axis_name="batch",
|
289 |
+
static_broadcasted_argnums=(2, 3, 4, 5),
|
290 |
+
devices=metric_devices,
|
291 |
+
backend="cpu" if metrics_on_cpu else None)
|
292 |
+
|
293 |
+
def reshape_fn(x):
|
294 |
+
"""Function to reshape preds and batch before calling p_get_eval_metrics."""
|
295 |
+
return np.reshape(x, [len(metric_devices), -1] + list(x.shape[2:]))
|
296 |
+
|
297 |
+
for batch in eval_ds:
|
298 |
+
rng, eval_rng = jax.random.split(rng)
|
299 |
+
eval_rng = jax.random.fold_in(eval_rng, jax.host_id()) # Bind to host.
|
300 |
+
eval_rngs = jax.random.split(eval_rng, jax.local_device_count())
|
301 |
+
batch = jax.tree_map(np.asarray, batch)
|
302 |
+
preds = eval_step(
|
303 |
+
model=model,
|
304 |
+
state=state,
|
305 |
+
batch=batch,
|
306 |
+
rng=eval_rngs,
|
307 |
+
p_eval_first_step=p_eval_first_step,
|
308 |
+
p_eval_continued_step=p_eval_continued_step,
|
309 |
+
slice_size=slice_size,
|
310 |
+
slice_keys=slice_keys,
|
311 |
+
conditioning_key=conditioning_key,
|
312 |
+
remove_from_predictions=remove_from_predictions)
|
313 |
+
|
314 |
+
if metrics_on_cpu:
|
315 |
+
# Reshape replica dim and batch-dims to work with metric_devices.
|
316 |
+
preds = jax.tree_map(reshape_fn, preds)
|
317 |
+
batch = jax.tree_map(reshape_fn, batch)
|
318 |
+
# Get metric updates.
|
319 |
+
update = p_get_eval_metrics(preds, batch, loss_fn, eval_metrics_cls,
|
320 |
+
predicted_max_num_instances,
|
321 |
+
ground_truth_max_num_instances)
|
322 |
+
update = flax.jax_utils.unreplicate(update)
|
323 |
+
eval_metrics = (
|
324 |
+
update if eval_metrics is None else eval_metrics.merge(update))
|
325 |
+
assert eval_metrics is not None
|
326 |
+
return eval_metrics, batch, preds
|
invariant_slot_attention/lib/input_pipeline.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Input pipeline for TFDS datasets."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
import os
|
20 |
+
from typing import Dict, List, Tuple
|
21 |
+
|
22 |
+
from clu import deterministic_data
|
23 |
+
from clu import preprocess_spec
|
24 |
+
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
import ml_collections
|
28 |
+
|
29 |
+
import sunds
|
30 |
+
import tensorflow as tf
|
31 |
+
import tensorflow_datasets as tfds
|
32 |
+
|
33 |
+
from invariant_slot_attention.lib import preprocessing
|
34 |
+
|
35 |
+
Array = jnp.ndarray
|
36 |
+
PRNGKey = Array
|
37 |
+
|
38 |
+
|
39 |
+
PATH_CLEVR_WITH_MASKS = "gs://multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords"
|
40 |
+
FEATURES_CLEVR_WITH_MASKS = {
|
41 |
+
"image": tf.io.FixedLenFeature([240, 320, 3], tf.string),
|
42 |
+
"mask": tf.io.FixedLenFeature([11, 240, 320, 1], tf.string),
|
43 |
+
"x": tf.io.FixedLenFeature([11], tf.float32),
|
44 |
+
"y": tf.io.FixedLenFeature([11], tf.float32),
|
45 |
+
"z": tf.io.FixedLenFeature([11], tf.float32),
|
46 |
+
"pixel_coords": tf.io.FixedLenFeature([11, 3], tf.float32),
|
47 |
+
"rotation": tf.io.FixedLenFeature([11], tf.float32),
|
48 |
+
"size": tf.io.FixedLenFeature([11], tf.string),
|
49 |
+
"material": tf.io.FixedLenFeature([11], tf.string),
|
50 |
+
"shape": tf.io.FixedLenFeature([11], tf.string),
|
51 |
+
"color": tf.io.FixedLenFeature([11], tf.string),
|
52 |
+
"visibility": tf.io.FixedLenFeature([11], tf.float32),
|
53 |
+
}
|
54 |
+
|
55 |
+
PATH_TETROMINOES = "gs://multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords"
|
56 |
+
FEATURES_TETROMINOES = {
|
57 |
+
"image": tf.io.FixedLenFeature([35, 35, 3], tf.string),
|
58 |
+
"mask": tf.io.FixedLenFeature([4, 35, 35, 1], tf.string),
|
59 |
+
"x": tf.io.FixedLenFeature([4], tf.float32),
|
60 |
+
"y": tf.io.FixedLenFeature([4], tf.float32),
|
61 |
+
"shape": tf.io.FixedLenFeature([4], tf.float32),
|
62 |
+
"color": tf.io.FixedLenFeature([4, 3], tf.float32),
|
63 |
+
"visibility": tf.io.FixedLenFeature([4], tf.float32),
|
64 |
+
}
|
65 |
+
|
66 |
+
PATH_OBJECTS_ROOM = "gs://multi-object-datasets/objects_room/objects_room_train.tfrecords"
|
67 |
+
FEATURES_OBJECTS_ROOM = {
|
68 |
+
"image": tf.io.FixedLenFeature([64, 64, 3], tf.string),
|
69 |
+
"mask": tf.io.FixedLenFeature([7, 64, 64, 1], tf.string),
|
70 |
+
}
|
71 |
+
|
72 |
+
PATH_WAYMO_OPEN = "datasets/waymo_v_1_4_0_images/tfrecords"
|
73 |
+
|
74 |
+
FEATURES_WAYMO_OPEN = {
|
75 |
+
"image": tf.io.FixedLenFeature([128, 192, 3], tf.string),
|
76 |
+
"segmentations": tf.io.FixedLenFeature([128, 192], tf.string),
|
77 |
+
"depth": tf.io.FixedLenFeature([128, 192], tf.float32),
|
78 |
+
"num_objects": tf.io.FixedLenFeature([1], tf.int64),
|
79 |
+
"has_mask": tf.io.FixedLenFeature([1], tf.int64),
|
80 |
+
"camera": tf.io.FixedLenFeature([1], tf.int64),
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
def _decode_tetrominoes(example_proto):
|
85 |
+
single_example = tf.io.parse_single_example(
|
86 |
+
example_proto, FEATURES_TETROMINOES)
|
87 |
+
for k in ["mask", "image"]:
|
88 |
+
single_example[k] = tf.squeeze(
|
89 |
+
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
|
90 |
+
return single_example
|
91 |
+
|
92 |
+
|
93 |
+
def _decode_objects_room(example_proto):
|
94 |
+
single_example = tf.io.parse_single_example(
|
95 |
+
example_proto, FEATURES_OBJECTS_ROOM)
|
96 |
+
for k in ["mask", "image"]:
|
97 |
+
single_example[k] = tf.squeeze(
|
98 |
+
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
|
99 |
+
return single_example
|
100 |
+
|
101 |
+
|
102 |
+
def _decode_clevr_with_masks(example_proto):
|
103 |
+
single_example = tf.io.parse_single_example(
|
104 |
+
example_proto, FEATURES_CLEVR_WITH_MASKS)
|
105 |
+
for k in ["mask", "image", "color", "material", "shape", "size"]:
|
106 |
+
single_example[k] = tf.squeeze(
|
107 |
+
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
|
108 |
+
return single_example
|
109 |
+
|
110 |
+
|
111 |
+
def _decode_waymo_open(example_proto):
|
112 |
+
"""Unserializes a serialized tf.train.Example sample."""
|
113 |
+
single_example = tf.io.parse_single_example(
|
114 |
+
example_proto, FEATURES_WAYMO_OPEN)
|
115 |
+
for k in ["image", "segmentations"]:
|
116 |
+
single_example[k] = tf.squeeze(
|
117 |
+
tf.io.decode_raw(single_example[k], tf.uint8), axis=-1)
|
118 |
+
single_example["segmentations"] = tf.expand_dims(
|
119 |
+
single_example["segmentations"], axis=-1)
|
120 |
+
single_example["depth"] = tf.expand_dims(
|
121 |
+
single_example["depth"], axis=-1)
|
122 |
+
return single_example
|
123 |
+
|
124 |
+
|
125 |
+
def _preprocess_minimal(example):
|
126 |
+
return {
|
127 |
+
"image": example["image"],
|
128 |
+
"segmentations": tf.cast(tf.argmax(example["mask"], axis=0), tf.uint8),
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
def _sunds_create_task():
|
133 |
+
"""Create a sunds task to return images and instance segmentation."""
|
134 |
+
return sunds.tasks.Nerf(
|
135 |
+
yield_mode=sunds.tasks.YieldMode.IMAGE,
|
136 |
+
additional_camera_specs={
|
137 |
+
"depth_image": False, # Not available in the dataset.
|
138 |
+
"category_image": False, # Not available in the dataset.
|
139 |
+
"instance_image": True,
|
140 |
+
"extrinsics": True,
|
141 |
+
},
|
142 |
+
additional_frame_specs={"pose": True},
|
143 |
+
add_name=True
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def preprocess_example(features,
|
148 |
+
preprocess_strs):
|
149 |
+
"""Processes a single data example.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
features: A dictionary containing the tensors of a single data example.
|
153 |
+
preprocess_strs: List of strings, describing one preprocessing operation
|
154 |
+
each, in clu.preprocess_spec format.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Dictionary containing the preprocessed tensors of a single data example.
|
158 |
+
"""
|
159 |
+
all_ops = preprocessing.all_ops()
|
160 |
+
preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops)
|
161 |
+
return preprocess_fn(features) # pytype: disable=bad-return-type # allow-recursive-types
|
162 |
+
|
163 |
+
|
164 |
+
def get_batch_dims(global_batch_size):
|
165 |
+
"""Gets the first two axis sizes for data batches.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
global_batch_size: Integer, the global batch size (across all devices).
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
List of batch dimensions
|
172 |
+
|
173 |
+
Raises:
|
174 |
+
ValueError if the requested dimensions don't make sense with the
|
175 |
+
number of devices.
|
176 |
+
"""
|
177 |
+
num_local_devices = jax.local_device_count()
|
178 |
+
if global_batch_size % jax.host_count() != 0:
|
179 |
+
raise ValueError(f"Global batch size {global_batch_size} not evenly "
|
180 |
+
f"divisble with {jax.host_count()}.")
|
181 |
+
per_host_batch_size = global_batch_size // jax.host_count()
|
182 |
+
if per_host_batch_size % num_local_devices != 0:
|
183 |
+
raise ValueError(f"Global batch size {global_batch_size} not evenly "
|
184 |
+
f"divisible with {jax.host_count()} hosts with a per host "
|
185 |
+
f"batch size of {per_host_batch_size} and "
|
186 |
+
f"{num_local_devices} local devices. ")
|
187 |
+
return [num_local_devices, per_host_batch_size // num_local_devices]
|
188 |
+
|
189 |
+
|
190 |
+
def create_datasets(
|
191 |
+
config,
|
192 |
+
data_rng):
|
193 |
+
"""Create datasets for training and evaluation.
|
194 |
+
|
195 |
+
For the same data_rng and config this will return the same datasets. The
|
196 |
+
datasets only contain stateless operations.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
config: Configuration to use.
|
200 |
+
data_rng: JAX PRNGKey for dataset pipeline.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
A tuple with the training dataset and the evaluation dataset.
|
204 |
+
"""
|
205 |
+
|
206 |
+
if config.data.dataset_name == "tetrominoes":
|
207 |
+
ds = tf.data.TFRecordDataset(
|
208 |
+
PATH_TETROMINOES,
|
209 |
+
compression_type="GZIP", buffer_size=2*(2**20))
|
210 |
+
ds = ds.map(_decode_tetrominoes,
|
211 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
212 |
+
ds = ds.map(_preprocess_minimal,
|
213 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
214 |
+
|
215 |
+
class TetrominoesBuilder:
|
216 |
+
"""Builder for tentrominoes dataset."""
|
217 |
+
|
218 |
+
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
|
219 |
+
"""Simple function to conform to the builder api."""
|
220 |
+
if split == "train":
|
221 |
+
# We use 512 training examples.
|
222 |
+
ds = ds.skip(100)
|
223 |
+
ds = ds.take(512)
|
224 |
+
return tf.data.experimental.assert_cardinality(512)(ds)
|
225 |
+
elif split == "validation":
|
226 |
+
# 100 validation examples.
|
227 |
+
ds = ds.take(100)
|
228 |
+
return tf.data.experimental.assert_cardinality(100)(ds)
|
229 |
+
else:
|
230 |
+
raise ValueError("Invalid split.")
|
231 |
+
|
232 |
+
dataset_builder = TetrominoesBuilder()
|
233 |
+
elif config.data.dataset_name == "objects_room":
|
234 |
+
ds = tf.data.TFRecordDataset(
|
235 |
+
PATH_OBJECTS_ROOM,
|
236 |
+
compression_type="GZIP", buffer_size=2*(2**20))
|
237 |
+
ds = ds.map(_decode_objects_room,
|
238 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
239 |
+
ds = ds.map(_preprocess_minimal,
|
240 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
241 |
+
|
242 |
+
class ObjectsRoomBuilder:
|
243 |
+
"""Builder for objects room dataset."""
|
244 |
+
|
245 |
+
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
|
246 |
+
"""Simple function to conform to the builder api."""
|
247 |
+
if split == "train":
|
248 |
+
# 1M - 100 training examples.
|
249 |
+
ds = ds.skip(100)
|
250 |
+
return tf.data.experimental.assert_cardinality(999900)(ds)
|
251 |
+
elif split == "validation":
|
252 |
+
# 100 validation examples.
|
253 |
+
ds = ds.take(100)
|
254 |
+
return tf.data.experimental.assert_cardinality(100)(ds)
|
255 |
+
else:
|
256 |
+
raise ValueError("Invalid split.")
|
257 |
+
|
258 |
+
dataset_builder = ObjectsRoomBuilder()
|
259 |
+
elif config.data.dataset_name == "clevr_with_masks":
|
260 |
+
ds = tf.data.TFRecordDataset(
|
261 |
+
PATH_CLEVR_WITH_MASKS,
|
262 |
+
compression_type="GZIP", buffer_size=2*(2**20))
|
263 |
+
ds = ds.map(_decode_clevr_with_masks,
|
264 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
265 |
+
ds = ds.map(_preprocess_minimal,
|
266 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
267 |
+
|
268 |
+
class CLEVRWithMasksBuilder:
|
269 |
+
def as_dataset(self, split, *unused_args, ds=ds, **unused_kwargs):
|
270 |
+
if split == "train":
|
271 |
+
ds = ds.skip(100)
|
272 |
+
return tf.data.experimental.assert_cardinality(99900)(ds)
|
273 |
+
elif split == "validation":
|
274 |
+
ds = ds.take(100)
|
275 |
+
return tf.data.experimental.assert_cardinality(100)(ds)
|
276 |
+
else:
|
277 |
+
raise ValueError("Invalid split.")
|
278 |
+
|
279 |
+
dataset_builder = CLEVRWithMasksBuilder()
|
280 |
+
elif config.data.dataset_name == "waymo_open":
|
281 |
+
train_path = os.path.join(
|
282 |
+
PATH_WAYMO_OPEN, "training/camera_1/*tfrecords*")
|
283 |
+
eval_path = os.path.join(
|
284 |
+
PATH_WAYMO_OPEN, "validation/camera_1/*tfrecords*")
|
285 |
+
|
286 |
+
train_files = tf.data.Dataset.list_files(train_path)
|
287 |
+
eval_files = tf.data.Dataset.list_files(eval_path)
|
288 |
+
|
289 |
+
train_data_reader = functools.partial(
|
290 |
+
tf.data.TFRecordDataset,
|
291 |
+
compression_type="ZLIB", buffer_size=2*(2**20))
|
292 |
+
eval_data_reader = functools.partial(
|
293 |
+
tf.data.TFRecordDataset,
|
294 |
+
compression_type="ZLIB", buffer_size=2*(2**20))
|
295 |
+
|
296 |
+
train_dataset = train_files.interleave(
|
297 |
+
train_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
298 |
+
eval_dataset = eval_files.interleave(
|
299 |
+
eval_data_reader, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
300 |
+
|
301 |
+
train_dataset = train_dataset.map(
|
302 |
+
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
303 |
+
eval_dataset = eval_dataset.map(
|
304 |
+
_decode_waymo_open, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
305 |
+
|
306 |
+
# We need to set the dataset cardinality. We assume we have
|
307 |
+
# the full dataset.
|
308 |
+
train_dataset = train_dataset.apply(
|
309 |
+
tf.data.experimental.assert_cardinality(158081))
|
310 |
+
|
311 |
+
class WaymoOpenBuilder:
|
312 |
+
def as_dataset(self, split, *unused_args, **unused_kwargs):
|
313 |
+
if split == "train":
|
314 |
+
return train_dataset
|
315 |
+
elif split == "validation":
|
316 |
+
return eval_dataset
|
317 |
+
else:
|
318 |
+
raise ValueError("Invalid split.")
|
319 |
+
|
320 |
+
dataset_builder = WaymoOpenBuilder()
|
321 |
+
elif config.data.dataset_name == "multishapenet_easy":
|
322 |
+
dataset_builder = sunds.builder(
|
323 |
+
name=config.get("tfds_name", "msn_easy"),
|
324 |
+
data_dir=config.get(
|
325 |
+
"data_dir", "gs://kubric-public/tfds"),
|
326 |
+
try_gcs=True)
|
327 |
+
dataset_builder.as_dataset = functools.partial(
|
328 |
+
dataset_builder.as_dataset, task=_sunds_create_task())
|
329 |
+
elif config.data.dataset_name == "tfds":
|
330 |
+
dataset_builder = tfds.builder(
|
331 |
+
config.data.tfds_name, data_dir=config.data.data_dir)
|
332 |
+
else:
|
333 |
+
raise ValueError("Please specify a valid dataset name.")
|
334 |
+
|
335 |
+
batch_dims = get_batch_dims(config.batch_size)
|
336 |
+
|
337 |
+
train_preprocess_fn = functools.partial(
|
338 |
+
preprocess_example, preprocess_strs=config.preproc_train)
|
339 |
+
eval_preprocess_fn = functools.partial(
|
340 |
+
preprocess_example, preprocess_strs=config.preproc_eval)
|
341 |
+
|
342 |
+
train_split_name = config.get("train_split", "train")
|
343 |
+
eval_split_name = config.get("validation_split", "validation")
|
344 |
+
|
345 |
+
train_ds = deterministic_data.create_dataset(
|
346 |
+
dataset_builder,
|
347 |
+
split=train_split_name,
|
348 |
+
rng=data_rng,
|
349 |
+
preprocess_fn=train_preprocess_fn,
|
350 |
+
cache=False,
|
351 |
+
shuffle_buffer_size=config.data.shuffle_buffer_size,
|
352 |
+
batch_dims=batch_dims,
|
353 |
+
num_epochs=None,
|
354 |
+
shuffle=True)
|
355 |
+
|
356 |
+
if config.data.dataset_name == "waymo_open":
|
357 |
+
# We filter Waymo Open for empty segmentation masks.
|
358 |
+
def filter_fn(features):
|
359 |
+
unique_instances = tf.unique(
|
360 |
+
tf.reshape(features[preprocessing.SEGMENTATIONS], (-1,)))[0]
|
361 |
+
n_instances = tf.size(unique_instances, tf.int32)
|
362 |
+
# n_instances == 1 means we only have the background.
|
363 |
+
return 2 <= n_instances
|
364 |
+
else:
|
365 |
+
filter_fn = None
|
366 |
+
|
367 |
+
eval_ds = deterministic_data.create_dataset(
|
368 |
+
dataset_builder,
|
369 |
+
split=eval_split_name,
|
370 |
+
rng=None,
|
371 |
+
preprocess_fn=eval_preprocess_fn,
|
372 |
+
filter_fn=filter_fn,
|
373 |
+
cache=False,
|
374 |
+
batch_dims=batch_dims,
|
375 |
+
num_epochs=1,
|
376 |
+
shuffle=False,
|
377 |
+
pad_up_to_batches=None)
|
378 |
+
|
379 |
+
if config.data.dataset_name == "waymo_open":
|
380 |
+
# We filter Waymo Open for empty segmentation masks after preprocessing.
|
381 |
+
# For the full dataset, we know how many we will end up with.
|
382 |
+
eval_batch_size = batch_dims[0] * batch_dims[1]
|
383 |
+
# We don't pad the last batch => floor.
|
384 |
+
eval_num_batches = int(
|
385 |
+
jnp.floor(1872 / eval_batch_size / jax.host_count()))
|
386 |
+
eval_ds = eval_ds.apply(
|
387 |
+
tf.data.experimental.assert_cardinality(
|
388 |
+
eval_num_batches))
|
389 |
+
|
390 |
+
return train_ds, eval_ds
|
invariant_slot_attention/lib/losses.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Loss functions."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
21 |
+
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
import ml_collections
|
25 |
+
|
26 |
+
_LOSS_FUNCTIONS = {}
|
27 |
+
|
28 |
+
Array = Any # jnp.ndarray somehow doesn't work anymore for pytype.
|
29 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
30 |
+
ArrayDict = Dict[str, Array]
|
31 |
+
DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet
|
32 |
+
PRNGKey = Array
|
33 |
+
LossFn = Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]],
|
34 |
+
Tuple[Array, ArrayTree]]
|
35 |
+
ConfigAttr = Any
|
36 |
+
MetricSpec = Dict[str, str]
|
37 |
+
|
38 |
+
|
39 |
+
def standardize_loss_config(
|
40 |
+
loss_config
|
41 |
+
):
|
42 |
+
"""Standardize loss configs into a common ConfigDict format.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
loss_config: List of strings or ConfigDict specifying loss configuration.
|
46 |
+
Valid input formats are: - Option 1 (list of strings), for example,
|
47 |
+
`loss_config = ["box", "presence"]` - Option 2 (losses with weights
|
48 |
+
only), for example,
|
49 |
+
`loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3
|
50 |
+
(losses with weights and other parameters), for example,
|
51 |
+
`loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"},
|
52 |
+
"presence": {"weight": 2}})`
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Standardized ConfigDict containing the loss configuration.
|
56 |
+
|
57 |
+
Raises:
|
58 |
+
ValueError: If loss_config is a list that contains non-string entries.
|
59 |
+
"""
|
60 |
+
|
61 |
+
if isinstance(loss_config, Sequence): # Option 1
|
62 |
+
if not all(isinstance(loss_type, str) for loss_type in loss_config):
|
63 |
+
raise ValueError(f"Loss types all need to be str but got {loss_config}")
|
64 |
+
return ml_collections.FrozenConfigDict({k: {} for k in loss_config})
|
65 |
+
|
66 |
+
# Convert all option-2-style weights to option-3-style dictionaries.
|
67 |
+
loss_config = {
|
68 |
+
k: {
|
69 |
+
"weight": v
|
70 |
+
} if isinstance(v, (float, int)) else v for k, v in loss_config.items()
|
71 |
+
}
|
72 |
+
return ml_collections.FrozenConfigDict(loss_config)
|
73 |
+
|
74 |
+
|
75 |
+
def update_loss_aux(loss_aux, update):
|
76 |
+
existing_keys = set(update.keys()).intersection(loss_aux.keys())
|
77 |
+
if existing_keys:
|
78 |
+
raise KeyError(
|
79 |
+
f"Can't overwrite existing keys in loss_aux: {existing_keys}")
|
80 |
+
loss_aux.update(update)
|
81 |
+
|
82 |
+
|
83 |
+
def compute_full_loss(
|
84 |
+
preds, targets,
|
85 |
+
loss_config
|
86 |
+
):
|
87 |
+
"""Loss function that parses and combines weighted loss terms.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
preds: Dictionary of tensors containing model predictions.
|
91 |
+
targets: Dictionary of tensors containing prediction targets.
|
92 |
+
loss_config: List of strings or ConfigDict specifying loss configuration.
|
93 |
+
See @register_loss decorated functions below for valid loss names.
|
94 |
+
Valid losses formats are: - Option 1 (list of strings), for example,
|
95 |
+
`loss_config = ["box", "presence"]` - Option 2 (losses with weights
|
96 |
+
only), for example,
|
97 |
+
`loss_config = ConfigDict({"box": 5, "presence": 2})` - Option 3 (losses
|
98 |
+
with weights and other parameters), for example,
|
99 |
+
`loss_config = ConfigDict({"box": {"weight": 5, "metric": "l1"},
|
100 |
+
"presence": {"weight": 2}})` - Option 4 (like
|
101 |
+
3 but decoupling name and loss_type), for
|
102 |
+
example,
|
103 |
+
`loss_config = ConfigDict({"recon_flow": {"loss_type": "recon",
|
104 |
+
"key": "flow"},
|
105 |
+
"recon_video": {"loss_type": "recon",
|
106 |
+
"key": "video"}})`
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
A 2-tuple of the sum of all individual loss terms and a dictionary of
|
110 |
+
auxiliary losses and metrics.
|
111 |
+
"""
|
112 |
+
|
113 |
+
loss = jnp.zeros([], jnp.float32)
|
114 |
+
loss_aux = {}
|
115 |
+
loss_config = standardize_loss_config(loss_config)
|
116 |
+
for loss_name, cfg in loss_config.items():
|
117 |
+
context_kwargs = {"preds": preds, "targets": targets}
|
118 |
+
weight, loss_term, loss_aux_update = compute_loss_term(
|
119 |
+
loss_name=loss_name, context_kwargs=context_kwargs, config_kwargs=cfg)
|
120 |
+
|
121 |
+
unweighted_loss = jnp.mean(loss_term)
|
122 |
+
loss += weight * unweighted_loss
|
123 |
+
loss_aux_update[loss_name + "_value"] = unweighted_loss
|
124 |
+
loss_aux_update[loss_name + "_weight"] = jnp.ones_like(unweighted_loss)
|
125 |
+
update_loss_aux(loss_aux, loss_aux_update)
|
126 |
+
return loss, loss_aux
|
127 |
+
|
128 |
+
|
129 |
+
def register_loss(func=None,
|
130 |
+
*,
|
131 |
+
name = None,
|
132 |
+
check_unused_kwargs = True):
|
133 |
+
"""Decorator for registering a loss function.
|
134 |
+
|
135 |
+
Can be used without arguments:
|
136 |
+
```
|
137 |
+
@register_loss
|
138 |
+
def my_loss(**_):
|
139 |
+
return 0
|
140 |
+
```
|
141 |
+
or with keyword arguments:
|
142 |
+
```
|
143 |
+
@register_loss(name="my_renamed_loss")
|
144 |
+
def my_loss(**_):
|
145 |
+
return 0
|
146 |
+
```
|
147 |
+
|
148 |
+
Loss functions may accept
|
149 |
+
- context kwargs: `preds` and `targets`
|
150 |
+
- config kwargs: any argument specified in the config
|
151 |
+
- the special `config_kwargs` parameter that contains the entire loss config
|
152 |
+
Loss functions also _need_ to accept a **kwarg argument to support extending
|
153 |
+
the interface.
|
154 |
+
They should return either:
|
155 |
+
- just the computed loss (pre-reduction)
|
156 |
+
- or a tuple of the computed loss and a loss_aux_updates dict
|
157 |
+
|
158 |
+
Args:
|
159 |
+
func: the decorated function
|
160 |
+
name (str): Optional name to be used for this loss in the config. Defaults
|
161 |
+
to the name of the function.
|
162 |
+
check_unused_kwargs (bool): By default compute_loss_term raises an error if
|
163 |
+
there are any unused config kwargs. If this flag is set to False that step
|
164 |
+
is skipped. This is useful if the config_kwargs should be passed onward to
|
165 |
+
another function.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
The decorated function (or a partial of the decorator)
|
169 |
+
"""
|
170 |
+
# If this decorator has been called with parameters but no function, then we
|
171 |
+
# return the decorator again (but with partially filled parameters).
|
172 |
+
# This allows using both @register_loss and @register_loss(name="foo")
|
173 |
+
if func is None:
|
174 |
+
return functools.partial(
|
175 |
+
register_loss, name=name, check_unused_kwargs=check_unused_kwargs)
|
176 |
+
|
177 |
+
# No (further) arguments: this is the actual decorator
|
178 |
+
# ensure that the loss function includes a **kwargs argument
|
179 |
+
loss_name = name if name is not None else func.__name__
|
180 |
+
if not any(v.kind == inspect.Parameter.VAR_KEYWORD
|
181 |
+
for k, v in inspect.signature(func).parameters.items()):
|
182 |
+
raise TypeError(
|
183 |
+
f"Loss function '{loss_name}' needs to include a **kwargs argument")
|
184 |
+
func.name = loss_name
|
185 |
+
func.check_unused_kwargs = check_unused_kwargs
|
186 |
+
_LOSS_FUNCTIONS[loss_name] = func
|
187 |
+
return func
|
188 |
+
|
189 |
+
|
190 |
+
def compute_loss_term(
|
191 |
+
loss_name, context_kwargs,
|
192 |
+
config_kwargs):
|
193 |
+
"""Compute a loss function given its config and context parameters.
|
194 |
+
|
195 |
+
Takes care of:
|
196 |
+
- finding the correct loss function based on "loss_type" or name
|
197 |
+
- the optional "weight" parameter
|
198 |
+
- checking for typos and collisions in config parameters
|
199 |
+
- adding the optional loss_aux_updates if omitted by the loss_fn
|
200 |
+
|
201 |
+
Args:
|
202 |
+
loss_name: Name of the loss, i.e. its key in the config.losses dict.
|
203 |
+
context_kwargs: Dictionary of context variables (`preds` and `targets`)
|
204 |
+
config_kwargs: The config dict for this loss.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
1. the loss weight (float)
|
208 |
+
2. loss term (Array)
|
209 |
+
3. loss aux updates (Dict[str, Array])
|
210 |
+
|
211 |
+
Raises:
|
212 |
+
KeyError:
|
213 |
+
Unknown loss_type
|
214 |
+
KeyError:
|
215 |
+
Unused config entries, i.e. not used by the loss function.
|
216 |
+
Not raised if using @register_loss(check_unused_kwargs=False)
|
217 |
+
KeyError: Config entry with a name that conflicts with a context_kwarg
|
218 |
+
ValueError: Non-numerical weight in config_kwargs
|
219 |
+
|
220 |
+
"""
|
221 |
+
|
222 |
+
# Make a dict copy of config_kwargs
|
223 |
+
kwargs = {k: v for k, v in config_kwargs.items()}
|
224 |
+
|
225 |
+
# Get the loss function
|
226 |
+
loss_type = kwargs.pop("loss_type", loss_name)
|
227 |
+
if loss_type not in _LOSS_FUNCTIONS:
|
228 |
+
raise KeyError(f"Unknown loss_type '{loss_type}'.")
|
229 |
+
loss_fn = _LOSS_FUNCTIONS[loss_type]
|
230 |
+
|
231 |
+
# Take care of "weight" term
|
232 |
+
weight = kwargs.pop("weight", 1.0)
|
233 |
+
if not isinstance(weight, (int, float)):
|
234 |
+
raise ValueError(f"Weight for loss {loss_name} should be a number, "
|
235 |
+
f"but was {weight}.")
|
236 |
+
|
237 |
+
# Check for unused config entries (to prevent typos etc.)
|
238 |
+
config_keys = set(kwargs)
|
239 |
+
if loss_fn.check_unused_kwargs:
|
240 |
+
param_names = set(inspect.signature(loss_fn).parameters)
|
241 |
+
unused_config_keys = config_keys - param_names
|
242 |
+
if unused_config_keys:
|
243 |
+
raise KeyError(f"Unrecognized config entries {unused_config_keys} "
|
244 |
+
f"for loss {loss_name}.")
|
245 |
+
|
246 |
+
# Check for key collisions between context and config
|
247 |
+
conflicting_config_keys = config_keys.intersection(context_kwargs)
|
248 |
+
if conflicting_config_keys:
|
249 |
+
raise KeyError(f"The config keys {conflicting_config_keys} conflict "
|
250 |
+
f"with the context parameters ({context_kwargs.keys()}) "
|
251 |
+
f"for loss {loss_name}.")
|
252 |
+
|
253 |
+
# Construct the arguments for the loss function
|
254 |
+
kwargs.update(context_kwargs)
|
255 |
+
kwargs["config_kwargs"] = config_kwargs
|
256 |
+
|
257 |
+
# Call loss
|
258 |
+
results = loss_fn(**kwargs)
|
259 |
+
|
260 |
+
# Add empty loss_aux_updates if necessary
|
261 |
+
if isinstance(results, Tuple):
|
262 |
+
loss, loss_aux_update = results
|
263 |
+
else:
|
264 |
+
loss, loss_aux_update = results, {}
|
265 |
+
|
266 |
+
return weight, loss, loss_aux_update
|
267 |
+
|
268 |
+
|
269 |
+
# -------- Loss functions --------
|
270 |
+
@register_loss
|
271 |
+
def recon(preds,
|
272 |
+
targets,
|
273 |
+
key = "video",
|
274 |
+
reduction_type = "sum",
|
275 |
+
**_):
|
276 |
+
"""Reconstruction loss (MSE)."""
|
277 |
+
squared_l2_norm_fn = jax.vmap(functools.partial(
|
278 |
+
squared_l2_norm, reduction_type=reduction_type))
|
279 |
+
targets = targets[key]
|
280 |
+
loss = squared_l2_norm_fn(preds["outputs"][key], targets)
|
281 |
+
if reduction_type == "mean":
|
282 |
+
# This rescaling reflects taking the sum over feature axis &
|
283 |
+
# mean over space/time axes.
|
284 |
+
loss *= targets.shape[-1] # pytype: disable=attribute-error # allow-recursive-types
|
285 |
+
return jnp.mean(loss)
|
286 |
+
|
287 |
+
|
288 |
+
def squared_l2_norm(preds, targets,
|
289 |
+
reduction_type = "sum"):
|
290 |
+
if reduction_type == "sum":
|
291 |
+
return jnp.sum(jnp.square(preds - targets))
|
292 |
+
elif reduction_type == "mean":
|
293 |
+
return jnp.mean(jnp.square(preds - targets))
|
294 |
+
else:
|
295 |
+
raise ValueError(f"Unsupported reduction_type: {reduction_type}")
|
invariant_slot_attention/lib/metrics.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Clustering metrics."""
|
17 |
+
|
18 |
+
from typing import Optional, Sequence, Union
|
19 |
+
|
20 |
+
from clu import metrics
|
21 |
+
import flax
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
Ndarray = Union[np.ndarray, jnp.ndarray]
|
27 |
+
|
28 |
+
|
29 |
+
def check_shape(x, expected_shape, name):
|
30 |
+
"""Check whether shape x is as expected.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x: Any data type with `shape` attribute. If `shape` attribute is not present
|
34 |
+
it is assumed to be a scalar with shape ().
|
35 |
+
expected_shape: The shape that is expected of x. For example,
|
36 |
+
[None, None, 3] can be the `expected_shape` for a color image,
|
37 |
+
[4, None, None, 3] if we know that batch size is 4.
|
38 |
+
name: Name of `x` to provide informative error messages.
|
39 |
+
|
40 |
+
Raises: ValueError if x's shape does not match expected_shape. Also raises
|
41 |
+
ValueError if expected_shape is not a list or tuple.
|
42 |
+
"""
|
43 |
+
if not isinstance(expected_shape, (list, tuple)):
|
44 |
+
raise ValueError(
|
45 |
+
"expected_shape should be a list or tuple of ints but got "
|
46 |
+
f"{expected_shape}.")
|
47 |
+
|
48 |
+
# Scalars have shape () by definition.
|
49 |
+
shape = getattr(x, "shape", ())
|
50 |
+
|
51 |
+
if (len(shape) != len(expected_shape) or
|
52 |
+
any(j is not None and i != j for i, j in zip(shape, expected_shape))):
|
53 |
+
raise ValueError(
|
54 |
+
f"Input {name} had shape {shape} but {expected_shape} was expected.")
|
55 |
+
|
56 |
+
|
57 |
+
def _validate_inputs(predicted_segmentations,
|
58 |
+
ground_truth_segmentations,
|
59 |
+
padding_mask,
|
60 |
+
mask = None):
|
61 |
+
"""Checks that all inputs have the expected shapes.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
predicted_segmentations: An array of integers of shape [bs, seq_len, H, W]
|
65 |
+
containing model segmentation predictions.
|
66 |
+
ground_truth_segmentations: An array of integers of shape [bs, seq_len, H,
|
67 |
+
W] containing ground truth segmentations.
|
68 |
+
padding_mask: An array of integers of shape [bs, seq_len, H, W] defining
|
69 |
+
regions where the ground truth is meaningless, for example because this
|
70 |
+
corresponds to regions which were padded during data augmentation. Value 0
|
71 |
+
corresponds to padded regions, 1 corresponds to valid regions to be used
|
72 |
+
for metric calculation.
|
73 |
+
mask: An optional array of boolean mask values of shape [bs]. `True`
|
74 |
+
corresponds to actual batch examples whereas `False` corresponds to
|
75 |
+
padding.
|
76 |
+
|
77 |
+
Raises:
|
78 |
+
ValueError if the inputs are not valid.
|
79 |
+
"""
|
80 |
+
|
81 |
+
check_shape(
|
82 |
+
predicted_segmentations, [None, None, None, None],
|
83 |
+
"predicted_segmentations [bs, seq_len, h, w]")
|
84 |
+
check_shape(
|
85 |
+
ground_truth_segmentations, [None, None, None, None],
|
86 |
+
"ground_truth_segmentations [bs, seq_len, h, w]")
|
87 |
+
check_shape(
|
88 |
+
predicted_segmentations, ground_truth_segmentations.shape,
|
89 |
+
"predicted_segmentations [should match ground_truth_segmentations]")
|
90 |
+
check_shape(
|
91 |
+
padding_mask, ground_truth_segmentations.shape,
|
92 |
+
"padding_mask [should match ground_truth_segmentations]")
|
93 |
+
|
94 |
+
if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer):
|
95 |
+
raise ValueError("predicted_segmentations has to be integer-valued. "
|
96 |
+
"Got {}".format(predicted_segmentations.dtype))
|
97 |
+
|
98 |
+
if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer):
|
99 |
+
raise ValueError("ground_truth_segmentations has to be integer-valued. "
|
100 |
+
"Got {}".format(ground_truth_segmentations.dtype))
|
101 |
+
|
102 |
+
if not jnp.issubdtype(padding_mask.dtype, jnp.integer):
|
103 |
+
raise ValueError("padding_mask has to be integer-valued. "
|
104 |
+
"Got {}".format(padding_mask.dtype))
|
105 |
+
|
106 |
+
if mask is not None:
|
107 |
+
check_shape(mask, [None], "mask [bs]")
|
108 |
+
if not jnp.issubdtype(mask.dtype, jnp.bool_):
|
109 |
+
raise ValueError("mask has to be boolean. Got {}".format(mask.dtype))
|
110 |
+
|
111 |
+
|
112 |
+
def adjusted_rand_index(true_ids, pred_ids,
|
113 |
+
num_instances_true, num_instances_pred,
|
114 |
+
padding_mask = None,
|
115 |
+
ignore_background = False):
|
116 |
+
"""Computes the adjusted Rand index (ARI), a clustering similarity score.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
true_ids: An integer-valued array of shape
|
120 |
+
[batch_size, seq_len, H, W]. The true cluster assignment encoded
|
121 |
+
as integer ids.
|
122 |
+
pred_ids: An integer-valued array of shape
|
123 |
+
[batch_size, seq_len, H, W]. The predicted cluster assignment
|
124 |
+
encoded as integer ids.
|
125 |
+
num_instances_true: An integer, the number of instances in true_ids
|
126 |
+
(i.e. max(true_ids) + 1).
|
127 |
+
num_instances_pred: An integer, the number of instances in true_ids
|
128 |
+
(i.e. max(pred_ids) + 1).
|
129 |
+
padding_mask: An array of integers of shape [batch_size, seq_len, H, W]
|
130 |
+
defining regions where the ground truth is meaningless, for example
|
131 |
+
because this corresponds to regions which were padded during data
|
132 |
+
augmentation. Value 0 corresponds to padded regions, 1 corresponds to
|
133 |
+
valid regions to be used for metric calculation.
|
134 |
+
ignore_background: Boolean, if True, then ignore all pixels where
|
135 |
+
true_ids == 0 (default: False).
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
ARI scores as a float32 array of shape [batch_size].
|
139 |
+
|
140 |
+
References:
|
141 |
+
Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
|
142 |
+
https://link.springer.com/article/10.1007/BF01908075
|
143 |
+
Wikipedia
|
144 |
+
https://en.wikipedia.org/wiki/Rand_index
|
145 |
+
Scikit Learn
|
146 |
+
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
|
147 |
+
"""
|
148 |
+
# pylint: disable=invalid-name
|
149 |
+
true_oh = jax.nn.one_hot(true_ids, num_instances_true)
|
150 |
+
pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred)
|
151 |
+
if padding_mask is not None:
|
152 |
+
true_oh = true_oh * padding_mask[Ellipsis, None]
|
153 |
+
# pred_oh = pred_oh * padding_mask[..., None] # <-- not needed
|
154 |
+
|
155 |
+
if ignore_background:
|
156 |
+
true_oh = true_oh[Ellipsis, 1:] # Remove the background row.
|
157 |
+
|
158 |
+
N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
|
159 |
+
A = jnp.sum(N, axis=-1) # row-sum (batch_size, c)
|
160 |
+
B = jnp.sum(N, axis=-2) # col-sum (batch_size, k)
|
161 |
+
num_points = jnp.sum(A, axis=1)
|
162 |
+
|
163 |
+
rindex = jnp.sum(N * (N - 1), axis=[1, 2])
|
164 |
+
aindex = jnp.sum(A * (A - 1), axis=1)
|
165 |
+
bindex = jnp.sum(B * (B - 1), axis=1)
|
166 |
+
expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1)
|
167 |
+
max_rindex = (aindex + bindex) / 2
|
168 |
+
denominator = max_rindex - expected_rindex
|
169 |
+
ari = (rindex - expected_rindex) / denominator
|
170 |
+
|
171 |
+
# There are two cases for which the denominator can be zero:
|
172 |
+
# 1. If both label_pred and label_true assign all pixels to a single cluster.
|
173 |
+
# (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
|
174 |
+
# 2. If both label_pred and label_true assign max 1 point to each cluster.
|
175 |
+
# (max_rindex == expected_rindex == rindex == 0)
|
176 |
+
# In both cases, we want the ARI score to be 1.0:
|
177 |
+
return jnp.where(denominator, ari, 1.0)
|
178 |
+
|
179 |
+
|
180 |
+
@flax.struct.dataclass
|
181 |
+
class Ari(metrics.Average):
|
182 |
+
"""Adjusted Rand Index (ARI) computed from predictions and labels.
|
183 |
+
|
184 |
+
ARI is a similarity score to compare two clusterings. ARI returns values in
|
185 |
+
the range [-1, 1], where 1 corresponds to two identical clusterings (up to
|
186 |
+
permutation), i.e. a perfect match between the predicted clustering and the
|
187 |
+
ground-truth clustering. A value of (close to) 0 corresponds to chance.
|
188 |
+
Negative values corresponds to cases where the agreement between the
|
189 |
+
clusterings is less than expected from a random assignment.
|
190 |
+
|
191 |
+
In this implementation, we use ARI to compare predicted instance segmentation
|
192 |
+
masks (including background prediction) with ground-truth segmentation
|
193 |
+
annotations.
|
194 |
+
"""
|
195 |
+
|
196 |
+
@classmethod
|
197 |
+
def from_model_output(cls,
|
198 |
+
predicted_segmentations,
|
199 |
+
ground_truth_segmentations,
|
200 |
+
padding_mask,
|
201 |
+
ground_truth_max_num_instances,
|
202 |
+
predicted_max_num_instances,
|
203 |
+
ignore_background = False,
|
204 |
+
mask = None,
|
205 |
+
**_):
|
206 |
+
"""Computation of the ARI clustering metric.
|
207 |
+
|
208 |
+
NOTE: This implementation does not currently support padding masks.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
predicted_segmentations: An array of integers of shape
|
212 |
+
[bs, seq_len, H, W] containing model segmentation predictions.
|
213 |
+
ground_truth_segmentations: An array of integers of shape
|
214 |
+
[bs, seq_len, H, W] containing ground truth segmentations.
|
215 |
+
padding_mask: An array of integers of shape [bs, seq_len, H, W]
|
216 |
+
defining regions where the ground truth is meaningless, for example
|
217 |
+
because this corresponds to regions which were padded during data
|
218 |
+
augmentation. Value 0 corresponds to padded regions, 1 corresponds to
|
219 |
+
valid regions to be used for metric calculation.
|
220 |
+
ground_truth_max_num_instances: Maximum number of instances (incl.
|
221 |
+
background, which counts as the 0-th instance) possible in the dataset.
|
222 |
+
predicted_max_num_instances: Maximum number of predicted instances (incl.
|
223 |
+
background).
|
224 |
+
ignore_background: If True, then ignore all pixels where
|
225 |
+
ground_truth_segmentations == 0 (default: False).
|
226 |
+
mask: An optional array of boolean mask values of shape [bs]. `True`
|
227 |
+
corresponds to actual batch examples whereas `False` corresponds to
|
228 |
+
padding.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Object of Ari with computed intermediate values.
|
232 |
+
"""
|
233 |
+
_validate_inputs(
|
234 |
+
predicted_segmentations=predicted_segmentations,
|
235 |
+
ground_truth_segmentations=ground_truth_segmentations,
|
236 |
+
padding_mask=padding_mask,
|
237 |
+
mask=mask)
|
238 |
+
|
239 |
+
batch_size = predicted_segmentations.shape[0]
|
240 |
+
|
241 |
+
if mask is None:
|
242 |
+
mask = jnp.ones(batch_size, dtype=padding_mask.dtype)
|
243 |
+
else:
|
244 |
+
mask = jnp.asarray(mask, dtype=padding_mask.dtype)
|
245 |
+
|
246 |
+
ari_batch = adjusted_rand_index(
|
247 |
+
pred_ids=predicted_segmentations,
|
248 |
+
true_ids=ground_truth_segmentations,
|
249 |
+
num_instances_true=ground_truth_max_num_instances,
|
250 |
+
num_instances_pred=predicted_max_num_instances,
|
251 |
+
padding_mask=padding_mask,
|
252 |
+
ignore_background=ignore_background)
|
253 |
+
return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask)) # pylint: disable=unexpected-keyword-arg
|
254 |
+
|
255 |
+
|
256 |
+
@flax.struct.dataclass
|
257 |
+
class AriNoBg(Ari):
|
258 |
+
"""Adjusted Rand Index (ARI), ignoring the ground-truth background label."""
|
259 |
+
|
260 |
+
@classmethod
|
261 |
+
def from_model_output(cls, **kwargs):
|
262 |
+
"""See `Ari` docstring for allowed keyword arguments."""
|
263 |
+
return super().from_model_output(**kwargs, ignore_background=True)
|
invariant_slot_attention/lib/preprocessing.py
ADDED
@@ -0,0 +1,1236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Video preprocessing ops."""
|
17 |
+
|
18 |
+
import abc
|
19 |
+
import dataclasses
|
20 |
+
import functools
|
21 |
+
from typing import Optional, Sequence, Tuple, Union
|
22 |
+
|
23 |
+
from absl import logging
|
24 |
+
from clu import preprocess_spec
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
from invariant_slot_attention.lib import transforms
|
30 |
+
|
31 |
+
Features = preprocess_spec.Features
|
32 |
+
all_ops = lambda: preprocess_spec.get_all_ops(__name__)
|
33 |
+
SEED_KEY = preprocess_spec.SEED_KEY
|
34 |
+
NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding.
|
35 |
+
NOTRACK_LABEL = -1
|
36 |
+
|
37 |
+
IMAGE = "image"
|
38 |
+
VIDEO = "video"
|
39 |
+
SEGMENTATIONS = "segmentations"
|
40 |
+
RAGGED_SEGMENTATIONS = "ragged_segmentations"
|
41 |
+
SPARSE_SEGMENTATIONS = "sparse_segmentations"
|
42 |
+
SHAPE = "shape"
|
43 |
+
PADDING_MASK = "padding_mask"
|
44 |
+
RAGGED_BOXES = "ragged_boxes"
|
45 |
+
BOXES = "boxes"
|
46 |
+
FRAMES = "frames"
|
47 |
+
FLOW = "flow"
|
48 |
+
DEPTH = "depth"
|
49 |
+
ORIGINAL_SIZE = "original_size"
|
50 |
+
INSTANCE_LABELS = "instance_labels"
|
51 |
+
INSTANCE_MULTI_LABELS = "instance_multi_labels"
|
52 |
+
BOXES_VIDEO = "boxes_video"
|
53 |
+
IMAGE_PADDING_MASK = "image_padding_mask"
|
54 |
+
VIDEO_PADDING_MASK = "video_padding_mask"
|
55 |
+
|
56 |
+
|
57 |
+
def convert_uint16_to_float(array, min_val, max_val):
|
58 |
+
return tf.cast(array, tf.float32) / 65535. * (max_val - min_val) + min_val
|
59 |
+
|
60 |
+
|
61 |
+
def get_resize_small_shape(original_size,
|
62 |
+
small_size):
|
63 |
+
h, w = original_size
|
64 |
+
ratio = (
|
65 |
+
tf.cast(small_size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32))
|
66 |
+
h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
|
67 |
+
w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)
|
68 |
+
return h, w
|
69 |
+
|
70 |
+
|
71 |
+
def adjust_small_size(original_size,
|
72 |
+
small_size, max_size):
|
73 |
+
"""Computes the adjusted small size to ensure large side < max_size."""
|
74 |
+
h, w = original_size
|
75 |
+
min_original_size = tf.cast(tf.minimum(w, h), tf.float32)
|
76 |
+
max_original_size = tf.cast(tf.maximum(w, h), tf.float32)
|
77 |
+
if max_original_size / min_original_size * small_size > max_size:
|
78 |
+
small_size = tf.cast(tf.floor(
|
79 |
+
max_size * min_original_size / max_original_size), tf.int32)
|
80 |
+
return small_size
|
81 |
+
|
82 |
+
|
83 |
+
def crop_or_pad_boxes(boxes, top, left, height,
|
84 |
+
width, h_orig, w_orig):
|
85 |
+
"""Transforms the relative box coordinates according to the frame crop.
|
86 |
+
|
87 |
+
Note that, if height/width are larger than h_orig/w_orig, this function
|
88 |
+
implements the equivalent of padding.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
boxes: Tensor of bounding boxes with shape (..., 4).
|
92 |
+
top: Top of crop box in absolute pixel coordinates.
|
93 |
+
left: Left of crop box in absolute pixel coordinates.
|
94 |
+
height: Height of crop box in absolute pixel coordinates.
|
95 |
+
width: Width of crop box in absolute pixel coordinates.
|
96 |
+
h_orig: Original image height in absolute pixel coordinates.
|
97 |
+
w_orig: Original image width in absolute pixel coordinates.
|
98 |
+
Returns:
|
99 |
+
Boxes tensor with same shape as input boxes but updated values.
|
100 |
+
"""
|
101 |
+
# Video track bound boxes: [num_instances, num_tracks, 4]
|
102 |
+
# Image bounding boxes: [num_instances, 4]
|
103 |
+
assert boxes.shape[-1] == 4
|
104 |
+
seq_len = tf.shape(boxes)[0]
|
105 |
+
has_tracks = len(boxes.shape) == 3
|
106 |
+
if has_tracks:
|
107 |
+
num_tracks = boxes.shape[1]
|
108 |
+
else:
|
109 |
+
assert len(boxes.shape) == 2
|
110 |
+
num_tracks = 1
|
111 |
+
|
112 |
+
# Transform the box coordinates.
|
113 |
+
a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32)
|
114 |
+
b = tf.cast(tf.stack([top, left]), tf.float32)
|
115 |
+
c = tf.cast(tf.stack([height, width]), tf.float32)
|
116 |
+
boxes = tf.reshape(
|
117 |
+
(tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c,
|
118 |
+
(seq_len, num_tracks, len(NOTRACK_BOX)))
|
119 |
+
|
120 |
+
# Filter the valid boxes.
|
121 |
+
boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0)
|
122 |
+
if has_tracks:
|
123 |
+
cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1)
|
124 |
+
boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX)
|
125 |
+
else:
|
126 |
+
boxes = tf.reshape(boxes, (seq_len, 4))
|
127 |
+
|
128 |
+
return boxes
|
129 |
+
|
130 |
+
|
131 |
+
def flow_tensor_to_rgb_tensor(motion_image, flow_scaling_factor=50.):
|
132 |
+
"""Visualizes flow motion image as an RGB image.
|
133 |
+
|
134 |
+
Similar as the flow_to_rgb function, but with tensors.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
motion_image: A tensor either of shape [batch_sz, height, width, 2] or of
|
138 |
+
shape [height, width, 2]. motion_image[..., 0] is flow in x and
|
139 |
+
motion_image[..., 1] is flow in y.
|
140 |
+
flow_scaling_factor: How much to scale flow for visualization.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
A visualization tensor with same shape as motion_image, except with three
|
144 |
+
channels. The dtype of the output is tf.uint8.
|
145 |
+
"""
|
146 |
+
|
147 |
+
hypot = lambda a, b: (a ** 2.0 + b ** 2.0) ** 0.5 # sqrt(a^2 + b^2)
|
148 |
+
|
149 |
+
height, width = motion_image.get_shape().as_list()[-3:-1] # pytype: disable=attribute-error # allow-recursive-types
|
150 |
+
scaling = flow_scaling_factor / hypot(height, width)
|
151 |
+
x, y = motion_image[Ellipsis, 0], motion_image[Ellipsis, 1]
|
152 |
+
motion_angle = tf.atan2(y, x)
|
153 |
+
motion_angle = (motion_angle / np.math.pi + 1.0) / 2.0
|
154 |
+
motion_magnitude = hypot(y, x)
|
155 |
+
motion_magnitude = tf.clip_by_value(motion_magnitude * scaling, 0.0, 1.0)
|
156 |
+
value_channel = tf.ones_like(motion_angle)
|
157 |
+
flow_hsv = tf.stack([motion_angle, motion_magnitude, value_channel], axis=-1)
|
158 |
+
flow_rgb = tf.image.convert_image_dtype(
|
159 |
+
tf.image.hsv_to_rgb(flow_hsv), tf.uint8)
|
160 |
+
return flow_rgb
|
161 |
+
|
162 |
+
|
163 |
+
def get_paddings(image_shape,
|
164 |
+
size,
|
165 |
+
pre_spatial_dim = None,
|
166 |
+
allow_crop = True):
|
167 |
+
"""Returns paddings tensors for tf.pad operation.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
image_shape: The shape of the Tensor to be padded. The shape can be
|
171 |
+
[..., N, H, W, C] or [..., H, W, C]. The paddings are computed for H, W
|
172 |
+
and optionally N dimensions.
|
173 |
+
size: The total size for the H and W dimensions to pad to.
|
174 |
+
pre_spatial_dim: Optional, additional padding dimension before the spatial
|
175 |
+
dimensions. It is only used if given and if len(shape) > 3.
|
176 |
+
allow_crop: If size is bigger than requested max size, padding will be
|
177 |
+
negative. If allow_crop is true, negative padding values will be set to 0.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Paddings the given tensor shape.
|
181 |
+
"""
|
182 |
+
assert image_shape.shape.rank == 1
|
183 |
+
if isinstance(size, int):
|
184 |
+
size = (size, size)
|
185 |
+
h, w = image_shape[-3], image_shape[-2]
|
186 |
+
# Spatial padding.
|
187 |
+
paddings = [
|
188 |
+
tf.stack([0, size[0] - h]),
|
189 |
+
tf.stack([0, size[1] - w]),
|
190 |
+
tf.stack([0, 0])
|
191 |
+
]
|
192 |
+
ndims = len(image_shape) # pytype: disable=wrong-arg-types
|
193 |
+
# Prepend padding for temporal dimension or number of instances.
|
194 |
+
if pre_spatial_dim is not None and ndims > 3:
|
195 |
+
paddings = [[0, pre_spatial_dim - image_shape[-4]]] + paddings
|
196 |
+
# Prepend with non-padded dimensions if available.
|
197 |
+
if ndims > len(paddings):
|
198 |
+
paddings = [[0, 0]] * (ndims - len(paddings)) + paddings
|
199 |
+
if allow_crop:
|
200 |
+
paddings = tf.maximum(paddings, 0)
|
201 |
+
return tf.stack(paddings)
|
202 |
+
|
203 |
+
|
204 |
+
@dataclasses.dataclass
|
205 |
+
class VideoFromTfds:
|
206 |
+
"""Standardize features coming from TFDS video datasets."""
|
207 |
+
|
208 |
+
video_key: str = VIDEO
|
209 |
+
segmentations_key: str = SEGMENTATIONS
|
210 |
+
ragged_segmentations_key: str = RAGGED_SEGMENTATIONS
|
211 |
+
shape_key: str = SHAPE
|
212 |
+
padding_mask_key: str = PADDING_MASK
|
213 |
+
ragged_boxes_key: str = RAGGED_BOXES
|
214 |
+
boxes_key: str = BOXES
|
215 |
+
frames_key: str = FRAMES
|
216 |
+
instance_multi_labels_key: str = INSTANCE_MULTI_LABELS
|
217 |
+
flow_key: str = FLOW
|
218 |
+
depth_key: str = DEPTH
|
219 |
+
|
220 |
+
def __call__(self, features):
|
221 |
+
|
222 |
+
features_new = {}
|
223 |
+
|
224 |
+
if "rng" in features:
|
225 |
+
features_new[SEED_KEY] = features.pop("rng")
|
226 |
+
|
227 |
+
if "instances" in features:
|
228 |
+
features_new[self.ragged_boxes_key] = features["instances"]["bboxes"]
|
229 |
+
features_new[self.frames_key] = features["instances"]["bbox_frames"]
|
230 |
+
if "segmentations" in features["instances"]:
|
231 |
+
features_new[self.ragged_segmentations_key] = tf.cast(
|
232 |
+
features["instances"]["segmentations"][Ellipsis, 0], tf.int32)
|
233 |
+
|
234 |
+
# Special handling of CLEVR (https://arxiv.org/abs/1612.06890) objects.
|
235 |
+
if ("color" in features["instances"] and
|
236 |
+
"shape" in features["instances"] and
|
237 |
+
"material" in features["instances"]):
|
238 |
+
color = tf.cast(features["instances"]["color"], tf.int32)
|
239 |
+
shape = tf.cast(features["instances"]["shape"], tf.int32)
|
240 |
+
material = tf.cast(features["instances"]["material"], tf.int32)
|
241 |
+
features_new[self.instance_multi_labels_key] = tf.stack(
|
242 |
+
(color, shape, material), axis=-1)
|
243 |
+
|
244 |
+
if "segmentations" in features:
|
245 |
+
features_new[self.segmentations_key] = tf.cast(
|
246 |
+
features["segmentations"][Ellipsis, 0], tf.int32)
|
247 |
+
|
248 |
+
if "depth" in features:
|
249 |
+
# Undo float to uint16 scaling
|
250 |
+
if "metadata" in features and "depth_range" in features["metadata"]:
|
251 |
+
depth_range = features["metadata"]["depth_range"]
|
252 |
+
features_new[self.depth_key] = convert_uint16_to_float(
|
253 |
+
features["depth"], depth_range[0], depth_range[1])
|
254 |
+
|
255 |
+
if "flows" in features:
|
256 |
+
# Some datasets use "flows" instead of "flow" for optical flow.
|
257 |
+
features["flow"] = features["flows"]
|
258 |
+
if "backward_flow" in features:
|
259 |
+
# By default, use "backward_flow" if available.
|
260 |
+
features["flow"] = features["backward_flow"]
|
261 |
+
features["metadata"]["flow_range"] = features["metadata"][
|
262 |
+
"backward_flow_range"]
|
263 |
+
if "flow" in features:
|
264 |
+
# Undo float to uint16 scaling
|
265 |
+
flow_range = features["metadata"].get("flow_range", (-255, 255))
|
266 |
+
features_new[self.flow_key] = convert_uint16_to_float(
|
267 |
+
features["flow"], flow_range[0], flow_range[1])
|
268 |
+
|
269 |
+
# Convert video to float and normalize.
|
270 |
+
video = features["video"]
|
271 |
+
assert video.dtype == tf.uint8 # pytype: disable=attribute-error # allow-recursive-types
|
272 |
+
video = tf.image.convert_image_dtype(video, tf.float32)
|
273 |
+
features_new[self.video_key] = video
|
274 |
+
|
275 |
+
# Store original video shape (e.g. for correct evaluation metrics).
|
276 |
+
features_new[self.shape_key] = tf.shape(video)
|
277 |
+
|
278 |
+
# Store padding mask
|
279 |
+
features_new[self.padding_mask_key] = tf.cast(
|
280 |
+
tf.ones_like(video)[Ellipsis, 0], tf.uint8)
|
281 |
+
|
282 |
+
return features_new
|
283 |
+
|
284 |
+
|
285 |
+
@dataclasses.dataclass
|
286 |
+
class AddTemporalAxis:
|
287 |
+
"""Lift images to videos by adding a temporal axis at the beginning.
|
288 |
+
|
289 |
+
We need to distinguish two cases because `image_ops.py` uses
|
290 |
+
ORIGINAL_SIZE = [H,W] and `video_ops.py` uses SHAPE = [T,H,W,C]:
|
291 |
+
a) The features are fed from image ops: ORIGINAL_SIZE is converted
|
292 |
+
to SHAPE ([H,W] -> [1,H,W,C]) and removed from the features.
|
293 |
+
Typical use case: Evaluation of GV image tasks in a video setting. This op
|
294 |
+
is added after the image preprocessing in order not to change the standard
|
295 |
+
image preprocessing.
|
296 |
+
b) The features are fed from video ops: The image SHAPE is lifted to a video
|
297 |
+
SHAPE ([H,W,C] -> [1,H,W,C]).
|
298 |
+
Typical use case: Training using images in a video setting. This op is added
|
299 |
+
before the video preprocessing in order not to change the standard video
|
300 |
+
preprocessing.
|
301 |
+
"""
|
302 |
+
|
303 |
+
image_key: str = IMAGE
|
304 |
+
video_key: str = VIDEO
|
305 |
+
boxes_key: str = BOXES
|
306 |
+
padding_mask_key: str = PADDING_MASK
|
307 |
+
segmentations_key: str = SEGMENTATIONS
|
308 |
+
sparse_segmentations_key: str = SPARSE_SEGMENTATIONS
|
309 |
+
shape_key: str = SHAPE
|
310 |
+
original_size_key: str = ORIGINAL_SIZE
|
311 |
+
|
312 |
+
def __call__(self, features):
|
313 |
+
assert self.image_key in features
|
314 |
+
|
315 |
+
features_new = {}
|
316 |
+
for k, v in features.items():
|
317 |
+
if k == self.image_key:
|
318 |
+
features_new[self.video_key] = v[tf.newaxis]
|
319 |
+
elif k in (self.padding_mask_key, self.boxes_key, self.segmentations_key,
|
320 |
+
self.sparse_segmentations_key):
|
321 |
+
features_new[k] = v[tf.newaxis]
|
322 |
+
elif k == self.original_size_key:
|
323 |
+
pass # See comment in the docstring of the class.
|
324 |
+
else:
|
325 |
+
features_new[k] = v
|
326 |
+
|
327 |
+
if self.original_size_key in features:
|
328 |
+
# The features come from an image preprocessing pipeline.
|
329 |
+
shape = tf.concat([[1], features[self.original_size_key],
|
330 |
+
[features[self.image_key].shape[-1]]], # pytype: disable=attribute-error # allow-recursive-types
|
331 |
+
axis=0)
|
332 |
+
elif self.shape_key in features:
|
333 |
+
# The features come from a video preprocessing pipeline.
|
334 |
+
shape = tf.concat([[1], features[self.shape_key]], axis=0)
|
335 |
+
else:
|
336 |
+
shape = tf.shape(features_new[self.video_key])
|
337 |
+
features_new[self.shape_key] = shape
|
338 |
+
|
339 |
+
if self.padding_mask_key not in features_new:
|
340 |
+
features_new[self.padding_mask_key] = tf.cast(
|
341 |
+
tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
|
342 |
+
|
343 |
+
return features_new
|
344 |
+
|
345 |
+
|
346 |
+
@dataclasses.dataclass
|
347 |
+
class SparseToDenseAnnotation:
|
348 |
+
"""Converts the sparse to a dense representation."""
|
349 |
+
|
350 |
+
max_instances: int = 10
|
351 |
+
segmentations_key: str = SEGMENTATIONS
|
352 |
+
|
353 |
+
def __call__(self, features):
|
354 |
+
|
355 |
+
features_new = {}
|
356 |
+
|
357 |
+
for k, v in features.items():
|
358 |
+
|
359 |
+
if k == self.segmentations_key:
|
360 |
+
# Dense segmentations are available for this dataset. It may be that
|
361 |
+
# max_instances < max(features_new[self.segmentations_key]).
|
362 |
+
# We prune out extra objects here.
|
363 |
+
segmentations = v
|
364 |
+
segmentations = tf.where(
|
365 |
+
tf.less_equal(segmentations, self.max_instances), segmentations, 0)
|
366 |
+
features_new[self.segmentations_key] = segmentations
|
367 |
+
else:
|
368 |
+
features_new[k] = v
|
369 |
+
|
370 |
+
return features_new
|
371 |
+
|
372 |
+
|
373 |
+
class VideoPreprocessOp(abc.ABC):
|
374 |
+
"""Base class for all video preprocess ops."""
|
375 |
+
|
376 |
+
video_key: str = VIDEO
|
377 |
+
segmentations_key: str = SEGMENTATIONS
|
378 |
+
padding_mask_key: str = PADDING_MASK
|
379 |
+
boxes_key: str = BOXES
|
380 |
+
flow_key: str = FLOW
|
381 |
+
depth_key: str = DEPTH
|
382 |
+
sparse_segmentations_key: str = SPARSE_SEGMENTATIONS
|
383 |
+
|
384 |
+
def __call__(self, features):
|
385 |
+
# Get current video shape.
|
386 |
+
video_shape = tf.shape(features[self.video_key])
|
387 |
+
# Assemble all feature keys that the op should be applied on.
|
388 |
+
all_keys = [
|
389 |
+
self.video_key, self.segmentations_key, self.padding_mask_key,
|
390 |
+
self.flow_key, self.depth_key, self.sparse_segmentations_key,
|
391 |
+
self.boxes_key
|
392 |
+
]
|
393 |
+
# Apply the op to all features.
|
394 |
+
for key in all_keys:
|
395 |
+
if key in features:
|
396 |
+
features[key] = self.apply(features[key], key, video_shape)
|
397 |
+
return features
|
398 |
+
|
399 |
+
@abc.abstractmethod
|
400 |
+
def apply(self, tensor, key,
|
401 |
+
video_shape):
|
402 |
+
"""Returns the transformed tensor.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
tensor: Any of a set of different video modalites, e.g video, flow,
|
406 |
+
bounding boxes, etc.
|
407 |
+
key: a string that indicates what feature the tensor represents so that
|
408 |
+
the apply function can take that into account.
|
409 |
+
video_shape: The shape of the video (which is necessary for some
|
410 |
+
transformations).
|
411 |
+
"""
|
412 |
+
|
413 |
+
|
414 |
+
class RandomVideoPreprocessOp(VideoPreprocessOp):
|
415 |
+
"""Base class for all random video preprocess ops."""
|
416 |
+
|
417 |
+
def __call__(self, features):
|
418 |
+
if features.get(SEED_KEY) is None:
|
419 |
+
logging.warning(
|
420 |
+
"Using random operation without seed. To avoid this "
|
421 |
+
"please provide a seed in feature %s.", SEED_KEY)
|
422 |
+
op_seed = tf.random.uniform(shape=(2,), maxval=2**32, dtype=tf.int64)
|
423 |
+
else:
|
424 |
+
features[SEED_KEY], op_seed = tf.unstack(
|
425 |
+
tf.random.experimental.stateless_split(features[SEED_KEY]))
|
426 |
+
# Get current video shape.
|
427 |
+
video_shape = tf.shape(features[self.video_key])
|
428 |
+
# Assemble all feature keys that the op should be applied on.
|
429 |
+
all_keys = [
|
430 |
+
self.video_key, self.segmentations_key, self.padding_mask_key,
|
431 |
+
self.flow_key, self.depth_key, self.sparse_segmentations_key,
|
432 |
+
self.boxes_key
|
433 |
+
]
|
434 |
+
# Apply the op to all features.
|
435 |
+
for key in all_keys:
|
436 |
+
if key in features:
|
437 |
+
features[key] = self.apply(features[key], op_seed, key, video_shape)
|
438 |
+
return features
|
439 |
+
|
440 |
+
@abc.abstractmethod
|
441 |
+
def apply(self, tensor, seed, key,
|
442 |
+
video_shape):
|
443 |
+
"""Returns the transformed tensor.
|
444 |
+
|
445 |
+
Args:
|
446 |
+
tensor: Any of a set of different video modalites, e.g video, flow,
|
447 |
+
bounding boxes, etc.
|
448 |
+
seed: A random seed.
|
449 |
+
key: a string that indicates what feature the tensor represents so that
|
450 |
+
the apply function can take that into account.
|
451 |
+
video_shape: The shape of the video (which is necessary for some
|
452 |
+
transformations).
|
453 |
+
"""
|
454 |
+
|
455 |
+
|
456 |
+
@dataclasses.dataclass
|
457 |
+
class ResizeSmall(VideoPreprocessOp):
|
458 |
+
"""Resizes the smaller (spatial) side to `size` keeping aspect ratio.
|
459 |
+
|
460 |
+
Attr:
|
461 |
+
size: An integer representing the new size of the smaller side of the input.
|
462 |
+
max_size: If set, an integer representing the maximum size in terms of the
|
463 |
+
largest side of the input.
|
464 |
+
"""
|
465 |
+
|
466 |
+
size: int
|
467 |
+
max_size: Optional[int] = None
|
468 |
+
|
469 |
+
def apply(self, tensor, key=None, video_shape=None):
|
470 |
+
"""See base class."""
|
471 |
+
|
472 |
+
# Boxes are defined in normalized image coordinates and are not affected.
|
473 |
+
if key == self.boxes_key:
|
474 |
+
return tensor
|
475 |
+
|
476 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
477 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
478 |
+
elif key == self.sparse_segmentations_key:
|
479 |
+
tensor = tf.reshape(tensor,
|
480 |
+
(-1, tf.shape(tensor)[2], tf.shape(tensor)[3], 1))
|
481 |
+
|
482 |
+
h, w = tf.shape(tensor)[1], tf.shape(tensor)[2]
|
483 |
+
|
484 |
+
# Determine resize method based on dtype (e.g. segmentations are int).
|
485 |
+
if tensor.dtype.is_integer:
|
486 |
+
resize_method = "nearest"
|
487 |
+
else:
|
488 |
+
resize_method = "bilinear"
|
489 |
+
|
490 |
+
# Clip size to max_size if needed.
|
491 |
+
small_size = self.size
|
492 |
+
if self.max_size is not None:
|
493 |
+
small_size = adjust_small_size(
|
494 |
+
original_size=(h, w), small_size=small_size, max_size=self.max_size)
|
495 |
+
new_h, new_w = get_resize_small_shape(
|
496 |
+
original_size=(h, w), small_size=small_size)
|
497 |
+
tensor = tf.image.resize(tensor, [new_h, new_w], method=resize_method)
|
498 |
+
|
499 |
+
# Flow needs to be rescaled according to the new size to stay valid.
|
500 |
+
if key == self.flow_key:
|
501 |
+
scale_h = tf.cast(new_h, tf.float32) / tf.cast(h, tf.float32)
|
502 |
+
scale_w = tf.cast(new_w, tf.float32) / tf.cast(w, tf.float32)
|
503 |
+
scale = tf.reshape(tf.stack([scale_h, scale_w], axis=0), (1, 2))
|
504 |
+
# Optionally repeat scale in case both forward and backward flow are
|
505 |
+
# stacked in the last dimension.
|
506 |
+
scale = tf.repeat(scale, tf.shape(tensor)[-1] // 2, axis=0)
|
507 |
+
scale = tf.reshape(scale, (1, 1, 1, tf.shape(tensor)[-1]))
|
508 |
+
tensor *= scale
|
509 |
+
|
510 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
511 |
+
tensor = tensor[Ellipsis, 0]
|
512 |
+
elif key == self.sparse_segmentations_key:
|
513 |
+
tensor = tf.reshape(tensor, (video_shape[0], -1, new_h, new_w))
|
514 |
+
|
515 |
+
return tensor
|
516 |
+
|
517 |
+
|
518 |
+
@dataclasses.dataclass
|
519 |
+
class CentralCrop(VideoPreprocessOp):
|
520 |
+
"""Makes central (spatial) crop of a given size.
|
521 |
+
|
522 |
+
Attr:
|
523 |
+
height: An integer representing the height of the crop.
|
524 |
+
width: An (optional) integer representing the width of the crop. Make square
|
525 |
+
crop if width is not provided.
|
526 |
+
"""
|
527 |
+
|
528 |
+
height: int
|
529 |
+
width: Optional[int] = None
|
530 |
+
|
531 |
+
def apply(self, tensor, key=None, video_shape=None):
|
532 |
+
"""See base class."""
|
533 |
+
if key == self.boxes_key:
|
534 |
+
width = self.width or self.height
|
535 |
+
h_orig, w_orig = video_shape[1], video_shape[2]
|
536 |
+
top = (h_orig - self.height) // 2
|
537 |
+
left = (w_orig - width) // 2
|
538 |
+
tensor = crop_or_pad_boxes(tensor, top, left, self.height,
|
539 |
+
width, h_orig, w_orig)
|
540 |
+
return tensor
|
541 |
+
else:
|
542 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
543 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
544 |
+
seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
|
545 |
+
h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2]
|
546 |
+
width = self.width or self.height
|
547 |
+
crop_size = (seq_len, self.height, width, n_channels)
|
548 |
+
top = (h_orig - self.height) // 2
|
549 |
+
left = (w_orig - width) // 2
|
550 |
+
tensor = tf.image.crop_to_bounding_box(tensor, top, left, self.height,
|
551 |
+
width)
|
552 |
+
tensor = tf.ensure_shape(tensor, crop_size)
|
553 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
554 |
+
tensor = tensor[Ellipsis, 0]
|
555 |
+
return tensor
|
556 |
+
|
557 |
+
|
558 |
+
@dataclasses.dataclass
|
559 |
+
class CropOrPad(VideoPreprocessOp):
|
560 |
+
"""Spatially crops or pads a video to a specified size.
|
561 |
+
|
562 |
+
Attr:
|
563 |
+
height: An integer representing the new height of the video.
|
564 |
+
width: An integer representing the new width of the video.
|
565 |
+
allow_crop: A boolean indicating if cropping is allowed.
|
566 |
+
"""
|
567 |
+
|
568 |
+
height: int
|
569 |
+
width: int
|
570 |
+
allow_crop: bool = True
|
571 |
+
|
572 |
+
def apply(self, tensor, key=None, video_shape=None):
|
573 |
+
"""See base class."""
|
574 |
+
if key == self.boxes_key:
|
575 |
+
# Pad and crop the spatial dimensions.
|
576 |
+
h_orig, w_orig = video_shape[1], video_shape[2]
|
577 |
+
if self.allow_crop:
|
578 |
+
# After cropping, the frame shape is always [self.height, self.width].
|
579 |
+
height, width = self.height, self.width
|
580 |
+
else:
|
581 |
+
# If only padding is performed, the frame size is at least
|
582 |
+
# [self.height, self.width].
|
583 |
+
height = tf.maximum(h_orig, self.height)
|
584 |
+
width = tf.maximum(w_orig, self.width)
|
585 |
+
tensor = crop_or_pad_boxes(
|
586 |
+
tensor,
|
587 |
+
top=0,
|
588 |
+
left=0,
|
589 |
+
height=height,
|
590 |
+
width=width,
|
591 |
+
h_orig=h_orig,
|
592 |
+
w_orig=w_orig)
|
593 |
+
return tensor
|
594 |
+
elif key == self.sparse_segmentations_key:
|
595 |
+
seq_len = tensor.get_shape()[0]
|
596 |
+
paddings = get_paddings(
|
597 |
+
tf.shape(tensor[Ellipsis, tf.newaxis]), (self.height, self.width),
|
598 |
+
allow_crop=self.allow_crop)[:-1]
|
599 |
+
tensor = tf.pad(tensor, paddings, constant_values=0)
|
600 |
+
if self.allow_crop:
|
601 |
+
tensor = tensor[Ellipsis, :self.height, :self.width]
|
602 |
+
tensor = tf.ensure_shape(
|
603 |
+
tensor, (seq_len, None, self.height, self.width))
|
604 |
+
return tensor
|
605 |
+
else:
|
606 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
607 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
608 |
+
seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
|
609 |
+
paddings = get_paddings(
|
610 |
+
tf.shape(tensor), (self.height, self.width),
|
611 |
+
allow_crop=self.allow_crop)
|
612 |
+
tensor = tf.pad(tensor, paddings, constant_values=0)
|
613 |
+
if self.allow_crop:
|
614 |
+
tensor = tensor[:, :self.height, :self.width, :]
|
615 |
+
tensor = tf.ensure_shape(tensor,
|
616 |
+
(seq_len, self.height, self.width, n_channels))
|
617 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
618 |
+
tensor = tensor[Ellipsis, 0]
|
619 |
+
return tensor
|
620 |
+
|
621 |
+
|
622 |
+
@dataclasses.dataclass
|
623 |
+
class RandomCrop(RandomVideoPreprocessOp):
|
624 |
+
"""Gets a random (width, height) crop of input video.
|
625 |
+
|
626 |
+
Assumption: Height and width are the same for all video-like modalities.
|
627 |
+
|
628 |
+
Attr:
|
629 |
+
height: An integer representing the height of the crop.
|
630 |
+
width: An integer representing the width of the crop.
|
631 |
+
"""
|
632 |
+
|
633 |
+
height: int
|
634 |
+
width: int
|
635 |
+
|
636 |
+
def apply(self, tensor, seed, key=None, video_shape=None):
|
637 |
+
"""See base class."""
|
638 |
+
if key == self.boxes_key:
|
639 |
+
# We copy the random generation part from tf.image.stateless_random_crop
|
640 |
+
# to generate exactly the same offset as for the video.
|
641 |
+
crop_size = (video_shape[0], self.height, self.width, video_shape[-1])
|
642 |
+
size = tf.convert_to_tensor(crop_size, tf.int32)
|
643 |
+
limit = video_shape - size + 1
|
644 |
+
offset = tf.random.stateless_uniform(
|
645 |
+
tf.shape(video_shape), dtype=tf.int32, maxval=tf.int32.max,
|
646 |
+
seed=seed) % limit
|
647 |
+
tensor = crop_or_pad_boxes(tensor, offset[1], offset[2], self.height,
|
648 |
+
self.width, video_shape[1], video_shape[2])
|
649 |
+
return tensor
|
650 |
+
elif key == self.sparse_segmentations_key:
|
651 |
+
raise NotImplementedError("Sparse segmentations aren't supported yet")
|
652 |
+
else:
|
653 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
654 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
655 |
+
seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
|
656 |
+
crop_size = (seq_len, self.height, self.width, n_channels)
|
657 |
+
tensor = tf.image.stateless_random_crop(tensor, size=crop_size, seed=seed)
|
658 |
+
tensor = tf.ensure_shape(tensor, crop_size)
|
659 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
660 |
+
tensor = tensor[Ellipsis, 0]
|
661 |
+
return tensor
|
662 |
+
|
663 |
+
|
664 |
+
@dataclasses.dataclass
|
665 |
+
class DropFrames(VideoPreprocessOp):
|
666 |
+
"""Subsamples a video by skipping frames.
|
667 |
+
|
668 |
+
Attr:
|
669 |
+
frame_skip: An integer representing the subsampling frequency of the video,
|
670 |
+
where 1 means no frames are skipped, 2 means every other frame is skipped,
|
671 |
+
and so forth.
|
672 |
+
"""
|
673 |
+
|
674 |
+
frame_skip: int
|
675 |
+
|
676 |
+
def apply(self, tensor, key=None, video_shape=None):
|
677 |
+
"""See base class."""
|
678 |
+
del key
|
679 |
+
del video_shape
|
680 |
+
tensor = tensor[::self.frame_skip]
|
681 |
+
new_length = tensor.get_shape()[0]
|
682 |
+
tensor = tf.ensure_shape(tensor, [new_length] + tensor.get_shape()[1:])
|
683 |
+
return tensor
|
684 |
+
|
685 |
+
|
686 |
+
@dataclasses.dataclass
|
687 |
+
class TemporalCropOrPad(VideoPreprocessOp):
|
688 |
+
"""Crops or pads a video in time to a specified length.
|
689 |
+
|
690 |
+
Attr:
|
691 |
+
length: An integer representing the new length of the video.
|
692 |
+
allow_crop: A boolean, specifying whether temporal cropping is allowed. If
|
693 |
+
False, will throw an error if length of the video is more than "length"
|
694 |
+
"""
|
695 |
+
|
696 |
+
length: int
|
697 |
+
allow_crop: bool = True
|
698 |
+
|
699 |
+
def _apply(self, tensor, constant_values):
|
700 |
+
frames_to_pad = self.length - tf.shape(tensor)[0]
|
701 |
+
if self.allow_crop:
|
702 |
+
frames_to_pad = tf.maximum(frames_to_pad, 0)
|
703 |
+
tensor = tf.pad(
|
704 |
+
tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1),
|
705 |
+
constant_values=constant_values)
|
706 |
+
tensor = tensor[:self.length]
|
707 |
+
tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:])
|
708 |
+
return tensor
|
709 |
+
|
710 |
+
def apply(self, tensor, key=None, video_shape=None):
|
711 |
+
"""See base class."""
|
712 |
+
del video_shape
|
713 |
+
if key == self.boxes_key:
|
714 |
+
constant_values = NOTRACK_BOX[0]
|
715 |
+
else:
|
716 |
+
constant_values = 0
|
717 |
+
return self._apply(tensor, constant_values=constant_values)
|
718 |
+
|
719 |
+
|
720 |
+
@dataclasses.dataclass
|
721 |
+
class TemporalRandomWindow(RandomVideoPreprocessOp):
|
722 |
+
"""Gets a random slice (window) along 0-th axis of input tensor.
|
723 |
+
|
724 |
+
Pads the video if the video length is shorter than the provided length.
|
725 |
+
|
726 |
+
Assumption: The number of frames is the same for all video-like modalities.
|
727 |
+
|
728 |
+
Attr:
|
729 |
+
length: An integer representing the new length of the video.
|
730 |
+
"""
|
731 |
+
|
732 |
+
length: int
|
733 |
+
|
734 |
+
def _apply(self, tensor, seed, constant_values):
|
735 |
+
length = tf.minimum(self.length, tf.shape(tensor)[0])
|
736 |
+
frames_to_pad = tf.maximum(self.length - tf.shape(tensor)[0], 0)
|
737 |
+
window_size = tf.concat(([length], tf.shape(tensor)[1:]), axis=0)
|
738 |
+
tensor = tf.image.stateless_random_crop(tensor, size=window_size, seed=seed)
|
739 |
+
tensor = tf.pad(
|
740 |
+
tensor, ((0, frames_to_pad),) + ((0, 0),) * (len(tensor.shape) - 1),
|
741 |
+
constant_values=constant_values)
|
742 |
+
tensor = tf.ensure_shape(tensor, [self.length] + tensor.get_shape()[1:])
|
743 |
+
return tensor
|
744 |
+
|
745 |
+
def apply(self, tensor, seed, key=None, video_shape=None):
|
746 |
+
"""See base class."""
|
747 |
+
del video_shape
|
748 |
+
if key == self.boxes_key:
|
749 |
+
constant_values = NOTRACK_BOX[0]
|
750 |
+
else:
|
751 |
+
constant_values = 0
|
752 |
+
return self._apply(tensor, seed, constant_values=constant_values)
|
753 |
+
|
754 |
+
|
755 |
+
@dataclasses.dataclass
|
756 |
+
class TemporalRandomStridedWindow(RandomVideoPreprocessOp):
|
757 |
+
"""Gets a random strided slice (window) along 0-th axis of input tensor.
|
758 |
+
|
759 |
+
This op is like TemporalRandomWindow but it samples from one of a set of
|
760 |
+
strides of the video, whereas TemporalRandomWindow will densely sample from
|
761 |
+
all possible slices of `length` frames from the video.
|
762 |
+
|
763 |
+
For the following video and `length=3`: [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
764 |
+
|
765 |
+
This op will return one of [1, 2, 3], [4, 5, 6], or [7, 8, 9]
|
766 |
+
|
767 |
+
This pads the video if the video length is shorter than the provided length.
|
768 |
+
|
769 |
+
Assumption: The number of frames is the same for all video-like modalities.
|
770 |
+
|
771 |
+
Attr:
|
772 |
+
length: An integer representing the new length of the video and the sampling
|
773 |
+
stride width.
|
774 |
+
"""
|
775 |
+
|
776 |
+
length: int
|
777 |
+
|
778 |
+
def _apply(self, tensor, seed,
|
779 |
+
constant_values):
|
780 |
+
"""Applies the strided crop operation to the video tensor."""
|
781 |
+
num_frames = tf.shape(tensor)[0]
|
782 |
+
num_crop_points = tf.cast(tf.math.ceil(num_frames / self.length), tf.int32)
|
783 |
+
crop_point = tf.random.stateless_uniform(
|
784 |
+
shape=(), minval=0, maxval=num_crop_points, dtype=tf.int32, seed=seed)
|
785 |
+
crop_point *= self.length
|
786 |
+
frames_sample = tensor[crop_point:crop_point + self.length]
|
787 |
+
frames_to_pad = tf.maximum(self.length - tf.shape(frames_sample)[0], 0)
|
788 |
+
frames_sample = tf.pad(
|
789 |
+
frames_sample,
|
790 |
+
((0, frames_to_pad),) + ((0, 0),) * (len(frames_sample.shape) - 1),
|
791 |
+
constant_values=constant_values)
|
792 |
+
frames_sample = tf.ensure_shape(frames_sample, [self.length] +
|
793 |
+
frames_sample.get_shape()[1:])
|
794 |
+
return frames_sample
|
795 |
+
|
796 |
+
def apply(self, tensor, seed, key=None, video_shape=None):
|
797 |
+
"""See base class."""
|
798 |
+
del video_shape
|
799 |
+
if key == self.boxes_key:
|
800 |
+
constant_values = NOTRACK_BOX[0]
|
801 |
+
else:
|
802 |
+
constant_values = 0
|
803 |
+
return self._apply(tensor, seed, constant_values=constant_values)
|
804 |
+
|
805 |
+
|
806 |
+
@dataclasses.dataclass
|
807 |
+
class FlowToRgb:
|
808 |
+
"""Converts flow to an RGB image.
|
809 |
+
|
810 |
+
NOTE: This operation requires a statically known shape for the input flow,
|
811 |
+
i.e. it is best to place it as final operation into the preprocessing
|
812 |
+
pipeline after all shapes are statically known (e.g. after cropping /
|
813 |
+
padding).
|
814 |
+
"""
|
815 |
+
flow_key: str = FLOW
|
816 |
+
|
817 |
+
def __call__(self, features):
|
818 |
+
if self.flow_key in features:
|
819 |
+
flow_rgb = flow_tensor_to_rgb_tensor(features[self.flow_key])
|
820 |
+
assert flow_rgb.dtype == tf.uint8
|
821 |
+
features[self.flow_key] = tf.image.convert_image_dtype(
|
822 |
+
flow_rgb, tf.float32)
|
823 |
+
return features
|
824 |
+
|
825 |
+
|
826 |
+
@dataclasses.dataclass
|
827 |
+
class TransformDepth:
|
828 |
+
"""Applies one of several possible transformations to depth features."""
|
829 |
+
transform: str
|
830 |
+
depth_key: str = DEPTH
|
831 |
+
|
832 |
+
def __call__(self, features):
|
833 |
+
if self.depth_key in features:
|
834 |
+
if self.transform == "log":
|
835 |
+
depth_norm = tf.math.log(features[self.depth_key])
|
836 |
+
elif self.transform == "log_plus":
|
837 |
+
depth_norm = tf.math.log(1. + features[self.depth_key])
|
838 |
+
elif self.transform == "invert_plus":
|
839 |
+
depth_norm = 1. / (1. + features[self.depth_key])
|
840 |
+
else:
|
841 |
+
raise ValueError(f"Unknown depth transformation {self.transform}")
|
842 |
+
|
843 |
+
features[self.depth_key] = depth_norm
|
844 |
+
return features
|
845 |
+
|
846 |
+
|
847 |
+
@dataclasses.dataclass
|
848 |
+
class RandomResizedCrop(RandomVideoPreprocessOp):
|
849 |
+
"""Random-resized crop for each of the two views.
|
850 |
+
|
851 |
+
Assumption: Height and width are the same for all video-like modalities.
|
852 |
+
|
853 |
+
We randomly crop the input and record the transformation this crop corresponds
|
854 |
+
to as a new feature. Croped images are resized to (height, width). Boxes are
|
855 |
+
corrected adjusted and boxes outside the crop are discarded. Flow is rescaled
|
856 |
+
so as to be pixel accurate after the operation. lidar_points_2d are
|
857 |
+
transformed using the computed transformation. These points may lie outside
|
858 |
+
the image after the operation.
|
859 |
+
|
860 |
+
Attr:
|
861 |
+
height: An integer representing the height to resize to.
|
862 |
+
width: An integer representing the width to resize to.
|
863 |
+
min_object_covered, aspect_ratio_range, area_range, max_attempts: See
|
864 |
+
docstring of `stateless_sample_distorted_bounding_box`. Aspect ratio range
|
865 |
+
has not been scaled by target aspect ratio. This differs from other
|
866 |
+
implementations of this data augmentation.
|
867 |
+
relative_box_area_threshold: If ratio of areas before and after cropping are
|
868 |
+
lower than this threshold, then the box is discarded (set to NOTRACK_BOX).
|
869 |
+
"""
|
870 |
+
# Target size.
|
871 |
+
height: int
|
872 |
+
width: int
|
873 |
+
|
874 |
+
# Crop sampling attributes.
|
875 |
+
min_object_covered: float = 0.1
|
876 |
+
aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.)
|
877 |
+
area_range: Tuple[float, float] = (0.08, 1.0)
|
878 |
+
max_attempts: int = 100
|
879 |
+
|
880 |
+
# Box retention attributes
|
881 |
+
relative_box_area_threshold: float = 0.0
|
882 |
+
|
883 |
+
def apply(self, tensor, seed, key,
|
884 |
+
video_shape):
|
885 |
+
"""Applies the crop operation on tensor."""
|
886 |
+
param = self.sample_augmentation_params(video_shape, seed)
|
887 |
+
si, sj = param[0], param[1]
|
888 |
+
crop_h, crop_w = param[2], param[3]
|
889 |
+
|
890 |
+
to_float32 = lambda x: tf.cast(x, tf.float32)
|
891 |
+
|
892 |
+
if key == self.boxes_key:
|
893 |
+
# First crop the boxes.
|
894 |
+
cropped_boxes = crop_or_pad_boxes(
|
895 |
+
tensor, si, sj,
|
896 |
+
crop_h, crop_w,
|
897 |
+
video_shape[1], video_shape[2])
|
898 |
+
# We do not need to scale the boxes because they are in normalized coords.
|
899 |
+
resized_boxes = cropped_boxes
|
900 |
+
# Lastly detects NOTRACK_BOX boxes and avoid manipulating those.
|
901 |
+
no_track_boxes = tf.convert_to_tensor(NOTRACK_BOX)
|
902 |
+
no_track_boxes = tf.reshape(no_track_boxes, [1, 4])
|
903 |
+
resized_boxes = tf.where(
|
904 |
+
tf.reduce_all(tensor == no_track_boxes, axis=-1, keepdims=True),
|
905 |
+
tensor, resized_boxes)
|
906 |
+
|
907 |
+
if self.relative_box_area_threshold > 0:
|
908 |
+
# Thresholds boxes that have been cropped too much, as in their area is
|
909 |
+
# lower, in relative terms, than `relative_box_area_threshold`.
|
910 |
+
area_before_crop = tf.reduce_prod(tensor[Ellipsis, 2:] - tensor[Ellipsis, :2],
|
911 |
+
axis=-1)
|
912 |
+
# Sets minimum area_before_crop to 1e-8 we avoid divisions by 0.
|
913 |
+
area_before_crop = tf.maximum(area_before_crop,
|
914 |
+
tf.zeros_like(area_before_crop) + 1e-8)
|
915 |
+
area_after_crop = tf.reduce_prod(
|
916 |
+
resized_boxes[Ellipsis, 2:] - resized_boxes[Ellipsis, :2], axis=-1)
|
917 |
+
# As the boxes have normalized coordinates, they need to be rescaled to
|
918 |
+
# be compared against the original uncropped boxes.
|
919 |
+
scale_x = to_float32(crop_w) / to_float32(self.width)
|
920 |
+
scale_y = to_float32(crop_h) / to_float32(self.height)
|
921 |
+
area_after_crop *= scale_x * scale_y
|
922 |
+
|
923 |
+
ratio = area_after_crop / area_before_crop
|
924 |
+
return tf.where(
|
925 |
+
tf.expand_dims(ratio > self.relative_box_area_threshold, -1),
|
926 |
+
resized_boxes, no_track_boxes)
|
927 |
+
|
928 |
+
else:
|
929 |
+
return resized_boxes
|
930 |
+
|
931 |
+
else:
|
932 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
933 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
934 |
+
|
935 |
+
# Crop.
|
936 |
+
seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
|
937 |
+
crop_size = (seq_len, crop_h, crop_w, n_channels)
|
938 |
+
tensor = tf.slice(tensor, tf.stack([0, si, sj, 0]), crop_size)
|
939 |
+
|
940 |
+
# Resize.
|
941 |
+
resize_method = tf.image.ResizeMethod.BILINEAR
|
942 |
+
if (tensor.dtype == tf.int32 or tensor.dtype == tf.int64 or
|
943 |
+
tensor.dtype == tf.uint8):
|
944 |
+
resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
|
945 |
+
tensor = tf.image.resize(tensor, [self.height, self.width],
|
946 |
+
method=resize_method)
|
947 |
+
out_size = (seq_len, self.height, self.width, n_channels)
|
948 |
+
tensor = tf.ensure_shape(tensor, out_size)
|
949 |
+
|
950 |
+
if key == self.flow_key:
|
951 |
+
# Rescales optical flow.
|
952 |
+
scale_x = to_float32(self.width) / to_float32(crop_w)
|
953 |
+
scale_y = to_float32(self.height) / to_float32(crop_h)
|
954 |
+
tensor = tf.stack(
|
955 |
+
[tensor[Ellipsis, 0] * scale_y, tensor[Ellipsis, 1] * scale_x], axis=-1)
|
956 |
+
|
957 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
958 |
+
tensor = tensor[Ellipsis, 0]
|
959 |
+
return tensor
|
960 |
+
|
961 |
+
def sample_augmentation_params(self, video_shape, rng):
|
962 |
+
"""Sample a random bounding box for the crop."""
|
963 |
+
sample_bbox = tf.image.stateless_sample_distorted_bounding_box(
|
964 |
+
video_shape[1:],
|
965 |
+
bounding_boxes=tf.constant([0.0, 0.0, 1.0, 1.0],
|
966 |
+
dtype=tf.float32, shape=[1, 1, 4]),
|
967 |
+
seed=rng,
|
968 |
+
min_object_covered=self.min_object_covered,
|
969 |
+
aspect_ratio_range=self.aspect_ratio_range,
|
970 |
+
area_range=self.area_range,
|
971 |
+
max_attempts=self.max_attempts,
|
972 |
+
use_image_if_no_bounding_boxes=True)
|
973 |
+
bbox_begin, bbox_size, _ = sample_bbox
|
974 |
+
|
975 |
+
# The specified bounding box provides crop coordinates.
|
976 |
+
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
977 |
+
target_height, target_width, _ = tf.unstack(bbox_size)
|
978 |
+
|
979 |
+
return tf.stack([offset_y, offset_x, target_height, target_width])
|
980 |
+
|
981 |
+
def estimate_transformation(self, param, video_shape
|
982 |
+
):
|
983 |
+
"""Computes the affine transformation for crop params.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
param: Crop parameters in the [y, x, h, w] format of shape [4,].
|
987 |
+
video_shape: Unused.
|
988 |
+
|
989 |
+
Returns:
|
990 |
+
Affine transformation of shape [3, 3] corresponding to cropping the image
|
991 |
+
at [y, x] of size [h, w] and resizing it into [self.height, self.width].
|
992 |
+
"""
|
993 |
+
del video_shape
|
994 |
+
crop = tf.cast(param, tf.float32)
|
995 |
+
si, sj = crop[0], crop[1]
|
996 |
+
crop_h, crop_w = crop[2], crop[3]
|
997 |
+
ei, ej = si + crop_h - 1.0, sj + crop_w - 1.0
|
998 |
+
h, w = float(self.height), float(self.width)
|
999 |
+
|
1000 |
+
a1 = (ei - si + 1.)/h
|
1001 |
+
a2 = 0.
|
1002 |
+
a3 = si - 0.5 + a1 / 2.
|
1003 |
+
a4 = 0.
|
1004 |
+
a5 = (ej - sj + 1.)/w
|
1005 |
+
a6 = sj - 0.5 + a5 / 2.
|
1006 |
+
affine = tf.stack([a1, a2, a3, a4, a5, a6, 0., 0., 1.])
|
1007 |
+
return tf.reshape(affine, [3, 3])
|
1008 |
+
|
1009 |
+
|
1010 |
+
@dataclasses.dataclass
|
1011 |
+
class TfdsImageToTfdsVideo:
|
1012 |
+
"""Lift TFDS image format to TFDS video format by adding a temporal axis.
|
1013 |
+
|
1014 |
+
This op is intended to be called directly before VideoFromTfds.
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
TFDS_SEGMENTATIONS_KEY = "segmentations"
|
1018 |
+
TFDS_INSTANCES_KEY = "instances"
|
1019 |
+
TFDS_BOXES_KEY = "bboxes"
|
1020 |
+
TFDS_BOXES_FRAMES_KEY = "bbox_frames"
|
1021 |
+
|
1022 |
+
image_key: str = IMAGE
|
1023 |
+
video_key: str = VIDEO
|
1024 |
+
boxes_image_key: str = BOXES
|
1025 |
+
boxes_key: str = BOXES_VIDEO
|
1026 |
+
image_padding_mask_key: str = IMAGE_PADDING_MASK
|
1027 |
+
video_padding_mask_key: str = VIDEO_PADDING_MASK
|
1028 |
+
depth_key: str = DEPTH
|
1029 |
+
depth_mask_key: str = "depth_mask"
|
1030 |
+
force_overwrite: bool = False
|
1031 |
+
|
1032 |
+
def __call__(self, features):
|
1033 |
+
if self.video_key in features and not self.force_overwrite:
|
1034 |
+
return features
|
1035 |
+
|
1036 |
+
features_new = {}
|
1037 |
+
for k, v in features.items():
|
1038 |
+
if k == self.image_key:
|
1039 |
+
features_new[self.video_key] = v[tf.newaxis]
|
1040 |
+
elif k == self.image_padding_mask_key:
|
1041 |
+
features_new[self.video_padding_mask_key] = v[tf.newaxis]
|
1042 |
+
elif k == self.boxes_image_key:
|
1043 |
+
features_new[self.boxes_key] = v[tf.newaxis]
|
1044 |
+
elif k == self.TFDS_SEGMENTATIONS_KEY:
|
1045 |
+
features_new[self.TFDS_SEGMENTATIONS_KEY] = v[tf.newaxis]
|
1046 |
+
elif k == self.TFDS_INSTANCES_KEY and self.TFDS_BOXES_KEY in v:
|
1047 |
+
# Add sequence dimension to boxes and create boxes frames for indexing.
|
1048 |
+
features_new[k] = v
|
1049 |
+
|
1050 |
+
# Create dummy ragged tensor (1, None) and broadcast
|
1051 |
+
dummy = tf.ragged.constant([[0]], dtype=tf.int32)
|
1052 |
+
boxes_frames_value = tf.zeros_like(
|
1053 |
+
v[self.TFDS_BOXES_KEY][Ellipsis, 0], dtype=tf.int32)[Ellipsis, tf.newaxis]
|
1054 |
+
features_new[k][self.TFDS_BOXES_FRAMES_KEY] = boxes_frames_value + dummy
|
1055 |
+
# Create dummy ragged tensor (1, None, 1) and broadcast
|
1056 |
+
dummy = tf.ragged.constant([[0]], dtype=tf.float32)[Ellipsis, tf.newaxis]
|
1057 |
+
boxes_value = v[self.TFDS_BOXES_KEY][Ellipsis, tf.newaxis, :]
|
1058 |
+
features_new[k][self.TFDS_BOXES_KEY] = boxes_value + dummy
|
1059 |
+
elif k == self.depth_key:
|
1060 |
+
features_new[self.depth_key] = v[tf.newaxis]
|
1061 |
+
elif k == self.depth_mask_key:
|
1062 |
+
features_new[self.depth_mask_key] = v[tf.newaxis]
|
1063 |
+
else:
|
1064 |
+
features_new[k] = v
|
1065 |
+
|
1066 |
+
if self.video_padding_mask_key not in features_new:
|
1067 |
+
logging.warning("Adding default video_padding_mask")
|
1068 |
+
features_new[self.video_padding_mask_key] = tf.cast(
|
1069 |
+
tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
|
1070 |
+
|
1071 |
+
return features_new
|
1072 |
+
|
1073 |
+
|
1074 |
+
@dataclasses.dataclass
|
1075 |
+
class TopLeftCrop(VideoPreprocessOp):
|
1076 |
+
"""Makes an arbitrary crop in all video frames.
|
1077 |
+
|
1078 |
+
Attr:
|
1079 |
+
top: An integer representing the horizontal coordinate of the crop start.
|
1080 |
+
left: An integer representing the vertical coordinate of the crop start.
|
1081 |
+
height: An integer representing the height of the crop.
|
1082 |
+
width: An (optional) integer representing the width of the crop. Make square
|
1083 |
+
crop if width is not provided.
|
1084 |
+
"""
|
1085 |
+
|
1086 |
+
top: int
|
1087 |
+
left: int
|
1088 |
+
height: int
|
1089 |
+
width: Optional[int] = None
|
1090 |
+
|
1091 |
+
def apply(self, tensor, key=None, video_shape=None):
|
1092 |
+
"""See base class."""
|
1093 |
+
if key in (self.boxes_key,):
|
1094 |
+
width = self.width or self.height
|
1095 |
+
h_orig, w_orig = video_shape[1], video_shape[2]
|
1096 |
+
tensor = transforms.crop_or_pad_boxes(
|
1097 |
+
tensor, self.top, self.left, self.height, width, h_orig, w_orig)
|
1098 |
+
return tensor
|
1099 |
+
else:
|
1100 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
1101 |
+
tensor = tensor[Ellipsis, tf.newaxis]
|
1102 |
+
seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3]
|
1103 |
+
h_orig, w_orig = tf.shape(tensor)[1], tf.shape(tensor)[2]
|
1104 |
+
width = self.width or self.height
|
1105 |
+
crop_size = (seq_len, self.height, width, n_channels)
|
1106 |
+
tensor = tf.image.crop_to_bounding_box(
|
1107 |
+
tensor, self.top, self.left, self.height, width)
|
1108 |
+
tensor = tf.ensure_shape(tensor, crop_size)
|
1109 |
+
if key in (self.padding_mask_key, self.segmentations_key):
|
1110 |
+
tensor = tensor[Ellipsis, 0]
|
1111 |
+
return tensor
|
1112 |
+
|
1113 |
+
|
1114 |
+
@dataclasses.dataclass
|
1115 |
+
class DeleteSmallMasks:
|
1116 |
+
"""Delete masks smaller than a selected fraction of pixels."""
|
1117 |
+
threshold: float = 0.05
|
1118 |
+
max_instances: int = 50
|
1119 |
+
max_instances_after: int = 11
|
1120 |
+
|
1121 |
+
def __call__(self, features):
|
1122 |
+
|
1123 |
+
features_new = {}
|
1124 |
+
|
1125 |
+
for key in features.keys():
|
1126 |
+
|
1127 |
+
if key == SEGMENTATIONS:
|
1128 |
+
seg = features[key]
|
1129 |
+
size = tf.shape(seg)
|
1130 |
+
|
1131 |
+
assert_op = tf.Assert(
|
1132 |
+
tf.equal(size[0], 1), ["Implemented only for a single frame."])
|
1133 |
+
|
1134 |
+
with tf.control_dependencies([assert_op]):
|
1135 |
+
# Delete time dimension.
|
1136 |
+
seg = seg[0]
|
1137 |
+
|
1138 |
+
# Get the minimum number of pixels a masks needs to have.
|
1139 |
+
max_pixels = size[1] * size[2]
|
1140 |
+
threshold_pixels = tf.cast(
|
1141 |
+
tf.cast(max_pixels, tf.float32) * self.threshold, tf.int32)
|
1142 |
+
|
1143 |
+
# Decompose the segmentation map as a single image for each instance.
|
1144 |
+
dec_seg = tf.stack(
|
1145 |
+
tf.map_fn(functools.partial(self._decompose, seg=seg),
|
1146 |
+
tf.range(self.max_instances)), axis=0)
|
1147 |
+
|
1148 |
+
# Count the pixels and find segmentation masks that are big enough.
|
1149 |
+
sums = tf.reduce_sum(dec_seg, axis=(1, 2))
|
1150 |
+
# We want the background to always be slot zero.
|
1151 |
+
# We can accomplish that be pretending it has the maximum
|
1152 |
+
# number of pixels.
|
1153 |
+
sums = tf.concat(
|
1154 |
+
[tf.ones_like(sums[0: 1]) * max_pixels, sums[1:]],
|
1155 |
+
axis=0)
|
1156 |
+
|
1157 |
+
sort = tf.argsort(sums, axis=0, direction="DESCENDING")
|
1158 |
+
sums_s = tf.gather(sums, sort, axis=0)
|
1159 |
+
mask_s = tf.cast(tf.greater_equal(sums_s, threshold_pixels), tf.int32)
|
1160 |
+
|
1161 |
+
dec_seg_plus = tf.stack(
|
1162 |
+
tf.map_fn(functools.partial(
|
1163 |
+
self._compose_sort, seg=seg, sort=sort, mask_s=mask_s),
|
1164 |
+
tf.range(self.max_instances_after)), axis=0)
|
1165 |
+
new_seg = tf.reduce_sum(dec_seg_plus, axis=0)
|
1166 |
+
|
1167 |
+
features_new[key] = tf.cast(new_seg[None], tf.int32)
|
1168 |
+
|
1169 |
+
else:
|
1170 |
+
# keep all other features
|
1171 |
+
features_new[key] = features[key]
|
1172 |
+
|
1173 |
+
return features_new
|
1174 |
+
|
1175 |
+
@classmethod
|
1176 |
+
def _decompose(cls, i, seg):
|
1177 |
+
return tf.cast(tf.equal(seg, i), tf.int32)
|
1178 |
+
|
1179 |
+
@classmethod
|
1180 |
+
def _compose_sort(cls, i, seg, sort, mask_s):
|
1181 |
+
return tf.cast(tf.equal(seg, sort[i]), tf.int32) * i * mask_s[i]
|
1182 |
+
|
1183 |
+
|
1184 |
+
@dataclasses.dataclass
|
1185 |
+
class SundsToTfdsVideo:
|
1186 |
+
"""Lift Sunds format to TFDS video format.
|
1187 |
+
|
1188 |
+
Renames fields and adds a temporal axis.
|
1189 |
+
This op is intended to be called directly before VideoFromTfds.
|
1190 |
+
"""
|
1191 |
+
|
1192 |
+
SUNDS_IMAGE_KEY = "color_image"
|
1193 |
+
SUNDS_SEGMENTATIONS_KEY = "instance_image"
|
1194 |
+
SUNDS_DEPTH_KEY = "depth_image"
|
1195 |
+
|
1196 |
+
image_key: str = SUNDS_IMAGE_KEY
|
1197 |
+
image_segmentations_key = SUNDS_SEGMENTATIONS_KEY
|
1198 |
+
video_key: str = VIDEO
|
1199 |
+
video_segmentations_key = SEGMENTATIONS
|
1200 |
+
image_depths_key: str = SUNDS_DEPTH_KEY
|
1201 |
+
depths_key = DEPTH
|
1202 |
+
video_padding_mask_key: str = VIDEO_PADDING_MASK
|
1203 |
+
force_overwrite: bool = False
|
1204 |
+
|
1205 |
+
def __call__(self, features):
|
1206 |
+
if self.video_key in features and not self.force_overwrite:
|
1207 |
+
return features
|
1208 |
+
|
1209 |
+
features_new = {}
|
1210 |
+
for k, v in features.items():
|
1211 |
+
if k == self.image_key:
|
1212 |
+
features_new[self.video_key] = v[tf.newaxis]
|
1213 |
+
elif k == self.image_segmentations_key:
|
1214 |
+
features_new[self.video_segmentations_key] = v[tf.newaxis]
|
1215 |
+
elif k == self.image_depths_key:
|
1216 |
+
features_new[self.depths_key] = v[tf.newaxis]
|
1217 |
+
else:
|
1218 |
+
features_new[k] = v
|
1219 |
+
|
1220 |
+
if self.video_padding_mask_key not in features_new:
|
1221 |
+
logging.warning("Adding default video_padding_mask")
|
1222 |
+
features_new[self.video_padding_mask_key] = tf.cast(
|
1223 |
+
tf.ones_like(features_new[self.video_key])[Ellipsis, 0], tf.uint8)
|
1224 |
+
|
1225 |
+
return features_new
|
1226 |
+
|
1227 |
+
|
1228 |
+
@dataclasses.dataclass
|
1229 |
+
class SubtractOneFromSegmentations:
|
1230 |
+
"""Subtract one from segmentation masks. Used for MultiShapeNet-Easy."""
|
1231 |
+
|
1232 |
+
segmentations_key: str = SEGMENTATIONS
|
1233 |
+
|
1234 |
+
def __call__(self, features):
|
1235 |
+
features[self.segmentations_key] = features[self.segmentations_key] - 1
|
1236 |
+
return features
|
invariant_slot_attention/lib/trainer.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The main model training loop."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
from typing import Dict, Iterable, Mapping, Optional, Tuple, Type, Union
|
22 |
+
|
23 |
+
from absl import logging
|
24 |
+
from clu import checkpoint
|
25 |
+
from clu import metric_writers
|
26 |
+
from clu import metrics
|
27 |
+
from clu import parameter_overview
|
28 |
+
from clu import periodic_actions
|
29 |
+
import flax
|
30 |
+
from flax import linen as nn
|
31 |
+
|
32 |
+
import jax
|
33 |
+
import jax.numpy as jnp
|
34 |
+
import ml_collections
|
35 |
+
import numpy as np
|
36 |
+
import optax
|
37 |
+
|
38 |
+
from scenic.train_lib import lr_schedules
|
39 |
+
from scenic.train_lib import optimizers
|
40 |
+
|
41 |
+
import tensorflow as tf
|
42 |
+
|
43 |
+
from invariant_slot_attention.lib import evaluator
|
44 |
+
from invariant_slot_attention.lib import input_pipeline
|
45 |
+
from invariant_slot_attention.lib import losses
|
46 |
+
from invariant_slot_attention.lib import utils
|
47 |
+
|
48 |
+
Array = jnp.ndarray
|
49 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
50 |
+
PRNGKey = Array
|
51 |
+
|
52 |
+
|
53 |
+
def train_step(
|
54 |
+
model,
|
55 |
+
tx,
|
56 |
+
rng,
|
57 |
+
step,
|
58 |
+
state_vars,
|
59 |
+
opt_state,
|
60 |
+
params,
|
61 |
+
batch,
|
62 |
+
loss_fn,
|
63 |
+
train_metrics_cls,
|
64 |
+
predicted_max_num_instances,
|
65 |
+
ground_truth_max_num_instances,
|
66 |
+
conditioning_key = None,
|
67 |
+
):
|
68 |
+
"""Perform a single training step.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
model: Model used in training step.
|
72 |
+
tx: The optimizer to use to minimize loss_fn.
|
73 |
+
rng: Random number key
|
74 |
+
step: Which training step we are on.
|
75 |
+
state_vars: Accessory variables.
|
76 |
+
opt_state: The state of the optimizer.
|
77 |
+
params: The current parameters to be updated.
|
78 |
+
batch: Training inputs for this step.
|
79 |
+
loss_fn: Loss function that takes model predictions and a batch of data.
|
80 |
+
train_metrics_cls: The metrics collection for computing training metrics.
|
81 |
+
predicted_max_num_instances: Maximum number of instances in prediction.
|
82 |
+
ground_truth_max_num_instances: Maximum number of instances in ground truth,
|
83 |
+
including background (which counts as a separate instance).
|
84 |
+
conditioning_key: Optional string. If provided, defines the batch key to be
|
85 |
+
used as conditioning signal for the model. Otherwise this is inferred from
|
86 |
+
the available keys in the batch.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Tuple of the updated opt, state_vars, new random number key,
|
90 |
+
metrics update, and step + 1. Note that some of this info is stored in
|
91 |
+
TrainState, but here it is unpacked.
|
92 |
+
"""
|
93 |
+
|
94 |
+
# Split PRNGKey and bind to host / device.
|
95 |
+
new_rng, rng = jax.random.split(rng)
|
96 |
+
rng = jax.random.fold_in(rng, jax.host_id())
|
97 |
+
rng = jax.random.fold_in(rng, jax.lax.axis_index("batch"))
|
98 |
+
init_rng, dropout_rng = jax.random.split(rng, 2)
|
99 |
+
|
100 |
+
mutable_var_keys = list(state_vars.keys()) + ["intermediates"]
|
101 |
+
|
102 |
+
conditioning = batch[conditioning_key] if conditioning_key else None
|
103 |
+
|
104 |
+
def train_loss_fn(params, state_vars):
|
105 |
+
preds, mutable_vars = model.apply(
|
106 |
+
{"params": params, **state_vars}, video=batch["video"],
|
107 |
+
conditioning=conditioning, mutable=mutable_var_keys,
|
108 |
+
rngs={"state_init": init_rng, "dropout": dropout_rng}, train=True,
|
109 |
+
padding_mask=batch.get("padding_mask"))
|
110 |
+
# Filter intermediates, as we do not want to store them in the TrainState.
|
111 |
+
state_vars = utils.filter_key_from_frozen_dict(
|
112 |
+
mutable_vars, key="intermediates")
|
113 |
+
loss, loss_aux = loss_fn(preds, batch)
|
114 |
+
return loss, (state_vars, preds, loss_aux)
|
115 |
+
|
116 |
+
grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
|
117 |
+
(loss, (state_vars, preds, loss_aux)), grad = grad_fn(params, state_vars)
|
118 |
+
|
119 |
+
# Compute average gradient across multiple workers.
|
120 |
+
grad = jax.lax.pmean(grad, axis_name="batch")
|
121 |
+
|
122 |
+
updates, new_opt_state = tx.update(grad, opt_state, params)
|
123 |
+
new_params = optax.apply_updates(params, updates)
|
124 |
+
|
125 |
+
# Compute metrics.
|
126 |
+
metrics_update = train_metrics_cls.gather_from_model_output(
|
127 |
+
loss=loss,
|
128 |
+
**loss_aux,
|
129 |
+
predicted_segmentations=utils.remove_singleton_dim(
|
130 |
+
preds["outputs"].get("segmentations")), # pytype: disable=attribute-error
|
131 |
+
ground_truth_segmentations=batch.get("segmentations"),
|
132 |
+
predicted_max_num_instances=predicted_max_num_instances,
|
133 |
+
ground_truth_max_num_instances=ground_truth_max_num_instances,
|
134 |
+
padding_mask=batch.get("padding_mask"),
|
135 |
+
mask=batch.get("mask"))
|
136 |
+
return (
|
137 |
+
new_opt_state, new_params, state_vars, new_rng, metrics_update, step + 1)
|
138 |
+
|
139 |
+
|
140 |
+
def train_and_evaluate(config,
|
141 |
+
workdir):
|
142 |
+
"""Runs a training and evaluation loop.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
config: Configuration to use.
|
146 |
+
workdir: Working directory for checkpoints and TF summaries. If this
|
147 |
+
contains checkpoint training will be resumed from the latest checkpoint.
|
148 |
+
"""
|
149 |
+
rng = jax.random.PRNGKey(config.seed)
|
150 |
+
|
151 |
+
tf.io.gfile.makedirs(workdir)
|
152 |
+
|
153 |
+
# Input pipeline.
|
154 |
+
rng, data_rng = jax.random.split(rng)
|
155 |
+
# Make sure each host uses a different RNG for the training data.
|
156 |
+
if config.get("seed_data", True): # Default to seeding data if not specified.
|
157 |
+
data_rng = jax.random.fold_in(data_rng, jax.host_id())
|
158 |
+
else:
|
159 |
+
data_rng = None
|
160 |
+
train_ds, eval_ds = input_pipeline.create_datasets(config, data_rng)
|
161 |
+
train_iter = iter(train_ds) # pytype: disable=wrong-arg-types
|
162 |
+
|
163 |
+
# Initialize model
|
164 |
+
model = utils.build_model_from_config(config.model)
|
165 |
+
|
166 |
+
# Construct TrainMetrics and EvalMetrics, metrics collections.
|
167 |
+
train_metrics_cls = utils.make_metrics_collection("TrainMetrics",
|
168 |
+
config.train_metrics_spec)
|
169 |
+
eval_metrics_cls = utils.make_metrics_collection("EvalMetrics",
|
170 |
+
config.eval_metrics_spec)
|
171 |
+
|
172 |
+
def init_model(rng):
|
173 |
+
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
174 |
+
|
175 |
+
init_conditioning = None
|
176 |
+
if config.get("conditioning_key"):
|
177 |
+
init_conditioning = jnp.ones(
|
178 |
+
[1] + list(train_ds.element_spec[config.conditioning_key].shape)[2:],
|
179 |
+
jnp.int32)
|
180 |
+
init_inputs = jnp.ones(
|
181 |
+
[1] + list(train_ds.element_spec["video"].shape)[2:],
|
182 |
+
jnp.float32)
|
183 |
+
initial_vars = model.init(
|
184 |
+
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
185 |
+
video=init_inputs, conditioning=init_conditioning,
|
186 |
+
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
|
187 |
+
|
188 |
+
# Split into state variables (e.g. for batchnorm stats) and model params.
|
189 |
+
# Note that `pop()` on a FrozenDict performs a deep copy.
|
190 |
+
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
|
191 |
+
|
192 |
+
# Filter out intermediates (we don't want to store these in the TrainState).
|
193 |
+
state_vars = utils.filter_key_from_frozen_dict(
|
194 |
+
state_vars, key="intermediates")
|
195 |
+
return state_vars, initial_params
|
196 |
+
|
197 |
+
state_vars, initial_params = init_model(rng)
|
198 |
+
parameter_overview.log_parameter_overview(initial_params) # pytype: disable=wrong-arg-types
|
199 |
+
|
200 |
+
learning_rate_fn = lr_schedules.get_learning_rate_fn(config)
|
201 |
+
tx = optimizers.get_optimizer(
|
202 |
+
config.optimizer_configs, learning_rate_fn, params=initial_params)
|
203 |
+
|
204 |
+
opt_state = tx.init(initial_params)
|
205 |
+
|
206 |
+
state = utils.TrainState(
|
207 |
+
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
208 |
+
variables=state_vars)
|
209 |
+
|
210 |
+
loss_fn = functools.partial(
|
211 |
+
losses.compute_full_loss, loss_config=config.losses)
|
212 |
+
|
213 |
+
checkpoint_dir = os.path.join(workdir, "checkpoints")
|
214 |
+
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
215 |
+
state = ckpt.restore_or_initialize(state)
|
216 |
+
initial_step = int(state.step)
|
217 |
+
|
218 |
+
# Replicate our parameters.
|
219 |
+
state = flax.jax_utils.replicate(state, devices=jax.local_devices())
|
220 |
+
del rng # rng is stored in the state.
|
221 |
+
|
222 |
+
# Only write metrics on host 0, write to logs on all other hosts.
|
223 |
+
writer = metric_writers.create_default_writer(
|
224 |
+
workdir, just_logging=jax.host_id() > 0)
|
225 |
+
writer.write_hparams(utils.prepare_dict_for_logging(config.to_dict()))
|
226 |
+
|
227 |
+
logging.info("Starting training loop at step %d.", initial_step)
|
228 |
+
report_progress = periodic_actions.ReportProgress(
|
229 |
+
num_train_steps=config.num_train_steps, writer=writer)
|
230 |
+
if jax.process_index() == 0:
|
231 |
+
profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
|
232 |
+
p_train_step = jax.pmap(
|
233 |
+
train_step,
|
234 |
+
axis_name="batch",
|
235 |
+
donate_argnums=(2, 3, 4, 5, 6, 7),
|
236 |
+
static_broadcasted_argnums=(0, 1, 8, 9, 10, 11, 12))
|
237 |
+
|
238 |
+
train_metrics = None
|
239 |
+
with metric_writers.ensure_flushes(writer):
|
240 |
+
if config.num_train_steps == 0:
|
241 |
+
with report_progress.timed("eval"):
|
242 |
+
evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls, config,
|
243 |
+
writer, step=0)
|
244 |
+
with report_progress.timed("checkpoint"):
|
245 |
+
ckpt.save(flax.jax_utils.unreplicate(state))
|
246 |
+
return
|
247 |
+
|
248 |
+
for step in range(initial_step, config.num_train_steps + 1):
|
249 |
+
# `step` is a Python integer. `state.step` is JAX integer on GPU/TPU.
|
250 |
+
is_last_step = step == config.num_train_steps
|
251 |
+
|
252 |
+
with jax.profiler.StepTraceAnnotation("train", step_num=step):
|
253 |
+
batch = jax.tree_map(np.asarray, next(train_iter))
|
254 |
+
(opt_state, params, state_vars, rng, metrics_update, p_step
|
255 |
+
) = p_train_step(
|
256 |
+
model, tx, state.rng, state.step, state.variables,
|
257 |
+
state.opt_state, state.params, batch, loss_fn,
|
258 |
+
train_metrics_cls,
|
259 |
+
config.num_slots,
|
260 |
+
config.max_instances + 1, # Incl. background.
|
261 |
+
config.get("conditioning_key"))
|
262 |
+
|
263 |
+
state = state.replace( # pytype: disable=attribute-error
|
264 |
+
opt_state=opt_state,
|
265 |
+
params=params,
|
266 |
+
step=p_step,
|
267 |
+
variables=state_vars,
|
268 |
+
rng=rng,
|
269 |
+
)
|
270 |
+
|
271 |
+
metric_update = flax.jax_utils.unreplicate(metrics_update)
|
272 |
+
train_metrics = (
|
273 |
+
metric_update
|
274 |
+
if train_metrics is None else train_metrics.merge(metric_update))
|
275 |
+
|
276 |
+
# Quick indication that training is happening.
|
277 |
+
logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
|
278 |
+
report_progress(step, time.time())
|
279 |
+
|
280 |
+
if jax.process_index() == 0:
|
281 |
+
profiler(step)
|
282 |
+
|
283 |
+
if step % config.log_loss_every_steps == 0 or is_last_step:
|
284 |
+
metrics_res = train_metrics.compute()
|
285 |
+
writer.write_scalars(step, jax.tree_map(np.array, metrics_res))
|
286 |
+
train_metrics = None
|
287 |
+
|
288 |
+
if step % config.eval_every_steps == 0 or is_last_step:
|
289 |
+
with report_progress.timed("eval"):
|
290 |
+
evaluate(model, state, eval_ds, loss_fn, eval_metrics_cls,
|
291 |
+
config, writer, step=step)
|
292 |
+
|
293 |
+
if step % config.checkpoint_every_steps == 0 or is_last_step:
|
294 |
+
with report_progress.timed("checkpoint"):
|
295 |
+
ckpt.save(flax.jax_utils.unreplicate(state))
|
296 |
+
|
297 |
+
|
298 |
+
def evaluate(model, state, eval_ds, loss_fn_eval, eval_metrics_cls, config,
|
299 |
+
writer, step):
|
300 |
+
"""Evaluate the model."""
|
301 |
+
eval_metrics, eval_batch, eval_preds = evaluator.evaluate(
|
302 |
+
model,
|
303 |
+
state,
|
304 |
+
eval_ds,
|
305 |
+
loss_fn_eval,
|
306 |
+
eval_metrics_cls,
|
307 |
+
predicted_max_num_instances=config.num_slots,
|
308 |
+
ground_truth_max_num_instances=config.max_instances + 1, # Incl. bg.
|
309 |
+
slice_size=config.get("eval_slice_size"),
|
310 |
+
slice_keys=config.get("eval_slice_keys"),
|
311 |
+
conditioning_key=config.get("conditioning_key"),
|
312 |
+
remove_from_predictions=config.get("remove_from_predictions"),
|
313 |
+
metrics_on_cpu=config.get("metrics_on_cpu", False))
|
314 |
+
|
315 |
+
metrics_res = eval_metrics.compute()
|
316 |
+
writer.write_scalars(
|
317 |
+
step, jax.tree_map(np.array, utils.flatten_named_dicttree(metrics_res)))
|
318 |
+
writer.write_images(
|
319 |
+
step,
|
320 |
+
jax.tree_map(
|
321 |
+
np.array,
|
322 |
+
utils.prepare_images_for_logging(
|
323 |
+
config,
|
324 |
+
eval_batch,
|
325 |
+
eval_preds,
|
326 |
+
n_samples=config.get("n_samples", 5),
|
327 |
+
n_frames=config.get("n_frames", 1),
|
328 |
+
min_n_colors=config.get("logging_min_n_colors", 1))))
|
invariant_slot_attention/lib/transforms.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Transform functions for preprocessing."""
|
17 |
+
from typing import Any, Optional, Tuple
|
18 |
+
|
19 |
+
import tensorflow as tf
|
20 |
+
|
21 |
+
|
22 |
+
SizeTuple = Tuple[tf.Tensor, tf.Tensor] # (height, width).
|
23 |
+
Self = Any
|
24 |
+
|
25 |
+
PADDING_VALUE = -1
|
26 |
+
PADDING_VALUE_STR = b""
|
27 |
+
|
28 |
+
NOTRACK_BOX = (0., 0., 0., 0.) # No-track bounding box for padding.
|
29 |
+
NOTRACK_RBOX = (0., 0., 0., 0., 0.) # No-track bounding rbox for padding.
|
30 |
+
|
31 |
+
|
32 |
+
def crop_or_pad_boxes(boxes, top, left, height,
|
33 |
+
width, h_orig, w_orig,
|
34 |
+
min_cropped_area = None):
|
35 |
+
"""Transforms the relative box coordinates according to the frame crop.
|
36 |
+
|
37 |
+
Note that, if height/width are larger than h_orig/w_orig, this function
|
38 |
+
implements the equivalent of padding.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
boxes: Tensor of bounding boxes with shape (..., 4).
|
42 |
+
top: Top of crop box in absolute pixel coordinates.
|
43 |
+
left: Left of crop box in absolute pixel coordinates.
|
44 |
+
height: Height of crop box in absolute pixel coordinates.
|
45 |
+
width: Width of crop box in absolute pixel coordinates.
|
46 |
+
h_orig: Original image height in absolute pixel coordinates.
|
47 |
+
w_orig: Original image width in absolute pixel coordinates.
|
48 |
+
min_cropped_area: If set, remove cropped boxes whose area relative to the
|
49 |
+
original box is less than min_cropped_area or that covers the entire
|
50 |
+
image.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Boxes tensor with same shape as input boxes but updated values.
|
54 |
+
"""
|
55 |
+
# Video track bound boxes: [num_instances, num_tracks, 4]
|
56 |
+
# Image bounding boxes: [num_instances, 4]
|
57 |
+
assert boxes.shape[-1] == 4
|
58 |
+
seq_len = tf.shape(boxes)[0]
|
59 |
+
not_padding = tf.reduce_any(tf.not_equal(boxes, PADDING_VALUE), axis=-1)
|
60 |
+
has_tracks = len(boxes.shape) == 3
|
61 |
+
if has_tracks:
|
62 |
+
num_tracks = tf.shape(boxes)[1]
|
63 |
+
else:
|
64 |
+
assert len(boxes.shape) == 2
|
65 |
+
num_tracks = 1
|
66 |
+
|
67 |
+
# Transform the box coordinates.
|
68 |
+
a = tf.cast(tf.stack([h_orig, w_orig]), tf.float32)
|
69 |
+
b = tf.cast(tf.stack([top, left]), tf.float32)
|
70 |
+
c = tf.cast(tf.stack([height, width]), tf.float32)
|
71 |
+
boxes = tf.reshape(
|
72 |
+
(tf.reshape(boxes, (seq_len, num_tracks, 2, 2)) * a - b) / c,
|
73 |
+
(seq_len, num_tracks, len(NOTRACK_BOX)),
|
74 |
+
)
|
75 |
+
|
76 |
+
# Filter the valid boxes.
|
77 |
+
areas_uncropped = tf.reduce_prod(
|
78 |
+
tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1
|
79 |
+
)
|
80 |
+
boxes = tf.minimum(tf.maximum(boxes, 0.0), 1.0)
|
81 |
+
if has_tracks:
|
82 |
+
cond = tf.reduce_all((boxes[:, :, 2:] - boxes[:, :, :2]) > 0.0, axis=-1)
|
83 |
+
boxes = tf.where(cond[:, :, tf.newaxis], boxes, NOTRACK_BOX)
|
84 |
+
if min_cropped_area is not None:
|
85 |
+
areas_cropped = tf.reduce_prod(
|
86 |
+
tf.maximum(boxes[Ellipsis, 2:] - boxes[Ellipsis, :2], 0), axis=-1
|
87 |
+
)
|
88 |
+
boxes = tf.where(
|
89 |
+
tf.logical_and(
|
90 |
+
tf.reduce_max(areas_cropped, axis=0, keepdims=True)
|
91 |
+
> min_cropped_area * areas_uncropped,
|
92 |
+
tf.reduce_min(areas_cropped, axis=0, keepdims=True) < 1,
|
93 |
+
)[Ellipsis, tf.newaxis],
|
94 |
+
boxes,
|
95 |
+
tf.constant(NOTRACK_BOX)[tf.newaxis, tf.newaxis],
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
boxes = tf.reshape(boxes, (seq_len, 4))
|
99 |
+
# Image ops use `-1``, whereas video ops above use `NOTRACK_BOX`.
|
100 |
+
boxes = tf.where(not_padding[Ellipsis, tf.newaxis], boxes, PADDING_VALUE)
|
101 |
+
|
102 |
+
return boxes
|
103 |
+
|
104 |
+
|
105 |
+
def cxcywha_to_corners(cxcywha):
|
106 |
+
"""Convert [cx, cy, w, h, a] to four corners of [x, y].
|
107 |
+
|
108 |
+
TF version of cxcywha_to_corners in
|
109 |
+
third_party/py/scenic/model_lib/base_models/box_utils.py.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
cxcywha: [..., 5]-tf.Tensor of [center-x, center-y, width, height, angle]
|
113 |
+
representation of rotated boxes. Angle is in radians and center of rotation
|
114 |
+
is defined by [center-x, center-y] point.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
[..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y] points.
|
118 |
+
"""
|
119 |
+
assert cxcywha.shape[-1] == 5, "Expected [..., [cx, cy, w, h, a] input."
|
120 |
+
bs = cxcywha.shape[:-1]
|
121 |
+
cx, cy, w, h, a = tf.split(cxcywha, num_or_size_splits=5, axis=-1)
|
122 |
+
xs = tf.constant([.5, .5, -.5, -.5]) * w
|
123 |
+
ys = tf.constant([-.5, .5, .5, -.5]) * h
|
124 |
+
pts = tf.stack([xs, ys], axis=-1)
|
125 |
+
sin = tf.sin(a)
|
126 |
+
cos = tf.cos(a)
|
127 |
+
rot = tf.reshape(tf.concat([cos, -sin, sin, cos], axis=-1), (*bs, 2, 2))
|
128 |
+
offset = tf.reshape(tf.concat([cx, cy], -1), (*bs, 1, 2))
|
129 |
+
corners = pts @ rot + offset
|
130 |
+
return corners
|
131 |
+
|
132 |
+
|
133 |
+
def corners_to_cxcywha(corners):
|
134 |
+
"""Convert four corners of [x, y] to [cx, cy, w, h, a].
|
135 |
+
|
136 |
+
Args:
|
137 |
+
corners: [..., 4, 2]-tf.Tensor of four corners of the rotated box as [x, y]
|
138 |
+
points.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
[..., 5]-tf.Tensor of [center-x, center-y, width, height, angle]
|
142 |
+
representation of rotated boxes. Angle is in radians and center of rotation
|
143 |
+
is defined by [center-x, center-y] point.
|
144 |
+
"""
|
145 |
+
assert corners.shape[-2] == 4 and corners.shape[-1] == 2, (
|
146 |
+
"Expected [..., [cx, cy, w, h, a] input.")
|
147 |
+
|
148 |
+
cornersx, cornersy = tf.unstack(corners, axis=-1)
|
149 |
+
cx = tf.reduce_mean(cornersx, axis=-1)
|
150 |
+
cy = tf.reduce_mean(cornersy, axis=-1)
|
151 |
+
wcornersx = (
|
152 |
+
cornersx[Ellipsis, 0] + cornersx[Ellipsis, 1] - cornersx[Ellipsis, 2] - cornersx[Ellipsis, 3])
|
153 |
+
wcornersy = (
|
154 |
+
cornersy[Ellipsis, 0] + cornersy[Ellipsis, 1] - cornersy[Ellipsis, 2] - cornersy[Ellipsis, 3])
|
155 |
+
hcornersy = (-cornersy[Ellipsis, 0,] + cornersy[Ellipsis, 1] + cornersy[Ellipsis, 2] -
|
156 |
+
cornersy[Ellipsis, 3])
|
157 |
+
a = -tf.atan2(wcornersy, wcornersx)
|
158 |
+
cos = tf.cos(a)
|
159 |
+
w = wcornersx / (2 * cos)
|
160 |
+
h = hcornersy / (2 * cos)
|
161 |
+
cxcywha = tf.stack([cx, cy, w, h, a], axis=-1)
|
162 |
+
|
163 |
+
return cxcywha
|
invariant_slot_attention/lib/utils.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Common utils."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union
|
21 |
+
|
22 |
+
from absl import logging
|
23 |
+
from clu import metrics as base_metrics
|
24 |
+
|
25 |
+
import flax
|
26 |
+
from flax import linen as nn
|
27 |
+
from flax import traverse_util
|
28 |
+
|
29 |
+
import jax
|
30 |
+
import jax.numpy as jnp
|
31 |
+
import jax.ops
|
32 |
+
|
33 |
+
import matplotlib
|
34 |
+
import matplotlib.pyplot as plt
|
35 |
+
import ml_collections
|
36 |
+
import numpy as np
|
37 |
+
import optax
|
38 |
+
|
39 |
+
import skimage.transform
|
40 |
+
import tensorflow as tf
|
41 |
+
|
42 |
+
from invariant_slot_attention.lib import metrics
|
43 |
+
|
44 |
+
|
45 |
+
Array = Any # Union[np.ndarray, jnp.ndarray]
|
46 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
47 |
+
DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet
|
48 |
+
PRNGKey = Array
|
49 |
+
ConfigAttr = Any
|
50 |
+
MetricSpec = Dict[str, str]
|
51 |
+
|
52 |
+
|
53 |
+
@flax.struct.dataclass
|
54 |
+
class TrainState:
|
55 |
+
"""Data structure for checkpointing the model."""
|
56 |
+
step: int
|
57 |
+
opt_state: optax.OptState
|
58 |
+
params: ArrayTree
|
59 |
+
variables: flax.core.FrozenDict
|
60 |
+
rng: PRNGKey
|
61 |
+
|
62 |
+
|
63 |
+
METRIC_TYPE_TO_CLS = {
|
64 |
+
"loss": base_metrics.Average.from_output(name="loss"),
|
65 |
+
"ari": metrics.Ari,
|
66 |
+
"ari_nobg": metrics.AriNoBg,
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
def make_metrics_collection(
|
71 |
+
class_name,
|
72 |
+
metrics_spec):
|
73 |
+
"""Create class inhering from metrics.Collection based on spec."""
|
74 |
+
metrics_dict = {}
|
75 |
+
if metrics_spec:
|
76 |
+
for m_name, m_type in metrics_spec.items():
|
77 |
+
metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type]
|
78 |
+
|
79 |
+
return flax.struct.dataclass(
|
80 |
+
type(class_name,
|
81 |
+
(base_metrics.Collection,),
|
82 |
+
{"__annotations__": metrics_dict}))
|
83 |
+
|
84 |
+
|
85 |
+
def flatten_named_dicttree(metrics_res, sep = "/"):
|
86 |
+
"""Flatten dictionary."""
|
87 |
+
metrics_res_flat = {}
|
88 |
+
for k, v in traverse_util.flatten_dict(metrics_res).items():
|
89 |
+
metrics_res_flat[(sep.join(k)).strip(sep)] = v
|
90 |
+
return metrics_res_flat
|
91 |
+
|
92 |
+
|
93 |
+
def spatial_broadcast(x, resolution):
|
94 |
+
"""Broadcast flat inputs to a 2D grid of a given resolution."""
|
95 |
+
# x.shape = (batch_size, features).
|
96 |
+
x = x[:, jnp.newaxis, jnp.newaxis, :]
|
97 |
+
return jnp.tile(x, [1, resolution[0], resolution[1], 1])
|
98 |
+
|
99 |
+
|
100 |
+
def time_distributed(cls, in_axes=1, axis=1):
|
101 |
+
"""Wrapper for time-distributed (vmapped) application of a module."""
|
102 |
+
return nn.vmap(
|
103 |
+
cls, in_axes=in_axes, out_axes=axis, axis_name="time",
|
104 |
+
# Stack debug vars along sequence dim and broadcast params.
|
105 |
+
variable_axes={
|
106 |
+
"params": None, "intermediates": axis, "batch_stats": None},
|
107 |
+
split_rngs={"params": False, "dropout": True, "state_init": True})
|
108 |
+
|
109 |
+
|
110 |
+
def broadcast_across_batch(inputs, batch_size):
|
111 |
+
"""Broadcasts inputs across a batch of examples (creates new axis)."""
|
112 |
+
return jnp.broadcast_to(
|
113 |
+
array=jnp.expand_dims(inputs, axis=0),
|
114 |
+
shape=(batch_size,) + inputs.shape)
|
115 |
+
|
116 |
+
|
117 |
+
def create_gradient_grid(
|
118 |
+
samples_per_dim, value_range = (-1.0, 1.0)
|
119 |
+
):
|
120 |
+
"""Creates a tensor with equidistant entries from -1 to +1 in each dim.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
samples_per_dim: Number of points to have along each dimension.
|
124 |
+
value_range: In each dimension, points will go from range[0] to range[1]
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
A tensor of shape [samples_per_dim] + [len(samples_per_dim)].
|
128 |
+
"""
|
129 |
+
s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim]
|
130 |
+
pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1)
|
131 |
+
return jnp.array(pe)
|
132 |
+
|
133 |
+
|
134 |
+
def convert_to_fourier_features(inputs, basis_degree):
|
135 |
+
"""Convert inputs to Fourier features, e.g. for positional encoding."""
|
136 |
+
|
137 |
+
# inputs.shape = (..., n_dims).
|
138 |
+
# inputs should be in range [-pi, pi] or [0, 2pi].
|
139 |
+
n_dims = inputs.shape[-1]
|
140 |
+
|
141 |
+
# Generate frequency basis.
|
142 |
+
freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree)
|
143 |
+
[2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1)
|
144 |
+
|
145 |
+
# x.shape = (..., n_dims * basis_degree)
|
146 |
+
x = inputs @ freq_basis # Project inputs onto frequency basis.
|
147 |
+
|
148 |
+
# Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)].
|
149 |
+
return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1))
|
150 |
+
|
151 |
+
|
152 |
+
def prepare_images_for_logging(
|
153 |
+
config,
|
154 |
+
batch = None,
|
155 |
+
preds = None,
|
156 |
+
n_samples = 5,
|
157 |
+
n_frames = 5,
|
158 |
+
min_n_colors = 1,
|
159 |
+
epsilon = 1e-6,
|
160 |
+
first_replica_only = False):
|
161 |
+
"""Prepare images from batch and/or model predictions for logging."""
|
162 |
+
|
163 |
+
images = dict()
|
164 |
+
# Converts all tensors to numpy arrays to run everything on CPU as JAX
|
165 |
+
# eager mode is inefficient and because memory usage from these ops may
|
166 |
+
# lead to OOM errors.
|
167 |
+
batch = jax.tree_map(np.array, batch)
|
168 |
+
preds = jax.tree_map(np.array, preds)
|
169 |
+
|
170 |
+
if n_samples <= 0:
|
171 |
+
return images
|
172 |
+
|
173 |
+
if not first_replica_only:
|
174 |
+
# Move the two leading batch dimensions into a single dimension. We do this
|
175 |
+
# to plot the same number of examples regardless of the data parallelism.
|
176 |
+
batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch)
|
177 |
+
preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds)
|
178 |
+
else:
|
179 |
+
batch = jax.tree_map(lambda x: x[0], batch)
|
180 |
+
preds = jax.tree_map(lambda x: x[0], preds)
|
181 |
+
|
182 |
+
# Limit the tensors to n_samples and n_frames.
|
183 |
+
batch = jax.tree_map(
|
184 |
+
lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
|
185 |
+
batch)
|
186 |
+
preds = jax.tree_map(
|
187 |
+
lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
|
188 |
+
preds)
|
189 |
+
|
190 |
+
# Log input data.
|
191 |
+
if batch is not None:
|
192 |
+
images["video"] = video_to_image_grid(batch["video"])
|
193 |
+
if "segmentations" in batch:
|
194 |
+
images["mask"] = video_to_image_grid(convert_categories_to_color(
|
195 |
+
batch["segmentations"], min_n_colors=min_n_colors))
|
196 |
+
if "flow" in batch:
|
197 |
+
images["flow"] = video_to_image_grid(batch["flow"])
|
198 |
+
if "boxes" in batch:
|
199 |
+
images["boxes"] = draw_bounding_boxes(
|
200 |
+
batch["video"],
|
201 |
+
batch["boxes"],
|
202 |
+
min_n_colors=min_n_colors)
|
203 |
+
|
204 |
+
# Log model predictions.
|
205 |
+
if preds is not None and preds.get("outputs") is not None:
|
206 |
+
if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error
|
207 |
+
images["segmentations"] = video_to_image_grid(
|
208 |
+
convert_categories_to_color(
|
209 |
+
preds["outputs"]["segmentations"], min_n_colors=min_n_colors))
|
210 |
+
|
211 |
+
def shape_fn(x):
|
212 |
+
if isinstance(x, (np.ndarray, jnp.ndarray)):
|
213 |
+
return x.shape
|
214 |
+
|
215 |
+
# Log intermediate variables.
|
216 |
+
if preds is not None and "intermediates" in preds:
|
217 |
+
|
218 |
+
logging.info("intermediates: %s",
|
219 |
+
jax.tree_map(shape_fn, preds["intermediates"]))
|
220 |
+
|
221 |
+
for key, path in config.debug_var_video_paths.items():
|
222 |
+
log_vars = retrieve_from_collection(preds["intermediates"], path)
|
223 |
+
if log_vars is not None:
|
224 |
+
if not isinstance(log_vars, Sequence):
|
225 |
+
log_vars = [log_vars]
|
226 |
+
for i, log_var in enumerate(log_vars):
|
227 |
+
log_var = np.array(log_var) # Moves log_var to CPU.
|
228 |
+
images[key + "_" + str(i)] = video_to_image_grid(log_var)
|
229 |
+
else:
|
230 |
+
logging.warning("%s not found in intermediates", path)
|
231 |
+
|
232 |
+
# Log attention weights.
|
233 |
+
for key, path in config.debug_var_attn_paths.items():
|
234 |
+
log_vars = retrieve_from_collection(preds["intermediates"], path)
|
235 |
+
if log_vars is not None:
|
236 |
+
if not isinstance(log_vars, Sequence):
|
237 |
+
log_vars = [log_vars]
|
238 |
+
for i, log_var in enumerate(log_vars):
|
239 |
+
log_var = np.array(log_var) # Moves log_var to CPU.
|
240 |
+
images.update(
|
241 |
+
prepare_attention_maps_for_logging(
|
242 |
+
attn_maps=log_var,
|
243 |
+
key=key + "_" + str(i),
|
244 |
+
map_width=config.debug_var_attn_widths.get(key),
|
245 |
+
video=batch["video"],
|
246 |
+
epsilon=epsilon,
|
247 |
+
n_samples=n_samples,
|
248 |
+
n_frames=n_frames))
|
249 |
+
else:
|
250 |
+
logging.warning("%s not found in intermediates", path)
|
251 |
+
|
252 |
+
# Crop each image to a maximum of 3 channels for RGB visualization.
|
253 |
+
for key, image in images.items():
|
254 |
+
if image.shape[-1] > 3:
|
255 |
+
logging.warning("Truncating channels of %s for visualization.", key)
|
256 |
+
images[key] = image[Ellipsis, :3]
|
257 |
+
|
258 |
+
return images
|
259 |
+
|
260 |
+
|
261 |
+
def prepare_attention_maps_for_logging(attn_maps, key,
|
262 |
+
map_width, epsilon,
|
263 |
+
n_samples, n_frames,
|
264 |
+
video):
|
265 |
+
"""Visualize (overlayed) attention maps as an image grid."""
|
266 |
+
images = {} # Results dictionary.
|
267 |
+
attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width)
|
268 |
+
|
269 |
+
num_heads = attn_maps.shape[2]
|
270 |
+
for head_idx in range(num_heads):
|
271 |
+
attn = attn_maps[:n_samples, :n_frames, head_idx]
|
272 |
+
attn /= attn.max() + epsilon # Standardizes scale for visualization.
|
273 |
+
# attn.shape: [bs, seq_len, 11, h', w', 1]
|
274 |
+
|
275 |
+
bs, seq_len, _, h_attn, w_attn, _ = attn.shape
|
276 |
+
images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn)
|
277 |
+
|
278 |
+
# Attention maps are interpretable when they align with object boundaries.
|
279 |
+
# However, if they are overly smooth then the following visualization which
|
280 |
+
# overlays attention maps on video is helpful.
|
281 |
+
video = video[:n_samples, :n_frames]
|
282 |
+
# video.shape: [bs, seq_len, h, w, 3]
|
283 |
+
video_resized = []
|
284 |
+
for i in range(n_samples):
|
285 |
+
for j in range(n_frames):
|
286 |
+
video_resized.append(
|
287 |
+
skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1))
|
288 |
+
video_resized = np.array(video_resized).reshape(
|
289 |
+
(bs, seq_len, h_attn, w_attn, 3))
|
290 |
+
attn_overlayed = attn * np.expand_dims(video_resized, 2)
|
291 |
+
images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid(
|
292 |
+
attn_overlayed)
|
293 |
+
|
294 |
+
return images
|
295 |
+
|
296 |
+
|
297 |
+
def convert_categories_to_color(
|
298 |
+
inputs, min_n_colors = 1, include_black = True):
|
299 |
+
"""Converts int-valued categories to color in last axis of input tensor.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the
|
303 |
+
categories.
|
304 |
+
min_n_colors: Minimum number of colors (excl. black) to encode categories.
|
305 |
+
include_black: Include black as 0-th entry in the color palette. Increases
|
306 |
+
`min_n_colors` by 1 if True.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
`np.ndarray` with RGB colors in last axis.
|
310 |
+
"""
|
311 |
+
if inputs.shape[-1] == 1: # Strip category axis.
|
312 |
+
inputs = np.squeeze(inputs, axis=-1)
|
313 |
+
inputs = np.array(inputs, dtype=np.int32) # Convert to int.
|
314 |
+
|
315 |
+
# Infer number of colors from inputs.
|
316 |
+
n_colors = int(inputs.max()) + 1 # One color per category incl. 0.
|
317 |
+
if include_black:
|
318 |
+
n_colors -= 1 # If we include black, we need one color less.
|
319 |
+
|
320 |
+
if min_n_colors > n_colors: # Use more colors in color palette if requested.
|
321 |
+
n_colors = min_n_colors
|
322 |
+
|
323 |
+
rgb_colors = get_uniform_colors(n_colors)
|
324 |
+
|
325 |
+
if include_black: # Add black as color for zero-th index.
|
326 |
+
rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0)
|
327 |
+
return rgb_colors[inputs]
|
328 |
+
|
329 |
+
|
330 |
+
def get_uniform_colors(n_colors):
|
331 |
+
"""Get n_colors with uniformly spaced hues."""
|
332 |
+
hues = np.linspace(0, 1, n_colors, endpoint=False)
|
333 |
+
hsv_colors = np.concatenate(
|
334 |
+
(np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1)
|
335 |
+
rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors)
|
336 |
+
return rgb_colors # rgb_colors.shape = (n_colors, 3)
|
337 |
+
|
338 |
+
|
339 |
+
def unflatten_image(image, width = None):
|
340 |
+
"""Unflatten image array of shape [batch_dims..., height*width, channels]."""
|
341 |
+
n_channels = image.shape[-1]
|
342 |
+
# If width is not provided, we assume that the image is square.
|
343 |
+
if width is None:
|
344 |
+
width = int(np.floor(np.sqrt(image.shape[-2])))
|
345 |
+
height = width
|
346 |
+
assert width * height == image.shape[-2], "Image is not square."
|
347 |
+
else:
|
348 |
+
height = image.shape[-2] // width
|
349 |
+
return image.reshape(image.shape[:-2] + (height, width, n_channels))
|
350 |
+
|
351 |
+
|
352 |
+
def video_to_image_grid(video):
|
353 |
+
"""Transform video to image grid by folding sequence dim along width."""
|
354 |
+
if len(video.shape) == 5:
|
355 |
+
n_samples, n_frames, height, width, n_channels = video.shape
|
356 |
+
video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height.
|
357 |
+
image_grid = np.reshape(
|
358 |
+
video, (n_samples, height, n_frames * width, n_channels))
|
359 |
+
elif len(video.shape) == 6:
|
360 |
+
n_samples, n_frames, n_slots, height, width, n_channels = video.shape
|
361 |
+
# Put n_frames next to width.
|
362 |
+
video = np.transpose(video, (0, 2, 3, 1, 4, 5))
|
363 |
+
image_grid = np.reshape(
|
364 |
+
video, (n_samples, n_slots * height, n_frames * width, n_channels))
|
365 |
+
else:
|
366 |
+
raise ValueError("Unsupported video shape for visualization.")
|
367 |
+
return image_grid
|
368 |
+
|
369 |
+
|
370 |
+
def draw_bounding_boxes(video,
|
371 |
+
boxes,
|
372 |
+
min_n_colors = 1,
|
373 |
+
include_black = True):
|
374 |
+
"""Draw bounding boxes in videos."""
|
375 |
+
colors = get_uniform_colors(min_n_colors - include_black)
|
376 |
+
|
377 |
+
b, t, h, w, c = video.shape
|
378 |
+
n = boxes.shape[2]
|
379 |
+
image_grid = tf.image.draw_bounding_boxes(
|
380 |
+
np.reshape(video, (b * t, h, w, c)),
|
381 |
+
np.reshape(boxes, (b * t, n, 4)),
|
382 |
+
colors).numpy()
|
383 |
+
image_grid = np.reshape(
|
384 |
+
np.transpose(np.reshape(image_grid, (b, t, h, w, c)),
|
385 |
+
(0, 2, 1, 3, 4)),
|
386 |
+
(b, h, t * w, c))
|
387 |
+
return image_grid
|
388 |
+
|
389 |
+
|
390 |
+
def plot_image(ax, image):
|
391 |
+
"""Add an image visualization to a provided `plt.Axes` instance."""
|
392 |
+
num_channels = image.shape[-1]
|
393 |
+
if num_channels == 1:
|
394 |
+
image = image.reshape(image.shape[:2])
|
395 |
+
ax.imshow(image, cmap="viridis")
|
396 |
+
ax.grid(False)
|
397 |
+
plt.axis("off")
|
398 |
+
|
399 |
+
|
400 |
+
def visualize_image_dict(images, plot_scale = 10):
|
401 |
+
"""Visualize a dictionary of images in colab using maptlotlib."""
|
402 |
+
|
403 |
+
for key in images.keys():
|
404 |
+
logging.info("Visualizing key: %s", key)
|
405 |
+
n_images = len(images[key])
|
406 |
+
fig = plt.figure(figsize=(n_images * plot_scale, plot_scale))
|
407 |
+
for idx, image in enumerate(images[key]):
|
408 |
+
ax = fig.add_subplot(1, n_images, idx+1)
|
409 |
+
plot_image(ax, image)
|
410 |
+
plt.show()
|
411 |
+
|
412 |
+
|
413 |
+
def filter_key_from_frozen_dict(
|
414 |
+
frozen_dict, key):
|
415 |
+
"""Filters (removes) an item by key from a flax.core.FrozenDict."""
|
416 |
+
if key in frozen_dict:
|
417 |
+
frozen_dict, _ = frozen_dict.pop(key)
|
418 |
+
return frozen_dict
|
419 |
+
|
420 |
+
|
421 |
+
def prepare_dict_for_logging(nested_dict, parent_key = "",
|
422 |
+
sep = "_"):
|
423 |
+
"""Prepare a nested dictionary for logging with `clu.metric_writers`.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
nested_dict: A nested dictionary, e.g. obtained from a
|
427 |
+
`ml_collections.ConfigDict` via `.to_dict()`.
|
428 |
+
parent_key: String used in recursion.
|
429 |
+
sep: String used to separate parent and child keys.
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
Flattened dict.
|
433 |
+
"""
|
434 |
+
items = []
|
435 |
+
for k, v in nested_dict.items():
|
436 |
+
# Flatten keys of nested elements.
|
437 |
+
new_key = parent_key + sep + k if parent_key else k
|
438 |
+
|
439 |
+
# Convert None values, lists and tuples to strings.
|
440 |
+
if v is None:
|
441 |
+
v = "None"
|
442 |
+
if isinstance(v, list) or isinstance(v, tuple):
|
443 |
+
v = str(v)
|
444 |
+
|
445 |
+
# Recursively flatten the dict.
|
446 |
+
if isinstance(v, dict):
|
447 |
+
items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items())
|
448 |
+
else:
|
449 |
+
items.append((new_key, v))
|
450 |
+
return dict(items)
|
451 |
+
|
452 |
+
|
453 |
+
def retrieve_from_collection(
|
454 |
+
variable_collection, path):
|
455 |
+
"""Finds variables by their path by recursively searching the collection.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
variable_collection: Nested dict containing the variables (or tuples/lists
|
459 |
+
of variables).
|
460 |
+
path: Path to variable in module tree, similar to Unix file names (e.g.
|
461 |
+
'/module/dense/0/bias').
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
The requested variable, variable collection or None (in case the variable
|
465 |
+
could not be found).
|
466 |
+
"""
|
467 |
+
key, _, rpath = path.strip("/").partition("/")
|
468 |
+
|
469 |
+
# In case the variable is not found, we return None.
|
470 |
+
if (key.isdigit() and not isinstance(variable_collection, Sequence)) or (
|
471 |
+
key.isdigit() and int(key) >= len(variable_collection)) or (
|
472 |
+
not key.isdigit() and key not in variable_collection):
|
473 |
+
return None
|
474 |
+
|
475 |
+
if key.isdigit():
|
476 |
+
key = int(key)
|
477 |
+
|
478 |
+
if not rpath:
|
479 |
+
return variable_collection[key]
|
480 |
+
else:
|
481 |
+
return retrieve_from_collection(variable_collection[key], rpath)
|
482 |
+
|
483 |
+
|
484 |
+
def build_model_from_config(config):
|
485 |
+
"""Build a Flax model from a (nested) ConfigDict."""
|
486 |
+
model_constructor = _parse_config(config)
|
487 |
+
if callable(model_constructor):
|
488 |
+
return model_constructor()
|
489 |
+
else:
|
490 |
+
raise ValueError("Provided config does not contain module constructors.")
|
491 |
+
|
492 |
+
|
493 |
+
def _parse_config(config
|
494 |
+
):
|
495 |
+
"""Recursively parses a nested ConfigDict and resolves module constructors."""
|
496 |
+
|
497 |
+
if isinstance(config, list):
|
498 |
+
return [_parse_config(c) for c in config]
|
499 |
+
elif isinstance(config, tuple):
|
500 |
+
return tuple([_parse_config(c) for c in config])
|
501 |
+
elif not isinstance(config, ml_collections.ConfigDict):
|
502 |
+
return config
|
503 |
+
elif "module" in config:
|
504 |
+
module_constructor = _resolve_module_constructor(config.module)
|
505 |
+
kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"}
|
506 |
+
return functools.partial(module_constructor, **kwargs)
|
507 |
+
else:
|
508 |
+
return {k: _parse_config(v) for k, v in config.items()}
|
509 |
+
|
510 |
+
|
511 |
+
def _resolve_module_constructor(
|
512 |
+
constructor_str):
|
513 |
+
import_str, _, module_name = constructor_str.rpartition(".")
|
514 |
+
py_module = importlib.import_module(import_str)
|
515 |
+
return getattr(py_module, module_name)
|
516 |
+
|
517 |
+
|
518 |
+
def get_slices_along_axis(
|
519 |
+
inputs,
|
520 |
+
slice_keys,
|
521 |
+
start_idx = 0,
|
522 |
+
end_idx = -1,
|
523 |
+
axis = 2,
|
524 |
+
pad_value = 0):
|
525 |
+
"""Extracts slices from a dictionary of tensors along the specified axis.
|
526 |
+
|
527 |
+
The slice operation is only applied to `slice_keys` dictionary keys. If
|
528 |
+
`end_idx` is larger than the actual size of the specified axis, padding is
|
529 |
+
added (with values provided in `pad_value`).
|
530 |
+
|
531 |
+
Args:
|
532 |
+
inputs: Dictionary of tensors.
|
533 |
+
slice_keys: Iterable of strings, the keys for the inputs dictionary for
|
534 |
+
which to apply the slice operation.
|
535 |
+
start_idx: Integer, defining the first index to be part of the slice.
|
536 |
+
end_idx: Integer, defining the end of the slice interval (exclusive). If set
|
537 |
+
to `-1`, the end index is set to the size of the axis. If a value is
|
538 |
+
provided that is larger than the size of the axis, zero-padding is added
|
539 |
+
for the remaining elements.
|
540 |
+
axis: Integer, the axis along which to slice.
|
541 |
+
pad_value: Integer, value to be used in padding.
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
Dictionary of tensors where elements described in `slice_keys` are sliced,
|
545 |
+
and all other elements are returned as original.
|
546 |
+
"""
|
547 |
+
|
548 |
+
max_size = None
|
549 |
+
pad_size = 0
|
550 |
+
|
551 |
+
# Check shapes and get maximum size of requested axis.
|
552 |
+
for key in slice_keys:
|
553 |
+
curr_size = inputs[key].shape[axis]
|
554 |
+
if max_size is None:
|
555 |
+
max_size = curr_size
|
556 |
+
elif max_size != curr_size:
|
557 |
+
raise ValueError(
|
558 |
+
"For specified tensors the requested axis needs to be of equal size.")
|
559 |
+
|
560 |
+
# Infer end index if not provided.
|
561 |
+
if end_idx == -1:
|
562 |
+
end_idx = max_size
|
563 |
+
|
564 |
+
# Set padding size if end index is larger than maximum size of requested axis.
|
565 |
+
elif end_idx > max_size:
|
566 |
+
pad_size = end_idx - max_size
|
567 |
+
end_idx = max_size
|
568 |
+
|
569 |
+
outputs = {}
|
570 |
+
for key in slice_keys:
|
571 |
+
outputs[key] = np.take(
|
572 |
+
inputs[key], indices=np.arange(start_idx, end_idx), axis=axis)
|
573 |
+
|
574 |
+
# Add padding if necessary.
|
575 |
+
if pad_size > 0:
|
576 |
+
pad_shape = np.array(outputs[key].shape)
|
577 |
+
np.put(pad_shape, axis, pad_size) # In-place op.
|
578 |
+
padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype)
|
579 |
+
outputs[key] = np.concatenate((outputs[key], padding), axis=axis)
|
580 |
+
|
581 |
+
return outputs
|
582 |
+
|
583 |
+
|
584 |
+
def get_element_by_str(
|
585 |
+
dictionary, multilevel_key, separator = "/"
|
586 |
+
):
|
587 |
+
"""Gets element in a dictionary with multilevel key (e.g., "key1/key2")."""
|
588 |
+
keys = multilevel_key.split(separator)
|
589 |
+
if len(keys) == 1:
|
590 |
+
return dictionary[keys[0]]
|
591 |
+
return get_element_by_str(
|
592 |
+
dictionary[keys[0]], separator.join(keys[1:]), separator=separator)
|
593 |
+
|
594 |
+
|
595 |
+
def set_element_by_str(
|
596 |
+
dictionary, multilevel_key, new_value,
|
597 |
+
separator = "/"):
|
598 |
+
"""Sets element in a dictionary with multilevel key (e.g., "key1/key2")."""
|
599 |
+
keys = multilevel_key.split(separator)
|
600 |
+
if len(keys) == 1:
|
601 |
+
if keys[0] not in dictionary:
|
602 |
+
key_error = (
|
603 |
+
"Pretrained {key} was not found in trained model. "
|
604 |
+
"Make sure you are loading the correct pretrained model "
|
605 |
+
"or consider adding {key} to exceptions.")
|
606 |
+
raise KeyError(key_error.format(type="parameter", key=keys[0]))
|
607 |
+
dictionary[keys[0]] = new_value
|
608 |
+
else:
|
609 |
+
set_element_by_str(
|
610 |
+
dictionary[keys[0]],
|
611 |
+
separator.join(keys[1:]),
|
612 |
+
new_value,
|
613 |
+
separator=separator)
|
614 |
+
|
615 |
+
|
616 |
+
def remove_singleton_dim(inputs):
|
617 |
+
"""Removes the final dimension if it is singleton (i.e. of size 1)."""
|
618 |
+
if inputs is None:
|
619 |
+
return None
|
620 |
+
if inputs.shape[-1] != 1:
|
621 |
+
logging.warning("Expected final dimension of inputs to be 1, "
|
622 |
+
"received inputs of shape %s: ", str(inputs.shape))
|
623 |
+
return inputs
|
624 |
+
return inputs[Ellipsis, 0]
|
625 |
+
|
invariant_slot_attention/modules/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Module library."""
|
17 |
+
# pylint: disable=g-multiple-import
|
18 |
+
# pylint: disable=g-bad-import-order
|
19 |
+
# Re-export commonly used modules and functions
|
20 |
+
|
21 |
+
from .attention import (GeneralizedDotProductAttention,
|
22 |
+
InvertedDotProductAttention, SlotAttention,
|
23 |
+
TransformerBlock, Transformer)
|
24 |
+
from .convolution import (SimpleCNN, CNN)
|
25 |
+
from .decoders import (SpatialBroadcastDecoder, SiameseSpatialBroadcastDecoder)
|
26 |
+
from .initializers import (GaussianStateInit, ParamStateInit,
|
27 |
+
SegmentationEncoderStateInit,
|
28 |
+
CoordinateEncoderStateInit)
|
29 |
+
from .misc import (Dense, GRU, Identity, MLP, PositionEmbedding, Readout,
|
30 |
+
RelativePositionEmbedding)
|
31 |
+
from .video import (CorrectorPredictorTuple, FrameEncoder, Processor, SAVi)
|
32 |
+
from .resnet import (ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
|
33 |
+
ResNet200)
|
34 |
+
from .invariant_attention import (InvertedDotProductAttentionKeyPerQuery,
|
35 |
+
SlotAttentionExplicitStats,
|
36 |
+
SlotAttentionPosKeysValues,
|
37 |
+
SlotAttentionTranslEquiv,
|
38 |
+
SlotAttentionTranslScaleEquiv,
|
39 |
+
SlotAttentionTranslRotScaleEquiv)
|
40 |
+
from .invariant_initializers import (
|
41 |
+
ParamStateInitRandomPositions,
|
42 |
+
ParamStateInitRandomPositionsScales,
|
43 |
+
ParamStateInitRandomPositionsRotationsScales,
|
44 |
+
ParamStateInitLearnablePositions,
|
45 |
+
ParamStateInitLearnablePositionsScales,
|
46 |
+
ParamStateInitLearnablePositionsRotationsScales)
|
47 |
+
|
48 |
+
|
49 |
+
# pylint: enable=g-multiple-import
|
invariant_slot_attention/modules/attention.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Attention module library."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from invariant_slot_attention.modules import misc
|
25 |
+
|
26 |
+
Shape = Tuple[int]
|
27 |
+
|
28 |
+
DType = Any
|
29 |
+
Array = Any # jnp.ndarray
|
30 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
31 |
+
ProcessorState = ArrayTree
|
32 |
+
PRNGKey = Array
|
33 |
+
NestedDict = Dict[str, Any]
|
34 |
+
|
35 |
+
|
36 |
+
class SlotAttention(nn.Module):
|
37 |
+
"""Slot Attention module.
|
38 |
+
|
39 |
+
Note: This module uses pre-normalization by default.
|
40 |
+
"""
|
41 |
+
|
42 |
+
num_iterations: int = 1
|
43 |
+
qkv_size: Optional[int] = None
|
44 |
+
mlp_size: Optional[int] = None
|
45 |
+
epsilon: float = 1e-8
|
46 |
+
num_heads: int = 1
|
47 |
+
|
48 |
+
@nn.compact
|
49 |
+
def __call__(self, slots, inputs,
|
50 |
+
padding_mask = None,
|
51 |
+
train = False):
|
52 |
+
"""Slot Attention module forward pass."""
|
53 |
+
del padding_mask, train # Unused.
|
54 |
+
|
55 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
56 |
+
head_dim = qkv_size // self.num_heads
|
57 |
+
dense = functools.partial(nn.DenseGeneral,
|
58 |
+
axis=-1, features=(self.num_heads, head_dim),
|
59 |
+
use_bias=False)
|
60 |
+
|
61 |
+
# Shared modules.
|
62 |
+
dense_q = dense(name="general_dense_q_0")
|
63 |
+
layernorm_q = nn.LayerNorm()
|
64 |
+
inverted_attention = InvertedDotProductAttention(
|
65 |
+
norm_type="mean", multi_head=self.num_heads > 1)
|
66 |
+
gru = misc.GRU()
|
67 |
+
|
68 |
+
if self.mlp_size is not None:
|
69 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
70 |
+
|
71 |
+
# inputs.shape = (..., n_inputs, inputs_size).
|
72 |
+
inputs = nn.LayerNorm()(inputs)
|
73 |
+
# k.shape = (..., n_inputs, slot_size).
|
74 |
+
k = dense(name="general_dense_k_0")(inputs)
|
75 |
+
# v.shape = (..., n_inputs, slot_size).
|
76 |
+
v = dense(name="general_dense_v_0")(inputs)
|
77 |
+
|
78 |
+
# Multiple rounds of attention.
|
79 |
+
for _ in range(self.num_iterations):
|
80 |
+
|
81 |
+
# Inverted dot-product attention.
|
82 |
+
slots_n = layernorm_q(slots)
|
83 |
+
q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
|
84 |
+
updates = inverted_attention(query=q, key=k, value=v)
|
85 |
+
|
86 |
+
# Recurrent update.
|
87 |
+
slots = gru(slots, updates)
|
88 |
+
|
89 |
+
# Feedforward block with pre-normalization.
|
90 |
+
if self.mlp_size is not None:
|
91 |
+
slots = mlp(slots)
|
92 |
+
|
93 |
+
return slots
|
94 |
+
|
95 |
+
|
96 |
+
class InvertedDotProductAttention(nn.Module):
|
97 |
+
"""Inverted version of dot-product attention (softmax over query axis)."""
|
98 |
+
|
99 |
+
norm_type: Optional[str] = "mean" # mean, layernorm, or None
|
100 |
+
multi_head: bool = False
|
101 |
+
epsilon: float = 1e-8
|
102 |
+
dtype: DType = jnp.float32
|
103 |
+
precision: Optional[jax.lax.Precision] = None
|
104 |
+
return_attn_weights: bool = False
|
105 |
+
|
106 |
+
@nn.compact
|
107 |
+
def __call__(self, query, key, value,
|
108 |
+
train = False):
|
109 |
+
"""Computes inverted dot-product attention.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
query: Queries with shape of `[batch..., q_num, qk_features]`.
|
113 |
+
key: Keys with shape of `[batch..., kv_num, qk_features]`.
|
114 |
+
value: Values with shape of `[batch..., kv_num, v_features]`.
|
115 |
+
train: Indicating whether we're training or evaluating.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Output of shape `[batch_size..., n_queries, v_features]`
|
119 |
+
"""
|
120 |
+
del train # Unused.
|
121 |
+
|
122 |
+
attn = GeneralizedDotProductAttention(
|
123 |
+
inverted_attn=True,
|
124 |
+
renormalize_keys=True if self.norm_type == "mean" else False,
|
125 |
+
epsilon=self.epsilon,
|
126 |
+
dtype=self.dtype,
|
127 |
+
precision=self.precision,
|
128 |
+
return_attn_weights=True)
|
129 |
+
|
130 |
+
# Apply attention mechanism.
|
131 |
+
output, attn = attn(query=query, key=key, value=value)
|
132 |
+
|
133 |
+
if self.multi_head:
|
134 |
+
# Multi-head aggregation. Equivalent to concat + dense layer.
|
135 |
+
output = nn.DenseGeneral(features=output.shape[-1], axis=(-2, -1))(output)
|
136 |
+
else:
|
137 |
+
# Remove head dimension.
|
138 |
+
output = jnp.squeeze(output, axis=-2)
|
139 |
+
attn = jnp.squeeze(attn, axis=-3)
|
140 |
+
|
141 |
+
if self.norm_type == "layernorm":
|
142 |
+
output = nn.LayerNorm()(output)
|
143 |
+
|
144 |
+
if self.return_attn_weights:
|
145 |
+
return output, attn
|
146 |
+
|
147 |
+
return output
|
148 |
+
|
149 |
+
|
150 |
+
class GeneralizedDotProductAttention(nn.Module):
|
151 |
+
"""Multi-head dot-product attention with customizable normalization axis.
|
152 |
+
|
153 |
+
This module supports logging of attention weights in a variable collection.
|
154 |
+
"""
|
155 |
+
|
156 |
+
dtype: DType = jnp.float32
|
157 |
+
precision: Optional[jax.lax.Precision] = None
|
158 |
+
epsilon: float = 1e-8
|
159 |
+
inverted_attn: bool = False
|
160 |
+
renormalize_keys: bool = False
|
161 |
+
attn_weights_only: bool = False
|
162 |
+
return_attn_weights: bool = False
|
163 |
+
|
164 |
+
@nn.compact
|
165 |
+
def __call__(self, query, key, value,
|
166 |
+
train = False, **kwargs
|
167 |
+
):
|
168 |
+
"""Computes multi-head dot-product attention given query, key, and value.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
query: Queries with shape of `[batch..., q_num, num_heads, qk_features]`.
|
172 |
+
key: Keys with shape of `[batch..., kv_num, num_heads, qk_features]`.
|
173 |
+
value: Values with shape of `[batch..., kv_num, num_heads, v_features]`.
|
174 |
+
train: Indicating whether we're training or evaluating.
|
175 |
+
**kwargs: Additional keyword arguments are required when used as attention
|
176 |
+
function in nn.MultiHeadDotProductAttention, but they will be ignored
|
177 |
+
here.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Output of shape `[batch..., q_num, num_heads, v_features]`.
|
181 |
+
"""
|
182 |
+
|
183 |
+
assert query.ndim == key.ndim == value.ndim, (
|
184 |
+
"Queries, keys, and values must have the same rank.")
|
185 |
+
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
|
186 |
+
"Query, key, and value batch dimensions must match.")
|
187 |
+
assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
|
188 |
+
"Query, key, and value num_heads dimensions must match.")
|
189 |
+
assert key.shape[-3] == value.shape[-3], (
|
190 |
+
"Key and value cardinality dimensions must match.")
|
191 |
+
assert query.shape[-1] == key.shape[-1], (
|
192 |
+
"Query and key feature dimensions must match.")
|
193 |
+
|
194 |
+
if kwargs.get("bias") is not None:
|
195 |
+
raise NotImplementedError(
|
196 |
+
"Support for masked attention is not yet implemented.")
|
197 |
+
|
198 |
+
if "dropout_rate" in kwargs:
|
199 |
+
if kwargs["dropout_rate"] > 0.:
|
200 |
+
raise NotImplementedError("Support for dropout is not yet implemented.")
|
201 |
+
|
202 |
+
# Temperature normalization.
|
203 |
+
qk_features = query.shape[-1]
|
204 |
+
query = query / jnp.sqrt(qk_features).astype(self.dtype)
|
205 |
+
|
206 |
+
# attn.shape = (batch..., num_heads, q_num, kv_num)
|
207 |
+
attn = jnp.einsum("...qhd,...khd->...hqk", query, key,
|
208 |
+
precision=self.precision)
|
209 |
+
|
210 |
+
if self.inverted_attn:
|
211 |
+
attention_axis = -2 # Query axis.
|
212 |
+
else:
|
213 |
+
attention_axis = -1 # Key axis.
|
214 |
+
|
215 |
+
# Softmax normalization (by default over key axis).
|
216 |
+
attn = jax.nn.softmax(attn, axis=attention_axis).astype(self.dtype)
|
217 |
+
|
218 |
+
# Defines intermediate for logging.
|
219 |
+
if not train:
|
220 |
+
self.sow("intermediates", "attn", attn)
|
221 |
+
|
222 |
+
if self.renormalize_keys:
|
223 |
+
# Corresponds to value aggregation via weighted mean (as opposed to sum).
|
224 |
+
normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon
|
225 |
+
attn = attn / normalizer
|
226 |
+
|
227 |
+
if self.attn_weights_only:
|
228 |
+
return attn
|
229 |
+
|
230 |
+
# Aggregate values using a weighted sum with weights provided by `attn`.
|
231 |
+
output = jnp.einsum(
|
232 |
+
"...hqk,...khd->...qhd", attn, value, precision=self.precision)
|
233 |
+
|
234 |
+
if self.return_attn_weights:
|
235 |
+
return output, attn
|
236 |
+
|
237 |
+
return output
|
238 |
+
|
239 |
+
|
240 |
+
class Transformer(nn.Module):
|
241 |
+
"""Transformer with multiple blocks."""
|
242 |
+
|
243 |
+
num_heads: int
|
244 |
+
qkv_size: int
|
245 |
+
mlp_size: int
|
246 |
+
num_layers: int
|
247 |
+
pre_norm: bool = False
|
248 |
+
|
249 |
+
@nn.compact
|
250 |
+
def __call__(self, queries, inputs = None,
|
251 |
+
padding_mask = None,
|
252 |
+
train = False):
|
253 |
+
x = queries
|
254 |
+
for lyr in range(self.num_layers):
|
255 |
+
x = TransformerBlock(
|
256 |
+
num_heads=self.num_heads, qkv_size=self.qkv_size,
|
257 |
+
mlp_size=self.mlp_size, pre_norm=self.pre_norm,
|
258 |
+
name=f"TransformerBlock{lyr}")( # pytype: disable=wrong-arg-types
|
259 |
+
x, inputs, padding_mask, train)
|
260 |
+
return x
|
261 |
+
|
262 |
+
|
263 |
+
class TransformerBlock(nn.Module):
|
264 |
+
"""Transformer decoder block."""
|
265 |
+
|
266 |
+
num_heads: int
|
267 |
+
qkv_size: int
|
268 |
+
mlp_size: int
|
269 |
+
pre_norm: bool = False
|
270 |
+
|
271 |
+
@nn.compact
|
272 |
+
def __call__(self, queries, inputs = None,
|
273 |
+
padding_mask = None,
|
274 |
+
train = False):
|
275 |
+
del padding_mask # Unused.
|
276 |
+
assert queries.ndim == 3
|
277 |
+
|
278 |
+
attention_fn = GeneralizedDotProductAttention()
|
279 |
+
|
280 |
+
attn = functools.partial(
|
281 |
+
nn.MultiHeadDotProductAttention,
|
282 |
+
num_heads=self.num_heads,
|
283 |
+
qkv_features=self.qkv_size,
|
284 |
+
attention_fn=attention_fn)
|
285 |
+
|
286 |
+
mlp = misc.MLP(hidden_size=self.mlp_size) # type: ignore
|
287 |
+
|
288 |
+
if self.pre_norm:
|
289 |
+
# Self-attention on queries.
|
290 |
+
x = nn.LayerNorm()(queries)
|
291 |
+
x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
|
292 |
+
x = x + queries
|
293 |
+
|
294 |
+
# Cross-attention on inputs.
|
295 |
+
if inputs is not None:
|
296 |
+
assert inputs.ndim == 3
|
297 |
+
y = nn.LayerNorm()(x)
|
298 |
+
y = attn()(inputs_q=y, inputs_kv=inputs, deterministic=not train)
|
299 |
+
y = y + x
|
300 |
+
else:
|
301 |
+
y = x
|
302 |
+
|
303 |
+
# MLP
|
304 |
+
z = nn.LayerNorm()(y)
|
305 |
+
z = mlp(z, train)
|
306 |
+
z = z + y
|
307 |
+
else:
|
308 |
+
# Self-attention on queries.
|
309 |
+
x = queries
|
310 |
+
x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
|
311 |
+
x = x + queries
|
312 |
+
x = nn.LayerNorm()(x)
|
313 |
+
|
314 |
+
# Cross-attention on inputs.
|
315 |
+
if inputs is not None:
|
316 |
+
assert inputs.ndim == 3
|
317 |
+
y = attn()(inputs_q=x, inputs_kv=inputs, deterministic=not train)
|
318 |
+
y = y + x
|
319 |
+
y = nn.LayerNorm()(y)
|
320 |
+
else:
|
321 |
+
y = x
|
322 |
+
|
323 |
+
# MLP.
|
324 |
+
z = mlp(y, train)
|
325 |
+
z = z + y
|
326 |
+
z = nn.LayerNorm()(z)
|
327 |
+
return z
|
invariant_slot_attention/modules/convolution.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Convolutional module library."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import jax
|
23 |
+
|
24 |
+
Shape = Tuple[int]
|
25 |
+
|
26 |
+
DType = Any
|
27 |
+
Array = Any # jnp.ndarray
|
28 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
29 |
+
ProcessorState = ArrayTree
|
30 |
+
PRNGKey = Array
|
31 |
+
NestedDict = Dict[str, Any]
|
32 |
+
|
33 |
+
|
34 |
+
class SimpleCNN(nn.Module):
|
35 |
+
"""Simple CNN encoder with multiple Conv+ReLU layers."""
|
36 |
+
|
37 |
+
features: Sequence[int]
|
38 |
+
kernel_size: Sequence[Tuple[int, int]]
|
39 |
+
strides: Sequence[Tuple[int, int]]
|
40 |
+
transpose: bool = False
|
41 |
+
use_batch_norm: bool = False
|
42 |
+
axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
|
43 |
+
padding: Union[str, Iterable[Tuple[int, int]]] = "SAME"
|
44 |
+
resize_output: Optional[Iterable[int]] = None
|
45 |
+
|
46 |
+
@nn.compact
|
47 |
+
def __call__(self, inputs, train = False):
|
48 |
+
num_layers = len(self.features)
|
49 |
+
assert len(self.kernel_size) == num_layers, (
|
50 |
+
"len(kernel_size) and len(features) must match.")
|
51 |
+
assert len(self.strides) == num_layers, (
|
52 |
+
"len(strides) and len(features) must match.")
|
53 |
+
assert num_layers >= 1, "Need to have at least one layer."
|
54 |
+
|
55 |
+
if self.transpose:
|
56 |
+
conv_module = nn.ConvTranspose
|
57 |
+
else:
|
58 |
+
conv_module = nn.Conv
|
59 |
+
|
60 |
+
x = conv_module(
|
61 |
+
name="conv_simple_0",
|
62 |
+
features=self.features[0],
|
63 |
+
kernel_size=self.kernel_size[0],
|
64 |
+
strides=self.strides[0],
|
65 |
+
use_bias=False if self.use_batch_norm else True,
|
66 |
+
padding=self.padding)(inputs)
|
67 |
+
|
68 |
+
for i in range(1, num_layers):
|
69 |
+
if self.use_batch_norm:
|
70 |
+
x = nn.BatchNorm(
|
71 |
+
momentum=0.9, use_running_average=not train,
|
72 |
+
axis_name=self.axis_name, name=f"bn_simple_{i-1}")(x)
|
73 |
+
|
74 |
+
x = nn.relu(x)
|
75 |
+
x = conv_module(
|
76 |
+
name=f"conv_simple_{i}",
|
77 |
+
features=self.features[i],
|
78 |
+
kernel_size=self.kernel_size[i],
|
79 |
+
strides=self.strides[i],
|
80 |
+
use_bias=False if (
|
81 |
+
self.use_batch_norm and i < (num_layers-1)) else True,
|
82 |
+
padding=self.padding)(x)
|
83 |
+
|
84 |
+
if self.resize_output:
|
85 |
+
x = jax.image.resize(
|
86 |
+
x, list(x.shape[:-3]) + list(self.resize_output) + [x.shape[-1]],
|
87 |
+
method=jax.image.ResizeMethod.LINEAR)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class CNN(nn.Module):
|
92 |
+
"""Flexible CNN model with Conv/Normalization/Pooling layers."""
|
93 |
+
|
94 |
+
features: Sequence[int]
|
95 |
+
kernel_size: Sequence[Tuple[int, int]]
|
96 |
+
strides: Sequence[Tuple[int, int]]
|
97 |
+
max_pool_strides: Sequence[Tuple[int, int]]
|
98 |
+
layer_transpose: Sequence[bool]
|
99 |
+
activation_fn: Callable[[Array], Array] = nn.relu
|
100 |
+
norm_type: Optional[str] = None
|
101 |
+
axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
|
102 |
+
output_size: Optional[int] = None
|
103 |
+
|
104 |
+
@nn.compact
|
105 |
+
def __call__(self, inputs, train = False):
|
106 |
+
num_layers = len(self.features)
|
107 |
+
|
108 |
+
assert num_layers >= 1, "Need to have at least one layer."
|
109 |
+
assert len(self.kernel_size) == num_layers, (
|
110 |
+
"len(kernel_size) and len(features) must match.")
|
111 |
+
assert len(self.strides) == num_layers, (
|
112 |
+
"len(strides) and len(features) must match.")
|
113 |
+
assert len(self.max_pool_strides) == num_layers, (
|
114 |
+
"len(max_pool_strides) and len(features) must match.")
|
115 |
+
assert len(self.layer_transpose) == num_layers, (
|
116 |
+
"len(layer_transpose) and len(features) must match.")
|
117 |
+
|
118 |
+
if self.norm_type:
|
119 |
+
assert self.norm_type in {"batch", "group", "instance", "layer"}, (
|
120 |
+
f"{self.norm_type} is unrecognizaed normalization")
|
121 |
+
|
122 |
+
# Whether transpose conv or regular conv.
|
123 |
+
conv_module = {False: nn.Conv, True: nn.ConvTranspose}
|
124 |
+
|
125 |
+
if self.norm_type == "batch":
|
126 |
+
norm_module = functools.partial(
|
127 |
+
nn.BatchNorm, momentum=0.9, use_running_average=not train,
|
128 |
+
axis_name=self.axis_name)
|
129 |
+
elif self.norm_type == "group":
|
130 |
+
norm_module = functools.partial(
|
131 |
+
nn.GroupNorm, num_groups=32)
|
132 |
+
elif self.norm_type == "layer":
|
133 |
+
norm_module = nn.LayerNorm
|
134 |
+
|
135 |
+
x = inputs
|
136 |
+
for i in range(num_layers):
|
137 |
+
x = conv_module[self.layer_transpose[i]](
|
138 |
+
name=f"conv_{i}",
|
139 |
+
features=self.features[i],
|
140 |
+
kernel_size=self.kernel_size[i],
|
141 |
+
strides=self.strides[i],
|
142 |
+
use_bias=False if self.norm_type else True)(x)
|
143 |
+
|
144 |
+
# Normalization layer.
|
145 |
+
if self.norm_type:
|
146 |
+
if self.norm_type == "instance":
|
147 |
+
x = nn.GroupNorm(
|
148 |
+
num_groups=self.features[i],
|
149 |
+
name=f"{self.norm_type}_norm_{i}")(x)
|
150 |
+
else:
|
151 |
+
norm_module(name=f"{self.norm_type}_norm_{i}")(x)
|
152 |
+
|
153 |
+
# Activation layer.
|
154 |
+
x = self.activation_fn(x)
|
155 |
+
|
156 |
+
# Max pooling layer.
|
157 |
+
x = x if self.max_pool_strides[i] == (1, 1) else nn.max_pool(
|
158 |
+
x, self.max_pool_strides[i], strides=self.max_pool_strides[i],
|
159 |
+
padding="SAME")
|
160 |
+
|
161 |
+
# Final dense layer.
|
162 |
+
if self.output_size:
|
163 |
+
x = nn.Dense(self.output_size, name="output_layer", use_bias=True)(x)
|
164 |
+
return x
|
invariant_slot_attention/modules/decoders.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Decoder module library."""
|
17 |
+
import functools
|
18 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
19 |
+
|
20 |
+
from flax import linen as nn
|
21 |
+
|
22 |
+
import jax.numpy as jnp
|
23 |
+
|
24 |
+
from invariant_slot_attention.lib import utils
|
25 |
+
from invariant_slot_attention.modules import misc
|
26 |
+
|
27 |
+
Shape = Tuple[int]
|
28 |
+
|
29 |
+
DType = Any
|
30 |
+
Array = Any # jnp.ndarray
|
31 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
32 |
+
ProcessorState = ArrayTree
|
33 |
+
PRNGKey = Array
|
34 |
+
NestedDict = Dict[str, Any]
|
35 |
+
|
36 |
+
|
37 |
+
class SpatialBroadcastDecoder(nn.Module):
|
38 |
+
"""Spatial broadcast decoder for a set of slots (per frame)."""
|
39 |
+
|
40 |
+
resolution: Sequence[int]
|
41 |
+
backbone: Callable[[], nn.Module]
|
42 |
+
pos_emb: Callable[[], nn.Module]
|
43 |
+
early_fusion: bool = False # Fuse slot features before constructing targets.
|
44 |
+
target_readout: Optional[Callable[[], nn.Module]] = None
|
45 |
+
|
46 |
+
# Vmapped application of module, consumes time axis (axis=1).
|
47 |
+
@functools.partial(utils.time_distributed, in_axes=(1, None))
|
48 |
+
@nn.compact
|
49 |
+
def __call__(self, slots, train = False):
|
50 |
+
|
51 |
+
batch_size, n_slots, n_features = slots.shape
|
52 |
+
|
53 |
+
# Fold slot dim into batch dim.
|
54 |
+
x = jnp.reshape(slots, (batch_size * n_slots, n_features))
|
55 |
+
|
56 |
+
# Spatial broadcast with position embedding.
|
57 |
+
x = utils.spatial_broadcast(x, self.resolution)
|
58 |
+
x = self.pos_emb()(x)
|
59 |
+
|
60 |
+
# bb_features.shape = (batch_size * n_slots, h, w, c)
|
61 |
+
bb_features = self.backbone()(x, train=train)
|
62 |
+
spatial_dims = bb_features.shape[-3:-1]
|
63 |
+
|
64 |
+
alpha_logits = nn.Dense(
|
65 |
+
features=1, use_bias=True, name="alpha_logits")(bb_features)
|
66 |
+
alpha_logits = jnp.reshape(
|
67 |
+
alpha_logits, (batch_size, n_slots) + spatial_dims + (-1,))
|
68 |
+
|
69 |
+
alphas = nn.softmax(alpha_logits, axis=1)
|
70 |
+
if not train:
|
71 |
+
# Define intermediates for logging / visualization.
|
72 |
+
self.sow("intermediates", "alphas", alphas)
|
73 |
+
|
74 |
+
if self.early_fusion:
|
75 |
+
# To save memory, fuse the slot features before predicting targets.
|
76 |
+
# The final target output should be equivalent to the late fusion when
|
77 |
+
# using linear prediction.
|
78 |
+
bb_features = jnp.reshape(
|
79 |
+
bb_features, (batch_size, n_slots) + spatial_dims + (-1,))
|
80 |
+
# Combine backbone features by alpha masks.
|
81 |
+
bb_features = jnp.sum(bb_features * alphas, axis=1)
|
82 |
+
|
83 |
+
targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable
|
84 |
+
|
85 |
+
preds_dict = dict()
|
86 |
+
for target_key, channels in targets_dict.items():
|
87 |
+
if self.early_fusion:
|
88 |
+
# decoded_target.shape = (batch_size, h, w, c) after next line.
|
89 |
+
decoded_target = channels
|
90 |
+
else:
|
91 |
+
# channels.shape = (batch_size, n_slots, h, w, c)
|
92 |
+
channels = jnp.reshape(
|
93 |
+
channels, (batch_size, n_slots) + (spatial_dims) + (-1,))
|
94 |
+
|
95 |
+
# masked_channels.shape = (batch_size, n_slots, h, w, c)
|
96 |
+
masked_channels = channels * alphas
|
97 |
+
|
98 |
+
# decoded_target.shape = (batch_size, h, w, c)
|
99 |
+
decoded_target = jnp.sum(masked_channels, axis=1) # Combine target.
|
100 |
+
preds_dict[target_key] = decoded_target
|
101 |
+
|
102 |
+
if not train:
|
103 |
+
# Define intermediates for logging / visualization.
|
104 |
+
self.sow("intermediates", f"{target_key}_slots", channels)
|
105 |
+
if not self.early_fusion:
|
106 |
+
self.sow("intermediates", f"{target_key}_masked", masked_channels)
|
107 |
+
self.sow("intermediates", f"{target_key}_combined", decoded_target)
|
108 |
+
|
109 |
+
preds_dict["segmentations"] = jnp.argmax(alpha_logits, axis=1)
|
110 |
+
|
111 |
+
return preds_dict
|
112 |
+
|
113 |
+
|
114 |
+
class SiameseSpatialBroadcastDecoder(nn.Module):
|
115 |
+
"""Siamese spatial broadcast decoder for a set of slots (per frame).
|
116 |
+
|
117 |
+
Similar to the decoders used in IODINE: https://arxiv.org/abs/1903.00450
|
118 |
+
and in Slot Attention: https://arxiv.org/abs/2006.15055.
|
119 |
+
"""
|
120 |
+
|
121 |
+
resolution: Sequence[int]
|
122 |
+
backbone: Callable[[], nn.Module]
|
123 |
+
pos_emb: Callable[[], nn.Module]
|
124 |
+
pass_intermediates: bool = False
|
125 |
+
alpha_only: bool = False # Predict only alpha masks.
|
126 |
+
concat_attn: bool = False
|
127 |
+
# Readout after backbone.
|
128 |
+
target_readout_from_slots: bool = False
|
129 |
+
target_readout: Optional[Callable[[], nn.Module]] = None
|
130 |
+
early_fusion: bool = False # Fuse slot features before constructing targets.
|
131 |
+
# Readout on slots.
|
132 |
+
attribute_readout: Optional[Callable[[], nn.Module]] = None
|
133 |
+
remove_background_attribute: bool = False
|
134 |
+
attn_key: Optional[str] = None
|
135 |
+
attn_width: Optional[int] = None
|
136 |
+
# If True, expects slot embeddings to contain slot positions.
|
137 |
+
relative_positions: bool = False
|
138 |
+
# Slot positions and scales.
|
139 |
+
relative_positions_and_scales: bool = False
|
140 |
+
relative_positions_rotations_and_scales: bool = False
|
141 |
+
|
142 |
+
# Vmapped application of module, consumes time axis (axis=1).
|
143 |
+
@functools.partial(utils.time_distributed, in_axes=(1, None))
|
144 |
+
@nn.compact
|
145 |
+
def __call__(self,
|
146 |
+
slots,
|
147 |
+
train = False):
|
148 |
+
|
149 |
+
if self.remove_background_attribute and self.attribute_readout is None:
|
150 |
+
raise NotImplementedError(
|
151 |
+
"Background removal is only supported for attribute readout.")
|
152 |
+
|
153 |
+
if self.relative_positions:
|
154 |
+
# Assume slot positions were concatenated to slot embeddings.
|
155 |
+
# E.g. an output of SlotAttentionTranslEquiv.
|
156 |
+
slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:]
|
157 |
+
# Reshape positions to [B * num_slots, 2]
|
158 |
+
positions = positions.reshape(
|
159 |
+
(positions.shape[0] * positions.shape[1], positions.shape[2]))
|
160 |
+
elif self.relative_positions_and_scales:
|
161 |
+
# Assume slot positions and scales were concatenated to slot embeddings.
|
162 |
+
# E.g. an output of SlotAttentionTranslScaleEquiv.
|
163 |
+
slots, positions, scales = (slots[Ellipsis, :-4],
|
164 |
+
slots[Ellipsis, -4: -2],
|
165 |
+
slots[Ellipsis, -2:])
|
166 |
+
positions = positions.reshape(
|
167 |
+
(positions.shape[0] * positions.shape[1], positions.shape[2]))
|
168 |
+
scales = scales.reshape(
|
169 |
+
(scales.shape[0] * scales.shape[1], scales.shape[2]))
|
170 |
+
elif self.relative_positions_rotations_and_scales:
|
171 |
+
slots, positions, scales, rotm = (slots[Ellipsis, :-8],
|
172 |
+
slots[Ellipsis, -8: -6],
|
173 |
+
slots[Ellipsis, -6: -4],
|
174 |
+
slots[Ellipsis, -4:])
|
175 |
+
positions = positions.reshape(
|
176 |
+
(positions.shape[0] * positions.shape[1], positions.shape[2]))
|
177 |
+
scales = scales.reshape(
|
178 |
+
(scales.shape[0] * scales.shape[1], scales.shape[2]))
|
179 |
+
rotm = rotm.reshape(
|
180 |
+
rotm.shape[0] * rotm.shape[1], 2, 2)
|
181 |
+
|
182 |
+
batch_size, n_slots, n_features = slots.shape
|
183 |
+
|
184 |
+
preds_dict = {}
|
185 |
+
# Fold slot dim into batch dim.
|
186 |
+
x = jnp.reshape(slots, (batch_size * n_slots, n_features))
|
187 |
+
|
188 |
+
# Attribute readout.
|
189 |
+
if self.attribute_readout is not None:
|
190 |
+
if self.remove_background_attribute:
|
191 |
+
slots = slots[:, 1:]
|
192 |
+
attributes_dict = self.attribute_readout()(slots, train) # pylint: disable=not-callable
|
193 |
+
preds_dict.update(attributes_dict)
|
194 |
+
|
195 |
+
# Spatial broadcast with position embedding.
|
196 |
+
# See https://arxiv.org/abs/1901.07017.
|
197 |
+
x = utils.spatial_broadcast(x, self.resolution)
|
198 |
+
|
199 |
+
if self.relative_positions:
|
200 |
+
x = self.pos_emb()(inputs=x, slot_positions=positions)
|
201 |
+
elif self.relative_positions_and_scales:
|
202 |
+
x = self.pos_emb()(inputs=x, slot_positions=positions, slot_scales=scales)
|
203 |
+
elif self.relative_positions_rotations_and_scales:
|
204 |
+
x = self.pos_emb()(
|
205 |
+
inputs=x, slot_positions=positions, slot_scales=scales,
|
206 |
+
slot_rotm=rotm)
|
207 |
+
else:
|
208 |
+
x = self.pos_emb()(x)
|
209 |
+
|
210 |
+
# bb_features.shape = (batch_size*n_slots, h, w, c)
|
211 |
+
bb_features = self.backbone()(x, train=train)
|
212 |
+
spatial_dims = bb_features.shape[-3:-1]
|
213 |
+
alphas = nn.Dense(features=1, use_bias=True, name="alphas")(bb_features)
|
214 |
+
alphas = jnp.reshape(
|
215 |
+
alphas, (batch_size, n_slots) + spatial_dims + (-1,))
|
216 |
+
alphas_softmaxed = nn.softmax(alphas, axis=1)
|
217 |
+
preds_dict["segmentation_logits"] = alphas
|
218 |
+
preds_dict["segmentations"] = jnp.argmax(alphas, axis=1)
|
219 |
+
# Define intermediates for logging.
|
220 |
+
_ = misc.Identity(name="alphas_softmaxed")(alphas_softmaxed)
|
221 |
+
if self.alpha_only or self.target_readout is None:
|
222 |
+
assert alphas.shape[-1] == 1, "Alpha masks need to be one-dimensional."
|
223 |
+
return preds_dict, {"segmentation_logits": alphas}
|
224 |
+
|
225 |
+
if self.early_fusion:
|
226 |
+
# To save memory, fuse the slot features before predicting targets.
|
227 |
+
# The final target output should be equivalent to the late fusion when
|
228 |
+
# using linear prediction.
|
229 |
+
bb_features = jnp.reshape(
|
230 |
+
bb_features, (batch_size, n_slots) + spatial_dims + (-1,))
|
231 |
+
# Combine backbone features by alpha masks.
|
232 |
+
bb_features = jnp.sum(bb_features * alphas_softmaxed, axis=1)
|
233 |
+
|
234 |
+
if self.target_readout_from_slots:
|
235 |
+
targets_dict = self.target_readout()(slots, train) # pylint: disable=not-callable
|
236 |
+
else:
|
237 |
+
targets_dict = self.target_readout()(bb_features, train) # pylint: disable=not-callable
|
238 |
+
|
239 |
+
targets_dict_new = dict()
|
240 |
+
targets_dict_new["targets_masks"] = alphas_softmaxed
|
241 |
+
targets_dict_new["targets_logits_masks"] = alphas
|
242 |
+
|
243 |
+
for target_key, channels in targets_dict.items():
|
244 |
+
if self.early_fusion:
|
245 |
+
# decoded_target.shape = (batch_size, h, w, c) after next line.
|
246 |
+
decoded_target = channels
|
247 |
+
else:
|
248 |
+
# channels.shape = (batch_size, n_slots, h, w, c) after next line.
|
249 |
+
channels = jnp.reshape(
|
250 |
+
channels, (batch_size, n_slots) +
|
251 |
+
(spatial_dims if not self.target_readout_from_slots else
|
252 |
+
(1, 1)) + (-1,))
|
253 |
+
# masked_channels.shape = (batch_size, n_slots, h, w, c) at next line.
|
254 |
+
masked_channels = channels * alphas_softmaxed
|
255 |
+
# decoded_target.shape = (batch_size, h, w, c) after next line.
|
256 |
+
decoded_target = jnp.sum(masked_channels, axis=1) # Combine target.
|
257 |
+
targets_dict_new[target_key + "_channels"] = channels
|
258 |
+
# Define intermediates for logging.
|
259 |
+
_ = misc.Identity(name=f"{target_key}_channels")(channels)
|
260 |
+
_ = misc.Identity(name=f"{target_key}_masked_channels")(masked_channels)
|
261 |
+
|
262 |
+
targets_dict_new[target_key] = decoded_target
|
263 |
+
# Define intermediates for logging.
|
264 |
+
_ = misc.Identity(name=f"decoded_{target_key}")(decoded_target)
|
265 |
+
|
266 |
+
preds_dict.update(targets_dict_new)
|
267 |
+
return preds_dict
|
invariant_slot_attention/modules/initializers.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Initializers module library."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
|
26 |
+
from invariant_slot_attention.lib import utils
|
27 |
+
from invariant_slot_attention.modules import misc
|
28 |
+
from invariant_slot_attention.modules import video
|
29 |
+
|
30 |
+
Shape = Tuple[int]
|
31 |
+
|
32 |
+
DType = Any
|
33 |
+
Array = Any # jnp.ndarray
|
34 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
35 |
+
ProcessorState = ArrayTree
|
36 |
+
PRNGKey = Array
|
37 |
+
NestedDict = Dict[str, Any]
|
38 |
+
|
39 |
+
|
40 |
+
class ParamStateInit(nn.Module):
|
41 |
+
"""Fixed, learnable state initalization.
|
42 |
+
|
43 |
+
Note: This module ignores any conditional input (by design).
|
44 |
+
"""
|
45 |
+
|
46 |
+
shape: Sequence[int]
|
47 |
+
init_fn: str = "normal" # Default init with unit variance.
|
48 |
+
|
49 |
+
@nn.compact
|
50 |
+
def __call__(self, inputs, batch_size,
|
51 |
+
train = False):
|
52 |
+
del inputs, train # Unused.
|
53 |
+
|
54 |
+
if self.init_fn == "normal":
|
55 |
+
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
|
56 |
+
elif self.init_fn == "zeros":
|
57 |
+
init_fn = lambda: nn.initializers.zeros
|
58 |
+
else:
|
59 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
60 |
+
|
61 |
+
param = self.param("state_init", init_fn(), self.shape)
|
62 |
+
return utils.broadcast_across_batch(param, batch_size=batch_size)
|
63 |
+
|
64 |
+
|
65 |
+
class GaussianStateInit(nn.Module):
|
66 |
+
"""Random state initialization with zero-mean, unit-variance Gaussian.
|
67 |
+
|
68 |
+
Note: This module does not contain any trainable parameters and requires
|
69 |
+
providing a jax.PRNGKey both at training and at test time. Note: This module
|
70 |
+
also ignores any conditional input (by design).
|
71 |
+
"""
|
72 |
+
|
73 |
+
shape: Sequence[int]
|
74 |
+
|
75 |
+
@nn.compact
|
76 |
+
def __call__(self, inputs, batch_size,
|
77 |
+
train = False):
|
78 |
+
del inputs, train # Unused.
|
79 |
+
rng = self.make_rng("state_init")
|
80 |
+
return jax.random.normal(rng, shape=[batch_size] + list(self.shape))
|
81 |
+
|
82 |
+
|
83 |
+
class SegmentationEncoderStateInit(nn.Module):
|
84 |
+
"""State init that encodes segmentation masks as conditional input."""
|
85 |
+
|
86 |
+
max_num_slots: int
|
87 |
+
backbone: Callable[[], nn.Module]
|
88 |
+
pos_emb: Callable[[], nn.Module] = misc.Identity
|
89 |
+
reduction: Optional[str] = "all_flatten" # Reduce spatial dim by default.
|
90 |
+
output_transform: Callable[[], nn.Module] = misc.Identity
|
91 |
+
zero_background: bool = False
|
92 |
+
|
93 |
+
@nn.compact
|
94 |
+
def __call__(self, inputs, batch_size,
|
95 |
+
train = False):
|
96 |
+
del batch_size # Unused.
|
97 |
+
|
98 |
+
# inputs.shape = (batch_size, seq_len, height, width)
|
99 |
+
inputs = inputs[:, 0] # Only condition on first time step.
|
100 |
+
|
101 |
+
# Convert mask index to one-hot.
|
102 |
+
inputs_oh = jax.nn.one_hot(inputs, self.max_num_slots)
|
103 |
+
# inputs_oh.shape = (batch_size, height, width, n_slots)
|
104 |
+
# NOTE: 0th entry inputs_oh[..., 0] will typically correspond to background.
|
105 |
+
|
106 |
+
# Set background slot to all-zeros.
|
107 |
+
if self.zero_background:
|
108 |
+
inputs_oh = inputs_oh.at[:, :, :, 0].set(0)
|
109 |
+
|
110 |
+
# Switch one-hot axis into 1st position (i.e. sequence axis).
|
111 |
+
inputs_oh = jnp.transpose(inputs_oh, (0, 3, 1, 2))
|
112 |
+
# inputs_oh.shape = (batch_size, max_num_slots, height, width)
|
113 |
+
|
114 |
+
# Append dummy feature axis.
|
115 |
+
inputs_oh = jnp.expand_dims(inputs_oh, axis=-1)
|
116 |
+
|
117 |
+
# Vmapped encoder over seq. axis (i.e. we process each slot independently).
|
118 |
+
encoder = video.FrameEncoder(
|
119 |
+
backbone=self.backbone,
|
120 |
+
pos_emb=self.pos_emb,
|
121 |
+
reduction=self.reduction,
|
122 |
+
output_transform=self.output_transform) # type: ignore
|
123 |
+
|
124 |
+
# encoder(inputs_oh).shape = (batch_size, n_slots, n_features)
|
125 |
+
slots = encoder(inputs_oh, None, train)
|
126 |
+
|
127 |
+
return slots
|
128 |
+
|
129 |
+
|
130 |
+
class CoordinateEncoderStateInit(nn.Module):
|
131 |
+
"""State init that encodes bounding box coordinates as conditional input.
|
132 |
+
|
133 |
+
Attributes:
|
134 |
+
embedding_transform: A nn.Module that is applied on inputs (bounding boxes).
|
135 |
+
prepend_background: Boolean flag; whether to prepend a special, zero-valued
|
136 |
+
background bounding box to the input. Default: false.
|
137 |
+
center_of_mass: Boolean flag; whether to convert bounding boxes to center
|
138 |
+
of mass coordinates. Default: false.
|
139 |
+
background_value: Default value to fill in the background.
|
140 |
+
"""
|
141 |
+
|
142 |
+
embedding_transform: Callable[[], nn.Module]
|
143 |
+
prepend_background: bool = False
|
144 |
+
center_of_mass: bool = False
|
145 |
+
background_value: float = 0.
|
146 |
+
|
147 |
+
@nn.compact
|
148 |
+
def __call__(self, inputs, batch_size,
|
149 |
+
train = False):
|
150 |
+
del batch_size # Unused.
|
151 |
+
|
152 |
+
# inputs.shape = (batch_size, seq_len, bboxes, 4)
|
153 |
+
inputs = inputs[:, 0] # Only condition on first time step.
|
154 |
+
# inputs.shape = (batch_size, bboxes, 4)
|
155 |
+
|
156 |
+
if self.prepend_background:
|
157 |
+
# Adds a fake background box [0, 0, 0, 0] at the beginning.
|
158 |
+
batch_size = inputs.shape[0]
|
159 |
+
|
160 |
+
# Encode the background as specified by background_value.
|
161 |
+
background = jnp.full(
|
162 |
+
(batch_size, 1, 4), self.background_value, dtype=inputs.dtype)
|
163 |
+
|
164 |
+
inputs = jnp.concatenate((background, inputs), axis=1)
|
165 |
+
|
166 |
+
if self.center_of_mass:
|
167 |
+
y_pos = (inputs[:, :, 0] + inputs[:, :, 2]) / 2
|
168 |
+
x_pos = (inputs[:, :, 1] + inputs[:, :, 3]) / 2
|
169 |
+
inputs = jnp.stack((y_pos, x_pos), axis=-1)
|
170 |
+
|
171 |
+
slots = self.embedding_transform()(inputs, train=train) # pytype: disable=not-callable
|
172 |
+
|
173 |
+
return slots
|
invariant_slot_attention/modules/invariant_attention.py
ADDED
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Equivariant attention module library."""
|
17 |
+
import functools
|
18 |
+
from typing import Any, Optional, Tuple
|
19 |
+
|
20 |
+
from flax import linen as nn
|
21 |
+
import jax
|
22 |
+
import jax.numpy as jnp
|
23 |
+
from invariant_slot_attention.modules import attention
|
24 |
+
from invariant_slot_attention.modules import misc
|
25 |
+
|
26 |
+
Shape = Tuple[int]
|
27 |
+
|
28 |
+
DType = Any
|
29 |
+
Array = Any # jnp.ndarray
|
30 |
+
PRNGKey = Array
|
31 |
+
|
32 |
+
|
33 |
+
class InvertedDotProductAttentionKeyPerQuery(nn.Module):
|
34 |
+
"""Inverted dot-product attention with a different set of keys per query.
|
35 |
+
|
36 |
+
Used in SlotAttentionTranslEquiv, where each slot has a position.
|
37 |
+
The positions are used to create relative coordinate grids,
|
38 |
+
which result in a different set of inputs (keys) for each slot.
|
39 |
+
"""
|
40 |
+
|
41 |
+
dtype: DType = jnp.float32
|
42 |
+
precision: Optional[jax.lax.Precision] = None
|
43 |
+
epsilon: float = 1e-8
|
44 |
+
renormalize_keys: bool = False
|
45 |
+
attn_weights_only: bool = False
|
46 |
+
softmax_temperature: float = 1.0
|
47 |
+
value_per_query: bool = False
|
48 |
+
|
49 |
+
@nn.compact
|
50 |
+
def __call__(self, query, key, value, train):
|
51 |
+
"""Computes inverted dot-product attention with key per query.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
query: Queries with shape of `[batch..., q_num, qk_features]`.
|
55 |
+
key: Keys with shape of `[batch..., q_num, kv_num, qk_features]`.
|
56 |
+
value: Values with shape of `[batch..., kv_num, v_features]`.
|
57 |
+
train: Indicating whether we're training or evaluating.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tuple of two elements: (1) output of shape
|
61 |
+
`[batch_size..., q_num, v_features]` and (2) attention mask of shape
|
62 |
+
`[batch_size..., q_num, kv_num]`.
|
63 |
+
"""
|
64 |
+
qk_features = query.shape[-1]
|
65 |
+
query = query / jnp.sqrt(qk_features).astype(self.dtype)
|
66 |
+
|
67 |
+
# Each query is multiplied with its own set of keys.
|
68 |
+
attn = jnp.einsum(
|
69 |
+
"...qd,...qkd->...qk", query, key, precision=self.precision
|
70 |
+
)
|
71 |
+
|
72 |
+
# axis=-2 for a softmax over query axis (inverted attention).
|
73 |
+
attn = jax.nn.softmax(
|
74 |
+
attn / self.softmax_temperature, axis=-2
|
75 |
+
).astype(self.dtype)
|
76 |
+
|
77 |
+
# We expand dims because the logger expect a #heads dimension.
|
78 |
+
self.sow("intermediates", "attn", jnp.expand_dims(attn, -3))
|
79 |
+
|
80 |
+
if self.renormalize_keys:
|
81 |
+
normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon
|
82 |
+
attn = attn / normalizer
|
83 |
+
|
84 |
+
if self.attn_weights_only:
|
85 |
+
return attn
|
86 |
+
|
87 |
+
output = jnp.einsum(
|
88 |
+
"...qk,...qkd->...qd" if self.value_per_query else "...qk,...kd->...qd",
|
89 |
+
attn,
|
90 |
+
value,
|
91 |
+
precision=self.precision
|
92 |
+
)
|
93 |
+
|
94 |
+
return output, attn
|
95 |
+
|
96 |
+
|
97 |
+
class SlotAttentionExplicitStats(nn.Module):
|
98 |
+
"""Slot Attention module with explicit slot statistics.
|
99 |
+
|
100 |
+
Slot statistics, such as position and scale, are appended to the
|
101 |
+
output slot representations.
|
102 |
+
|
103 |
+
Note: This module expects a 2D coordinate grid to be appended
|
104 |
+
at the end of inputs.
|
105 |
+
|
106 |
+
Note: This module uses pre-normalization by default.
|
107 |
+
"""
|
108 |
+
grid_encoder: nn.Module
|
109 |
+
num_iterations: int = 1
|
110 |
+
qkv_size: Optional[int] = None
|
111 |
+
mlp_size: Optional[int] = None
|
112 |
+
epsilon: float = 1e-8
|
113 |
+
softmax_temperature: float = 1.0
|
114 |
+
gumbel_softmax: bool = False
|
115 |
+
gumbel_softmax_straight_through: bool = False
|
116 |
+
num_heads: int = 1
|
117 |
+
min_scale: float = 0.01
|
118 |
+
max_scale: float = 5.
|
119 |
+
return_slot_positions: bool = True
|
120 |
+
return_slot_scales: bool = True
|
121 |
+
|
122 |
+
@nn.compact
|
123 |
+
def __call__(self, slots, inputs,
|
124 |
+
padding_mask = None,
|
125 |
+
train = False):
|
126 |
+
"""Slot Attention with explicit slot statistics module forward pass."""
|
127 |
+
del padding_mask # Unused.
|
128 |
+
# Slot scales require slot positions.
|
129 |
+
assert self.return_slot_positions or not self.return_slot_scales
|
130 |
+
|
131 |
+
# Separate a concatenated linear coordinate grid from the inputs.
|
132 |
+
inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
|
133 |
+
|
134 |
+
# Hack so that the input and output slot dimensions are the same.
|
135 |
+
to_remove = 0
|
136 |
+
if self.return_slot_positions:
|
137 |
+
to_remove += 2
|
138 |
+
if self.return_slot_scales:
|
139 |
+
to_remove += 2
|
140 |
+
if to_remove > 0:
|
141 |
+
slots = slots[Ellipsis, :-to_remove]
|
142 |
+
|
143 |
+
# Add position encodings to inputs
|
144 |
+
n_features = inputs.shape[-1]
|
145 |
+
grid_projector = nn.Dense(n_features, name="dense_pe_0")
|
146 |
+
inputs = self.grid_encoder()(inputs + grid_projector(grid))
|
147 |
+
|
148 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
149 |
+
head_dim = qkv_size // self.num_heads
|
150 |
+
dense = functools.partial(nn.DenseGeneral,
|
151 |
+
axis=-1, features=(self.num_heads, head_dim),
|
152 |
+
use_bias=False)
|
153 |
+
|
154 |
+
# Shared modules.
|
155 |
+
dense_q = dense(name="general_dense_q_0")
|
156 |
+
layernorm_q = nn.LayerNorm()
|
157 |
+
inverted_attention = attention.InvertedDotProductAttention(
|
158 |
+
norm_type="mean",
|
159 |
+
multi_head=self.num_heads > 1,
|
160 |
+
return_attn_weights=True)
|
161 |
+
gru = misc.GRU()
|
162 |
+
|
163 |
+
if self.mlp_size is not None:
|
164 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
165 |
+
|
166 |
+
# inputs.shape = (..., n_inputs, inputs_size).
|
167 |
+
inputs = nn.LayerNorm()(inputs)
|
168 |
+
# k.shape = (..., n_inputs, slot_size).
|
169 |
+
k = dense(name="general_dense_k_0")(inputs)
|
170 |
+
# v.shape = (..., n_inputs, slot_size).
|
171 |
+
v = dense(name="general_dense_v_0")(inputs)
|
172 |
+
|
173 |
+
# Multiple rounds of attention.
|
174 |
+
for _ in range(self.num_iterations):
|
175 |
+
|
176 |
+
# Inverted dot-product attention.
|
177 |
+
slots_n = layernorm_q(slots)
|
178 |
+
q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
|
179 |
+
updates, attn = inverted_attention(query=q, key=k, value=v, train=train)
|
180 |
+
|
181 |
+
# Recurrent update.
|
182 |
+
slots = gru(slots, updates)
|
183 |
+
|
184 |
+
# Feedforward block with pre-normalization.
|
185 |
+
if self.mlp_size is not None:
|
186 |
+
slots = mlp(slots)
|
187 |
+
|
188 |
+
if self.return_slot_positions:
|
189 |
+
# Compute the center of mass of each slot attention mask.
|
190 |
+
positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
|
191 |
+
slots = jnp.concatenate([slots, positions], axis=-1)
|
192 |
+
|
193 |
+
if self.return_slot_scales:
|
194 |
+
# Compute slot scales. Take the square root to make the operation
|
195 |
+
# analogous to normalizing data drawn from a Gaussian.
|
196 |
+
spread = jnp.square(
|
197 |
+
jnp.expand_dims(grid, axis=-3) - jnp.expand_dims(positions, axis=-2))
|
198 |
+
scales = jnp.sqrt(
|
199 |
+
jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
|
200 |
+
scales = jnp.clip(scales, self.min_scale, self.max_scale)
|
201 |
+
slots = jnp.concatenate([slots, scales], axis=-1)
|
202 |
+
|
203 |
+
return slots
|
204 |
+
|
205 |
+
|
206 |
+
class SlotAttentionPosKeysValues(nn.Module):
|
207 |
+
"""Slot Attention module with positional encodings in keys and values.
|
208 |
+
|
209 |
+
Feature position encodings are added to keys and values instead
|
210 |
+
of the inputs.
|
211 |
+
|
212 |
+
Note: This module expects a 2D coordinate grid to be appended
|
213 |
+
at the end of inputs.
|
214 |
+
|
215 |
+
Note: This module uses pre-normalization by default.
|
216 |
+
"""
|
217 |
+
grid_encoder: nn.Module
|
218 |
+
num_iterations: int = 1
|
219 |
+
qkv_size: Optional[int] = None
|
220 |
+
mlp_size: Optional[int] = None
|
221 |
+
epsilon: float = 1e-8
|
222 |
+
softmax_temperature: float = 1.0
|
223 |
+
gumbel_softmax: bool = False
|
224 |
+
gumbel_softmax_straight_through: bool = False
|
225 |
+
num_heads: int = 1
|
226 |
+
|
227 |
+
@nn.compact
|
228 |
+
def __call__(self, slots, inputs,
|
229 |
+
padding_mask = None,
|
230 |
+
train = False):
|
231 |
+
"""Slot Attention with explicit slot statistics module forward pass."""
|
232 |
+
del padding_mask # Unused.
|
233 |
+
|
234 |
+
# Separate a concatenated linear coordinate grid from the inputs.
|
235 |
+
inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
|
236 |
+
|
237 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
238 |
+
head_dim = qkv_size // self.num_heads
|
239 |
+
dense = functools.partial(nn.DenseGeneral,
|
240 |
+
axis=-1, features=(self.num_heads, head_dim),
|
241 |
+
use_bias=False)
|
242 |
+
|
243 |
+
# Shared modules.
|
244 |
+
dense_q = dense(name="general_dense_q_0")
|
245 |
+
layernorm_q = nn.LayerNorm()
|
246 |
+
inverted_attention = attention.InvertedDotProductAttention(
|
247 |
+
norm_type="mean",
|
248 |
+
multi_head=self.num_heads > 1)
|
249 |
+
gru = misc.GRU()
|
250 |
+
|
251 |
+
if self.mlp_size is not None:
|
252 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
253 |
+
|
254 |
+
# inputs.shape = (..., n_inputs, inputs_size).
|
255 |
+
inputs = nn.LayerNorm()(inputs)
|
256 |
+
# k.shape = (..., n_inputs, slot_size).
|
257 |
+
k = dense(name="general_dense_k_0")(inputs)
|
258 |
+
# v.shape = (..., n_inputs, slot_size).
|
259 |
+
v = dense(name="general_dense_v_0")(inputs)
|
260 |
+
|
261 |
+
# Add position encodings to keys and values.
|
262 |
+
grid_projector = dense(name="general_dense_p_0")
|
263 |
+
grid_encoder = self.grid_encoder()
|
264 |
+
k = grid_encoder(k + grid_projector(grid))
|
265 |
+
v = grid_encoder(v + grid_projector(grid))
|
266 |
+
|
267 |
+
# Multiple rounds of attention.
|
268 |
+
for _ in range(self.num_iterations):
|
269 |
+
|
270 |
+
# Inverted dot-product attention.
|
271 |
+
slots_n = layernorm_q(slots)
|
272 |
+
q = dense_q(slots_n) # q.shape = (..., n_inputs, slot_size).
|
273 |
+
updates = inverted_attention(query=q, key=k, value=v, train=train)
|
274 |
+
|
275 |
+
# Recurrent update.
|
276 |
+
slots = gru(slots, updates)
|
277 |
+
|
278 |
+
# Feedforward block with pre-normalization.
|
279 |
+
if self.mlp_size is not None:
|
280 |
+
slots = mlp(slots)
|
281 |
+
|
282 |
+
return slots
|
283 |
+
|
284 |
+
|
285 |
+
class SlotAttentionTranslEquiv(nn.Module):
|
286 |
+
"""Slot Attention module with slot positions.
|
287 |
+
|
288 |
+
A position is computed for each slot. Slot positions are used to create
|
289 |
+
relative coordinate grids, which are used as position embeddings reapplied
|
290 |
+
in each iteration of slot attention. The last two channels in inputs
|
291 |
+
must contain the flattened position grid.
|
292 |
+
|
293 |
+
Note: This module uses pre-normalization by default.
|
294 |
+
"""
|
295 |
+
|
296 |
+
grid_encoder: nn.Module
|
297 |
+
num_iterations: int = 1
|
298 |
+
qkv_size: Optional[int] = None
|
299 |
+
mlp_size: Optional[int] = None
|
300 |
+
epsilon: float = 1e-8
|
301 |
+
softmax_temperature: float = 1.0
|
302 |
+
gumbel_softmax: bool = False
|
303 |
+
gumbel_softmax_straight_through: bool = False
|
304 |
+
num_heads: int = 1
|
305 |
+
zero_position_init: bool = True
|
306 |
+
ablate_non_equivariant: bool = False
|
307 |
+
stop_grad_positions: bool = False
|
308 |
+
mix_slots: bool = False
|
309 |
+
add_rel_pos_to_values: bool = False
|
310 |
+
append_statistics: bool = False
|
311 |
+
|
312 |
+
@nn.compact
|
313 |
+
def __call__(self, slots, inputs,
|
314 |
+
padding_mask = None,
|
315 |
+
train = False):
|
316 |
+
"""Slot Attention translation equiv. module forward pass."""
|
317 |
+
del padding_mask # Unused.
|
318 |
+
|
319 |
+
if self.num_heads > 1:
|
320 |
+
raise NotImplementedError("This prototype only uses one attn. head.")
|
321 |
+
|
322 |
+
# Separate a concatenated linear coordinate grid from the inputs.
|
323 |
+
inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
|
324 |
+
|
325 |
+
# Separate position (x,y) from slot embeddings.
|
326 |
+
slots, positions = slots[Ellipsis, :-2], slots[Ellipsis, -2:]
|
327 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
328 |
+
num_slots = slots.shape[-2]
|
329 |
+
|
330 |
+
# Prepare initial slot positions.
|
331 |
+
if self.zero_position_init:
|
332 |
+
# All slots start in the middle of the image.
|
333 |
+
positions *= 0.
|
334 |
+
|
335 |
+
# Learnable initial positions might deviate from the allowed range.
|
336 |
+
positions = jnp.clip(positions, -1., 1.)
|
337 |
+
|
338 |
+
# Pre-normalization.
|
339 |
+
inputs = nn.LayerNorm()(inputs)
|
340 |
+
|
341 |
+
grid_per_slot = jnp.repeat(
|
342 |
+
jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
|
343 |
+
|
344 |
+
# Shared modules.
|
345 |
+
dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
|
346 |
+
dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
|
347 |
+
dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
|
348 |
+
grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
|
349 |
+
grid_enc = self.grid_encoder()
|
350 |
+
layernorm_q = nn.LayerNorm()
|
351 |
+
inverted_attention = InvertedDotProductAttentionKeyPerQuery(
|
352 |
+
epsilon=self.epsilon,
|
353 |
+
renormalize_keys=True,
|
354 |
+
softmax_temperature=self.softmax_temperature,
|
355 |
+
value_per_query=self.add_rel_pos_to_values
|
356 |
+
)
|
357 |
+
gru = misc.GRU()
|
358 |
+
|
359 |
+
if self.mlp_size is not None:
|
360 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
361 |
+
|
362 |
+
if self.append_statistics:
|
363 |
+
embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
|
364 |
+
|
365 |
+
# k.shape and v.shape = (..., n_inputs, slot_size).
|
366 |
+
v = dense_v(inputs)
|
367 |
+
k = dense_k(inputs)
|
368 |
+
k_expand = jnp.expand_dims(k, axis=-3)
|
369 |
+
v_expand = jnp.expand_dims(v, axis=-3)
|
370 |
+
|
371 |
+
# Multiple rounds of attention. Last iteration updates positions only.
|
372 |
+
for attn_round in range(self.num_iterations + 1):
|
373 |
+
|
374 |
+
if self.ablate_non_equivariant:
|
375 |
+
# Add an encoded coordinate grid with absolute positions.
|
376 |
+
grid_emb_per_slot = grid_proj(grid_per_slot)
|
377 |
+
k_rel_pos = grid_enc(k_expand + grid_emb_per_slot)
|
378 |
+
if self.add_rel_pos_to_values:
|
379 |
+
v_rel_pos = grid_enc(v_expand + grid_emb_per_slot)
|
380 |
+
else:
|
381 |
+
# Relativize positions, encode them and add them to the keys
|
382 |
+
# and optionally to values.
|
383 |
+
relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
|
384 |
+
grid_emb_per_slot = grid_proj(relative_grid)
|
385 |
+
k_rel_pos = grid_enc(k_expand + grid_emb_per_slot)
|
386 |
+
if self.add_rel_pos_to_values:
|
387 |
+
v_rel_pos = grid_enc(v_expand + grid_emb_per_slot)
|
388 |
+
|
389 |
+
# Inverted dot-product attention.
|
390 |
+
slots_n = layernorm_q(slots)
|
391 |
+
q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
|
392 |
+
updates, attn = inverted_attention(
|
393 |
+
query=q,
|
394 |
+
key=k_rel_pos,
|
395 |
+
value=v_rel_pos if self.add_rel_pos_to_values else v,
|
396 |
+
train=train)
|
397 |
+
|
398 |
+
# Compute the center of mass of each slot attention mask.
|
399 |
+
# Guaranteed to be in [-1, 1].
|
400 |
+
positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
|
401 |
+
|
402 |
+
if self.stop_grad_positions:
|
403 |
+
# Do not backprop through positions and scales.
|
404 |
+
positions = jax.lax.stop_gradient(positions)
|
405 |
+
|
406 |
+
if attn_round < self.num_iterations:
|
407 |
+
if self.append_statistics:
|
408 |
+
# Projects and add 2D slot positions into slot latents.
|
409 |
+
tmp = jnp.concatenate([slots, positions], axis=-1)
|
410 |
+
slots = embed_statistics(tmp)
|
411 |
+
|
412 |
+
# Recurrent update.
|
413 |
+
slots = gru(slots, updates)
|
414 |
+
|
415 |
+
# Feedforward block with pre-normalization.
|
416 |
+
if self.mlp_size is not None:
|
417 |
+
slots = mlp(slots)
|
418 |
+
|
419 |
+
# Concatenate position information to slots.
|
420 |
+
output = jnp.concatenate([slots, positions], axis=-1)
|
421 |
+
|
422 |
+
if self.mix_slots:
|
423 |
+
output = misc.MLP(hidden_size=128, layernorm="pre")(output)
|
424 |
+
|
425 |
+
return output
|
426 |
+
|
427 |
+
|
428 |
+
class SlotAttentionTranslScaleEquiv(nn.Module):
|
429 |
+
"""Slot Attention module with slot positions and scales.
|
430 |
+
|
431 |
+
A position and scale is computed for each slot. Slot positions and scales
|
432 |
+
are used to create relative coordinate grids, which are used as position
|
433 |
+
embeddings reapplied in each iteration of slot attention. The last two
|
434 |
+
channels in input must contain the flattened position grid.
|
435 |
+
|
436 |
+
Note: This module uses pre-normalization by default.
|
437 |
+
"""
|
438 |
+
|
439 |
+
grid_encoder: nn.Module
|
440 |
+
num_iterations: int = 1
|
441 |
+
qkv_size: Optional[int] = None
|
442 |
+
mlp_size: Optional[int] = None
|
443 |
+
epsilon: float = 1e-8
|
444 |
+
softmax_temperature: float = 1.0
|
445 |
+
gumbel_softmax: bool = False
|
446 |
+
gumbel_softmax_straight_through: bool = False
|
447 |
+
num_heads: int = 1
|
448 |
+
zero_position_init: bool = True
|
449 |
+
# Scale of 0.1 corresponds to fairly small objects.
|
450 |
+
init_with_fixed_scale: Optional[float] = 0.1
|
451 |
+
ablate_non_equivariant: bool = False
|
452 |
+
stop_grad_positions_and_scales: bool = False
|
453 |
+
mix_slots: bool = False
|
454 |
+
add_rel_pos_to_values: bool = False
|
455 |
+
scales_factor: float = 1.
|
456 |
+
# Slot scales cannot be negative and should not be too close to zero
|
457 |
+
# or too large.
|
458 |
+
min_scale: float = 0.001
|
459 |
+
max_scale: float = 2.
|
460 |
+
append_statistics: bool = False
|
461 |
+
|
462 |
+
@nn.compact
|
463 |
+
def __call__(self, slots, inputs,
|
464 |
+
padding_mask = None,
|
465 |
+
train = False):
|
466 |
+
"""Slot Attention translation and scale equiv. module forward pass."""
|
467 |
+
del padding_mask # Unused.
|
468 |
+
|
469 |
+
if self.num_heads > 1:
|
470 |
+
raise NotImplementedError("This prototype only uses one attn. head.")
|
471 |
+
|
472 |
+
# Separate a concatenated linear coordinate grid from the inputs.
|
473 |
+
inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
|
474 |
+
|
475 |
+
# Separate position (x,y) and scale from slot embeddings.
|
476 |
+
slots, positions, scales = (slots[Ellipsis, :-4],
|
477 |
+
slots[Ellipsis, -4: -2],
|
478 |
+
slots[Ellipsis, -2:])
|
479 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
480 |
+
num_slots = slots.shape[-2]
|
481 |
+
|
482 |
+
# Prepare initial slot positions.
|
483 |
+
if self.zero_position_init:
|
484 |
+
# All slots start in the middle of the image.
|
485 |
+
positions *= 0.
|
486 |
+
|
487 |
+
if self.init_with_fixed_scale is not None:
|
488 |
+
scales = scales * 0. + self.init_with_fixed_scale
|
489 |
+
|
490 |
+
# Learnable initial positions and scales could have arbitrary values.
|
491 |
+
positions = jnp.clip(positions, -1., 1.)
|
492 |
+
scales = jnp.clip(scales, self.min_scale, self.max_scale)
|
493 |
+
|
494 |
+
# Pre-normalization.
|
495 |
+
inputs = nn.LayerNorm()(inputs)
|
496 |
+
|
497 |
+
grid_per_slot = jnp.repeat(
|
498 |
+
jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
|
499 |
+
|
500 |
+
# Shared modules.
|
501 |
+
dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
|
502 |
+
dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
|
503 |
+
dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
|
504 |
+
grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
|
505 |
+
grid_enc = self.grid_encoder()
|
506 |
+
layernorm_q = nn.LayerNorm()
|
507 |
+
inverted_attention = InvertedDotProductAttentionKeyPerQuery(
|
508 |
+
epsilon=self.epsilon,
|
509 |
+
renormalize_keys=True,
|
510 |
+
softmax_temperature=self.softmax_temperature,
|
511 |
+
value_per_query=self.add_rel_pos_to_values
|
512 |
+
)
|
513 |
+
gru = misc.GRU()
|
514 |
+
|
515 |
+
if self.mlp_size is not None:
|
516 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
517 |
+
|
518 |
+
if self.append_statistics:
|
519 |
+
embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
|
520 |
+
|
521 |
+
# k.shape and v.shape = (..., n_inputs, slot_size).
|
522 |
+
v = dense_v(inputs)
|
523 |
+
k = dense_k(inputs)
|
524 |
+
k_expand = jnp.expand_dims(k, axis=-3)
|
525 |
+
v_expand = jnp.expand_dims(v, axis=-3)
|
526 |
+
|
527 |
+
# Multiple rounds of attention.
|
528 |
+
# Last iteration updates positions and scales only.
|
529 |
+
for attn_round in range(self.num_iterations + 1):
|
530 |
+
|
531 |
+
if self.ablate_non_equivariant:
|
532 |
+
# Add an encoded coordinate grid with absolute positions.
|
533 |
+
tmp_grid = grid_proj(grid_per_slot)
|
534 |
+
k_rel_pos = grid_enc(k_expand + tmp_grid)
|
535 |
+
if self.add_rel_pos_to_values:
|
536 |
+
v_rel_pos = grid_enc(v_expand + tmp_grid)
|
537 |
+
else:
|
538 |
+
# Relativize and scale positions, encode them and add them to inputs.
|
539 |
+
relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
|
540 |
+
# Scales are usually small so the grid might get too large.
|
541 |
+
relative_grid = relative_grid / self.scales_factor
|
542 |
+
relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2)
|
543 |
+
tmp_grid = grid_proj(relative_grid)
|
544 |
+
k_rel_pos = grid_enc(k_expand + tmp_grid)
|
545 |
+
if self.add_rel_pos_to_values:
|
546 |
+
v_rel_pos = grid_enc(v_expand + tmp_grid)
|
547 |
+
|
548 |
+
# Inverted dot-product attention.
|
549 |
+
slots_n = layernorm_q(slots)
|
550 |
+
q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
|
551 |
+
updates, attn = inverted_attention(
|
552 |
+
query=q,
|
553 |
+
key=k_rel_pos,
|
554 |
+
value=v_rel_pos if self.add_rel_pos_to_values else v,
|
555 |
+
train=train)
|
556 |
+
|
557 |
+
# Compute the center of mass of each slot attention mask.
|
558 |
+
positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
|
559 |
+
|
560 |
+
# Compute slot scales. Take the square root to make the operation
|
561 |
+
# analogous to normalizing data drawn from a Gaussian.
|
562 |
+
spread = jnp.square(grid_per_slot - jnp.expand_dims(positions, axis=-2))
|
563 |
+
scales = jnp.sqrt(
|
564 |
+
jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
|
565 |
+
|
566 |
+
# Computed positions are guaranteed to be in [-1, 1].
|
567 |
+
# Scales are unbounded.
|
568 |
+
scales = jnp.clip(scales, self.min_scale, self.max_scale)
|
569 |
+
|
570 |
+
if self.stop_grad_positions_and_scales:
|
571 |
+
# Do not backprop through positions and scales.
|
572 |
+
positions = jax.lax.stop_gradient(positions)
|
573 |
+
scales = jax.lax.stop_gradient(scales)
|
574 |
+
|
575 |
+
if attn_round < self.num_iterations:
|
576 |
+
if self.append_statistics:
|
577 |
+
# Project and add 2D slot positions and scales into slot latents.
|
578 |
+
tmp = jnp.concatenate([slots, positions, scales], axis=-1)
|
579 |
+
slots = embed_statistics(tmp)
|
580 |
+
|
581 |
+
# Recurrent update.
|
582 |
+
slots = gru(slots, updates)
|
583 |
+
|
584 |
+
# Feedforward block with pre-normalization.
|
585 |
+
if self.mlp_size is not None:
|
586 |
+
slots = mlp(slots)
|
587 |
+
|
588 |
+
# Concatenate position and scale information to slots.
|
589 |
+
output = jnp.concatenate([slots, positions, scales], axis=-1)
|
590 |
+
|
591 |
+
if self.mix_slots:
|
592 |
+
output = misc.MLP(hidden_size=128, layernorm="pre")(output)
|
593 |
+
|
594 |
+
return output
|
595 |
+
|
596 |
+
|
597 |
+
class SlotAttentionTranslRotScaleEquiv(nn.Module):
|
598 |
+
"""Slot Attention module with slot positions, rotations and scales.
|
599 |
+
|
600 |
+
A position, rotation and scale is computed for each slot.
|
601 |
+
Slot positions, rotations and scales are used to create relative
|
602 |
+
coordinate grids, which are used as position embeddings reapplied in each
|
603 |
+
iteration of slot attention. The last two channels in input must contain
|
604 |
+
the flattened position grid.
|
605 |
+
|
606 |
+
Note: This module uses pre-normalization by default.
|
607 |
+
"""
|
608 |
+
|
609 |
+
grid_encoder: nn.Module
|
610 |
+
num_iterations: int = 1
|
611 |
+
qkv_size: Optional[int] = None
|
612 |
+
mlp_size: Optional[int] = None
|
613 |
+
epsilon: float = 1e-8
|
614 |
+
softmax_temperature: float = 1.0
|
615 |
+
gumbel_softmax: bool = False
|
616 |
+
gumbel_softmax_straight_through: bool = False
|
617 |
+
num_heads: int = 1
|
618 |
+
zero_position_init: bool = True
|
619 |
+
# Scale of 0.1 corresponds to fairly small objects.
|
620 |
+
init_with_fixed_scale: Optional[float] = 0.1
|
621 |
+
ablate_non_equivariant: bool = False
|
622 |
+
stop_grad_positions: bool = False
|
623 |
+
stop_grad_scales: bool = False
|
624 |
+
stop_grad_rotations: bool = False
|
625 |
+
mix_slots: bool = False
|
626 |
+
add_rel_pos_to_values: bool = False
|
627 |
+
scales_factor: float = 1.
|
628 |
+
# Slot scales cannot be negative and should not be too close to zero
|
629 |
+
# or too large.
|
630 |
+
min_scale: float = 0.001
|
631 |
+
max_scale: float = 2.
|
632 |
+
limit_rot_to_45_deg: bool = True
|
633 |
+
append_statistics: bool = False
|
634 |
+
|
635 |
+
@nn.compact
|
636 |
+
def __call__(self, slots, inputs,
|
637 |
+
padding_mask = None,
|
638 |
+
train = False):
|
639 |
+
"""Slot Attention translation and scale equiv. module forward pass."""
|
640 |
+
del padding_mask # Unused.
|
641 |
+
|
642 |
+
if self.num_heads > 1:
|
643 |
+
raise NotImplementedError("This prototype only uses one attn. head.")
|
644 |
+
|
645 |
+
# Separate a concatenated linear coordinate grid from the inputs.
|
646 |
+
inputs, grid = inputs[Ellipsis, :-2], inputs[Ellipsis, -2:]
|
647 |
+
|
648 |
+
# Separate position (x,y) and scale from slot embeddings.
|
649 |
+
slots, positions, scales, rotm = (slots[Ellipsis, :-8],
|
650 |
+
slots[Ellipsis, -8: -6],
|
651 |
+
slots[Ellipsis, -6: -4],
|
652 |
+
slots[Ellipsis, -4:])
|
653 |
+
rotm = jnp.reshape(rotm, (*rotm.shape[:-1], 2, 2))
|
654 |
+
qkv_size = self.qkv_size or slots.shape[-1]
|
655 |
+
num_slots = slots.shape[-2]
|
656 |
+
|
657 |
+
# Prepare initial slot positions.
|
658 |
+
if self.zero_position_init:
|
659 |
+
# All slots start in the middle of the image.
|
660 |
+
positions *= 0.
|
661 |
+
|
662 |
+
if self.init_with_fixed_scale is not None:
|
663 |
+
scales = scales * 0. + self.init_with_fixed_scale
|
664 |
+
|
665 |
+
# Learnable initial positions and scales could have arbitrary values.
|
666 |
+
positions = jnp.clip(positions, -1., 1.)
|
667 |
+
scales = jnp.clip(scales, self.min_scale, self.max_scale)
|
668 |
+
|
669 |
+
# Pre-normalization.
|
670 |
+
inputs = nn.LayerNorm()(inputs)
|
671 |
+
|
672 |
+
grid_per_slot = jnp.repeat(
|
673 |
+
jnp.expand_dims(grid, axis=-3), num_slots, axis=-3)
|
674 |
+
|
675 |
+
# Shared modules.
|
676 |
+
dense_q = nn.Dense(qkv_size, use_bias=False, name="general_dense_q_0")
|
677 |
+
dense_k = nn.Dense(qkv_size, use_bias=False, name="general_dense_k_0")
|
678 |
+
dense_v = nn.Dense(qkv_size, use_bias=False, name="general_dense_v_0")
|
679 |
+
grid_proj = nn.Dense(qkv_size, name="dense_gp_0")
|
680 |
+
grid_enc = self.grid_encoder()
|
681 |
+
layernorm_q = nn.LayerNorm()
|
682 |
+
inverted_attention = InvertedDotProductAttentionKeyPerQuery(
|
683 |
+
epsilon=self.epsilon,
|
684 |
+
renormalize_keys=True,
|
685 |
+
softmax_temperature=self.softmax_temperature,
|
686 |
+
value_per_query=self.add_rel_pos_to_values
|
687 |
+
)
|
688 |
+
gru = misc.GRU()
|
689 |
+
|
690 |
+
if self.mlp_size is not None:
|
691 |
+
mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True) # type: ignore
|
692 |
+
|
693 |
+
if self.append_statistics:
|
694 |
+
embed_statistics = nn.Dense(slots.shape[-1], name="dense_embed_0")
|
695 |
+
|
696 |
+
# k.shape and v.shape = (..., n_inputs, slot_size).
|
697 |
+
v = dense_v(inputs)
|
698 |
+
k = dense_k(inputs)
|
699 |
+
k_expand = jnp.expand_dims(k, axis=-3)
|
700 |
+
v_expand = jnp.expand_dims(v, axis=-3)
|
701 |
+
|
702 |
+
# Multiple rounds of attention.
|
703 |
+
# Last iteration updates positions and scales only.
|
704 |
+
for attn_round in range(self.num_iterations + 1):
|
705 |
+
|
706 |
+
if self.ablate_non_equivariant:
|
707 |
+
# Add an encoded coordinate grid with absolute positions.
|
708 |
+
tmp_grid = grid_proj(grid_per_slot)
|
709 |
+
k_rel_pos = grid_enc(k_expand + tmp_grid)
|
710 |
+
if self.add_rel_pos_to_values:
|
711 |
+
v_rel_pos = grid_enc(v_expand + tmp_grid)
|
712 |
+
else:
|
713 |
+
# Relativize and scale positions, encode them and add them to inputs.
|
714 |
+
relative_grid = grid_per_slot - jnp.expand_dims(positions, axis=-2)
|
715 |
+
|
716 |
+
# Rotation.
|
717 |
+
relative_grid = self.transform(rotm, relative_grid)
|
718 |
+
|
719 |
+
# Scales are usually small so the grid might get too large.
|
720 |
+
relative_grid = relative_grid / self.scales_factor
|
721 |
+
relative_grid = relative_grid / jnp.expand_dims(scales, axis=-2)
|
722 |
+
tmp_grid = grid_proj(relative_grid)
|
723 |
+
k_rel_pos = grid_enc(k_expand + tmp_grid)
|
724 |
+
if self.add_rel_pos_to_values:
|
725 |
+
v_rel_pos = grid_enc(v_expand + tmp_grid)
|
726 |
+
|
727 |
+
# Inverted dot-product attention.
|
728 |
+
slots_n = layernorm_q(slots)
|
729 |
+
q = dense_q(slots_n) # q.shape = (..., n_slots, slot_size).
|
730 |
+
updates, attn = inverted_attention(
|
731 |
+
query=q,
|
732 |
+
key=k_rel_pos,
|
733 |
+
value=v_rel_pos if self.add_rel_pos_to_values else v,
|
734 |
+
train=train)
|
735 |
+
|
736 |
+
# Compute the center of mass of each slot attention mask.
|
737 |
+
positions = jnp.einsum("...qk,...kd->...qd", attn, grid)
|
738 |
+
|
739 |
+
# Find the axis with the highest spread.
|
740 |
+
relp = grid_per_slot - jnp.expand_dims(positions, axis=-2)
|
741 |
+
if self.limit_rot_to_45_deg:
|
742 |
+
rotm = self.compute_rotation_matrix_45_deg(relp, attn)
|
743 |
+
else:
|
744 |
+
rotm = self.compute_rotation_matrix_90_deg(relp, attn)
|
745 |
+
|
746 |
+
# Compute slot scales. Take the square root to make the operation
|
747 |
+
# analogous to normalizing data drawn from a Gaussian.
|
748 |
+
relp = self.transform(rotm, relp)
|
749 |
+
|
750 |
+
spread = jnp.square(relp)
|
751 |
+
scales = jnp.sqrt(
|
752 |
+
jnp.einsum("...qk,...qkd->...qd", attn + self.epsilon, spread))
|
753 |
+
|
754 |
+
# Computed positions are guaranteed to be in [-1, 1].
|
755 |
+
# Scales are unbounded.
|
756 |
+
scales = jnp.clip(scales, self.min_scale, self.max_scale)
|
757 |
+
|
758 |
+
if self.stop_grad_positions:
|
759 |
+
positions = jax.lax.stop_gradient(positions)
|
760 |
+
if self.stop_grad_scales:
|
761 |
+
scales = jax.lax.stop_gradient(scales)
|
762 |
+
if self.stop_grad_rotations:
|
763 |
+
rotm = jax.lax.stop_gradient(rotm)
|
764 |
+
|
765 |
+
if attn_round < self.num_iterations:
|
766 |
+
if self.append_statistics:
|
767 |
+
# For the slot rotations, we append both the 2D rotation matrix
|
768 |
+
# and the angle by which we rotate.
|
769 |
+
# We can compute the angle using atan2(R[0, 0], R[1, 0]).
|
770 |
+
tmp = jnp.concatenate(
|
771 |
+
[slots, positions, scales,
|
772 |
+
rotm.reshape(*rotm.shape[:-2], 4),
|
773 |
+
jnp.arctan2(rotm[Ellipsis, 0, 0], rotm[Ellipsis, 1, 0])[Ellipsis, None]],
|
774 |
+
axis=-1)
|
775 |
+
slots = embed_statistics(tmp)
|
776 |
+
|
777 |
+
# Recurrent update.
|
778 |
+
slots = gru(slots, updates)
|
779 |
+
|
780 |
+
# Feedforward block with pre-normalization.
|
781 |
+
if self.mlp_size is not None:
|
782 |
+
slots = mlp(slots)
|
783 |
+
|
784 |
+
# Concatenate position and scale information to slots.
|
785 |
+
output = jnp.concatenate(
|
786 |
+
[slots, positions, scales, rotm.reshape(*rotm.shape[:-2], 4)], axis=-1)
|
787 |
+
|
788 |
+
if self.mix_slots:
|
789 |
+
output = misc.MLP(hidden_size=128, layernorm="pre")(output)
|
790 |
+
|
791 |
+
return output
|
792 |
+
|
793 |
+
@classmethod
|
794 |
+
def compute_weighted_covariance(cls, x, w):
|
795 |
+
# The coordinate grid is (y, x), we want (x, y).
|
796 |
+
x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1)
|
797 |
+
|
798 |
+
# Pixel coordinates weighted by attention mask.
|
799 |
+
cov = x * w[Ellipsis, None]
|
800 |
+
cov = jnp.einsum(
|
801 |
+
"...ji,...jk->...ik", cov, x, precision=jax.lax.Precision.HIGHEST)
|
802 |
+
|
803 |
+
return cov
|
804 |
+
|
805 |
+
@classmethod
|
806 |
+
def compute_reference_frame_45_deg(cls, x, w):
|
807 |
+
cov = cls.compute_weighted_covariance(x, w)
|
808 |
+
|
809 |
+
# Compute eigenvalues.
|
810 |
+
pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) +
|
811 |
+
jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16)
|
812 |
+
|
813 |
+
eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2.
|
814 |
+
eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2.
|
815 |
+
|
816 |
+
# Compute eigenvectors, note that both have a positive y-axis.
|
817 |
+
# This means we have eliminated half of the possible rotations.
|
818 |
+
div = cov[Ellipsis, 0, 1] + 1e-16
|
819 |
+
|
820 |
+
v1 = (eig1 - cov[Ellipsis, 1, 1]) / div
|
821 |
+
v2 = (eig2 - cov[Ellipsis, 1, 1]) / div
|
822 |
+
|
823 |
+
v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1)
|
824 |
+
v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1)
|
825 |
+
|
826 |
+
# RULE 1:
|
827 |
+
# We catch two failure modes here.
|
828 |
+
# 1. If all attention weights are zero the covariance is also zero.
|
829 |
+
# Then the above computation is meaningless.
|
830 |
+
# 2. If the attention pattern is exactly aligned with the axes
|
831 |
+
# (e.g. a horizontal/vertical bar), the off-diagonal covariance
|
832 |
+
# values are going to be very low. If we use float32, we get
|
833 |
+
# basis vectors that are not orthogonal.
|
834 |
+
# Solution: use the default reference frame if the off-diagonal
|
835 |
+
# covariance value is too low.
|
836 |
+
default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1)
|
837 |
+
default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1)
|
838 |
+
|
839 |
+
mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None]
|
840 |
+
v1 = (1. - mask) * v1 + mask * default_1
|
841 |
+
v2 = (1. - mask) * v2 + mask * default_2
|
842 |
+
|
843 |
+
# Turn eigenvectors into unit vectors, so that we can construct
|
844 |
+
# a basis of a new reference frame.
|
845 |
+
norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True))
|
846 |
+
norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True))
|
847 |
+
|
848 |
+
v1 = v1 / norm1
|
849 |
+
v2 = v2 / norm2
|
850 |
+
|
851 |
+
# RULE 2:
|
852 |
+
# If the first basis vector is "pointing up" we assume the object
|
853 |
+
# is vertical (e.g. we say a door is vertical, whereas a car is horizontal).
|
854 |
+
# In the case of vertical objects, we swap the two basis vectors.
|
855 |
+
# This limits the possible rotations to +- 45deg instead of +- 90deg.
|
856 |
+
# We define "pointing up" as the first coordinate of the first basis vector
|
857 |
+
# being between +- sin(pi/4). The second coordinate is always positive.
|
858 |
+
mask = (jnp.logical_and(v1[Ellipsis, 0] < 0.707, v1[Ellipsis, 0] > -0.707)
|
859 |
+
).astype(jnp.float32)[Ellipsis, None]
|
860 |
+
v1_ = (1. - mask) * v1 + mask * v2
|
861 |
+
v2_ = (1. - mask) * v2 + mask * v1
|
862 |
+
v1 = v1_
|
863 |
+
v2 = v2_
|
864 |
+
|
865 |
+
# RULE 3:
|
866 |
+
# Mirror the first basis vector if the first coordinate is negative.
|
867 |
+
# Here, we ensure that our coordinate system is always left-handed.
|
868 |
+
# Otherwise, we would sometimes unintentionally mirror the grid.
|
869 |
+
mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None]
|
870 |
+
v1 = (1. - mask) * v1 - mask * v1
|
871 |
+
|
872 |
+
return v1, v2
|
873 |
+
|
874 |
+
@classmethod
|
875 |
+
def compute_reference_frame_90_deg(cls, x, w):
|
876 |
+
cov = cls.compute_weighted_covariance(x, w)
|
877 |
+
|
878 |
+
# Compute eigenvalues.
|
879 |
+
pm = jnp.sqrt(4. * jnp.square(cov[Ellipsis, 0, 1]) +
|
880 |
+
jnp.square(cov[Ellipsis, 0, 0] - cov[Ellipsis, 1, 1]) + 1e-16)
|
881 |
+
|
882 |
+
eig1 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] + pm) / 2.
|
883 |
+
eig2 = (cov[Ellipsis, 0, 0] + cov[Ellipsis, 1, 1] - pm) / 2.
|
884 |
+
|
885 |
+
# Compute eigenvectors, note that both have a positive y-axis.
|
886 |
+
# This means we have eliminated half of the possible rotations.
|
887 |
+
div = cov[Ellipsis, 0, 1] + 1e-16
|
888 |
+
|
889 |
+
v1 = (eig1 - cov[Ellipsis, 1, 1]) / div
|
890 |
+
v2 = (eig2 - cov[Ellipsis, 1, 1]) / div
|
891 |
+
|
892 |
+
v1 = jnp.stack([v1, jnp.ones_like(v1)], axis=-1)
|
893 |
+
v2 = jnp.stack([v2, jnp.ones_like(v2)], axis=-1)
|
894 |
+
|
895 |
+
# RULE 1:
|
896 |
+
# We catch two failure modes here.
|
897 |
+
# 1. If all attention weights are zero the covariance is also zero.
|
898 |
+
# Then the above computation is meaningless.
|
899 |
+
# 2. If the attention pattern is exactly aligned with the axes
|
900 |
+
# (e.g. a horizontal/vertical bar), the off-diagonal covariance
|
901 |
+
# values are going to be very low. If we use float32, we get
|
902 |
+
# basis vectors that are not orthogonal.
|
903 |
+
# Solution: use the default reference frame if the off-diagonal
|
904 |
+
# covariance value is too low.
|
905 |
+
default_1 = jnp.stack([jnp.ones_like(div), jnp.zeros_like(div)], axis=-1)
|
906 |
+
default_2 = jnp.stack([jnp.zeros_like(div), jnp.ones_like(div)], axis=-1)
|
907 |
+
|
908 |
+
# RULE 1.5:
|
909 |
+
# RULE 1 is activated if we see a vertical or a horizontal bar.
|
910 |
+
# We make sure that the coordinate grid for a horizontal bar is not rotated,
|
911 |
+
# whereas the coordinate grid for a vertical bar is rotated by 90deg.
|
912 |
+
# If cov[0, 0] > cov[1, 1], the bar is vertical.
|
913 |
+
mask = (cov[Ellipsis, 0, 0] <= cov[Ellipsis, 1, 1]).astype(jnp.float32)[Ellipsis, None]
|
914 |
+
# Furthermore, we have to mirror one of the basis vectors (if mask==1)
|
915 |
+
# so that we always have a left-handed coordinate grid.
|
916 |
+
default_v1 = (1. - mask) * default_1 - mask * default_2
|
917 |
+
default_v2 = (1. - mask) * default_2 + mask * default_1
|
918 |
+
|
919 |
+
# Continuation of RULE 1.
|
920 |
+
mask = (jnp.abs(div) < 1e-6).astype(jnp.float32)[Ellipsis, None]
|
921 |
+
v1 = mask * default_v1 + (1. - mask) * v1
|
922 |
+
v2 = mask * default_v2 + (1. - mask) * v2
|
923 |
+
|
924 |
+
# Turn eigenvectors into unit vectors, so that we can construct
|
925 |
+
# a basis of a new reference frame.
|
926 |
+
norm1 = jnp.sqrt(jnp.sum(jnp.square(v1), axis=-1, keepdims=True))
|
927 |
+
norm2 = jnp.sqrt(jnp.sum(jnp.square(v2), axis=-1, keepdims=True))
|
928 |
+
|
929 |
+
v1 = v1 / norm1
|
930 |
+
v2 = v2 / norm2
|
931 |
+
|
932 |
+
# RULE 2:
|
933 |
+
# Mirror the first basis vector if the first coordinate is negative.
|
934 |
+
# Here, we ensure that the our coordinate system is always left-handed.
|
935 |
+
# Otherwise, we would sometimes unintentionally mirror the grid.
|
936 |
+
mask = (v1[Ellipsis, 0] < 0).astype(jnp.float32)[Ellipsis, None]
|
937 |
+
v1 = (1. - mask) * v1 - mask * v1
|
938 |
+
|
939 |
+
return v1, v2
|
940 |
+
|
941 |
+
@classmethod
|
942 |
+
def compute_rotation_matrix_45_deg(cls, x, w):
|
943 |
+
v1, v2 = cls.compute_reference_frame_45_deg(x, w)
|
944 |
+
return jnp.stack([v1, v2], axis=-1)
|
945 |
+
|
946 |
+
@classmethod
|
947 |
+
def compute_rotation_matrix_90_deg(cls, x, w):
|
948 |
+
v1, v2 = cls.compute_reference_frame_90_deg(x, w)
|
949 |
+
return jnp.stack([v1, v2], axis=-1)
|
950 |
+
|
951 |
+
@classmethod
|
952 |
+
def transform(cls, rotm, x):
|
953 |
+
# The coordinate grid x is in the (y, x) format, so we need to swap
|
954 |
+
# the coordinates on the input and output.
|
955 |
+
x = jnp.stack([x[Ellipsis, 1], x[Ellipsis, 0]], axis=-1)
|
956 |
+
# Equivalent to inv(R) * x^T = R^T * x^T = (x * R)^T.
|
957 |
+
# We are multiplying by the inverse of the rotation matrix because
|
958 |
+
# we are rotating the coordinate grid *against* the rotation of the object.
|
959 |
+
# y = jnp.matmul(x, R)
|
960 |
+
y = jnp.einsum("...ij,...jk->...ik", x, rotm)
|
961 |
+
# Swap coordinates again.
|
962 |
+
y = jnp.stack([y[Ellipsis, 1], y[Ellipsis, 0]], axis=-1)
|
963 |
+
return y
|
invariant_slot_attention/modules/invariant_initializers.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Initializers module library for equivariant slot attention."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from invariant_slot_attention.lib import utils
|
25 |
+
|
26 |
+
Shape = Tuple[int]
|
27 |
+
|
28 |
+
DType = Any
|
29 |
+
Array = Any # jnp.ndarray
|
30 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
31 |
+
ProcessorState = ArrayTree
|
32 |
+
PRNGKey = Array
|
33 |
+
NestedDict = Dict[str, Any]
|
34 |
+
|
35 |
+
|
36 |
+
def get_uniform_initializer(vmin, vmax):
|
37 |
+
"""Get an uniform initializer with an arbitrary range."""
|
38 |
+
init = nn.initializers.uniform(scale=vmax - vmin)
|
39 |
+
|
40 |
+
def fn(*args, **kwargs):
|
41 |
+
return init(*args, **kwargs) + vmin
|
42 |
+
|
43 |
+
return fn
|
44 |
+
|
45 |
+
|
46 |
+
def get_normal_initializer(mean, sd):
|
47 |
+
"""Get a normal initializer with an arbitrary mean."""
|
48 |
+
init = nn.initializers.normal(stddev=sd)
|
49 |
+
|
50 |
+
def fn(*args, **kwargs):
|
51 |
+
return init(*args, **kwargs) + mean
|
52 |
+
|
53 |
+
return fn
|
54 |
+
|
55 |
+
|
56 |
+
class ParamStateInitRandomPositions(nn.Module):
|
57 |
+
"""Fixed, learnable state initalization with random positions.
|
58 |
+
|
59 |
+
Random slot positions sampled from U[-1, 1] are concatenated
|
60 |
+
as the last two dimensions.
|
61 |
+
Note: This module ignores any conditional input (by design).
|
62 |
+
"""
|
63 |
+
|
64 |
+
shape: Sequence[int]
|
65 |
+
init_fn: str = "normal" # Default init with unit variance.
|
66 |
+
conditioning_key: Optional[str] = None
|
67 |
+
slot_positions_min: float = -1.
|
68 |
+
slot_positions_max: float = 1.
|
69 |
+
|
70 |
+
@nn.compact
|
71 |
+
def __call__(self, inputs, batch_size,
|
72 |
+
train = False):
|
73 |
+
del inputs, train # Unused.
|
74 |
+
|
75 |
+
if self.init_fn == "normal":
|
76 |
+
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
|
77 |
+
elif self.init_fn == "zeros":
|
78 |
+
init_fn = lambda: nn.initializers.zeros
|
79 |
+
else:
|
80 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
81 |
+
|
82 |
+
param = self.param("state_init", init_fn(), self.shape)
|
83 |
+
|
84 |
+
out = utils.broadcast_across_batch(param, batch_size=batch_size)
|
85 |
+
shape = out.shape[:-1]
|
86 |
+
rng = self.make_rng("state_init")
|
87 |
+
slot_positions = jax.random.uniform(
|
88 |
+
rng, shape=[*shape, 2], minval=self.slot_positions_min,
|
89 |
+
maxval=self.slot_positions_max)
|
90 |
+
out = jnp.concatenate((out, slot_positions), axis=-1)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class ParamStateInitLearnablePositions(nn.Module):
|
95 |
+
"""Fixed, learnable state initalization with learnable positions.
|
96 |
+
|
97 |
+
Learnable initial positions are concatenated at the end of slots.
|
98 |
+
Note: This module ignores any conditional input (by design).
|
99 |
+
"""
|
100 |
+
|
101 |
+
shape: Sequence[int]
|
102 |
+
init_fn: str = "normal" # Default init with unit variance.
|
103 |
+
conditioning_key: Optional[str] = None
|
104 |
+
slot_positions_min: float = -1.
|
105 |
+
slot_positions_max: float = 1.
|
106 |
+
|
107 |
+
@nn.compact
|
108 |
+
def __call__(self, inputs, batch_size,
|
109 |
+
train = False):
|
110 |
+
del inputs, train # Unused.
|
111 |
+
|
112 |
+
if self.init_fn == "normal":
|
113 |
+
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
|
114 |
+
elif self.init_fn == "zeros":
|
115 |
+
init_fn_state = lambda: nn.initializers.zeros
|
116 |
+
else:
|
117 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
118 |
+
|
119 |
+
init_fn_state = init_fn_state()
|
120 |
+
init_fn_pos = get_uniform_initializer(
|
121 |
+
self.slot_positions_min, self.slot_positions_max)
|
122 |
+
|
123 |
+
param_state = self.param("state_init", init_fn_state, self.shape)
|
124 |
+
param_pos = self.param(
|
125 |
+
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
|
126 |
+
|
127 |
+
param = jnp.concatenate((param_state, param_pos), axis=-1)
|
128 |
+
|
129 |
+
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
|
130 |
+
|
131 |
+
|
132 |
+
class ParamStateInitRandomPositionsScales(nn.Module):
|
133 |
+
"""Fixed, learnable state initalization with random positions and scales.
|
134 |
+
|
135 |
+
Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
|
136 |
+
are concatenated as the last four dimensions.
|
137 |
+
Note: This module ignores any conditional input (by design).
|
138 |
+
"""
|
139 |
+
|
140 |
+
shape: Sequence[int]
|
141 |
+
init_fn: str = "normal" # Default init with unit variance.
|
142 |
+
conditioning_key: Optional[str] = None
|
143 |
+
slot_positions_min: float = -1.
|
144 |
+
slot_positions_max: float = 1.
|
145 |
+
slot_scales_mean: float = 0.1
|
146 |
+
slot_scales_sd: float = 0.1
|
147 |
+
|
148 |
+
@nn.compact
|
149 |
+
def __call__(self, inputs, batch_size,
|
150 |
+
train = False):
|
151 |
+
del inputs, train # Unused.
|
152 |
+
|
153 |
+
if self.init_fn == "normal":
|
154 |
+
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
|
155 |
+
elif self.init_fn == "zeros":
|
156 |
+
init_fn = lambda: nn.initializers.zeros
|
157 |
+
else:
|
158 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
159 |
+
|
160 |
+
param = self.param("state_init", init_fn(), self.shape)
|
161 |
+
|
162 |
+
out = utils.broadcast_across_batch(param, batch_size=batch_size)
|
163 |
+
shape = out.shape[:-1]
|
164 |
+
rng = self.make_rng("state_init")
|
165 |
+
slot_positions = jax.random.uniform(
|
166 |
+
rng, shape=[*shape, 2], minval=self.slot_positions_min,
|
167 |
+
maxval=self.slot_positions_max)
|
168 |
+
slot_scales = jax.random.normal(rng, shape=[*shape, 2])
|
169 |
+
slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
|
170 |
+
out = jnp.concatenate((out, slot_positions, slot_scales), axis=-1)
|
171 |
+
return out
|
172 |
+
|
173 |
+
|
174 |
+
class ParamStateInitLearnablePositionsScales(nn.Module):
|
175 |
+
"""Fixed, learnable state initalization with random positions and scales.
|
176 |
+
|
177 |
+
Lernable initial positions and scales are concatenated at the end of slots.
|
178 |
+
Note: This module ignores any conditional input (by design).
|
179 |
+
"""
|
180 |
+
|
181 |
+
shape: Sequence[int]
|
182 |
+
init_fn: str = "normal" # Default init with unit variance.
|
183 |
+
conditioning_key: Optional[str] = None
|
184 |
+
slot_positions_min: float = -1.
|
185 |
+
slot_positions_max: float = 1.
|
186 |
+
slot_scales_mean: float = 0.1
|
187 |
+
slot_scales_sd: float = 0.01
|
188 |
+
|
189 |
+
@nn.compact
|
190 |
+
def __call__(self, inputs, batch_size,
|
191 |
+
train = False):
|
192 |
+
del inputs, train # Unused.
|
193 |
+
|
194 |
+
if self.init_fn == "normal":
|
195 |
+
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
|
196 |
+
elif self.init_fn == "zeros":
|
197 |
+
init_fn_state = lambda: nn.initializers.zeros
|
198 |
+
else:
|
199 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
200 |
+
|
201 |
+
init_fn_state = init_fn_state()
|
202 |
+
init_fn_pos = get_uniform_initializer(
|
203 |
+
self.slot_positions_min, self.slot_positions_max)
|
204 |
+
init_fn_scales = get_normal_initializer(
|
205 |
+
self.slot_scales_mean, self.slot_scales_sd)
|
206 |
+
|
207 |
+
param_state = self.param("state_init", init_fn_state, self.shape)
|
208 |
+
param_pos = self.param(
|
209 |
+
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
|
210 |
+
param_scales = self.param(
|
211 |
+
"state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
|
212 |
+
|
213 |
+
param = jnp.concatenate((param_state, param_pos, param_scales), axis=-1)
|
214 |
+
|
215 |
+
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
|
216 |
+
|
217 |
+
|
218 |
+
class ParamStateInitLearnablePositionsRotationsScales(nn.Module):
|
219 |
+
"""Fixed, learnable state initalization.
|
220 |
+
|
221 |
+
Learnable initial positions, rotations and scales are concatenated
|
222 |
+
at the end of slots. The rotation matrix is flattened.
|
223 |
+
Note: This module ignores any conditional input (by design).
|
224 |
+
"""
|
225 |
+
|
226 |
+
shape: Sequence[int]
|
227 |
+
init_fn: str = "normal" # Default init with unit variance.
|
228 |
+
conditioning_key: Optional[str] = None
|
229 |
+
slot_positions_min: float = -1.
|
230 |
+
slot_positions_max: float = 1.
|
231 |
+
slot_scales_mean: float = 0.1
|
232 |
+
slot_scales_sd: float = 0.01
|
233 |
+
slot_angles_mean: float = 0.
|
234 |
+
slot_angles_sd: float = 0.1
|
235 |
+
|
236 |
+
@nn.compact
|
237 |
+
def __call__(self, inputs, batch_size,
|
238 |
+
train = False):
|
239 |
+
del inputs, train # Unused.
|
240 |
+
|
241 |
+
if self.init_fn == "normal":
|
242 |
+
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
|
243 |
+
elif self.init_fn == "zeros":
|
244 |
+
init_fn_state = lambda: nn.initializers.zeros
|
245 |
+
else:
|
246 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
247 |
+
|
248 |
+
init_fn_state = init_fn_state()
|
249 |
+
init_fn_pos = get_uniform_initializer(
|
250 |
+
self.slot_positions_min, self.slot_positions_max)
|
251 |
+
init_fn_scales = get_normal_initializer(
|
252 |
+
self.slot_scales_mean, self.slot_scales_sd)
|
253 |
+
init_fn_angles = get_normal_initializer(
|
254 |
+
self.slot_angles_mean, self.slot_angles_sd)
|
255 |
+
|
256 |
+
param_state = self.param("state_init", init_fn_state, self.shape)
|
257 |
+
param_pos = self.param(
|
258 |
+
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
|
259 |
+
param_scales = self.param(
|
260 |
+
"state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
|
261 |
+
param_angles = self.param(
|
262 |
+
"state_init_angles", init_fn_angles, (*self.shape[:-1], 1))
|
263 |
+
|
264 |
+
# Initial angles in the range of (-pi / 4, pi / 4) <=> (-45, 45) degrees.
|
265 |
+
angles = jnp.tanh(param_angles) * (jnp.pi / 4)
|
266 |
+
rotm = jnp.concatenate(
|
267 |
+
[jnp.cos(angles), jnp.sin(angles),
|
268 |
+
-jnp.sin(angles), jnp.cos(angles)], axis=-1)
|
269 |
+
|
270 |
+
param = jnp.concatenate(
|
271 |
+
(param_state, param_pos, param_scales, rotm), axis=-1)
|
272 |
+
|
273 |
+
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
|
274 |
+
|
275 |
+
|
276 |
+
class ParamStateInitRandomPositionsRotationsScales(nn.Module):
|
277 |
+
"""Fixed, learnable state initialization with random pos., rot. and scales.
|
278 |
+
|
279 |
+
Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
|
280 |
+
are concatenated as the last four dimensions. Rotations are sampled
|
281 |
+
from +- 45 degrees.
|
282 |
+
Note: This module ignores any conditional input (by design).
|
283 |
+
"""
|
284 |
+
|
285 |
+
shape: Sequence[int]
|
286 |
+
init_fn: str = "normal" # Default init with unit variance.
|
287 |
+
conditioning_key: Optional[str] = None
|
288 |
+
slot_positions_min: float = -1.
|
289 |
+
slot_positions_max: float = 1.
|
290 |
+
slot_scales_mean: float = 0.1
|
291 |
+
slot_scales_sd: float = 0.1
|
292 |
+
slot_angles_min: float = -jnp.pi / 4.
|
293 |
+
slot_angles_max: float = jnp.pi / 4.
|
294 |
+
|
295 |
+
@nn.compact
|
296 |
+
def __call__(self, inputs, batch_size,
|
297 |
+
train = False):
|
298 |
+
del inputs, train # Unused.
|
299 |
+
|
300 |
+
if self.init_fn == "normal":
|
301 |
+
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
|
302 |
+
elif self.init_fn == "zeros":
|
303 |
+
init_fn = lambda: nn.initializers.zeros
|
304 |
+
else:
|
305 |
+
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
|
306 |
+
|
307 |
+
param = self.param("state_init", init_fn(), self.shape)
|
308 |
+
|
309 |
+
out = utils.broadcast_across_batch(param, batch_size=batch_size)
|
310 |
+
shape = out.shape[:-1]
|
311 |
+
rng = self.make_rng("state_init")
|
312 |
+
slot_positions = jax.random.uniform(
|
313 |
+
rng, shape=[*shape, 2], minval=self.slot_positions_min,
|
314 |
+
maxval=self.slot_positions_max)
|
315 |
+
rng = self.make_rng("state_init")
|
316 |
+
slot_scales = jax.random.normal(rng, shape=[*shape, 2])
|
317 |
+
slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
|
318 |
+
rng = self.make_rng("state_init")
|
319 |
+
slot_angles = jax.random.uniform(rng, shape=[*shape, 1])
|
320 |
+
slot_angles = (slot_angles * (self.slot_angles_max - self.slot_angles_min)
|
321 |
+
) + self.slot_angles_min
|
322 |
+
slot_rotm = jnp.concatenate(
|
323 |
+
[jnp.cos(slot_angles), jnp.sin(slot_angles),
|
324 |
+
-jnp.sin(slot_angles), jnp.cos(slot_angles)], axis=-1)
|
325 |
+
out = jnp.concatenate(
|
326 |
+
(out, slot_positions, slot_scales, slot_rotm), axis=-1)
|
327 |
+
return out
|
invariant_slot_attention/modules/misc.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Miscellaneous modules."""
|
17 |
+
|
18 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
19 |
+
|
20 |
+
from flax import linen as nn
|
21 |
+
import jax
|
22 |
+
import jax.numpy as jnp
|
23 |
+
|
24 |
+
from invariant_slot_attention.lib import utils
|
25 |
+
|
26 |
+
Shape = Tuple[int]
|
27 |
+
|
28 |
+
DType = Any
|
29 |
+
Array = Any # jnp.ndarray
|
30 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
31 |
+
ProcessorState = ArrayTree
|
32 |
+
PRNGKey = Array
|
33 |
+
NestedDict = Dict[str, Any]
|
34 |
+
|
35 |
+
|
36 |
+
class Identity(nn.Module):
|
37 |
+
"""Module that applies the identity function, ignoring any additional args."""
|
38 |
+
|
39 |
+
@nn.compact
|
40 |
+
def __call__(self, inputs, **args):
|
41 |
+
return inputs
|
42 |
+
|
43 |
+
|
44 |
+
class Readout(nn.Module):
|
45 |
+
"""Module for reading out multiple targets from an embedding."""
|
46 |
+
|
47 |
+
keys: Sequence[str]
|
48 |
+
readout_modules: Sequence[Callable[[], nn.Module]]
|
49 |
+
stop_gradient: Optional[Sequence[bool]] = None
|
50 |
+
|
51 |
+
@nn.compact
|
52 |
+
def __call__(self, inputs, train = False):
|
53 |
+
num_targets = len(self.keys)
|
54 |
+
assert num_targets >= 1, "Need to have at least one target."
|
55 |
+
assert len(self.readout_modules) == num_targets, (
|
56 |
+
"len(modules) and len(keys) must match.")
|
57 |
+
if self.stop_gradient is not None:
|
58 |
+
assert len(self.stop_gradient) == num_targets, (
|
59 |
+
"len(stop_gradient) and len(keys) must match.")
|
60 |
+
outputs = {}
|
61 |
+
for i in range(num_targets):
|
62 |
+
if self.stop_gradient is not None and self.stop_gradient[i]:
|
63 |
+
x = jax.lax.stop_gradient(inputs)
|
64 |
+
else:
|
65 |
+
x = inputs
|
66 |
+
outputs[self.keys[i]] = self.readout_modules[i]()(x, train) # pytype: disable=not-callable
|
67 |
+
return outputs
|
68 |
+
|
69 |
+
|
70 |
+
class MLP(nn.Module):
|
71 |
+
"""Simple MLP with one hidden layer and optional pre-/post-layernorm."""
|
72 |
+
|
73 |
+
hidden_size: int
|
74 |
+
output_size: Optional[int] = None
|
75 |
+
num_hidden_layers: int = 1
|
76 |
+
activation_fn: Callable[[Array], Array] = nn.relu
|
77 |
+
layernorm: Optional[str] = None
|
78 |
+
activate_output: bool = False
|
79 |
+
residual: bool = False
|
80 |
+
|
81 |
+
@nn.compact
|
82 |
+
def __call__(self, inputs, train = False):
|
83 |
+
del train # Unused.
|
84 |
+
|
85 |
+
output_size = self.output_size or inputs.shape[-1]
|
86 |
+
|
87 |
+
x = inputs
|
88 |
+
|
89 |
+
if self.layernorm == "pre":
|
90 |
+
x = nn.LayerNorm()(x)
|
91 |
+
|
92 |
+
for i in range(self.num_hidden_layers):
|
93 |
+
x = nn.Dense(self.hidden_size, name=f"dense_mlp_{i}")(x)
|
94 |
+
x = self.activation_fn(x)
|
95 |
+
x = nn.Dense(output_size, name=f"dense_mlp_{self.num_hidden_layers}")(x)
|
96 |
+
|
97 |
+
if self.activate_output:
|
98 |
+
x = self.activation_fn(x)
|
99 |
+
|
100 |
+
if self.residual:
|
101 |
+
x = x + inputs
|
102 |
+
|
103 |
+
if self.layernorm == "post":
|
104 |
+
x = nn.LayerNorm()(x)
|
105 |
+
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class GRU(nn.Module):
|
110 |
+
"""GRU cell as nn.Module."""
|
111 |
+
|
112 |
+
@nn.compact
|
113 |
+
def __call__(self, carry, inputs,
|
114 |
+
train = False):
|
115 |
+
del train # Unused.
|
116 |
+
carry, _ = nn.GRUCell()(carry, inputs)
|
117 |
+
return carry
|
118 |
+
|
119 |
+
|
120 |
+
class Dense(nn.Module):
|
121 |
+
"""Dense layer as nn.Module accepting "train" flag."""
|
122 |
+
|
123 |
+
features: int
|
124 |
+
use_bias: bool = True
|
125 |
+
|
126 |
+
@nn.compact
|
127 |
+
def __call__(self, inputs, train = False):
|
128 |
+
del train # Unused.
|
129 |
+
return nn.Dense(features=self.features, use_bias=self.use_bias)(inputs)
|
130 |
+
|
131 |
+
|
132 |
+
class PositionEmbedding(nn.Module):
|
133 |
+
"""A module for applying N-dimensional position embedding.
|
134 |
+
|
135 |
+
Attr:
|
136 |
+
embedding_type: A string defining the type of position embedding to use. One
|
137 |
+
of ["linear", "discrete_1d", "fourier", "gaussian_fourier"].
|
138 |
+
update_type: A string defining how the input is updated with the position
|
139 |
+
embedding. One of ["proj_add", "concat"].
|
140 |
+
num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
|
141 |
+
"fourier", the embedding dimensionality is 2 x number of position
|
142 |
+
dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
|
143 |
+
the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
|
144 |
+
== "linear", this parameter is ignored.
|
145 |
+
gaussian_sigma: Standard deviation of sampled Gaussians.
|
146 |
+
pos_transform: Optional transform for the embedding.
|
147 |
+
output_transform: Optional transform for the combined input and embedding.
|
148 |
+
trainable_pos_embedding: Boolean flag for allowing gradients to flow into
|
149 |
+
the position embedding, so that the optimizer can update it.
|
150 |
+
"""
|
151 |
+
|
152 |
+
embedding_type: str
|
153 |
+
update_type: str
|
154 |
+
num_fourier_bases: int = 0
|
155 |
+
gaussian_sigma: float = 1.0
|
156 |
+
pos_transform: Callable[[], nn.Module] = Identity
|
157 |
+
output_transform: Callable[[], nn.Module] = Identity
|
158 |
+
trainable_pos_embedding: bool = False
|
159 |
+
|
160 |
+
def _make_pos_embedding_tensor(self, rng, input_shape):
|
161 |
+
if self.embedding_type == "discrete_1d":
|
162 |
+
# An integer tensor in [0, input_shape[-2]-1] reflecting
|
163 |
+
# 1D discrete position encoding (encode the second-to-last axis).
|
164 |
+
pos_embedding = jnp.broadcast_to(
|
165 |
+
jnp.arange(input_shape[-2]), input_shape[1:-1])
|
166 |
+
else:
|
167 |
+
# A tensor grid in [-1, +1] for each input dimension.
|
168 |
+
pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])
|
169 |
+
|
170 |
+
if self.embedding_type == "linear":
|
171 |
+
pass
|
172 |
+
elif self.embedding_type == "discrete_1d":
|
173 |
+
pos_embedding = jax.nn.one_hot(pos_embedding, input_shape[-2])
|
174 |
+
elif self.embedding_type == "fourier":
|
175 |
+
# NeRF-style Fourier/sinusoidal position encoding.
|
176 |
+
pos_embedding = utils.convert_to_fourier_features(
|
177 |
+
pos_embedding * jnp.pi, basis_degree=self.num_fourier_bases)
|
178 |
+
elif self.embedding_type == "gaussian_fourier":
|
179 |
+
# Gaussian Fourier features. Reference: https://arxiv.org/abs/2006.10739
|
180 |
+
num_dims = pos_embedding.shape[-1]
|
181 |
+
projection = jax.random.normal(
|
182 |
+
rng, [num_dims, self.num_fourier_bases]) * self.gaussian_sigma
|
183 |
+
pos_embedding = jnp.pi * pos_embedding.dot(projection)
|
184 |
+
# A slightly faster implementation of sin and cos.
|
185 |
+
pos_embedding = jnp.sin(
|
186 |
+
jnp.concatenate([pos_embedding, pos_embedding + 0.5 * jnp.pi],
|
187 |
+
axis=-1))
|
188 |
+
else:
|
189 |
+
raise ValueError("Invalid embedding type provided.")
|
190 |
+
|
191 |
+
# Add batch dimension.
|
192 |
+
pos_embedding = jnp.expand_dims(pos_embedding, axis=0)
|
193 |
+
|
194 |
+
return pos_embedding
|
195 |
+
|
196 |
+
@nn.compact
|
197 |
+
def __call__(self, inputs):
|
198 |
+
|
199 |
+
# Compute the position embedding only in the initial call use the same rng
|
200 |
+
# as is used for initializing learnable parameters.
|
201 |
+
pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
|
202 |
+
inputs.shape)
|
203 |
+
|
204 |
+
if not self.trainable_pos_embedding:
|
205 |
+
pos_embedding = jax.lax.stop_gradient(pos_embedding)
|
206 |
+
|
207 |
+
# Apply optional transformation on the position embedding.
|
208 |
+
pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable
|
209 |
+
|
210 |
+
# Apply position encoding to inputs.
|
211 |
+
if self.update_type == "project_add":
|
212 |
+
# Here, we project the position encodings to the same dimensionality as
|
213 |
+
# the inputs and add them to the inputs (broadcast along batch dimension).
|
214 |
+
# This is roughly equivalent to concatenation of position encodings to the
|
215 |
+
# inputs (if followed by a Dense layer), but is slightly more efficient.
|
216 |
+
n_features = inputs.shape[-1]
|
217 |
+
x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
|
218 |
+
elif self.update_type == "concat":
|
219 |
+
# Repeat the position embedding along the first (batch) dimension.
|
220 |
+
pos_embedding = jnp.broadcast_to(
|
221 |
+
pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
|
222 |
+
# concatenate along the channel dimension.
|
223 |
+
x = jnp.concatenate((inputs, pos_embedding), axis=-1)
|
224 |
+
else:
|
225 |
+
raise ValueError("Invalid update type provided.")
|
226 |
+
|
227 |
+
# Apply optional output transformation.
|
228 |
+
x = self.output_transform()(x) # pytype: disable=not-callable
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
class RelativePositionEmbedding(nn.Module):
|
233 |
+
"""A module for applying embedding of input position relative to slots.
|
234 |
+
|
235 |
+
Attr
|
236 |
+
update_type: A string defining how the input is updated with the position
|
237 |
+
embedding. One of ["proj_add", "concat"].
|
238 |
+
embedding_type: A string defining the type of position embedding to use.
|
239 |
+
Currently only "linear" is supported.
|
240 |
+
num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
|
241 |
+
"fourier", the embedding dimensionality is 2 x number of position
|
242 |
+
dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
|
243 |
+
the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
|
244 |
+
== "linear", this parameter is ignored.
|
245 |
+
gaussian_sigma: Standard deviation of sampled Gaussians.
|
246 |
+
pos_transform: Optional transform for the embedding.
|
247 |
+
output_transform: Optional transform for the combined input and embedding.
|
248 |
+
trainable_pos_embedding: Boolean flag for allowing gradients to flow into
|
249 |
+
the position embedding, so that the optimizer can update it.
|
250 |
+
"""
|
251 |
+
|
252 |
+
update_type: str
|
253 |
+
embedding_type: str = "linear"
|
254 |
+
num_fourier_bases: int = 0
|
255 |
+
gaussian_sigma: float = 1.0
|
256 |
+
pos_transform: Callable[[], nn.Module] = Identity
|
257 |
+
output_transform: Callable[[], nn.Module] = Identity
|
258 |
+
trainable_pos_embedding: bool = False
|
259 |
+
scales_factor: float = 1.0
|
260 |
+
|
261 |
+
def _make_pos_embedding_tensor(self, rng, input_shape):
|
262 |
+
|
263 |
+
# A tensor grid in [-1, +1] for each input dimension.
|
264 |
+
pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])
|
265 |
+
|
266 |
+
# Add batch dimension.
|
267 |
+
pos_embedding = jnp.expand_dims(pos_embedding, axis=0)
|
268 |
+
|
269 |
+
return pos_embedding
|
270 |
+
|
271 |
+
@nn.compact
|
272 |
+
def __call__(self, inputs, slot_positions,
|
273 |
+
slot_scales = None,
|
274 |
+
slot_rotm = None):
|
275 |
+
|
276 |
+
# Compute the position embedding only in the initial call use the same rng
|
277 |
+
# as is used for initializing learnable parameters.
|
278 |
+
pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
|
279 |
+
inputs.shape)
|
280 |
+
|
281 |
+
if not self.trainable_pos_embedding:
|
282 |
+
pos_embedding = jax.lax.stop_gradient(pos_embedding)
|
283 |
+
|
284 |
+
# Relativize pos_embedding with respect to slot positions
|
285 |
+
# and optionally slot scales.
|
286 |
+
slot_positions = jnp.expand_dims(
|
287 |
+
jnp.expand_dims(slot_positions, axis=-2), axis=-2)
|
288 |
+
if slot_scales is not None:
|
289 |
+
slot_scales = jnp.expand_dims(
|
290 |
+
jnp.expand_dims(slot_scales, axis=-2), axis=-2)
|
291 |
+
|
292 |
+
if self.embedding_type == "linear":
|
293 |
+
pos_embedding = pos_embedding - slot_positions
|
294 |
+
if slot_rotm is not None:
|
295 |
+
pos_embedding = self.transform(slot_rotm, pos_embedding)
|
296 |
+
if slot_scales is not None:
|
297 |
+
# Scales are usually small so the grid might get too large.
|
298 |
+
pos_embedding = pos_embedding / self.scales_factor
|
299 |
+
pos_embedding = pos_embedding / slot_scales
|
300 |
+
else:
|
301 |
+
raise ValueError("Invalid embedding type provided.")
|
302 |
+
|
303 |
+
# Apply optional transformation on the position embedding.
|
304 |
+
pos_embedding = self.pos_transform()(pos_embedding) # pytype: disable=not-callable
|
305 |
+
|
306 |
+
# Define intermediate for logging.
|
307 |
+
pos_embedding = Identity(name="pos_emb")(pos_embedding)
|
308 |
+
|
309 |
+
# Apply position encoding to inputs.
|
310 |
+
if self.update_type == "project_add":
|
311 |
+
# Here, we project the position encodings to the same dimensionality as
|
312 |
+
# the inputs and add them to the inputs (broadcast along batch dimension).
|
313 |
+
# This is roughly equivalent to concatenation of position encodings to the
|
314 |
+
# inputs (if followed by a Dense layer), but is slightly more efficient.
|
315 |
+
n_features = inputs.shape[-1]
|
316 |
+
x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
|
317 |
+
elif self.update_type == "concat":
|
318 |
+
# Repeat the position embedding along the first (batch) dimension.
|
319 |
+
pos_embedding = jnp.broadcast_to(
|
320 |
+
pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
|
321 |
+
# concatenate along the channel dimension.
|
322 |
+
x = jnp.concatenate((inputs, pos_embedding), axis=-1)
|
323 |
+
else:
|
324 |
+
raise ValueError("Invalid update type provided.")
|
325 |
+
|
326 |
+
# Apply optional output transformation.
|
327 |
+
x = self.output_transform()(x) # pytype: disable=not-callable
|
328 |
+
return x
|
329 |
+
|
330 |
+
@classmethod
|
331 |
+
def transform(cls, rot, coords):
|
332 |
+
# The coordinate grid coords is in the (y, x) format, so we need to swap
|
333 |
+
# the coordinates on the input and output.
|
334 |
+
coords = jnp.stack([coords[Ellipsis, 1], coords[Ellipsis, 0]], axis=-1)
|
335 |
+
# Equivalent to inv(R) * coords^T = R^T * coords^T = (coords * R)^T.
|
336 |
+
# We are multiplying by the inverse of the rotation matrix because
|
337 |
+
# we are rotating the coordinate grid *against* the rotation of the object.
|
338 |
+
new_coords = jnp.einsum("...hij,...jk->...hik", coords, rot)
|
339 |
+
# Swap coordinates again.
|
340 |
+
return jnp.stack([new_coords[Ellipsis, 1], new_coords[Ellipsis, 0]], axis=-1)
|
invariant_slot_attention/modules/resnet.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Implementation of ResNet V1 in Flax.
|
17 |
+
|
18 |
+
"Deep Residual Learning for Image Recognition"
|
19 |
+
He et al., 2015, [https://arxiv.org/abs/1512.03385]
|
20 |
+
"""
|
21 |
+
|
22 |
+
import functools
|
23 |
+
|
24 |
+
from typing import Any, Tuple, Type, List, Optional, Callable, Sequence
|
25 |
+
import flax.linen as nn
|
26 |
+
import jax.numpy as jnp
|
27 |
+
|
28 |
+
|
29 |
+
Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False)
|
30 |
+
Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False)
|
31 |
+
|
32 |
+
|
33 |
+
class ResNetBlock(nn.Module):
|
34 |
+
"""ResNet block without bottleneck used in ResNet-18 and ResNet-34."""
|
35 |
+
|
36 |
+
filters: int
|
37 |
+
norm: Any
|
38 |
+
kernel_dilation: Tuple[int, int] = (1, 1)
|
39 |
+
strides: Tuple[int, int] = (1, 1)
|
40 |
+
|
41 |
+
@nn.compact
|
42 |
+
def __call__(self, x):
|
43 |
+
residual = x
|
44 |
+
|
45 |
+
x = Conv3x3(
|
46 |
+
self.filters,
|
47 |
+
strides=self.strides,
|
48 |
+
kernel_dilation=self.kernel_dilation,
|
49 |
+
name="conv1")(x)
|
50 |
+
x = self.norm(name="bn1")(x)
|
51 |
+
x = nn.relu(x)
|
52 |
+
x = Conv3x3(self.filters, name="conv2")(x)
|
53 |
+
# Initializing the scale to 0 has been common practice since "Fixup
|
54 |
+
# Initialization: Residual Learning Without Normalization" Tengyu et al,
|
55 |
+
# 2019, [https://openreview.net/forum?id=H1gsz30cKX].
|
56 |
+
x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x)
|
57 |
+
|
58 |
+
if residual.shape != x.shape:
|
59 |
+
residual = Conv1x1(
|
60 |
+
self.filters, strides=self.strides, name="proj_conv")(
|
61 |
+
residual)
|
62 |
+
residual = self.norm(name="proj_bn")(residual)
|
63 |
+
|
64 |
+
x = nn.relu(residual + x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class BottleneckResNetBlock(ResNetBlock):
|
69 |
+
"""Bottleneck ResNet block used in ResNet-50 and larger."""
|
70 |
+
|
71 |
+
@nn.compact
|
72 |
+
def __call__(self, x):
|
73 |
+
residual = x
|
74 |
+
|
75 |
+
x = Conv1x1(self.filters, name="conv1")(x)
|
76 |
+
x = self.norm(name="bn1")(x)
|
77 |
+
x = nn.relu(x)
|
78 |
+
x = Conv3x3(
|
79 |
+
self.filters,
|
80 |
+
strides=self.strides,
|
81 |
+
kernel_dilation=self.kernel_dilation,
|
82 |
+
name="conv2")(x)
|
83 |
+
x = self.norm(name="bn2")(x)
|
84 |
+
x = nn.relu(x)
|
85 |
+
x = Conv1x1(4 * self.filters, name="conv3")(x)
|
86 |
+
# Initializing the scale to 0 has been common practice since "Fixup
|
87 |
+
# Initialization: Residual Learning Without Normalization" Tengyu et al,
|
88 |
+
# 2019, [https://openreview.net/forum?id=H1gsz30cKX].
|
89 |
+
x = self.norm(name="bn3")(x)
|
90 |
+
|
91 |
+
if residual.shape != x.shape:
|
92 |
+
residual = Conv1x1(
|
93 |
+
4 * self.filters, strides=self.strides, name="proj_conv")(
|
94 |
+
residual)
|
95 |
+
residual = self.norm(name="proj_bn")(residual)
|
96 |
+
|
97 |
+
x = nn.relu(residual + x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class ResNetStage(nn.Module):
|
102 |
+
"""ResNet stage consistent of multiple ResNet blocks."""
|
103 |
+
|
104 |
+
stage_size: int
|
105 |
+
filters: int
|
106 |
+
block_cls: Type[ResNetBlock]
|
107 |
+
norm: Any
|
108 |
+
first_block_strides: Tuple[int, int]
|
109 |
+
|
110 |
+
@nn.compact
|
111 |
+
def __call__(self, x):
|
112 |
+
for i in range(self.stage_size):
|
113 |
+
x = self.block_cls(
|
114 |
+
filters=self.filters,
|
115 |
+
norm=self.norm,
|
116 |
+
strides=self.first_block_strides if i == 0 else (1, 1),
|
117 |
+
name=f"block{i + 1}")(
|
118 |
+
x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
class ResNet(nn.Module):
|
123 |
+
"""Construct ResNet V1 with `num_classes` outputs.
|
124 |
+
|
125 |
+
Attributes:
|
126 |
+
num_classes: Number of nodes in the final layer.
|
127 |
+
block_cls: Class for the blocks. ResNet-50 and larger use
|
128 |
+
`BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
|
129 |
+
ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
|
130 |
+
stage_sizes: List with the number of ResNet blocks in each stage. Number of
|
131 |
+
stages can be varied.
|
132 |
+
norm_type: Which type of normalization layer to apply. Options are:
|
133 |
+
"batch": BatchNorm, "group": GroupNorm, "layer": LayerNorm. Defaults to
|
134 |
+
BatchNorm.
|
135 |
+
width_factor: Factor applied to the number of filters. The 64 * width_factor
|
136 |
+
is the number of filters in the first stage, every consecutive stage
|
137 |
+
doubles the number of filters.
|
138 |
+
small_inputs: Bool, if True, ignore strides and skip max pooling in the root
|
139 |
+
block and use smaller filter size.
|
140 |
+
stage_strides: Stride per stage. This overrides all other arguments.
|
141 |
+
include_top: Whether to include the fully-connected layer at the top
|
142 |
+
of the network.
|
143 |
+
axis_name: Axis name over which to aggregate batchnorm statistics.
|
144 |
+
"""
|
145 |
+
num_classes: int
|
146 |
+
block_cls: Type[ResNetBlock]
|
147 |
+
stage_sizes: List[int]
|
148 |
+
norm_type: str = "batch"
|
149 |
+
width_factor: int = 1
|
150 |
+
small_inputs: bool = False
|
151 |
+
stage_strides: Optional[List[Tuple[int, int]]] = None
|
152 |
+
include_top: bool = False
|
153 |
+
axis_name: Optional[str] = None
|
154 |
+
output_initializer: Callable[[Any, Sequence[int], Any], Any] = (
|
155 |
+
nn.initializers.zeros)
|
156 |
+
|
157 |
+
@nn.compact
|
158 |
+
def __call__(self, x, *, train):
|
159 |
+
"""Apply the ResNet to the inputs `x`.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
x: Inputs.
|
163 |
+
train: Whether to use BatchNorm in training or inference mode.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
The output head with `num_classes` entries.
|
167 |
+
"""
|
168 |
+
width = 64 * self.width_factor
|
169 |
+
|
170 |
+
if self.norm_type == "batch":
|
171 |
+
norm = functools.partial(
|
172 |
+
nn.BatchNorm, use_running_average=not train, momentum=0.9,
|
173 |
+
axis_name=self.axis_name)
|
174 |
+
elif self.norm_type == "layer":
|
175 |
+
norm = nn.LayerNorm
|
176 |
+
elif self.norm_type == "group":
|
177 |
+
norm = nn.GroupNorm
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Invalid norm_type: {self.norm_type}")
|
180 |
+
|
181 |
+
# Root block.
|
182 |
+
x = nn.Conv(
|
183 |
+
features=width,
|
184 |
+
kernel_size=(7, 7) if not self.small_inputs else (3, 3),
|
185 |
+
strides=(2, 2) if not self.small_inputs else (1, 1),
|
186 |
+
use_bias=False,
|
187 |
+
name="init_conv")(
|
188 |
+
x)
|
189 |
+
x = norm(name="init_bn")(x)
|
190 |
+
|
191 |
+
if not self.small_inputs:
|
192 |
+
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
|
193 |
+
|
194 |
+
# Stages.
|
195 |
+
for i, stage_size in enumerate(self.stage_sizes):
|
196 |
+
if i == 0:
|
197 |
+
first_block_strides = (
|
198 |
+
1, 1) if self.stage_strides is None else self.stage_strides[i]
|
199 |
+
else:
|
200 |
+
first_block_strides = (
|
201 |
+
2, 2) if self.stage_strides is None else self.stage_strides[i]
|
202 |
+
|
203 |
+
x = ResNetStage(
|
204 |
+
stage_size,
|
205 |
+
filters=width * 2**i,
|
206 |
+
block_cls=self.block_cls,
|
207 |
+
norm=norm,
|
208 |
+
first_block_strides=first_block_strides,
|
209 |
+
name=f"stage{i + 1}")(x)
|
210 |
+
|
211 |
+
# Head.
|
212 |
+
if self.include_top:
|
213 |
+
x = jnp.mean(x, axis=(1, 2))
|
214 |
+
x = nn.Dense(
|
215 |
+
self.num_classes, kernel_init=self.output_initializer, name="head")(x)
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
ResNetWithBasicBlk = functools.partial(ResNet, block_cls=ResNetBlock)
|
220 |
+
ResNetWithBottleneckBlk = functools.partial(ResNet,
|
221 |
+
block_cls=BottleneckResNetBlock)
|
222 |
+
|
223 |
+
ResNet18 = functools.partial(ResNetWithBasicBlk, stage_sizes=[2, 2, 2, 2])
|
224 |
+
ResNet34 = functools.partial(ResNetWithBasicBlk, stage_sizes=[3, 4, 6, 3])
|
225 |
+
ResNet50 = functools.partial(ResNetWithBottleneckBlk, stage_sizes=[3, 4, 6, 3])
|
226 |
+
ResNet101 = functools.partial(ResNetWithBottleneckBlk,
|
227 |
+
stage_sizes=[3, 4, 23, 3])
|
228 |
+
ResNet152 = functools.partial(ResNetWithBottleneckBlk,
|
229 |
+
stage_sizes=[3, 8, 36, 3])
|
230 |
+
ResNet200 = functools.partial(ResNetWithBottleneckBlk,
|
231 |
+
stage_sizes=[3, 24, 36, 3])
|
invariant_slot_attention/modules/video.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Video module library."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Optional, Tuple, Union
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import jax.numpy as jnp
|
23 |
+
from invariant_slot_attention.lib import utils
|
24 |
+
from invariant_slot_attention.modules import misc
|
25 |
+
|
26 |
+
Shape = Tuple[int]
|
27 |
+
|
28 |
+
DType = Any
|
29 |
+
Array = Any # jnp.ndarray
|
30 |
+
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
|
31 |
+
ProcessorState = ArrayTree
|
32 |
+
PRNGKey = Array
|
33 |
+
NestedDict = Dict[str, Any]
|
34 |
+
|
35 |
+
|
36 |
+
class CorrectorPredictorTuple(NamedTuple):
|
37 |
+
corrected: ProcessorState
|
38 |
+
predicted: ProcessorState
|
39 |
+
|
40 |
+
|
41 |
+
class Processor(nn.Module):
|
42 |
+
"""Recurrent processor module.
|
43 |
+
|
44 |
+
This module is scanned (applied recurrently) over the sequence dimension of
|
45 |
+
the input and applies a corrector and a predictor module. The corrector is
|
46 |
+
only applied if new inputs (such as a new image/frame) are received and uses
|
47 |
+
the new input to correct its internal state.
|
48 |
+
|
49 |
+
The predictor is equivalent to a latent transition model and produces a
|
50 |
+
prediction for the state at the next time step, given the current (corrected)
|
51 |
+
state.
|
52 |
+
"""
|
53 |
+
corrector: Callable[[ProcessorState, Array], ProcessorState]
|
54 |
+
predictor: Callable[[ProcessorState], ProcessorState]
|
55 |
+
|
56 |
+
@functools.partial(
|
57 |
+
nn.scan, # Scan (recurrently apply) over time axis.
|
58 |
+
in_axes=(1, 1, nn.broadcast), # (inputs, padding_mask, train).
|
59 |
+
out_axes=1,
|
60 |
+
variable_axes={"intermediates": 1}, # Stack intermediates along seq. dim.
|
61 |
+
variable_broadcast="params",
|
62 |
+
split_rngs={"params": False, "dropout": True})
|
63 |
+
@nn.compact
|
64 |
+
def __call__(self, state, inputs,
|
65 |
+
padding_mask,
|
66 |
+
train):
|
67 |
+
|
68 |
+
# Only apply corrector if we receive new inputs.
|
69 |
+
if inputs is not None:
|
70 |
+
corrected_state = self.corrector(state, inputs, padding_mask, train=train)
|
71 |
+
# Otherwise simply use previous state as input for predictor.
|
72 |
+
else:
|
73 |
+
corrected_state = state
|
74 |
+
|
75 |
+
# Always apply predictor (i.e. transition model).
|
76 |
+
predicted_state = self.predictor(corrected_state, train=train)
|
77 |
+
|
78 |
+
# Prepare outputs in a format compatible with nn.scan.
|
79 |
+
new_state = predicted_state
|
80 |
+
outputs = CorrectorPredictorTuple(
|
81 |
+
corrected=corrected_state, predicted=predicted_state)
|
82 |
+
return new_state, outputs
|
83 |
+
|
84 |
+
|
85 |
+
class SAVi(nn.Module):
|
86 |
+
"""Video model consisting of encoder, recurrent processor, and decoder."""
|
87 |
+
|
88 |
+
encoder: Callable[[], nn.Module]
|
89 |
+
decoder: Callable[[], nn.Module]
|
90 |
+
corrector: Callable[[], nn.Module]
|
91 |
+
predictor: Callable[[], nn.Module]
|
92 |
+
initializer: Callable[[], nn.Module]
|
93 |
+
decode_corrected: bool = True
|
94 |
+
decode_predicted: bool = True
|
95 |
+
|
96 |
+
@nn.compact
|
97 |
+
def __call__(self, video, conditioning = None,
|
98 |
+
continue_from_previous_state = False,
|
99 |
+
padding_mask = None,
|
100 |
+
train = False):
|
101 |
+
"""Performs a forward pass on a video.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
video: Video of shape `[batch_size, n_frames, height, width, n_channels]`.
|
105 |
+
conditioning: Optional jnp.ndarray used for conditioning the initial state
|
106 |
+
of the recurrent processor.
|
107 |
+
continue_from_previous_state: Boolean, whether to continue from a previous
|
108 |
+
state or not. If True, the conditioning variable is used directly as
|
109 |
+
initial state.
|
110 |
+
padding_mask: Binary mask for padding video inputs (e.g. for videos of
|
111 |
+
different sizes/lengths). Zero corresponds to padding.
|
112 |
+
train: Indicating whether we're training or evaluating.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
A dictionary of model predictions.
|
116 |
+
"""
|
117 |
+
processor = Processor(
|
118 |
+
corrector=self.corrector(), predictor=self.predictor()) # pytype: disable=wrong-arg-types
|
119 |
+
|
120 |
+
if padding_mask is None:
|
121 |
+
padding_mask = jnp.ones(video.shape[:-1], jnp.int32)
|
122 |
+
|
123 |
+
# video.shape = (batch_size, n_frames, height, width, n_channels)
|
124 |
+
# Vmapped over sequence dim.
|
125 |
+
encoded_inputs = self.encoder()(video, padding_mask, train) # pytype: disable=not-callable
|
126 |
+
if continue_from_previous_state:
|
127 |
+
assert conditioning is not None, (
|
128 |
+
"When continuing from a previous state, the state has to be passed "
|
129 |
+
"via the `conditioning` variable, which cannot be `None`.")
|
130 |
+
init_state = conditioning[:, -1] # We currently only use last state.
|
131 |
+
else:
|
132 |
+
# Same as above but without encoded inputs.
|
133 |
+
init_state = self.initializer()(
|
134 |
+
conditioning, batch_size=video.shape[0], train=train) # pytype: disable=not-callable
|
135 |
+
|
136 |
+
# Scan recurrent processor over encoded inputs along sequence dimension.
|
137 |
+
_, states = processor(init_state, encoded_inputs, padding_mask, train)
|
138 |
+
# type(states) = CorrectorPredictorTuple.
|
139 |
+
# states.corrected.shape = (batch_size, n_frames, ..., n_features).
|
140 |
+
# states.predicted.shape = (batch_size, n_frames, ..., n_features).
|
141 |
+
|
142 |
+
# Decode latent states.
|
143 |
+
decoder = self.decoder() # Vmapped over sequence dim.
|
144 |
+
outputs = decoder(states.corrected,
|
145 |
+
train) if self.decode_corrected else None # pytype: disable=not-callable
|
146 |
+
outputs_pred = decoder(states.predicted,
|
147 |
+
train) if self.decode_predicted else None # pytype: disable=not-callable
|
148 |
+
|
149 |
+
return {
|
150 |
+
"states": states.corrected,
|
151 |
+
"states_pred": states.predicted,
|
152 |
+
"outputs": outputs,
|
153 |
+
"outputs_pred": outputs_pred,
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
class FrameEncoder(nn.Module):
|
158 |
+
"""Encoder for single video frame, vmapped over time axis."""
|
159 |
+
|
160 |
+
backbone: Callable[[], nn.Module]
|
161 |
+
pos_emb: Callable[[], nn.Module] = misc.Identity
|
162 |
+
reduction: Optional[str] = None
|
163 |
+
output_transform: Callable[[], nn.Module] = misc.Identity
|
164 |
+
|
165 |
+
# Vmapped application of module, consumes time axis (axis=1).
|
166 |
+
@functools.partial(utils.time_distributed, in_axes=(1, 1, None))
|
167 |
+
@nn.compact
|
168 |
+
def __call__(self, inputs, padding_mask = None,
|
169 |
+
train = False):
|
170 |
+
del padding_mask # Unused.
|
171 |
+
|
172 |
+
# inputs.shape = (batch_size, height, width, n_channels)
|
173 |
+
x = self.backbone()(inputs, train=train)
|
174 |
+
|
175 |
+
x = self.pos_emb()(x)
|
176 |
+
|
177 |
+
if self.reduction == "spatial_flatten":
|
178 |
+
batch_size, height, width, n_features = x.shape
|
179 |
+
x = jnp.reshape(x, (batch_size, height * width, n_features))
|
180 |
+
elif self.reduction == "spatial_average":
|
181 |
+
x = jnp.mean(x, axis=(1, 2))
|
182 |
+
elif self.reduction == "all_flatten":
|
183 |
+
batch_size, height, width, n_features = x.shape
|
184 |
+
x = jnp.reshape(x, (batch_size, height * width * n_features))
|
185 |
+
elif self.reduction is not None:
|
186 |
+
raise ValueError("Unknown reduction type: {}.".format(self.reduction))
|
187 |
+
|
188 |
+
output_block = self.output_transform()
|
189 |
+
|
190 |
+
if hasattr(output_block, "qkv_size"):
|
191 |
+
# Project to qkv_size if used transformer.
|
192 |
+
x = nn.relu(nn.Dense(output_block.qkv_size)(x))
|
193 |
+
|
194 |
+
x = output_block(x, train=train)
|
195 |
+
return x
|
main.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Main file for running the model trainer."""
|
17 |
+
|
18 |
+
from absl import app
|
19 |
+
from absl import flags
|
20 |
+
from absl import logging
|
21 |
+
|
22 |
+
from clu import platform
|
23 |
+
import jax
|
24 |
+
from ml_collections import config_flags
|
25 |
+
|
26 |
+
import tensorflow as tf
|
27 |
+
|
28 |
+
|
29 |
+
from invariant_slot_attention.lib import trainer
|
30 |
+
|
31 |
+
FLAGS = flags.FLAGS
|
32 |
+
|
33 |
+
config_flags.DEFINE_config_file(
|
34 |
+
"config", None, "Config file.")
|
35 |
+
flags.DEFINE_string("workdir", None, "Work unit directory.")
|
36 |
+
flags.DEFINE_string("jax_backend_target", None, "JAX backend target to use.")
|
37 |
+
flags.mark_flags_as_required(["config", "workdir"])
|
38 |
+
|
39 |
+
|
40 |
+
def main(argv):
|
41 |
+
del argv
|
42 |
+
|
43 |
+
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
|
44 |
+
# it unavailable to JAX.
|
45 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
46 |
+
|
47 |
+
if FLAGS.jax_backend_target:
|
48 |
+
logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
|
49 |
+
jax.config.update("jax_xla_backend", "tpu_driver")
|
50 |
+
jax.config.update("jax_backend_target", FLAGS.jax_backend_target)
|
51 |
+
|
52 |
+
logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
|
53 |
+
logging.info("JAX devices: %r", jax.devices())
|
54 |
+
|
55 |
+
# Add a note so that we can tell which task is which JAX host.
|
56 |
+
platform.work_unit().set_task_status(
|
57 |
+
f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
|
58 |
+
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
|
59 |
+
FLAGS.workdir, "workdir")
|
60 |
+
|
61 |
+
trainer.train_and_evaluate(FLAGS.config, FLAGS.workdir)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
app.run(main)
|